Refactor: Update Penn-Fudan Mask R-CNN configuration and data transformation logic for memory optimization
This commit is contained in:
@@ -104,29 +104,28 @@ class PennFudanDataset(torch.utils.data.Dataset):
|
||||
|
||||
|
||||
def get_transform(train):
|
||||
"""Gets the appropriate set of transforms.
|
||||
"""Get the transformations for the dataset.
|
||||
|
||||
Args:
|
||||
train (bool): Whether to apply training augmentations.
|
||||
train (bool): Whether to get transforms for training or evaluation.
|
||||
|
||||
Returns:
|
||||
torchvision.transforms.Compose: A composed Torchvision transform.
|
||||
torchvision.transforms.Compose: The composed transforms.
|
||||
"""
|
||||
transforms = []
|
||||
# Always convert image to PyTorch tensor and scale to [0, 1]
|
||||
|
||||
# Convert to PyTorch tensor and normalize
|
||||
transforms.append(T.ToImage())
|
||||
|
||||
# Add resize transform to reduce memory usage (max size of 800px)
|
||||
transforms.append(T.Resize(800))
|
||||
|
||||
transforms.append(T.ToDtype(torch.float32, scale=True))
|
||||
|
||||
# Data augmentation for training
|
||||
if train:
|
||||
# Add simple data augmentation for training
|
||||
transforms.append(T.RandomHorizontalFlip(p=0.5))
|
||||
# Add other augmentations here if needed
|
||||
# e.g., T.ColorJitter(...), T.RandomResizedCrop(...) ensuring
|
||||
# bounding boxes/masks are handled correctly by v2 transforms.
|
||||
|
||||
# Note: Normalization (e.g., T.Normalize) is often applied,
|
||||
# but pre-trained models in torchvision usually handle this internally
|
||||
# or expect [0, 1] range inputs.
|
||||
transforms.append(T.RandomHorizontalFlip(0.5))
|
||||
# Could add more augmentations here if desired
|
||||
|
||||
return T.Compose(transforms)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user