Pausing for now, train loop should now work and adding some tests
This commit is contained in:
@@ -18,80 +18,91 @@ class PennFudanDataset(torch.utils.data.Dataset):
|
||||
self.masks = sorted(list(os.listdir(os.path.join(root, "PedMasks"))))
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# Load images and masks
|
||||
"""Get a sample from the dataset.
|
||||
|
||||
Args:
|
||||
idx (int): Index of the sample to retrieve.
|
||||
|
||||
Returns:
|
||||
tuple: (image, target) where target is a dictionary containing various object annotations.
|
||||
"""
|
||||
# Load image
|
||||
img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
|
||||
mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
|
||||
|
||||
# Use PIL to load images (more memory efficient)
|
||||
img = Image.open(img_path).convert("RGB")
|
||||
# Note: Masks are not converted to RGB, contains index values
|
||||
mask = Image.open(mask_path)
|
||||
|
||||
# Convert mask to numpy array
|
||||
# Convert mask PIL image to numpy array
|
||||
mask = np.array(mask)
|
||||
# Instances are encoded as different colors
|
||||
|
||||
# Find all object instances (each instance has a unique value in the mask)
|
||||
# Value 0 is the background
|
||||
obj_ids = np.unique(mask)
|
||||
# First id is the background, so remove it
|
||||
obj_ids = obj_ids[1:]
|
||||
obj_ids = obj_ids[1:] # Remove background (id=0)
|
||||
|
||||
# Split the color-encoded mask into a set of binary masks
|
||||
binary_masks = mask == obj_ids[:, None, None]
|
||||
# Split the mask into binary masks for each object instance
|
||||
masks = mask == obj_ids[:, None, None]
|
||||
|
||||
# Get bounding box coordinates for each mask
|
||||
# Get bounding box for each mask
|
||||
num_objs = len(obj_ids)
|
||||
boxes = []
|
||||
|
||||
for i in range(num_objs):
|
||||
pos = np.where(binary_masks[i])
|
||||
pos = np.where(masks[i])
|
||||
if len(pos[0]) == 0 or len(pos[1]) == 0: # Skip empty masks
|
||||
continue
|
||||
|
||||
xmin = np.min(pos[1])
|
||||
xmax = np.max(pos[1])
|
||||
ymin = np.min(pos[0])
|
||||
ymax = np.max(pos[0])
|
||||
# Filter out potentially empty masks or masks with zero area
|
||||
if xmax > xmin and ymax > ymin:
|
||||
boxes.append([xmin, ymin, xmax, ymax])
|
||||
else:
|
||||
# If box is invalid, we might need to handle this
|
||||
# For now, let's remove the corresponding mask as well
|
||||
# This requires careful index handling if filtering occurs
|
||||
# A safer approach might be to filter masks *after* box generation
|
||||
# Let's recalculate binary_masks based on valid boxes later if needed
|
||||
pass # placeholder for potential filtering logic
|
||||
|
||||
# Ensure boxes list isn't empty if filtering happened
|
||||
if not boxes:
|
||||
# Handle case with no valid boxes found - return dummy target? Or raise error?
|
||||
# For now, let's create dummy tensors. This should be revisited.
|
||||
print(
|
||||
f"Warning: No valid boxes found for image {idx}. Returning dummy target."
|
||||
)
|
||||
boxes = torch.zeros((0, 4), dtype=torch.float32)
|
||||
labels = torch.zeros((0,), dtype=torch.int64)
|
||||
binary_masks = torch.zeros(
|
||||
(0, mask.shape[0], mask.shape[1]), dtype=torch.uint8
|
||||
)
|
||||
image_id = torch.tensor([idx])
|
||||
area = torch.zeros((0,), dtype=torch.float32)
|
||||
iscrowd = torch.zeros((0,), dtype=torch.uint8)
|
||||
else:
|
||||
# Skip boxes with zero area
|
||||
if xmax <= xmin or ymax <= ymin:
|
||||
continue
|
||||
|
||||
boxes.append([xmin, ymin, xmax, ymax])
|
||||
|
||||
# Convert everything to tensors
|
||||
if boxes:
|
||||
boxes = torch.as_tensor(boxes, dtype=torch.float32)
|
||||
# There is only one class (pedestrian)
|
||||
labels = torch.ones((num_objs,), dtype=torch.int64)
|
||||
binary_masks = torch.as_tensor(binary_masks, dtype=torch.uint8)
|
||||
image_id = torch.tensor([idx])
|
||||
# Calculate area
|
||||
labels = torch.ones(
|
||||
(len(boxes),), dtype=torch.int64
|
||||
) # All objects are pedestrians (class 1)
|
||||
masks = torch.as_tensor(masks, dtype=torch.uint8)
|
||||
|
||||
# Calculate area of each box
|
||||
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
|
||||
# Assume all instances are not crowd
|
||||
iscrowd = torch.zeros((num_objs,), dtype=torch.uint8)
|
||||
|
||||
target = {}
|
||||
target["boxes"] = boxes
|
||||
target["labels"] = labels
|
||||
target["masks"] = binary_masks
|
||||
target["image_id"] = image_id
|
||||
target["area"] = area
|
||||
target["iscrowd"] = iscrowd
|
||||
# All instances are not crowd
|
||||
iscrowd = torch.zeros((len(boxes),), dtype=torch.uint8)
|
||||
|
||||
# Create the target dictionary
|
||||
target = {
|
||||
"boxes": boxes,
|
||||
"labels": labels,
|
||||
"masks": masks,
|
||||
"image_id": torch.tensor([idx]),
|
||||
"area": area,
|
||||
"iscrowd": iscrowd,
|
||||
}
|
||||
else:
|
||||
# Handle case with no valid objects (rare but possible)
|
||||
target = {
|
||||
"boxes": torch.zeros((0, 4), dtype=torch.float32),
|
||||
"labels": torch.zeros((0,), dtype=torch.int64),
|
||||
"masks": torch.zeros(
|
||||
(0, mask.shape[0], mask.shape[1]), dtype=torch.uint8
|
||||
),
|
||||
"image_id": torch.tensor([idx]),
|
||||
"area": torch.zeros((0,), dtype=torch.float32),
|
||||
"iscrowd": torch.zeros((0,), dtype=torch.uint8),
|
||||
}
|
||||
|
||||
# Apply transforms if provided
|
||||
if self.transforms is not None:
|
||||
# Apply transforms to both image and target
|
||||
# Note: torchvision v2 transforms handle target dicts automatically
|
||||
img, target = self.transforms(img, target)
|
||||
|
||||
return img, target
|
||||
@@ -117,15 +128,18 @@ def get_transform(train):
|
||||
# Convert to PyTorch tensor and normalize
|
||||
transforms.append(T.ToImage())
|
||||
|
||||
# Add resize transform to reduce memory usage (max size of 800px)
|
||||
transforms.append(T.Resize(800))
|
||||
# Resize images to control memory usage
|
||||
# Use a smaller size for training (more memory-intensive due to gradients)
|
||||
if train:
|
||||
transforms.append(T.Resize(700))
|
||||
else:
|
||||
transforms.append(T.Resize(800)) # Can use larger size for eval
|
||||
|
||||
transforms.append(T.ToDtype(torch.float32, scale=True))
|
||||
|
||||
# Data augmentation for training
|
||||
if train:
|
||||
transforms.append(T.RandomHorizontalFlip(0.5))
|
||||
# Could add more augmentations here if desired
|
||||
|
||||
return T.Compose(transforms)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user