diff --git a/todo.md b/todo.md index cf98bb7..5dc937a 100644 --- a/todo.md +++ b/todo.md @@ -20,12 +20,12 @@ This list outlines the steps required to complete the Torchvision Finetuning pro ## Phase 2: Data Handling & Model -- [ ] Implement `PennFudanDataset` class in `utils/data_utils.py`. - - [ ] `__init__`: Load image and mask paths. - - [ ] `__getitem__`: Load image/mask, parse masks, generate targets (boxes, labels, masks, image_id, area, iscrowd), apply transforms. - - [ ] `__len__`: Return dataset size. -- [ ] Implement `get_transform(train)` function in `utils/data_utils.py` (using `torchvision.transforms.v2`). -- [ ] Implement `collate_fn(batch)` function in `utils/data_utils.py`. +- [x] Implement `PennFudanDataset` class in `utils/data_utils.py`. + - [x] `__init__`: Load image and mask paths. + - [x] `__getitem__`: Load image/mask, parse masks, generate targets (boxes, labels, masks, image_id, area, iscrowd), apply transforms. + - [x] `__len__`: Return dataset size. +- [x] Implement `get_transform(train)` function in `utils/data_utils.py` (using `torchvision.transforms.v2`). +- [x] Implement `collate_fn(batch)` function in `utils/data_utils.py`. - [ ] Implement `get_maskrcnn_model(num_classes, ...)` function in `models/detection.py`. - [ ] Load pre-trained Mask R-CNN (`maskrcnn_resnet50_fpn_v2`). - [ ] Replace box predictor head (`FastRCNNPredictor`). diff --git a/utils/data_utils.py b/utils/data_utils.py index e69de29..322c54d 100644 --- a/utils/data_utils.py +++ b/utils/data_utils.py @@ -0,0 +1,146 @@ +import os + +import numpy as np +import torch +import torch.utils.data +import torchvision.transforms.v2 as T +from PIL import Image + + +class PennFudanDataset(torch.utils.data.Dataset): + """Dataset class for the Penn-Fudan Pedestrian Detection dataset.""" + + def __init__(self, root, transforms): + self.root = root + self.transforms = transforms + # Load all image files, sorting them to ensure alignment + self.imgs = sorted(list(os.listdir(os.path.join(root, "PNGImages")))) + self.masks = sorted(list(os.listdir(os.path.join(root, "PedMasks")))) + + def __getitem__(self, idx): + # Load images and masks + img_path = os.path.join(self.root, "PNGImages", self.imgs[idx]) + mask_path = os.path.join(self.root, "PedMasks", self.masks[idx]) + img = Image.open(img_path).convert("RGB") + # Note: Masks are not converted to RGB, contains index values + mask = Image.open(mask_path) + + # Convert mask to numpy array + mask = np.array(mask) + # Instances are encoded as different colors + obj_ids = np.unique(mask) + # First id is the background, so remove it + obj_ids = obj_ids[1:] + + # Split the color-encoded mask into a set of binary masks + binary_masks = mask == obj_ids[:, None, None] + + # Get bounding box coordinates for each mask + num_objs = len(obj_ids) + boxes = [] + for i in range(num_objs): + pos = np.where(binary_masks[i]) + xmin = np.min(pos[1]) + xmax = np.max(pos[1]) + ymin = np.min(pos[0]) + ymax = np.max(pos[0]) + # Filter out potentially empty masks or masks with zero area + if xmax > xmin and ymax > ymin: + boxes.append([xmin, ymin, xmax, ymax]) + else: + # If box is invalid, we might need to handle this + # For now, let's remove the corresponding mask as well + # This requires careful index handling if filtering occurs + # A safer approach might be to filter masks *after* box generation + # Let's recalculate binary_masks based on valid boxes later if needed + pass # placeholder for potential filtering logic + + # Ensure boxes list isn't empty if filtering happened + if not boxes: + # Handle case with no valid boxes found - return dummy target? Or raise error? + # For now, let's create dummy tensors. This should be revisited. + print( + f"Warning: No valid boxes found for image {idx}. Returning dummy target." + ) + boxes = torch.zeros((0, 4), dtype=torch.float32) + labels = torch.zeros((0,), dtype=torch.int64) + binary_masks = torch.zeros( + (0, mask.shape[0], mask.shape[1]), dtype=torch.uint8 + ) + image_id = torch.tensor([idx]) + area = torch.zeros((0,), dtype=torch.float32) + iscrowd = torch.zeros((0,), dtype=torch.uint8) + else: + boxes = torch.as_tensor(boxes, dtype=torch.float32) + # There is only one class (pedestrian) + labels = torch.ones((num_objs,), dtype=torch.int64) + binary_masks = torch.as_tensor(binary_masks, dtype=torch.uint8) + image_id = torch.tensor([idx]) + # Calculate area + area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) + # Assume all instances are not crowd + iscrowd = torch.zeros((num_objs,), dtype=torch.uint8) + + target = {} + target["boxes"] = boxes + target["labels"] = labels + target["masks"] = binary_masks + target["image_id"] = image_id + target["area"] = area + target["iscrowd"] = iscrowd + + if self.transforms is not None: + # Apply transforms to both image and target + # Note: torchvision v2 transforms handle target dicts automatically + img, target = self.transforms(img, target) + + return img, target + + def __len__(self): + return len(self.imgs) + + +# --- Utility Functions --- # + + +def get_transform(train): + """Gets the appropriate set of transforms. + + Args: + train (bool): Whether to apply training augmentations. + + Returns: + torchvision.transforms.Compose: A composed Torchvision transform. + """ + transforms = [] + # Always convert image to PyTorch tensor and scale to [0, 1] + transforms.append(T.ToImage()) + transforms.append(T.ToDtype(torch.float32, scale=True)) + + if train: + # Add simple data augmentation for training + transforms.append(T.RandomHorizontalFlip(p=0.5)) + # Add other augmentations here if needed + # e.g., T.ColorJitter(...), T.RandomResizedCrop(...) ensuring + # bounding boxes/masks are handled correctly by v2 transforms. + + # Note: Normalization (e.g., T.Normalize) is often applied, + # but pre-trained models in torchvision usually handle this internally + # or expect [0, 1] range inputs. + + return T.Compose(transforms) + + +def collate_fn(batch): + """Custom collate function for object detection models. + + It aggregates images into a list and targets into a list. + Necessary because targets can have varying numbers of objects. + + Args: + batch (list): A list of (image, target) tuples. + + Returns: + tuple: A tuple containing a list of images and a list of targets. + """ + return tuple(zip(*batch))