Pausing for now, train loop should now work and adding some tests

This commit is contained in:
Craig
2025-04-12 12:01:13 +01:00
parent 2b38c04a57
commit be70c4e160
13 changed files with 967 additions and 58 deletions

View File

@@ -0,0 +1,56 @@
import os
import sys
from pathlib import Path
import pytest
import torch
# Add project root to the path to enable imports
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from models.detection import get_maskrcnn_model # noqa: E402
from utils.data_utils import PennFudanDataset, get_transform # noqa: E402
@pytest.fixture
def device():
"""Return CPU device for consistent testing."""
return torch.device("cpu")
@pytest.fixture
def test_config():
"""Return a minimal config dictionary for testing."""
return {
"data_root": "data/PennFudanPed",
"num_classes": 2,
"batch_size": 1,
"device": "cpu",
"output_dir": "test_outputs",
"config_name": "test_run",
}
@pytest.fixture
def small_model(device):
"""Return a small Mask R-CNN model for testing."""
model = get_maskrcnn_model(
num_classes=2, pretrained=False, pretrained_backbone=False
)
model.to(device)
return model
@pytest.fixture
def sample_dataset():
"""Return a small dataset for testing if available."""
data_root = "data/PennFudanPed"
# Skip if data is not available
if not os.path.exists(data_root):
pytest.skip("Test dataset not available")
transforms = get_transform(train=False)
dataset = PennFudanDataset(root=data_root, transforms=transforms)
return dataset

108
tests/test_data_utils.py Normal file
View File

@@ -0,0 +1,108 @@
import torch
from utils.data_utils import collate_fn, get_transform
def test_dataset_len(sample_dataset):
"""Test that the dataset has the expected length."""
# PennFudanPed has 170 images
assert len(sample_dataset) > 0, "Dataset should not be empty"
def test_dataset_getitem(sample_dataset):
"""Test that __getitem__ returns expected format."""
if len(sample_dataset) == 0:
return # Skip if no data
# Get first item
img, target = sample_dataset[0]
# Check image
assert isinstance(img, torch.Tensor), "Image should be a tensor"
assert img.dim() == 3, "Image should have 3 dimensions (C, H, W)"
assert img.shape[0] == 3, "Image should have 3 channels (RGB)"
# Check target
assert isinstance(target, dict), "Target should be a dictionary"
assert "boxes" in target, "Target should contain 'boxes'"
assert "labels" in target, "Target should contain 'labels'"
assert "masks" in target, "Target should contain 'masks'"
assert "image_id" in target, "Target should contain 'image_id'"
assert "area" in target, "Target should contain 'area'"
assert "iscrowd" in target, "Target should contain 'iscrowd'"
# Check target values
assert (
target["boxes"].shape[1] == 4
), "Boxes should have 4 coordinates (x1, y1, x2, y2)"
assert target["labels"].dim() == 1, "Labels should be a 1D tensor"
assert target["masks"].dim() == 3, "Masks should be a 3D tensor (N, H, W)"
def test_transforms(sample_dataset):
"""Test that transforms are applied correctly."""
if len(sample_dataset) == 0:
return # Skip if no data
# Get original transform
orig_transforms = sample_dataset.transforms
# Apply different transforms
train_transforms = get_transform(train=True)
eval_transforms = get_transform(train=False)
# Test that we can switch transforms
sample_dataset.transforms = train_transforms
img_train, target_train = sample_dataset[0]
sample_dataset.transforms = eval_transforms
img_eval, target_eval = sample_dataset[0]
# Restore original transforms
sample_dataset.transforms = orig_transforms
# Images should be tensors with expected properties
assert img_train.dim() == img_eval.dim() == 3
assert img_train.shape[0] == img_eval.shape[0] == 3
def test_collate_fn():
"""Test the collate function."""
# Create dummy batch data
dummy_img1 = torch.rand(3, 100, 100)
dummy_img2 = torch.rand(3, 100, 100)
dummy_target1 = {
"boxes": torch.tensor([[10, 10, 50, 50]], dtype=torch.float32),
"labels": torch.tensor([1], dtype=torch.int64),
"masks": torch.zeros(1, 100, 100, dtype=torch.uint8),
"image_id": torch.tensor([0]),
"area": torch.tensor([1600.0], dtype=torch.float32),
"iscrowd": torch.tensor([0], dtype=torch.uint8),
}
dummy_target2 = {
"boxes": torch.tensor([[20, 20, 60, 60]], dtype=torch.float32),
"labels": torch.tensor([1], dtype=torch.int64),
"masks": torch.zeros(1, 100, 100, dtype=torch.uint8),
"image_id": torch.tensor([1]),
"area": torch.tensor([1600.0], dtype=torch.float32),
"iscrowd": torch.tensor([0], dtype=torch.uint8),
}
batch = [(dummy_img1, dummy_target1), (dummy_img2, dummy_target2)]
# Apply collate_fn
images, targets = collate_fn(batch)
# Check results
assert len(images) == 2, "Should have 2 images"
assert len(targets) == 2, "Should have 2 targets"
assert torch.allclose(images[0], dummy_img1), "First image should match"
assert torch.allclose(images[1], dummy_img2), "Second image should match"
assert torch.allclose(
targets[0]["boxes"], dummy_target1["boxes"]
), "First boxes should match"
assert torch.allclose(
targets[1]["boxes"], dummy_target2["boxes"]
), "Second boxes should match"

102
tests/test_model.py Normal file
View File

@@ -0,0 +1,102 @@
import torch
import torchvision
from utils.eval_utils import evaluate
def test_model_creation(small_model):
"""Test that the model is created correctly."""
assert isinstance(small_model, torchvision.models.detection.MaskRCNN)
assert small_model.roi_heads.box_predictor.cls_score.out_features == 2
assert small_model.roi_heads.mask_predictor.mask_fcn_logits.out_channels == 2
def test_model_forward_train_mode(small_model, sample_dataset, device):
"""Test model forward pass in training mode."""
if len(sample_dataset) == 0:
return # Skip if no data
# Set model to training mode
small_model.train()
# Get a batch
img, target = sample_dataset[0]
img = img.to(device)
target = {k: v.to(device) for k, v in target.items()}
# Forward pass with targets should return loss dict in training mode
loss_dict = small_model([img], [target])
# Verify loss dict structure
assert isinstance(loss_dict, dict), "Loss should be a dictionary"
assert "loss_classifier" in loss_dict, "Should have classifier loss"
assert "loss_box_reg" in loss_dict, "Should have box regression loss"
assert "loss_mask" in loss_dict, "Should have mask loss"
assert "loss_objectness" in loss_dict, "Should have objectness loss"
assert "loss_rpn_box_reg" in loss_dict, "Should have RPN box regression loss"
# Verify loss values
for loss_name, loss_value in loss_dict.items():
assert isinstance(loss_value, torch.Tensor), f"{loss_name} should be a tensor"
assert loss_value.dim() == 0, f"{loss_name} should be a scalar tensor"
assert not torch.isnan(loss_value), f"{loss_name} should not be NaN"
assert not torch.isinf(loss_value), f"{loss_name} should not be infinite"
def test_model_forward_eval_mode(small_model, sample_dataset, device):
"""Test model forward pass in evaluation mode."""
if len(sample_dataset) == 0:
return # Skip if no data
# Set model to evaluation mode
small_model.eval()
# Get a batch
img, target = sample_dataset[0]
img = img.to(device)
# Forward pass without targets should return predictions in eval mode
with torch.no_grad():
predictions = small_model([img])
# Verify predictions structure
assert isinstance(predictions, list), "Predictions should be a list"
assert len(predictions) == 1, "Should have predictions for 1 image"
pred = predictions[0]
assert "boxes" in pred, "Predictions should contain 'boxes'"
assert "scores" in pred, "Predictions should contain 'scores'"
assert "labels" in pred, "Predictions should contain 'labels'"
assert "masks" in pred, "Predictions should contain 'masks'"
def test_evaluate_function(small_model, sample_dataset, device):
"""Test the evaluate function."""
if len(sample_dataset) == 0:
return # Skip if no data
# Create a tiny dataloader for testing
from torch.utils.data import DataLoader
from utils.data_utils import collate_fn
# Use only 2 samples for quick testing
small_ds = torch.utils.data.Subset(
sample_dataset, range(min(2, len(sample_dataset)))
)
dataloader = DataLoader(
small_ds, batch_size=1, shuffle=False, collate_fn=collate_fn
)
# Set model to eval mode
small_model.eval()
# Import evaluate function
# Run evaluation
metrics = evaluate(small_model, dataloader, device)
# Check results
assert isinstance(metrics, dict), "Metrics should be a dictionary"
assert "average_loss" in metrics, "Metrics should contain 'average_loss'"
assert metrics["average_loss"] >= 0, "Loss should be non-negative"

View File

@@ -0,0 +1,77 @@
import os
import sys
import matplotlib.pyplot as plt
import torch
# Import visualization functions
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from scripts.visualize_predictions import visualize_prediction
def test_visualize_prediction():
"""Test that the visualization function works."""
# Create a dummy image tensor
image = torch.rand(3, 400, 600)
# Create a dummy prediction dictionary
prediction = {
"boxes": torch.tensor(
[[100, 100, 200, 200], [300, 300, 400, 400]], dtype=torch.float32
),
"scores": torch.tensor([0.9, 0.7], dtype=torch.float32),
"labels": torch.tensor([1, 1], dtype=torch.int64),
"masks": torch.zeros(2, 1, 400, 600, dtype=torch.float32),
}
# Set some pixels in the mask to 1
prediction["masks"][0, 0, 100:200, 100:200] = 1.0
prediction["masks"][1, 0, 300:400, 300:400] = 1.0
# Call the visualization function
fig = visualize_prediction(image, prediction, threshold=0.5)
# Check that a figure was returned
assert isinstance(fig, plt.Figure)
# Check figure properties
assert len(fig.axes) == 1
# Close the figure to avoid memory leaks
plt.close(fig)
def test_visualize_prediction_threshold():
"""Test that the threshold parameter filters predictions correctly."""
# Create a dummy image tensor
image = torch.rand(3, 400, 600)
# Create a dummy prediction dictionary with varying scores
prediction = {
"boxes": torch.tensor(
[[100, 100, 200, 200], [300, 300, 400, 400], [500, 100, 550, 150]],
dtype=torch.float32,
),
"scores": torch.tensor([0.9, 0.7, 0.3], dtype=torch.float32),
"labels": torch.tensor([1, 1, 1], dtype=torch.int64),
"masks": torch.zeros(3, 1, 400, 600, dtype=torch.float32),
}
# Call the visualization function with different thresholds
fig_low = visualize_prediction(image, prediction, threshold=0.2)
fig_med = visualize_prediction(image, prediction, threshold=0.5)
fig_high = visualize_prediction(image, prediction, threshold=0.8)
# Low threshold should show all 3 boxes
assert "Found 3" in fig_low.axes[0].get_xlabel()
# Medium threshold should show 2 boxes
assert "Found 2" in fig_med.axes[0].get_xlabel()
# High threshold should show 1 box
assert "Found 1" in fig_high.axes[0].get_xlabel()
# Close figures
plt.close(fig_low)
plt.close(fig_med)
plt.close(fig_high)