Files

160 lines
4.9 KiB
Python

import os
import numpy as np
import torch
import torch.utils.data
import torchvision.transforms.v2 as T
from PIL import Image
class PennFudanDataset(torch.utils.data.Dataset):
"""Dataset class for the Penn-Fudan Pedestrian Detection dataset."""
def __init__(self, root, transforms):
self.root = root
self.transforms = transforms
# Load all image files, sorting them to ensure alignment
self.imgs = sorted(list(os.listdir(os.path.join(root, "PNGImages"))))
self.masks = sorted(list(os.listdir(os.path.join(root, "PedMasks"))))
def __getitem__(self, idx):
"""Get a sample from the dataset.
Args:
idx (int): Index of the sample to retrieve.
Returns:
tuple: (image, target) where target is a dictionary containing various object annotations.
"""
# Load image
img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
# Use PIL to load images (more memory efficient)
img = Image.open(img_path).convert("RGB")
mask = Image.open(mask_path)
# Convert mask PIL image to numpy array
mask = np.array(mask)
# Find all object instances (each instance has a unique value in the mask)
# Value 0 is the background
obj_ids = np.unique(mask)
obj_ids = obj_ids[1:] # Remove background (id=0)
# Split the mask into binary masks for each object instance
masks = mask == obj_ids[:, None, None]
# Get bounding box for each mask
num_objs = len(obj_ids)
boxes = []
for i in range(num_objs):
pos = np.where(masks[i])
if len(pos[0]) == 0 or len(pos[1]) == 0: # Skip empty masks
continue
xmin = np.min(pos[1])
xmax = np.max(pos[1])
ymin = np.min(pos[0])
ymax = np.max(pos[0])
# Skip boxes with zero area
if xmax <= xmin or ymax <= ymin:
continue
boxes.append([xmin, ymin, xmax, ymax])
# Convert everything to tensors
if boxes:
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.ones(
(len(boxes),), dtype=torch.int64
) # All objects are pedestrians (class 1)
masks = torch.as_tensor(masks, dtype=torch.uint8)
# Calculate area of each box
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
# All instances are not crowd
iscrowd = torch.zeros((len(boxes),), dtype=torch.uint8)
# Create the target dictionary
target = {
"boxes": boxes,
"labels": labels,
"masks": masks,
"image_id": torch.tensor([idx]),
"area": area,
"iscrowd": iscrowd,
}
else:
# Handle case with no valid objects (rare but possible)
target = {
"boxes": torch.zeros((0, 4), dtype=torch.float32),
"labels": torch.zeros((0,), dtype=torch.int64),
"masks": torch.zeros(
(0, mask.shape[0], mask.shape[1]), dtype=torch.uint8
),
"image_id": torch.tensor([idx]),
"area": torch.zeros((0,), dtype=torch.float32),
"iscrowd": torch.zeros((0,), dtype=torch.uint8),
}
# Apply transforms if provided
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target
def __len__(self):
return len(self.imgs)
# --- Utility Functions --- #
def get_transform(train):
"""Get the transformations for the dataset.
Args:
train (bool): Whether to get transforms for training or evaluation.
Returns:
torchvision.transforms.Compose: The composed transforms.
"""
transforms = []
# Convert to PyTorch tensor and normalize
transforms.append(T.ToImage())
# Resize images to control memory usage
# Use a smaller size for training (more memory-intensive due to gradients)
if train:
transforms.append(T.Resize(700))
else:
transforms.append(T.Resize(800)) # Can use larger size for eval
transforms.append(T.ToDtype(torch.float32, scale=True))
# Data augmentation for training
if train:
transforms.append(T.RandomHorizontalFlip(0.5))
return T.Compose(transforms)
def collate_fn(batch):
"""Custom collate function for object detection models.
It aggregates images into a list and targets into a list.
Necessary because targets can have varying numbers of objects.
Args:
batch (list): A list of (image, target) tuples.
Returns:
tuple: A tuple containing a list of images and a list of targets.
"""
return tuple(zip(*batch))