Refactor: Update Penn-Fudan Mask R-CNN configuration and data transformation logic for memory optimization

This commit is contained in:
Craig
2025-04-12 11:13:22 +01:00
parent 217cfba9ba
commit 2b38c04a57
2 changed files with 31 additions and 22 deletions

View File

@@ -2,20 +2,30 @@
Configuration for training Mask R-CNN on the Penn-Fudan dataset. Configuration for training Mask R-CNN on the Penn-Fudan dataset.
""" """
from .base_config import base_config from configs.base_config import base_config
# Create a copy of the base configuration
config = base_config.copy() config = base_config.copy()
# Override necessary settings from base_config # Update specific values for this experiment
config.update( config.update(
{ {
"config_name": "pennfudan_maskrcnn_v1", # Unique name for this experiment run # Core configuration
"data_root": "data/PennFudanPed", # Explicitly set dataset root "config_name": "pennfudan_maskrcnn_v1",
"num_classes": 2, # Penn-Fudan has 1 class (pedestrian) + background "data_root": "data/PennFudanPed",
# Adjust other parameters as needed for this specific experiment, e.g.: "num_classes": 2, # background + pedestrian
# 'batch_size': 4, # Training parameters - modified for memory constraints
# 'num_epochs': 15, "batch_size": 1, # Reduced from 2 to 1 to save memory
# 'lr': 0.001, "num_epochs": 10,
# Optimizer settings
"lr": 0.002, # Slightly reduced learning rate for smaller batch size
"momentum": 0.9,
"weight_decay": 0.0005,
# Memory optimization settings
"pin_memory": False, # Set to False to reduce memory pressure
"num_workers": 2, # Reduced from 4 to 2
# Device settings
"device": "cuda",
} }
) )

View File

@@ -104,29 +104,28 @@ class PennFudanDataset(torch.utils.data.Dataset):
def get_transform(train): def get_transform(train):
"""Gets the appropriate set of transforms. """Get the transformations for the dataset.
Args: Args:
train (bool): Whether to apply training augmentations. train (bool): Whether to get transforms for training or evaluation.
Returns: Returns:
torchvision.transforms.Compose: A composed Torchvision transform. torchvision.transforms.Compose: The composed transforms.
""" """
transforms = [] transforms = []
# Always convert image to PyTorch tensor and scale to [0, 1]
# Convert to PyTorch tensor and normalize
transforms.append(T.ToImage()) transforms.append(T.ToImage())
# Add resize transform to reduce memory usage (max size of 800px)
transforms.append(T.Resize(800))
transforms.append(T.ToDtype(torch.float32, scale=True)) transforms.append(T.ToDtype(torch.float32, scale=True))
# Data augmentation for training
if train: if train:
# Add simple data augmentation for training transforms.append(T.RandomHorizontalFlip(0.5))
transforms.append(T.RandomHorizontalFlip(p=0.5)) # Could add more augmentations here if desired
# Add other augmentations here if needed
# e.g., T.ColorJitter(...), T.RandomResizedCrop(...) ensuring
# bounding boxes/masks are handled correctly by v2 transforms.
# Note: Normalization (e.g., T.Normalize) is often applied,
# but pre-trained models in torchvision usually handle this internally
# or expect [0, 1] range inputs.
return T.Compose(transforms) return T.Compose(transforms)