format
This commit is contained in:
12
todo.md
12
todo.md
@@ -20,12 +20,12 @@ This list outlines the steps required to complete the Torchvision Finetuning pro
|
||||
|
||||
## Phase 2: Data Handling & Model
|
||||
|
||||
- [ ] Implement `PennFudanDataset` class in `utils/data_utils.py`.
|
||||
- [ ] `__init__`: Load image and mask paths.
|
||||
- [ ] `__getitem__`: Load image/mask, parse masks, generate targets (boxes, labels, masks, image_id, area, iscrowd), apply transforms.
|
||||
- [ ] `__len__`: Return dataset size.
|
||||
- [ ] Implement `get_transform(train)` function in `utils/data_utils.py` (using `torchvision.transforms.v2`).
|
||||
- [ ] Implement `collate_fn(batch)` function in `utils/data_utils.py`.
|
||||
- [x] Implement `PennFudanDataset` class in `utils/data_utils.py`.
|
||||
- [x] `__init__`: Load image and mask paths.
|
||||
- [x] `__getitem__`: Load image/mask, parse masks, generate targets (boxes, labels, masks, image_id, area, iscrowd), apply transforms.
|
||||
- [x] `__len__`: Return dataset size.
|
||||
- [x] Implement `get_transform(train)` function in `utils/data_utils.py` (using `torchvision.transforms.v2`).
|
||||
- [x] Implement `collate_fn(batch)` function in `utils/data_utils.py`.
|
||||
- [ ] Implement `get_maskrcnn_model(num_classes, ...)` function in `models/detection.py`.
|
||||
- [ ] Load pre-trained Mask R-CNN (`maskrcnn_resnet50_fpn_v2`).
|
||||
- [ ] Replace box predictor head (`FastRCNNPredictor`).
|
||||
|
||||
@@ -0,0 +1,146 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import torchvision.transforms.v2 as T
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class PennFudanDataset(torch.utils.data.Dataset):
|
||||
"""Dataset class for the Penn-Fudan Pedestrian Detection dataset."""
|
||||
|
||||
def __init__(self, root, transforms):
|
||||
self.root = root
|
||||
self.transforms = transforms
|
||||
# Load all image files, sorting them to ensure alignment
|
||||
self.imgs = sorted(list(os.listdir(os.path.join(root, "PNGImages"))))
|
||||
self.masks = sorted(list(os.listdir(os.path.join(root, "PedMasks"))))
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# Load images and masks
|
||||
img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
|
||||
mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
|
||||
img = Image.open(img_path).convert("RGB")
|
||||
# Note: Masks are not converted to RGB, contains index values
|
||||
mask = Image.open(mask_path)
|
||||
|
||||
# Convert mask to numpy array
|
||||
mask = np.array(mask)
|
||||
# Instances are encoded as different colors
|
||||
obj_ids = np.unique(mask)
|
||||
# First id is the background, so remove it
|
||||
obj_ids = obj_ids[1:]
|
||||
|
||||
# Split the color-encoded mask into a set of binary masks
|
||||
binary_masks = mask == obj_ids[:, None, None]
|
||||
|
||||
# Get bounding box coordinates for each mask
|
||||
num_objs = len(obj_ids)
|
||||
boxes = []
|
||||
for i in range(num_objs):
|
||||
pos = np.where(binary_masks[i])
|
||||
xmin = np.min(pos[1])
|
||||
xmax = np.max(pos[1])
|
||||
ymin = np.min(pos[0])
|
||||
ymax = np.max(pos[0])
|
||||
# Filter out potentially empty masks or masks with zero area
|
||||
if xmax > xmin and ymax > ymin:
|
||||
boxes.append([xmin, ymin, xmax, ymax])
|
||||
else:
|
||||
# If box is invalid, we might need to handle this
|
||||
# For now, let's remove the corresponding mask as well
|
||||
# This requires careful index handling if filtering occurs
|
||||
# A safer approach might be to filter masks *after* box generation
|
||||
# Let's recalculate binary_masks based on valid boxes later if needed
|
||||
pass # placeholder for potential filtering logic
|
||||
|
||||
# Ensure boxes list isn't empty if filtering happened
|
||||
if not boxes:
|
||||
# Handle case with no valid boxes found - return dummy target? Or raise error?
|
||||
# For now, let's create dummy tensors. This should be revisited.
|
||||
print(
|
||||
f"Warning: No valid boxes found for image {idx}. Returning dummy target."
|
||||
)
|
||||
boxes = torch.zeros((0, 4), dtype=torch.float32)
|
||||
labels = torch.zeros((0,), dtype=torch.int64)
|
||||
binary_masks = torch.zeros(
|
||||
(0, mask.shape[0], mask.shape[1]), dtype=torch.uint8
|
||||
)
|
||||
image_id = torch.tensor([idx])
|
||||
area = torch.zeros((0,), dtype=torch.float32)
|
||||
iscrowd = torch.zeros((0,), dtype=torch.uint8)
|
||||
else:
|
||||
boxes = torch.as_tensor(boxes, dtype=torch.float32)
|
||||
# There is only one class (pedestrian)
|
||||
labels = torch.ones((num_objs,), dtype=torch.int64)
|
||||
binary_masks = torch.as_tensor(binary_masks, dtype=torch.uint8)
|
||||
image_id = torch.tensor([idx])
|
||||
# Calculate area
|
||||
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
|
||||
# Assume all instances are not crowd
|
||||
iscrowd = torch.zeros((num_objs,), dtype=torch.uint8)
|
||||
|
||||
target = {}
|
||||
target["boxes"] = boxes
|
||||
target["labels"] = labels
|
||||
target["masks"] = binary_masks
|
||||
target["image_id"] = image_id
|
||||
target["area"] = area
|
||||
target["iscrowd"] = iscrowd
|
||||
|
||||
if self.transforms is not None:
|
||||
# Apply transforms to both image and target
|
||||
# Note: torchvision v2 transforms handle target dicts automatically
|
||||
img, target = self.transforms(img, target)
|
||||
|
||||
return img, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.imgs)
|
||||
|
||||
|
||||
# --- Utility Functions --- #
|
||||
|
||||
|
||||
def get_transform(train):
|
||||
"""Gets the appropriate set of transforms.
|
||||
|
||||
Args:
|
||||
train (bool): Whether to apply training augmentations.
|
||||
|
||||
Returns:
|
||||
torchvision.transforms.Compose: A composed Torchvision transform.
|
||||
"""
|
||||
transforms = []
|
||||
# Always convert image to PyTorch tensor and scale to [0, 1]
|
||||
transforms.append(T.ToImage())
|
||||
transforms.append(T.ToDtype(torch.float32, scale=True))
|
||||
|
||||
if train:
|
||||
# Add simple data augmentation for training
|
||||
transforms.append(T.RandomHorizontalFlip(p=0.5))
|
||||
# Add other augmentations here if needed
|
||||
# e.g., T.ColorJitter(...), T.RandomResizedCrop(...) ensuring
|
||||
# bounding boxes/masks are handled correctly by v2 transforms.
|
||||
|
||||
# Note: Normalization (e.g., T.Normalize) is often applied,
|
||||
# but pre-trained models in torchvision usually handle this internally
|
||||
# or expect [0, 1] range inputs.
|
||||
|
||||
return T.Compose(transforms)
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
"""Custom collate function for object detection models.
|
||||
|
||||
It aggregates images into a list and targets into a list.
|
||||
Necessary because targets can have varying numbers of objects.
|
||||
|
||||
Args:
|
||||
batch (list): A list of (image, target) tuples.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing a list of images and a list of targets.
|
||||
"""
|
||||
return tuple(zip(*batch))
|
||||
|
||||
Reference in New Issue
Block a user