57 lines
1.4 KiB
Python
57 lines
1.4 KiB
Python
import os
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
# Add project root to the path to enable imports
|
|
project_root = Path(__file__).parent.parent
|
|
sys.path.insert(0, str(project_root))
|
|
|
|
from models.detection import get_maskrcnn_model # noqa: E402
|
|
from utils.data_utils import PennFudanDataset, get_transform # noqa: E402
|
|
|
|
|
|
@pytest.fixture
|
|
def device():
|
|
"""Return CPU device for consistent testing."""
|
|
return torch.device("cpu")
|
|
|
|
|
|
@pytest.fixture
|
|
def test_config():
|
|
"""Return a minimal config dictionary for testing."""
|
|
return {
|
|
"data_root": "data/PennFudanPed",
|
|
"num_classes": 2,
|
|
"batch_size": 1,
|
|
"device": "cpu",
|
|
"output_dir": "test_outputs",
|
|
"config_name": "test_run",
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def small_model(device):
|
|
"""Return a small Mask R-CNN model for testing."""
|
|
model = get_maskrcnn_model(
|
|
num_classes=2, pretrained=False, pretrained_backbone=False
|
|
)
|
|
model.to(device)
|
|
return model
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_dataset():
|
|
"""Return a small dataset for testing if available."""
|
|
data_root = "data/PennFudanPed"
|
|
|
|
# Skip if data is not available
|
|
if not os.path.exists(data_root):
|
|
pytest.skip("Test dataset not available")
|
|
|
|
transforms = get_transform(train=False)
|
|
dataset = PennFudanDataset(root=data_root, transforms=transforms)
|
|
return dataset
|