160 lines
4.9 KiB
Python
160 lines
4.9 KiB
Python
import os
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.utils.data
|
|
import torchvision.transforms.v2 as T
|
|
from PIL import Image
|
|
|
|
|
|
class PennFudanDataset(torch.utils.data.Dataset):
|
|
"""Dataset class for the Penn-Fudan Pedestrian Detection dataset."""
|
|
|
|
def __init__(self, root, transforms):
|
|
self.root = root
|
|
self.transforms = transforms
|
|
# Load all image files, sorting them to ensure alignment
|
|
self.imgs = sorted(list(os.listdir(os.path.join(root, "PNGImages"))))
|
|
self.masks = sorted(list(os.listdir(os.path.join(root, "PedMasks"))))
|
|
|
|
def __getitem__(self, idx):
|
|
"""Get a sample from the dataset.
|
|
|
|
Args:
|
|
idx (int): Index of the sample to retrieve.
|
|
|
|
Returns:
|
|
tuple: (image, target) where target is a dictionary containing various object annotations.
|
|
"""
|
|
# Load image
|
|
img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
|
|
mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
|
|
|
|
# Use PIL to load images (more memory efficient)
|
|
img = Image.open(img_path).convert("RGB")
|
|
mask = Image.open(mask_path)
|
|
|
|
# Convert mask PIL image to numpy array
|
|
mask = np.array(mask)
|
|
|
|
# Find all object instances (each instance has a unique value in the mask)
|
|
# Value 0 is the background
|
|
obj_ids = np.unique(mask)
|
|
obj_ids = obj_ids[1:] # Remove background (id=0)
|
|
|
|
# Split the mask into binary masks for each object instance
|
|
masks = mask == obj_ids[:, None, None]
|
|
|
|
# Get bounding box for each mask
|
|
num_objs = len(obj_ids)
|
|
boxes = []
|
|
|
|
for i in range(num_objs):
|
|
pos = np.where(masks[i])
|
|
if len(pos[0]) == 0 or len(pos[1]) == 0: # Skip empty masks
|
|
continue
|
|
|
|
xmin = np.min(pos[1])
|
|
xmax = np.max(pos[1])
|
|
ymin = np.min(pos[0])
|
|
ymax = np.max(pos[0])
|
|
|
|
# Skip boxes with zero area
|
|
if xmax <= xmin or ymax <= ymin:
|
|
continue
|
|
|
|
boxes.append([xmin, ymin, xmax, ymax])
|
|
|
|
# Convert everything to tensors
|
|
if boxes:
|
|
boxes = torch.as_tensor(boxes, dtype=torch.float32)
|
|
labels = torch.ones(
|
|
(len(boxes),), dtype=torch.int64
|
|
) # All objects are pedestrians (class 1)
|
|
masks = torch.as_tensor(masks, dtype=torch.uint8)
|
|
|
|
# Calculate area of each box
|
|
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
|
|
|
|
# All instances are not crowd
|
|
iscrowd = torch.zeros((len(boxes),), dtype=torch.uint8)
|
|
|
|
# Create the target dictionary
|
|
target = {
|
|
"boxes": boxes,
|
|
"labels": labels,
|
|
"masks": masks,
|
|
"image_id": torch.tensor([idx]),
|
|
"area": area,
|
|
"iscrowd": iscrowd,
|
|
}
|
|
else:
|
|
# Handle case with no valid objects (rare but possible)
|
|
target = {
|
|
"boxes": torch.zeros((0, 4), dtype=torch.float32),
|
|
"labels": torch.zeros((0,), dtype=torch.int64),
|
|
"masks": torch.zeros(
|
|
(0, mask.shape[0], mask.shape[1]), dtype=torch.uint8
|
|
),
|
|
"image_id": torch.tensor([idx]),
|
|
"area": torch.zeros((0,), dtype=torch.float32),
|
|
"iscrowd": torch.zeros((0,), dtype=torch.uint8),
|
|
}
|
|
|
|
# Apply transforms if provided
|
|
if self.transforms is not None:
|
|
img, target = self.transforms(img, target)
|
|
|
|
return img, target
|
|
|
|
def __len__(self):
|
|
return len(self.imgs)
|
|
|
|
|
|
# --- Utility Functions --- #
|
|
|
|
|
|
def get_transform(train):
|
|
"""Get the transformations for the dataset.
|
|
|
|
Args:
|
|
train (bool): Whether to get transforms for training or evaluation.
|
|
|
|
Returns:
|
|
torchvision.transforms.Compose: The composed transforms.
|
|
"""
|
|
transforms = []
|
|
|
|
# Convert to PyTorch tensor and normalize
|
|
transforms.append(T.ToImage())
|
|
|
|
# Resize images to control memory usage
|
|
# Use a smaller size for training (more memory-intensive due to gradients)
|
|
if train:
|
|
transforms.append(T.Resize(700))
|
|
else:
|
|
transforms.append(T.Resize(800)) # Can use larger size for eval
|
|
|
|
transforms.append(T.ToDtype(torch.float32, scale=True))
|
|
|
|
# Data augmentation for training
|
|
if train:
|
|
transforms.append(T.RandomHorizontalFlip(0.5))
|
|
|
|
return T.Compose(transforms)
|
|
|
|
|
|
def collate_fn(batch):
|
|
"""Custom collate function for object detection models.
|
|
|
|
It aggregates images into a list and targets into a list.
|
|
Necessary because targets can have varying numbers of objects.
|
|
|
|
Args:
|
|
batch (list): A list of (image, target) tuples.
|
|
|
|
Returns:
|
|
tuple: A tuple containing a list of images and a list of targets.
|
|
"""
|
|
return tuple(zip(*batch))
|