import torch from utils.data_utils import collate_fn, get_transform def test_dataset_len(sample_dataset): """Test that the dataset has the expected length.""" # PennFudanPed has 170 images assert len(sample_dataset) > 0, "Dataset should not be empty" def test_dataset_getitem(sample_dataset): """Test that __getitem__ returns expected format.""" if len(sample_dataset) == 0: return # Skip if no data # Get first item img, target = sample_dataset[0] # Check image assert isinstance(img, torch.Tensor), "Image should be a tensor" assert img.dim() == 3, "Image should have 3 dimensions (C, H, W)" assert img.shape[0] == 3, "Image should have 3 channels (RGB)" # Check target assert isinstance(target, dict), "Target should be a dictionary" assert "boxes" in target, "Target should contain 'boxes'" assert "labels" in target, "Target should contain 'labels'" assert "masks" in target, "Target should contain 'masks'" assert "image_id" in target, "Target should contain 'image_id'" assert "area" in target, "Target should contain 'area'" assert "iscrowd" in target, "Target should contain 'iscrowd'" # Check target values assert ( target["boxes"].shape[1] == 4 ), "Boxes should have 4 coordinates (x1, y1, x2, y2)" assert target["labels"].dim() == 1, "Labels should be a 1D tensor" assert target["masks"].dim() == 3, "Masks should be a 3D tensor (N, H, W)" def test_transforms(sample_dataset): """Test that transforms are applied correctly.""" if len(sample_dataset) == 0: return # Skip if no data # Get original transform orig_transforms = sample_dataset.transforms # Apply different transforms train_transforms = get_transform(train=True) eval_transforms = get_transform(train=False) # Test that we can switch transforms sample_dataset.transforms = train_transforms img_train, target_train = sample_dataset[0] sample_dataset.transforms = eval_transforms img_eval, target_eval = sample_dataset[0] # Restore original transforms sample_dataset.transforms = orig_transforms # Images should be tensors with expected properties assert img_train.dim() == img_eval.dim() == 3 assert img_train.shape[0] == img_eval.shape[0] == 3 def test_collate_fn(): """Test the collate function.""" # Create dummy batch data dummy_img1 = torch.rand(3, 100, 100) dummy_img2 = torch.rand(3, 100, 100) dummy_target1 = { "boxes": torch.tensor([[10, 10, 50, 50]], dtype=torch.float32), "labels": torch.tensor([1], dtype=torch.int64), "masks": torch.zeros(1, 100, 100, dtype=torch.uint8), "image_id": torch.tensor([0]), "area": torch.tensor([1600.0], dtype=torch.float32), "iscrowd": torch.tensor([0], dtype=torch.uint8), } dummy_target2 = { "boxes": torch.tensor([[20, 20, 60, 60]], dtype=torch.float32), "labels": torch.tensor([1], dtype=torch.int64), "masks": torch.zeros(1, 100, 100, dtype=torch.uint8), "image_id": torch.tensor([1]), "area": torch.tensor([1600.0], dtype=torch.float32), "iscrowd": torch.tensor([0], dtype=torch.uint8), } batch = [(dummy_img1, dummy_target1), (dummy_img2, dummy_target2)] # Apply collate_fn images, targets = collate_fn(batch) # Check results assert len(images) == 2, "Should have 2 images" assert len(targets) == 2, "Should have 2 targets" assert torch.allclose(images[0], dummy_img1), "First image should match" assert torch.allclose(images[1], dummy_img2), "Second image should match" assert torch.allclose( targets[0]["boxes"], dummy_target1["boxes"] ), "First boxes should match" assert torch.allclose( targets[1]["boxes"], dummy_target2["boxes"] ), "Second boxes should match"