import os import sys from pathlib import Path import pytest import torch # Add project root to the path to enable imports project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) from models.detection import get_maskrcnn_model # noqa: E402 from utils.data_utils import PennFudanDataset, get_transform # noqa: E402 @pytest.fixture def device(): """Return CPU device for consistent testing.""" return torch.device("cpu") @pytest.fixture def test_config(): """Return a minimal config dictionary for testing.""" return { "data_root": "data/PennFudanPed", "num_classes": 2, "batch_size": 1, "device": "cpu", "output_dir": "test_outputs", "config_name": "test_run", } @pytest.fixture def small_model(device): """Return a small Mask R-CNN model for testing.""" model = get_maskrcnn_model( num_classes=2, pretrained=False, pretrained_backbone=False ) model.to(device) return model @pytest.fixture def sample_dataset(): """Return a small dataset for testing if available.""" data_root = "data/PennFudanPed" # Skip if data is not available if not os.path.exists(data_root): pytest.skip("Test dataset not available") transforms = get_transform(train=False) dataset = PennFudanDataset(root=data_root, transforms=transforms) return dataset