Pausing for now, train loop should now work and adding some tests
This commit is contained in:
108
tests/test_data_utils.py
Normal file
108
tests/test_data_utils.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import torch
|
||||
|
||||
from utils.data_utils import collate_fn, get_transform
|
||||
|
||||
|
||||
def test_dataset_len(sample_dataset):
|
||||
"""Test that the dataset has the expected length."""
|
||||
# PennFudanPed has 170 images
|
||||
assert len(sample_dataset) > 0, "Dataset should not be empty"
|
||||
|
||||
|
||||
def test_dataset_getitem(sample_dataset):
|
||||
"""Test that __getitem__ returns expected format."""
|
||||
if len(sample_dataset) == 0:
|
||||
return # Skip if no data
|
||||
|
||||
# Get first item
|
||||
img, target = sample_dataset[0]
|
||||
|
||||
# Check image
|
||||
assert isinstance(img, torch.Tensor), "Image should be a tensor"
|
||||
assert img.dim() == 3, "Image should have 3 dimensions (C, H, W)"
|
||||
assert img.shape[0] == 3, "Image should have 3 channels (RGB)"
|
||||
|
||||
# Check target
|
||||
assert isinstance(target, dict), "Target should be a dictionary"
|
||||
assert "boxes" in target, "Target should contain 'boxes'"
|
||||
assert "labels" in target, "Target should contain 'labels'"
|
||||
assert "masks" in target, "Target should contain 'masks'"
|
||||
assert "image_id" in target, "Target should contain 'image_id'"
|
||||
assert "area" in target, "Target should contain 'area'"
|
||||
assert "iscrowd" in target, "Target should contain 'iscrowd'"
|
||||
|
||||
# Check target values
|
||||
assert (
|
||||
target["boxes"].shape[1] == 4
|
||||
), "Boxes should have 4 coordinates (x1, y1, x2, y2)"
|
||||
assert target["labels"].dim() == 1, "Labels should be a 1D tensor"
|
||||
assert target["masks"].dim() == 3, "Masks should be a 3D tensor (N, H, W)"
|
||||
|
||||
|
||||
def test_transforms(sample_dataset):
|
||||
"""Test that transforms are applied correctly."""
|
||||
if len(sample_dataset) == 0:
|
||||
return # Skip if no data
|
||||
|
||||
# Get original transform
|
||||
orig_transforms = sample_dataset.transforms
|
||||
|
||||
# Apply different transforms
|
||||
train_transforms = get_transform(train=True)
|
||||
eval_transforms = get_transform(train=False)
|
||||
|
||||
# Test that we can switch transforms
|
||||
sample_dataset.transforms = train_transforms
|
||||
img_train, target_train = sample_dataset[0]
|
||||
|
||||
sample_dataset.transforms = eval_transforms
|
||||
img_eval, target_eval = sample_dataset[0]
|
||||
|
||||
# Restore original transforms
|
||||
sample_dataset.transforms = orig_transforms
|
||||
|
||||
# Images should be tensors with expected properties
|
||||
assert img_train.dim() == img_eval.dim() == 3
|
||||
assert img_train.shape[0] == img_eval.shape[0] == 3
|
||||
|
||||
|
||||
def test_collate_fn():
|
||||
"""Test the collate function."""
|
||||
# Create dummy batch data
|
||||
dummy_img1 = torch.rand(3, 100, 100)
|
||||
dummy_img2 = torch.rand(3, 100, 100)
|
||||
|
||||
dummy_target1 = {
|
||||
"boxes": torch.tensor([[10, 10, 50, 50]], dtype=torch.float32),
|
||||
"labels": torch.tensor([1], dtype=torch.int64),
|
||||
"masks": torch.zeros(1, 100, 100, dtype=torch.uint8),
|
||||
"image_id": torch.tensor([0]),
|
||||
"area": torch.tensor([1600.0], dtype=torch.float32),
|
||||
"iscrowd": torch.tensor([0], dtype=torch.uint8),
|
||||
}
|
||||
|
||||
dummy_target2 = {
|
||||
"boxes": torch.tensor([[20, 20, 60, 60]], dtype=torch.float32),
|
||||
"labels": torch.tensor([1], dtype=torch.int64),
|
||||
"masks": torch.zeros(1, 100, 100, dtype=torch.uint8),
|
||||
"image_id": torch.tensor([1]),
|
||||
"area": torch.tensor([1600.0], dtype=torch.float32),
|
||||
"iscrowd": torch.tensor([0], dtype=torch.uint8),
|
||||
}
|
||||
|
||||
batch = [(dummy_img1, dummy_target1), (dummy_img2, dummy_target2)]
|
||||
|
||||
# Apply collate_fn
|
||||
images, targets = collate_fn(batch)
|
||||
|
||||
# Check results
|
||||
assert len(images) == 2, "Should have 2 images"
|
||||
assert len(targets) == 2, "Should have 2 targets"
|
||||
assert torch.allclose(images[0], dummy_img1), "First image should match"
|
||||
assert torch.allclose(images[1], dummy_img2), "Second image should match"
|
||||
assert torch.allclose(
|
||||
targets[0]["boxes"], dummy_target1["boxes"]
|
||||
), "First boxes should match"
|
||||
assert torch.allclose(
|
||||
targets[1]["boxes"], dummy_target2["boxes"]
|
||||
), "Second boxes should match"
|
||||
Reference in New Issue
Block a user