147 lines
5.4 KiB
Python
147 lines
5.4 KiB
Python
import os
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.utils.data
|
|
import torchvision.transforms.v2 as T
|
|
from PIL import Image
|
|
|
|
|
|
class PennFudanDataset(torch.utils.data.Dataset):
|
|
"""Dataset class for the Penn-Fudan Pedestrian Detection dataset."""
|
|
|
|
def __init__(self, root, transforms):
|
|
self.root = root
|
|
self.transforms = transforms
|
|
# Load all image files, sorting them to ensure alignment
|
|
self.imgs = sorted(list(os.listdir(os.path.join(root, "PNGImages"))))
|
|
self.masks = sorted(list(os.listdir(os.path.join(root, "PedMasks"))))
|
|
|
|
def __getitem__(self, idx):
|
|
# Load images and masks
|
|
img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
|
|
mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
|
|
img = Image.open(img_path).convert("RGB")
|
|
# Note: Masks are not converted to RGB, contains index values
|
|
mask = Image.open(mask_path)
|
|
|
|
# Convert mask to numpy array
|
|
mask = np.array(mask)
|
|
# Instances are encoded as different colors
|
|
obj_ids = np.unique(mask)
|
|
# First id is the background, so remove it
|
|
obj_ids = obj_ids[1:]
|
|
|
|
# Split the color-encoded mask into a set of binary masks
|
|
binary_masks = mask == obj_ids[:, None, None]
|
|
|
|
# Get bounding box coordinates for each mask
|
|
num_objs = len(obj_ids)
|
|
boxes = []
|
|
for i in range(num_objs):
|
|
pos = np.where(binary_masks[i])
|
|
xmin = np.min(pos[1])
|
|
xmax = np.max(pos[1])
|
|
ymin = np.min(pos[0])
|
|
ymax = np.max(pos[0])
|
|
# Filter out potentially empty masks or masks with zero area
|
|
if xmax > xmin and ymax > ymin:
|
|
boxes.append([xmin, ymin, xmax, ymax])
|
|
else:
|
|
# If box is invalid, we might need to handle this
|
|
# For now, let's remove the corresponding mask as well
|
|
# This requires careful index handling if filtering occurs
|
|
# A safer approach might be to filter masks *after* box generation
|
|
# Let's recalculate binary_masks based on valid boxes later if needed
|
|
pass # placeholder for potential filtering logic
|
|
|
|
# Ensure boxes list isn't empty if filtering happened
|
|
if not boxes:
|
|
# Handle case with no valid boxes found - return dummy target? Or raise error?
|
|
# For now, let's create dummy tensors. This should be revisited.
|
|
print(
|
|
f"Warning: No valid boxes found for image {idx}. Returning dummy target."
|
|
)
|
|
boxes = torch.zeros((0, 4), dtype=torch.float32)
|
|
labels = torch.zeros((0,), dtype=torch.int64)
|
|
binary_masks = torch.zeros(
|
|
(0, mask.shape[0], mask.shape[1]), dtype=torch.uint8
|
|
)
|
|
image_id = torch.tensor([idx])
|
|
area = torch.zeros((0,), dtype=torch.float32)
|
|
iscrowd = torch.zeros((0,), dtype=torch.uint8)
|
|
else:
|
|
boxes = torch.as_tensor(boxes, dtype=torch.float32)
|
|
# There is only one class (pedestrian)
|
|
labels = torch.ones((num_objs,), dtype=torch.int64)
|
|
binary_masks = torch.as_tensor(binary_masks, dtype=torch.uint8)
|
|
image_id = torch.tensor([idx])
|
|
# Calculate area
|
|
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
|
|
# Assume all instances are not crowd
|
|
iscrowd = torch.zeros((num_objs,), dtype=torch.uint8)
|
|
|
|
target = {}
|
|
target["boxes"] = boxes
|
|
target["labels"] = labels
|
|
target["masks"] = binary_masks
|
|
target["image_id"] = image_id
|
|
target["area"] = area
|
|
target["iscrowd"] = iscrowd
|
|
|
|
if self.transforms is not None:
|
|
# Apply transforms to both image and target
|
|
# Note: torchvision v2 transforms handle target dicts automatically
|
|
img, target = self.transforms(img, target)
|
|
|
|
return img, target
|
|
|
|
def __len__(self):
|
|
return len(self.imgs)
|
|
|
|
|
|
# --- Utility Functions --- #
|
|
|
|
|
|
def get_transform(train):
|
|
"""Gets the appropriate set of transforms.
|
|
|
|
Args:
|
|
train (bool): Whether to apply training augmentations.
|
|
|
|
Returns:
|
|
torchvision.transforms.Compose: A composed Torchvision transform.
|
|
"""
|
|
transforms = []
|
|
# Always convert image to PyTorch tensor and scale to [0, 1]
|
|
transforms.append(T.ToImage())
|
|
transforms.append(T.ToDtype(torch.float32, scale=True))
|
|
|
|
if train:
|
|
# Add simple data augmentation for training
|
|
transforms.append(T.RandomHorizontalFlip(p=0.5))
|
|
# Add other augmentations here if needed
|
|
# e.g., T.ColorJitter(...), T.RandomResizedCrop(...) ensuring
|
|
# bounding boxes/masks are handled correctly by v2 transforms.
|
|
|
|
# Note: Normalization (e.g., T.Normalize) is often applied,
|
|
# but pre-trained models in torchvision usually handle this internally
|
|
# or expect [0, 1] range inputs.
|
|
|
|
return T.Compose(transforms)
|
|
|
|
|
|
def collate_fn(batch):
|
|
"""Custom collate function for object detection models.
|
|
|
|
It aggregates images into a list and targets into a list.
|
|
Necessary because targets can have varying numbers of objects.
|
|
|
|
Args:
|
|
batch (list): A list of (image, target) tuples.
|
|
|
|
Returns:
|
|
tuple: A tuple containing a list of images and a list of targets.
|
|
"""
|
|
return tuple(zip(*batch))
|