import os import numpy as np import torch import torch.utils.data import torchvision.transforms.v2 as T from PIL import Image class PennFudanDataset(torch.utils.data.Dataset): """Dataset class for the Penn-Fudan Pedestrian Detection dataset.""" def __init__(self, root, transforms): self.root = root self.transforms = transforms # Load all image files, sorting them to ensure alignment self.imgs = sorted(list(os.listdir(os.path.join(root, "PNGImages")))) self.masks = sorted(list(os.listdir(os.path.join(root, "PedMasks")))) def __getitem__(self, idx): # Load images and masks img_path = os.path.join(self.root, "PNGImages", self.imgs[idx]) mask_path = os.path.join(self.root, "PedMasks", self.masks[idx]) img = Image.open(img_path).convert("RGB") # Note: Masks are not converted to RGB, contains index values mask = Image.open(mask_path) # Convert mask to numpy array mask = np.array(mask) # Instances are encoded as different colors obj_ids = np.unique(mask) # First id is the background, so remove it obj_ids = obj_ids[1:] # Split the color-encoded mask into a set of binary masks binary_masks = mask == obj_ids[:, None, None] # Get bounding box coordinates for each mask num_objs = len(obj_ids) boxes = [] for i in range(num_objs): pos = np.where(binary_masks[i]) xmin = np.min(pos[1]) xmax = np.max(pos[1]) ymin = np.min(pos[0]) ymax = np.max(pos[0]) # Filter out potentially empty masks or masks with zero area if xmax > xmin and ymax > ymin: boxes.append([xmin, ymin, xmax, ymax]) else: # If box is invalid, we might need to handle this # For now, let's remove the corresponding mask as well # This requires careful index handling if filtering occurs # A safer approach might be to filter masks *after* box generation # Let's recalculate binary_masks based on valid boxes later if needed pass # placeholder for potential filtering logic # Ensure boxes list isn't empty if filtering happened if not boxes: # Handle case with no valid boxes found - return dummy target? Or raise error? # For now, let's create dummy tensors. This should be revisited. print( f"Warning: No valid boxes found for image {idx}. Returning dummy target." ) boxes = torch.zeros((0, 4), dtype=torch.float32) labels = torch.zeros((0,), dtype=torch.int64) binary_masks = torch.zeros( (0, mask.shape[0], mask.shape[1]), dtype=torch.uint8 ) image_id = torch.tensor([idx]) area = torch.zeros((0,), dtype=torch.float32) iscrowd = torch.zeros((0,), dtype=torch.uint8) else: boxes = torch.as_tensor(boxes, dtype=torch.float32) # There is only one class (pedestrian) labels = torch.ones((num_objs,), dtype=torch.int64) binary_masks = torch.as_tensor(binary_masks, dtype=torch.uint8) image_id = torch.tensor([idx]) # Calculate area area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) # Assume all instances are not crowd iscrowd = torch.zeros((num_objs,), dtype=torch.uint8) target = {} target["boxes"] = boxes target["labels"] = labels target["masks"] = binary_masks target["image_id"] = image_id target["area"] = area target["iscrowd"] = iscrowd if self.transforms is not None: # Apply transforms to both image and target # Note: torchvision v2 transforms handle target dicts automatically img, target = self.transforms(img, target) return img, target def __len__(self): return len(self.imgs) # --- Utility Functions --- # def get_transform(train): """Gets the appropriate set of transforms. Args: train (bool): Whether to apply training augmentations. Returns: torchvision.transforms.Compose: A composed Torchvision transform. """ transforms = [] # Always convert image to PyTorch tensor and scale to [0, 1] transforms.append(T.ToImage()) transforms.append(T.ToDtype(torch.float32, scale=True)) if train: # Add simple data augmentation for training transforms.append(T.RandomHorizontalFlip(p=0.5)) # Add other augmentations here if needed # e.g., T.ColorJitter(...), T.RandomResizedCrop(...) ensuring # bounding boxes/masks are handled correctly by v2 transforms. # Note: Normalization (e.g., T.Normalize) is often applied, # but pre-trained models in torchvision usually handle this internally # or expect [0, 1] range inputs. return T.Compose(transforms) def collate_fn(batch): """Custom collate function for object detection models. It aggregates images into a list and targets into a list. Necessary because targets can have varying numbers of objects. Args: batch (list): A list of (image, target) tuples. Returns: tuple: A tuple containing a list of images and a list of targets. """ return tuple(zip(*batch))