Pausing for now, train loop should now work and adding some tests

This commit is contained in:
Craig
2025-04-12 12:01:13 +01:00
parent 2b38c04a57
commit be70c4e160
13 changed files with 967 additions and 58 deletions

View File

@@ -18,80 +18,91 @@ class PennFudanDataset(torch.utils.data.Dataset):
self.masks = sorted(list(os.listdir(os.path.join(root, "PedMasks"))))
def __getitem__(self, idx):
# Load images and masks
"""Get a sample from the dataset.
Args:
idx (int): Index of the sample to retrieve.
Returns:
tuple: (image, target) where target is a dictionary containing various object annotations.
"""
# Load image
img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
# Use PIL to load images (more memory efficient)
img = Image.open(img_path).convert("RGB")
# Note: Masks are not converted to RGB, contains index values
mask = Image.open(mask_path)
# Convert mask to numpy array
# Convert mask PIL image to numpy array
mask = np.array(mask)
# Instances are encoded as different colors
# Find all object instances (each instance has a unique value in the mask)
# Value 0 is the background
obj_ids = np.unique(mask)
# First id is the background, so remove it
obj_ids = obj_ids[1:]
obj_ids = obj_ids[1:] # Remove background (id=0)
# Split the color-encoded mask into a set of binary masks
binary_masks = mask == obj_ids[:, None, None]
# Split the mask into binary masks for each object instance
masks = mask == obj_ids[:, None, None]
# Get bounding box coordinates for each mask
# Get bounding box for each mask
num_objs = len(obj_ids)
boxes = []
for i in range(num_objs):
pos = np.where(binary_masks[i])
pos = np.where(masks[i])
if len(pos[0]) == 0 or len(pos[1]) == 0: # Skip empty masks
continue
xmin = np.min(pos[1])
xmax = np.max(pos[1])
ymin = np.min(pos[0])
ymax = np.max(pos[0])
# Filter out potentially empty masks or masks with zero area
if xmax > xmin and ymax > ymin:
boxes.append([xmin, ymin, xmax, ymax])
else:
# If box is invalid, we might need to handle this
# For now, let's remove the corresponding mask as well
# This requires careful index handling if filtering occurs
# A safer approach might be to filter masks *after* box generation
# Let's recalculate binary_masks based on valid boxes later if needed
pass # placeholder for potential filtering logic
# Ensure boxes list isn't empty if filtering happened
if not boxes:
# Handle case with no valid boxes found - return dummy target? Or raise error?
# For now, let's create dummy tensors. This should be revisited.
print(
f"Warning: No valid boxes found for image {idx}. Returning dummy target."
)
boxes = torch.zeros((0, 4), dtype=torch.float32)
labels = torch.zeros((0,), dtype=torch.int64)
binary_masks = torch.zeros(
(0, mask.shape[0], mask.shape[1]), dtype=torch.uint8
)
image_id = torch.tensor([idx])
area = torch.zeros((0,), dtype=torch.float32)
iscrowd = torch.zeros((0,), dtype=torch.uint8)
else:
# Skip boxes with zero area
if xmax <= xmin or ymax <= ymin:
continue
boxes.append([xmin, ymin, xmax, ymax])
# Convert everything to tensors
if boxes:
boxes = torch.as_tensor(boxes, dtype=torch.float32)
# There is only one class (pedestrian)
labels = torch.ones((num_objs,), dtype=torch.int64)
binary_masks = torch.as_tensor(binary_masks, dtype=torch.uint8)
image_id = torch.tensor([idx])
# Calculate area
labels = torch.ones(
(len(boxes),), dtype=torch.int64
) # All objects are pedestrians (class 1)
masks = torch.as_tensor(masks, dtype=torch.uint8)
# Calculate area of each box
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
# Assume all instances are not crowd
iscrowd = torch.zeros((num_objs,), dtype=torch.uint8)
target = {}
target["boxes"] = boxes
target["labels"] = labels
target["masks"] = binary_masks
target["image_id"] = image_id
target["area"] = area
target["iscrowd"] = iscrowd
# All instances are not crowd
iscrowd = torch.zeros((len(boxes),), dtype=torch.uint8)
# Create the target dictionary
target = {
"boxes": boxes,
"labels": labels,
"masks": masks,
"image_id": torch.tensor([idx]),
"area": area,
"iscrowd": iscrowd,
}
else:
# Handle case with no valid objects (rare but possible)
target = {
"boxes": torch.zeros((0, 4), dtype=torch.float32),
"labels": torch.zeros((0,), dtype=torch.int64),
"masks": torch.zeros(
(0, mask.shape[0], mask.shape[1]), dtype=torch.uint8
),
"image_id": torch.tensor([idx]),
"area": torch.zeros((0,), dtype=torch.float32),
"iscrowd": torch.zeros((0,), dtype=torch.uint8),
}
# Apply transforms if provided
if self.transforms is not None:
# Apply transforms to both image and target
# Note: torchvision v2 transforms handle target dicts automatically
img, target = self.transforms(img, target)
return img, target
@@ -117,15 +128,18 @@ def get_transform(train):
# Convert to PyTorch tensor and normalize
transforms.append(T.ToImage())
# Add resize transform to reduce memory usage (max size of 800px)
transforms.append(T.Resize(800))
# Resize images to control memory usage
# Use a smaller size for training (more memory-intensive due to gradients)
if train:
transforms.append(T.Resize(700))
else:
transforms.append(T.Resize(800)) # Can use larger size for eval
transforms.append(T.ToDtype(torch.float32, scale=True))
# Data augmentation for training
if train:
transforms.append(T.RandomHorizontalFlip(0.5))
# Could add more augmentations here if desired
return T.Compose(transforms)