This commit is contained in:
Craig
2025-04-12 10:34:10 +01:00
parent c3096f0664
commit 97776a4a82
2 changed files with 152 additions and 6 deletions

12
todo.md
View File

@@ -20,12 +20,12 @@ This list outlines the steps required to complete the Torchvision Finetuning pro
## Phase 2: Data Handling & Model
- [ ] Implement `PennFudanDataset` class in `utils/data_utils.py`.
- [ ] `__init__`: Load image and mask paths.
- [ ] `__getitem__`: Load image/mask, parse masks, generate targets (boxes, labels, masks, image_id, area, iscrowd), apply transforms.
- [ ] `__len__`: Return dataset size.
- [ ] Implement `get_transform(train)` function in `utils/data_utils.py` (using `torchvision.transforms.v2`).
- [ ] Implement `collate_fn(batch)` function in `utils/data_utils.py`.
- [x] Implement `PennFudanDataset` class in `utils/data_utils.py`.
- [x] `__init__`: Load image and mask paths.
- [x] `__getitem__`: Load image/mask, parse masks, generate targets (boxes, labels, masks, image_id, area, iscrowd), apply transforms.
- [x] `__len__`: Return dataset size.
- [x] Implement `get_transform(train)` function in `utils/data_utils.py` (using `torchvision.transforms.v2`).
- [x] Implement `collate_fn(batch)` function in `utils/data_utils.py`.
- [ ] Implement `get_maskrcnn_model(num_classes, ...)` function in `models/detection.py`.
- [ ] Load pre-trained Mask R-CNN (`maskrcnn_resnet50_fpn_v2`).
- [ ] Replace box predictor head (`FastRCNNPredictor`).

View File

@@ -0,0 +1,146 @@
import os
import numpy as np
import torch
import torch.utils.data
import torchvision.transforms.v2 as T
from PIL import Image
class PennFudanDataset(torch.utils.data.Dataset):
"""Dataset class for the Penn-Fudan Pedestrian Detection dataset."""
def __init__(self, root, transforms):
self.root = root
self.transforms = transforms
# Load all image files, sorting them to ensure alignment
self.imgs = sorted(list(os.listdir(os.path.join(root, "PNGImages"))))
self.masks = sorted(list(os.listdir(os.path.join(root, "PedMasks"))))
def __getitem__(self, idx):
# Load images and masks
img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
img = Image.open(img_path).convert("RGB")
# Note: Masks are not converted to RGB, contains index values
mask = Image.open(mask_path)
# Convert mask to numpy array
mask = np.array(mask)
# Instances are encoded as different colors
obj_ids = np.unique(mask)
# First id is the background, so remove it
obj_ids = obj_ids[1:]
# Split the color-encoded mask into a set of binary masks
binary_masks = mask == obj_ids[:, None, None]
# Get bounding box coordinates for each mask
num_objs = len(obj_ids)
boxes = []
for i in range(num_objs):
pos = np.where(binary_masks[i])
xmin = np.min(pos[1])
xmax = np.max(pos[1])
ymin = np.min(pos[0])
ymax = np.max(pos[0])
# Filter out potentially empty masks or masks with zero area
if xmax > xmin and ymax > ymin:
boxes.append([xmin, ymin, xmax, ymax])
else:
# If box is invalid, we might need to handle this
# For now, let's remove the corresponding mask as well
# This requires careful index handling if filtering occurs
# A safer approach might be to filter masks *after* box generation
# Let's recalculate binary_masks based on valid boxes later if needed
pass # placeholder for potential filtering logic
# Ensure boxes list isn't empty if filtering happened
if not boxes:
# Handle case with no valid boxes found - return dummy target? Or raise error?
# For now, let's create dummy tensors. This should be revisited.
print(
f"Warning: No valid boxes found for image {idx}. Returning dummy target."
)
boxes = torch.zeros((0, 4), dtype=torch.float32)
labels = torch.zeros((0,), dtype=torch.int64)
binary_masks = torch.zeros(
(0, mask.shape[0], mask.shape[1]), dtype=torch.uint8
)
image_id = torch.tensor([idx])
area = torch.zeros((0,), dtype=torch.float32)
iscrowd = torch.zeros((0,), dtype=torch.uint8)
else:
boxes = torch.as_tensor(boxes, dtype=torch.float32)
# There is only one class (pedestrian)
labels = torch.ones((num_objs,), dtype=torch.int64)
binary_masks = torch.as_tensor(binary_masks, dtype=torch.uint8)
image_id = torch.tensor([idx])
# Calculate area
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
# Assume all instances are not crowd
iscrowd = torch.zeros((num_objs,), dtype=torch.uint8)
target = {}
target["boxes"] = boxes
target["labels"] = labels
target["masks"] = binary_masks
target["image_id"] = image_id
target["area"] = area
target["iscrowd"] = iscrowd
if self.transforms is not None:
# Apply transforms to both image and target
# Note: torchvision v2 transforms handle target dicts automatically
img, target = self.transforms(img, target)
return img, target
def __len__(self):
return len(self.imgs)
# --- Utility Functions --- #
def get_transform(train):
"""Gets the appropriate set of transforms.
Args:
train (bool): Whether to apply training augmentations.
Returns:
torchvision.transforms.Compose: A composed Torchvision transform.
"""
transforms = []
# Always convert image to PyTorch tensor and scale to [0, 1]
transforms.append(T.ToImage())
transforms.append(T.ToDtype(torch.float32, scale=True))
if train:
# Add simple data augmentation for training
transforms.append(T.RandomHorizontalFlip(p=0.5))
# Add other augmentations here if needed
# e.g., T.ColorJitter(...), T.RandomResizedCrop(...) ensuring
# bounding boxes/masks are handled correctly by v2 transforms.
# Note: Normalization (e.g., T.Normalize) is often applied,
# but pre-trained models in torchvision usually handle this internally
# or expect [0, 1] range inputs.
return T.Compose(transforms)
def collate_fn(batch):
"""Custom collate function for object detection models.
It aggregates images into a list and targets into a list.
Necessary because targets can have varying numbers of objects.
Args:
batch (list): A list of (image, target) tuples.
Returns:
tuple: A tuple containing a list of images and a list of targets.
"""
return tuple(zip(*batch))