Files
torchvision-vibecoding-project/tests/test_data_utils.py

109 lines
3.8 KiB
Python

import torch
from utils.data_utils import collate_fn, get_transform
def test_dataset_len(sample_dataset):
"""Test that the dataset has the expected length."""
# PennFudanPed has 170 images
assert len(sample_dataset) > 0, "Dataset should not be empty"
def test_dataset_getitem(sample_dataset):
"""Test that __getitem__ returns expected format."""
if len(sample_dataset) == 0:
return # Skip if no data
# Get first item
img, target = sample_dataset[0]
# Check image
assert isinstance(img, torch.Tensor), "Image should be a tensor"
assert img.dim() == 3, "Image should have 3 dimensions (C, H, W)"
assert img.shape[0] == 3, "Image should have 3 channels (RGB)"
# Check target
assert isinstance(target, dict), "Target should be a dictionary"
assert "boxes" in target, "Target should contain 'boxes'"
assert "labels" in target, "Target should contain 'labels'"
assert "masks" in target, "Target should contain 'masks'"
assert "image_id" in target, "Target should contain 'image_id'"
assert "area" in target, "Target should contain 'area'"
assert "iscrowd" in target, "Target should contain 'iscrowd'"
# Check target values
assert (
target["boxes"].shape[1] == 4
), "Boxes should have 4 coordinates (x1, y1, x2, y2)"
assert target["labels"].dim() == 1, "Labels should be a 1D tensor"
assert target["masks"].dim() == 3, "Masks should be a 3D tensor (N, H, W)"
def test_transforms(sample_dataset):
"""Test that transforms are applied correctly."""
if len(sample_dataset) == 0:
return # Skip if no data
# Get original transform
orig_transforms = sample_dataset.transforms
# Apply different transforms
train_transforms = get_transform(train=True)
eval_transforms = get_transform(train=False)
# Test that we can switch transforms
sample_dataset.transforms = train_transforms
img_train, target_train = sample_dataset[0]
sample_dataset.transforms = eval_transforms
img_eval, target_eval = sample_dataset[0]
# Restore original transforms
sample_dataset.transforms = orig_transforms
# Images should be tensors with expected properties
assert img_train.dim() == img_eval.dim() == 3
assert img_train.shape[0] == img_eval.shape[0] == 3
def test_collate_fn():
"""Test the collate function."""
# Create dummy batch data
dummy_img1 = torch.rand(3, 100, 100)
dummy_img2 = torch.rand(3, 100, 100)
dummy_target1 = {
"boxes": torch.tensor([[10, 10, 50, 50]], dtype=torch.float32),
"labels": torch.tensor([1], dtype=torch.int64),
"masks": torch.zeros(1, 100, 100, dtype=torch.uint8),
"image_id": torch.tensor([0]),
"area": torch.tensor([1600.0], dtype=torch.float32),
"iscrowd": torch.tensor([0], dtype=torch.uint8),
}
dummy_target2 = {
"boxes": torch.tensor([[20, 20, 60, 60]], dtype=torch.float32),
"labels": torch.tensor([1], dtype=torch.int64),
"masks": torch.zeros(1, 100, 100, dtype=torch.uint8),
"image_id": torch.tensor([1]),
"area": torch.tensor([1600.0], dtype=torch.float32),
"iscrowd": torch.tensor([0], dtype=torch.uint8),
}
batch = [(dummy_img1, dummy_target1), (dummy_img2, dummy_target2)]
# Apply collate_fn
images, targets = collate_fn(batch)
# Check results
assert len(images) == 2, "Should have 2 images"
assert len(targets) == 2, "Should have 2 targets"
assert torch.allclose(images[0], dummy_img1), "First image should match"
assert torch.allclose(images[1], dummy_img2), "Second image should match"
assert torch.allclose(
targets[0]["boxes"], dummy_target1["boxes"]
), "First boxes should match"
assert torch.allclose(
targets[1]["boxes"], dummy_target2["boxes"]
), "Second boxes should match"