Formatting
This commit is contained in:
@@ -0,0 +1,55 @@
|
||||
import torchvision
|
||||
from torchvision.models import ResNet50_Weights
|
||||
|
||||
# Import weights enums for clarity
|
||||
from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
|
||||
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
||||
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
|
||||
|
||||
|
||||
def get_maskrcnn_model(num_classes, pretrained=True, pretrained_backbone=True):
|
||||
"""Loads a Mask R-CNN model with a ResNet-50-FPN backbone.
|
||||
|
||||
Args:
|
||||
num_classes (int): Number of output classes (including background).
|
||||
pretrained (bool): If True, loads weights pre-trained on COCO.
|
||||
pretrained_backbone (bool): If True (and pretrained=False), loads backbone
|
||||
weights pre-trained on ImageNet.
|
||||
|
||||
Returns:
|
||||
torchvision.models.detection.MaskRCNN: The modified Mask R-CNN model.
|
||||
"""
|
||||
|
||||
# Determine weights based on arguments
|
||||
if pretrained:
|
||||
weights = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT
|
||||
weights_backbone = None # Backbone weights are included in MaskRCNN weights
|
||||
elif pretrained_backbone:
|
||||
weights = None
|
||||
weights_backbone = ResNet50_Weights.DEFAULT
|
||||
else:
|
||||
weights = None
|
||||
weights_backbone = None
|
||||
|
||||
# Load the model structure with specified weights
|
||||
# Use maskrcnn_resnet50_fpn_v2 for compatibility with V2 weights
|
||||
model = torchvision.models.detection.maskrcnn_resnet50_fpn_v2(
|
||||
weights=weights, weights_backbone=weights_backbone
|
||||
)
|
||||
|
||||
# 1. Replace the box predictor
|
||||
# Get number of input features for the classifier
|
||||
in_features_box = model.roi_heads.box_predictor.cls_score.in_features
|
||||
# Replace the pre-trained head with a new one
|
||||
model.roi_heads.box_predictor = FastRCNNPredictor(in_features_box, num_classes)
|
||||
|
||||
# 2. Replace the mask predictor
|
||||
# Get number of input features for the mask classifier
|
||||
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
|
||||
hidden_layer = 256 # Default value
|
||||
# Replace the mask predictor with a new one
|
||||
model.roi_heads.mask_predictor = MaskRCNNPredictor(
|
||||
in_features_mask, hidden_layer, num_classes
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user