import torchvision from torchvision.models import ResNet50_Weights # Import weights enums for clarity from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights from torchvision.models.detection.faster_rcnn import FastRCNNPredictor from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor def get_maskrcnn_model(num_classes, pretrained=True, pretrained_backbone=True): """Loads a Mask R-CNN model with a ResNet-50-FPN backbone. Args: num_classes (int): Number of output classes (including background). pretrained (bool): If True, loads weights pre-trained on COCO. pretrained_backbone (bool): If True (and pretrained=False), loads backbone weights pre-trained on ImageNet. Returns: torchvision.models.detection.MaskRCNN: The modified Mask R-CNN model. """ # Determine weights based on arguments if pretrained: weights = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT weights_backbone = None # Backbone weights are included in MaskRCNN weights elif pretrained_backbone: weights = None weights_backbone = ResNet50_Weights.DEFAULT else: weights = None weights_backbone = None # Load the model structure with specified weights # Use maskrcnn_resnet50_fpn_v2 for compatibility with V2 weights model = torchvision.models.detection.maskrcnn_resnet50_fpn_v2( weights=weights, weights_backbone=weights_backbone ) # 1. Replace the box predictor # Get number of input features for the classifier in_features_box = model.roi_heads.box_predictor.cls_score.in_features # Replace the pre-trained head with a new one model.roi_heads.box_predictor = FastRCNNPredictor(in_features_box, num_classes) # 2. Replace the mask predictor # Get number of input features for the mask classifier in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels hidden_layer = 256 # Default value # Replace the mask predictor with a new one model.roi_heads.mask_predictor = MaskRCNNPredictor( in_features_mask, hidden_layer, num_classes ) return model