56 lines
2.1 KiB
Python
56 lines
2.1 KiB
Python
import torchvision
|
|
from torchvision.models import ResNet50_Weights
|
|
|
|
# Import weights enums for clarity
|
|
from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
|
|
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
|
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
|
|
|
|
|
|
def get_maskrcnn_model(num_classes, pretrained=True, pretrained_backbone=True):
|
|
"""Loads a Mask R-CNN model with a ResNet-50-FPN backbone.
|
|
|
|
Args:
|
|
num_classes (int): Number of output classes (including background).
|
|
pretrained (bool): If True, loads weights pre-trained on COCO.
|
|
pretrained_backbone (bool): If True (and pretrained=False), loads backbone
|
|
weights pre-trained on ImageNet.
|
|
|
|
Returns:
|
|
torchvision.models.detection.MaskRCNN: The modified Mask R-CNN model.
|
|
"""
|
|
|
|
# Determine weights based on arguments
|
|
if pretrained:
|
|
weights = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT
|
|
weights_backbone = None # Backbone weights are included in MaskRCNN weights
|
|
elif pretrained_backbone:
|
|
weights = None
|
|
weights_backbone = ResNet50_Weights.DEFAULT
|
|
else:
|
|
weights = None
|
|
weights_backbone = None
|
|
|
|
# Load the model structure with specified weights
|
|
# Use maskrcnn_resnet50_fpn_v2 for compatibility with V2 weights
|
|
model = torchvision.models.detection.maskrcnn_resnet50_fpn_v2(
|
|
weights=weights, weights_backbone=weights_backbone
|
|
)
|
|
|
|
# 1. Replace the box predictor
|
|
# Get number of input features for the classifier
|
|
in_features_box = model.roi_heads.box_predictor.cls_score.in_features
|
|
# Replace the pre-trained head with a new one
|
|
model.roi_heads.box_predictor = FastRCNNPredictor(in_features_box, num_classes)
|
|
|
|
# 2. Replace the mask predictor
|
|
# Get number of input features for the mask classifier
|
|
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
|
|
hidden_layer = 256 # Default value
|
|
# Replace the mask predictor with a new one
|
|
model.roi_heads.mask_predictor = MaskRCNNPredictor(
|
|
in_features_mask, hidden_layer, num_classes
|
|
)
|
|
|
|
return model
|