import os import numpy as np import torch import torch.utils.data import torchvision.transforms.v2 as T from PIL import Image class PennFudanDataset(torch.utils.data.Dataset): """Dataset class for the Penn-Fudan Pedestrian Detection dataset.""" def __init__(self, root, transforms): self.root = root self.transforms = transforms # Load all image files, sorting them to ensure alignment self.imgs = sorted(list(os.listdir(os.path.join(root, "PNGImages")))) self.masks = sorted(list(os.listdir(os.path.join(root, "PedMasks")))) def __getitem__(self, idx): """Get a sample from the dataset. Args: idx (int): Index of the sample to retrieve. Returns: tuple: (image, target) where target is a dictionary containing various object annotations. """ # Load image img_path = os.path.join(self.root, "PNGImages", self.imgs[idx]) mask_path = os.path.join(self.root, "PedMasks", self.masks[idx]) # Use PIL to load images (more memory efficient) img = Image.open(img_path).convert("RGB") mask = Image.open(mask_path) # Convert mask PIL image to numpy array mask = np.array(mask) # Find all object instances (each instance has a unique value in the mask) # Value 0 is the background obj_ids = np.unique(mask) obj_ids = obj_ids[1:] # Remove background (id=0) # Split the mask into binary masks for each object instance masks = mask == obj_ids[:, None, None] # Get bounding box for each mask num_objs = len(obj_ids) boxes = [] for i in range(num_objs): pos = np.where(masks[i]) if len(pos[0]) == 0 or len(pos[1]) == 0: # Skip empty masks continue xmin = np.min(pos[1]) xmax = np.max(pos[1]) ymin = np.min(pos[0]) ymax = np.max(pos[0]) # Skip boxes with zero area if xmax <= xmin or ymax <= ymin: continue boxes.append([xmin, ymin, xmax, ymax]) # Convert everything to tensors if boxes: boxes = torch.as_tensor(boxes, dtype=torch.float32) labels = torch.ones( (len(boxes),), dtype=torch.int64 ) # All objects are pedestrians (class 1) masks = torch.as_tensor(masks, dtype=torch.uint8) # Calculate area of each box area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) # All instances are not crowd iscrowd = torch.zeros((len(boxes),), dtype=torch.uint8) # Create the target dictionary target = { "boxes": boxes, "labels": labels, "masks": masks, "image_id": torch.tensor([idx]), "area": area, "iscrowd": iscrowd, } else: # Handle case with no valid objects (rare but possible) target = { "boxes": torch.zeros((0, 4), dtype=torch.float32), "labels": torch.zeros((0,), dtype=torch.int64), "masks": torch.zeros( (0, mask.shape[0], mask.shape[1]), dtype=torch.uint8 ), "image_id": torch.tensor([idx]), "area": torch.zeros((0,), dtype=torch.float32), "iscrowd": torch.zeros((0,), dtype=torch.uint8), } # Apply transforms if provided if self.transforms is not None: img, target = self.transforms(img, target) return img, target def __len__(self): return len(self.imgs) # --- Utility Functions --- # def get_transform(train): """Get the transformations for the dataset. Args: train (bool): Whether to get transforms for training or evaluation. Returns: torchvision.transforms.Compose: The composed transforms. """ transforms = [] # Convert to PyTorch tensor and normalize transforms.append(T.ToImage()) # Resize images to control memory usage # Use a smaller size for training (more memory-intensive due to gradients) if train: transforms.append(T.Resize(700)) else: transforms.append(T.Resize(800)) # Can use larger size for eval transforms.append(T.ToDtype(torch.float32, scale=True)) # Data augmentation for training if train: transforms.append(T.RandomHorizontalFlip(0.5)) return T.Compose(transforms) def collate_fn(batch): """Custom collate function for object detection models. It aggregates images into a list and targets into a list. Necessary because targets can have varying numbers of objects. Args: batch (list): A list of (image, target) tuples. Returns: tuple: A tuple containing a list of images and a list of targets. """ return tuple(zip(*batch))