Pausing for now, train loop should now work and adding some tests
This commit is contained in:
@@ -0,0 +1,56 @@
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Add project root to the path to enable imports
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from models.detection import get_maskrcnn_model # noqa: E402
|
||||
from utils.data_utils import PennFudanDataset, get_transform # noqa: E402
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def device():
|
||||
"""Return CPU device for consistent testing."""
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_config():
|
||||
"""Return a minimal config dictionary for testing."""
|
||||
return {
|
||||
"data_root": "data/PennFudanPed",
|
||||
"num_classes": 2,
|
||||
"batch_size": 1,
|
||||
"device": "cpu",
|
||||
"output_dir": "test_outputs",
|
||||
"config_name": "test_run",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def small_model(device):
|
||||
"""Return a small Mask R-CNN model for testing."""
|
||||
model = get_maskrcnn_model(
|
||||
num_classes=2, pretrained=False, pretrained_backbone=False
|
||||
)
|
||||
model.to(device)
|
||||
return model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_dataset():
|
||||
"""Return a small dataset for testing if available."""
|
||||
data_root = "data/PennFudanPed"
|
||||
|
||||
# Skip if data is not available
|
||||
if not os.path.exists(data_root):
|
||||
pytest.skip("Test dataset not available")
|
||||
|
||||
transforms = get_transform(train=False)
|
||||
dataset = PennFudanDataset(root=data_root, transforms=transforms)
|
||||
return dataset
|
||||
|
||||
Reference in New Issue
Block a user