Files
2025-04-12 10:44:52 +01:00

56 lines
2.1 KiB
Python

import torchvision
from torchvision.models import ResNet50_Weights
# Import weights enums for clarity
from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
def get_maskrcnn_model(num_classes, pretrained=True, pretrained_backbone=True):
"""Loads a Mask R-CNN model with a ResNet-50-FPN backbone.
Args:
num_classes (int): Number of output classes (including background).
pretrained (bool): If True, loads weights pre-trained on COCO.
pretrained_backbone (bool): If True (and pretrained=False), loads backbone
weights pre-trained on ImageNet.
Returns:
torchvision.models.detection.MaskRCNN: The modified Mask R-CNN model.
"""
# Determine weights based on arguments
if pretrained:
weights = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT
weights_backbone = None # Backbone weights are included in MaskRCNN weights
elif pretrained_backbone:
weights = None
weights_backbone = ResNet50_Weights.DEFAULT
else:
weights = None
weights_backbone = None
# Load the model structure with specified weights
# Use maskrcnn_resnet50_fpn_v2 for compatibility with V2 weights
model = torchvision.models.detection.maskrcnn_resnet50_fpn_v2(
weights=weights, weights_backbone=weights_backbone
)
# 1. Replace the box predictor
# Get number of input features for the classifier
in_features_box = model.roi_heads.box_predictor.cls_score.in_features
# Replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features_box, num_classes)
# 2. Replace the mask predictor
# Get number of input features for the mask classifier
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256 # Default value
# Replace the mask predictor with a new one
model.roi_heads.mask_predictor = MaskRCNNPredictor(
in_features_mask, hidden_layer, num_classes
)
return model