109 lines
3.8 KiB
Python
109 lines
3.8 KiB
Python
import torch
|
|
|
|
from utils.data_utils import collate_fn, get_transform
|
|
|
|
|
|
def test_dataset_len(sample_dataset):
|
|
"""Test that the dataset has the expected length."""
|
|
# PennFudanPed has 170 images
|
|
assert len(sample_dataset) > 0, "Dataset should not be empty"
|
|
|
|
|
|
def test_dataset_getitem(sample_dataset):
|
|
"""Test that __getitem__ returns expected format."""
|
|
if len(sample_dataset) == 0:
|
|
return # Skip if no data
|
|
|
|
# Get first item
|
|
img, target = sample_dataset[0]
|
|
|
|
# Check image
|
|
assert isinstance(img, torch.Tensor), "Image should be a tensor"
|
|
assert img.dim() == 3, "Image should have 3 dimensions (C, H, W)"
|
|
assert img.shape[0] == 3, "Image should have 3 channels (RGB)"
|
|
|
|
# Check target
|
|
assert isinstance(target, dict), "Target should be a dictionary"
|
|
assert "boxes" in target, "Target should contain 'boxes'"
|
|
assert "labels" in target, "Target should contain 'labels'"
|
|
assert "masks" in target, "Target should contain 'masks'"
|
|
assert "image_id" in target, "Target should contain 'image_id'"
|
|
assert "area" in target, "Target should contain 'area'"
|
|
assert "iscrowd" in target, "Target should contain 'iscrowd'"
|
|
|
|
# Check target values
|
|
assert (
|
|
target["boxes"].shape[1] == 4
|
|
), "Boxes should have 4 coordinates (x1, y1, x2, y2)"
|
|
assert target["labels"].dim() == 1, "Labels should be a 1D tensor"
|
|
assert target["masks"].dim() == 3, "Masks should be a 3D tensor (N, H, W)"
|
|
|
|
|
|
def test_transforms(sample_dataset):
|
|
"""Test that transforms are applied correctly."""
|
|
if len(sample_dataset) == 0:
|
|
return # Skip if no data
|
|
|
|
# Get original transform
|
|
orig_transforms = sample_dataset.transforms
|
|
|
|
# Apply different transforms
|
|
train_transforms = get_transform(train=True)
|
|
eval_transforms = get_transform(train=False)
|
|
|
|
# Test that we can switch transforms
|
|
sample_dataset.transforms = train_transforms
|
|
img_train, target_train = sample_dataset[0]
|
|
|
|
sample_dataset.transforms = eval_transforms
|
|
img_eval, target_eval = sample_dataset[0]
|
|
|
|
# Restore original transforms
|
|
sample_dataset.transforms = orig_transforms
|
|
|
|
# Images should be tensors with expected properties
|
|
assert img_train.dim() == img_eval.dim() == 3
|
|
assert img_train.shape[0] == img_eval.shape[0] == 3
|
|
|
|
|
|
def test_collate_fn():
|
|
"""Test the collate function."""
|
|
# Create dummy batch data
|
|
dummy_img1 = torch.rand(3, 100, 100)
|
|
dummy_img2 = torch.rand(3, 100, 100)
|
|
|
|
dummy_target1 = {
|
|
"boxes": torch.tensor([[10, 10, 50, 50]], dtype=torch.float32),
|
|
"labels": torch.tensor([1], dtype=torch.int64),
|
|
"masks": torch.zeros(1, 100, 100, dtype=torch.uint8),
|
|
"image_id": torch.tensor([0]),
|
|
"area": torch.tensor([1600.0], dtype=torch.float32),
|
|
"iscrowd": torch.tensor([0], dtype=torch.uint8),
|
|
}
|
|
|
|
dummy_target2 = {
|
|
"boxes": torch.tensor([[20, 20, 60, 60]], dtype=torch.float32),
|
|
"labels": torch.tensor([1], dtype=torch.int64),
|
|
"masks": torch.zeros(1, 100, 100, dtype=torch.uint8),
|
|
"image_id": torch.tensor([1]),
|
|
"area": torch.tensor([1600.0], dtype=torch.float32),
|
|
"iscrowd": torch.tensor([0], dtype=torch.uint8),
|
|
}
|
|
|
|
batch = [(dummy_img1, dummy_target1), (dummy_img2, dummy_target2)]
|
|
|
|
# Apply collate_fn
|
|
images, targets = collate_fn(batch)
|
|
|
|
# Check results
|
|
assert len(images) == 2, "Should have 2 images"
|
|
assert len(targets) == 2, "Should have 2 targets"
|
|
assert torch.allclose(images[0], dummy_img1), "First image should match"
|
|
assert torch.allclose(images[1], dummy_img2), "Second image should match"
|
|
assert torch.allclose(
|
|
targets[0]["boxes"], dummy_target1["boxes"]
|
|
), "First boxes should match"
|
|
assert torch.allclose(
|
|
targets[1]["boxes"], dummy_target2["boxes"]
|
|
), "Second boxes should match"
|