Files
torchvision-vibecoding-project/tests/conftest.py

57 lines
1.4 KiB
Python

import os
import sys
from pathlib import Path
import pytest
import torch
# Add project root to the path to enable imports
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from models.detection import get_maskrcnn_model # noqa: E402
from utils.data_utils import PennFudanDataset, get_transform # noqa: E402
@pytest.fixture
def device():
"""Return CPU device for consistent testing."""
return torch.device("cpu")
@pytest.fixture
def test_config():
"""Return a minimal config dictionary for testing."""
return {
"data_root": "data/PennFudanPed",
"num_classes": 2,
"batch_size": 1,
"device": "cpu",
"output_dir": "test_outputs",
"config_name": "test_run",
}
@pytest.fixture
def small_model(device):
"""Return a small Mask R-CNN model for testing."""
model = get_maskrcnn_model(
num_classes=2, pretrained=False, pretrained_backbone=False
)
model.to(device)
return model
@pytest.fixture
def sample_dataset():
"""Return a small dataset for testing if available."""
data_root = "data/PennFudanPed"
# Skip if data is not available
if not os.path.exists(data_root):
pytest.skip("Test dataset not available")
transforms = get_transform(train=False)
dataset = PennFudanDataset(root=data_root, transforms=transforms)
return dataset