Pausing for now, train loop should now work and adding some tests
This commit is contained in:
@@ -0,0 +1,56 @@
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Add project root to the path to enable imports
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from models.detection import get_maskrcnn_model # noqa: E402
|
||||
from utils.data_utils import PennFudanDataset, get_transform # noqa: E402
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def device():
|
||||
"""Return CPU device for consistent testing."""
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_config():
|
||||
"""Return a minimal config dictionary for testing."""
|
||||
return {
|
||||
"data_root": "data/PennFudanPed",
|
||||
"num_classes": 2,
|
||||
"batch_size": 1,
|
||||
"device": "cpu",
|
||||
"output_dir": "test_outputs",
|
||||
"config_name": "test_run",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def small_model(device):
|
||||
"""Return a small Mask R-CNN model for testing."""
|
||||
model = get_maskrcnn_model(
|
||||
num_classes=2, pretrained=False, pretrained_backbone=False
|
||||
)
|
||||
model.to(device)
|
||||
return model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_dataset():
|
||||
"""Return a small dataset for testing if available."""
|
||||
data_root = "data/PennFudanPed"
|
||||
|
||||
# Skip if data is not available
|
||||
if not os.path.exists(data_root):
|
||||
pytest.skip("Test dataset not available")
|
||||
|
||||
transforms = get_transform(train=False)
|
||||
dataset = PennFudanDataset(root=data_root, transforms=transforms)
|
||||
return dataset
|
||||
|
||||
108
tests/test_data_utils.py
Normal file
108
tests/test_data_utils.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import torch
|
||||
|
||||
from utils.data_utils import collate_fn, get_transform
|
||||
|
||||
|
||||
def test_dataset_len(sample_dataset):
|
||||
"""Test that the dataset has the expected length."""
|
||||
# PennFudanPed has 170 images
|
||||
assert len(sample_dataset) > 0, "Dataset should not be empty"
|
||||
|
||||
|
||||
def test_dataset_getitem(sample_dataset):
|
||||
"""Test that __getitem__ returns expected format."""
|
||||
if len(sample_dataset) == 0:
|
||||
return # Skip if no data
|
||||
|
||||
# Get first item
|
||||
img, target = sample_dataset[0]
|
||||
|
||||
# Check image
|
||||
assert isinstance(img, torch.Tensor), "Image should be a tensor"
|
||||
assert img.dim() == 3, "Image should have 3 dimensions (C, H, W)"
|
||||
assert img.shape[0] == 3, "Image should have 3 channels (RGB)"
|
||||
|
||||
# Check target
|
||||
assert isinstance(target, dict), "Target should be a dictionary"
|
||||
assert "boxes" in target, "Target should contain 'boxes'"
|
||||
assert "labels" in target, "Target should contain 'labels'"
|
||||
assert "masks" in target, "Target should contain 'masks'"
|
||||
assert "image_id" in target, "Target should contain 'image_id'"
|
||||
assert "area" in target, "Target should contain 'area'"
|
||||
assert "iscrowd" in target, "Target should contain 'iscrowd'"
|
||||
|
||||
# Check target values
|
||||
assert (
|
||||
target["boxes"].shape[1] == 4
|
||||
), "Boxes should have 4 coordinates (x1, y1, x2, y2)"
|
||||
assert target["labels"].dim() == 1, "Labels should be a 1D tensor"
|
||||
assert target["masks"].dim() == 3, "Masks should be a 3D tensor (N, H, W)"
|
||||
|
||||
|
||||
def test_transforms(sample_dataset):
|
||||
"""Test that transforms are applied correctly."""
|
||||
if len(sample_dataset) == 0:
|
||||
return # Skip if no data
|
||||
|
||||
# Get original transform
|
||||
orig_transforms = sample_dataset.transforms
|
||||
|
||||
# Apply different transforms
|
||||
train_transforms = get_transform(train=True)
|
||||
eval_transforms = get_transform(train=False)
|
||||
|
||||
# Test that we can switch transforms
|
||||
sample_dataset.transforms = train_transforms
|
||||
img_train, target_train = sample_dataset[0]
|
||||
|
||||
sample_dataset.transforms = eval_transforms
|
||||
img_eval, target_eval = sample_dataset[0]
|
||||
|
||||
# Restore original transforms
|
||||
sample_dataset.transforms = orig_transforms
|
||||
|
||||
# Images should be tensors with expected properties
|
||||
assert img_train.dim() == img_eval.dim() == 3
|
||||
assert img_train.shape[0] == img_eval.shape[0] == 3
|
||||
|
||||
|
||||
def test_collate_fn():
|
||||
"""Test the collate function."""
|
||||
# Create dummy batch data
|
||||
dummy_img1 = torch.rand(3, 100, 100)
|
||||
dummy_img2 = torch.rand(3, 100, 100)
|
||||
|
||||
dummy_target1 = {
|
||||
"boxes": torch.tensor([[10, 10, 50, 50]], dtype=torch.float32),
|
||||
"labels": torch.tensor([1], dtype=torch.int64),
|
||||
"masks": torch.zeros(1, 100, 100, dtype=torch.uint8),
|
||||
"image_id": torch.tensor([0]),
|
||||
"area": torch.tensor([1600.0], dtype=torch.float32),
|
||||
"iscrowd": torch.tensor([0], dtype=torch.uint8),
|
||||
}
|
||||
|
||||
dummy_target2 = {
|
||||
"boxes": torch.tensor([[20, 20, 60, 60]], dtype=torch.float32),
|
||||
"labels": torch.tensor([1], dtype=torch.int64),
|
||||
"masks": torch.zeros(1, 100, 100, dtype=torch.uint8),
|
||||
"image_id": torch.tensor([1]),
|
||||
"area": torch.tensor([1600.0], dtype=torch.float32),
|
||||
"iscrowd": torch.tensor([0], dtype=torch.uint8),
|
||||
}
|
||||
|
||||
batch = [(dummy_img1, dummy_target1), (dummy_img2, dummy_target2)]
|
||||
|
||||
# Apply collate_fn
|
||||
images, targets = collate_fn(batch)
|
||||
|
||||
# Check results
|
||||
assert len(images) == 2, "Should have 2 images"
|
||||
assert len(targets) == 2, "Should have 2 targets"
|
||||
assert torch.allclose(images[0], dummy_img1), "First image should match"
|
||||
assert torch.allclose(images[1], dummy_img2), "Second image should match"
|
||||
assert torch.allclose(
|
||||
targets[0]["boxes"], dummy_target1["boxes"]
|
||||
), "First boxes should match"
|
||||
assert torch.allclose(
|
||||
targets[1]["boxes"], dummy_target2["boxes"]
|
||||
), "Second boxes should match"
|
||||
102
tests/test_model.py
Normal file
102
tests/test_model.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
from utils.eval_utils import evaluate
|
||||
|
||||
|
||||
def test_model_creation(small_model):
|
||||
"""Test that the model is created correctly."""
|
||||
assert isinstance(small_model, torchvision.models.detection.MaskRCNN)
|
||||
assert small_model.roi_heads.box_predictor.cls_score.out_features == 2
|
||||
assert small_model.roi_heads.mask_predictor.mask_fcn_logits.out_channels == 2
|
||||
|
||||
|
||||
def test_model_forward_train_mode(small_model, sample_dataset, device):
|
||||
"""Test model forward pass in training mode."""
|
||||
if len(sample_dataset) == 0:
|
||||
return # Skip if no data
|
||||
|
||||
# Set model to training mode
|
||||
small_model.train()
|
||||
|
||||
# Get a batch
|
||||
img, target = sample_dataset[0]
|
||||
img = img.to(device)
|
||||
target = {k: v.to(device) for k, v in target.items()}
|
||||
|
||||
# Forward pass with targets should return loss dict in training mode
|
||||
loss_dict = small_model([img], [target])
|
||||
|
||||
# Verify loss dict structure
|
||||
assert isinstance(loss_dict, dict), "Loss should be a dictionary"
|
||||
assert "loss_classifier" in loss_dict, "Should have classifier loss"
|
||||
assert "loss_box_reg" in loss_dict, "Should have box regression loss"
|
||||
assert "loss_mask" in loss_dict, "Should have mask loss"
|
||||
assert "loss_objectness" in loss_dict, "Should have objectness loss"
|
||||
assert "loss_rpn_box_reg" in loss_dict, "Should have RPN box regression loss"
|
||||
|
||||
# Verify loss values
|
||||
for loss_name, loss_value in loss_dict.items():
|
||||
assert isinstance(loss_value, torch.Tensor), f"{loss_name} should be a tensor"
|
||||
assert loss_value.dim() == 0, f"{loss_name} should be a scalar tensor"
|
||||
assert not torch.isnan(loss_value), f"{loss_name} should not be NaN"
|
||||
assert not torch.isinf(loss_value), f"{loss_name} should not be infinite"
|
||||
|
||||
|
||||
def test_model_forward_eval_mode(small_model, sample_dataset, device):
|
||||
"""Test model forward pass in evaluation mode."""
|
||||
if len(sample_dataset) == 0:
|
||||
return # Skip if no data
|
||||
|
||||
# Set model to evaluation mode
|
||||
small_model.eval()
|
||||
|
||||
# Get a batch
|
||||
img, target = sample_dataset[0]
|
||||
img = img.to(device)
|
||||
|
||||
# Forward pass without targets should return predictions in eval mode
|
||||
with torch.no_grad():
|
||||
predictions = small_model([img])
|
||||
|
||||
# Verify predictions structure
|
||||
assert isinstance(predictions, list), "Predictions should be a list"
|
||||
assert len(predictions) == 1, "Should have predictions for 1 image"
|
||||
|
||||
pred = predictions[0]
|
||||
assert "boxes" in pred, "Predictions should contain 'boxes'"
|
||||
assert "scores" in pred, "Predictions should contain 'scores'"
|
||||
assert "labels" in pred, "Predictions should contain 'labels'"
|
||||
assert "masks" in pred, "Predictions should contain 'masks'"
|
||||
|
||||
|
||||
def test_evaluate_function(small_model, sample_dataset, device):
|
||||
"""Test the evaluate function."""
|
||||
if len(sample_dataset) == 0:
|
||||
return # Skip if no data
|
||||
|
||||
# Create a tiny dataloader for testing
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from utils.data_utils import collate_fn
|
||||
|
||||
# Use only 2 samples for quick testing
|
||||
small_ds = torch.utils.data.Subset(
|
||||
sample_dataset, range(min(2, len(sample_dataset)))
|
||||
)
|
||||
dataloader = DataLoader(
|
||||
small_ds, batch_size=1, shuffle=False, collate_fn=collate_fn
|
||||
)
|
||||
|
||||
# Set model to eval mode
|
||||
small_model.eval()
|
||||
|
||||
# Import evaluate function
|
||||
|
||||
# Run evaluation
|
||||
metrics = evaluate(small_model, dataloader, device)
|
||||
|
||||
# Check results
|
||||
assert isinstance(metrics, dict), "Metrics should be a dictionary"
|
||||
assert "average_loss" in metrics, "Metrics should contain 'average_loss'"
|
||||
assert metrics["average_loss"] >= 0, "Loss should be non-negative"
|
||||
77
tests/test_visualization.py
Normal file
77
tests/test_visualization.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
|
||||
# Import visualization functions
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
from scripts.visualize_predictions import visualize_prediction
|
||||
|
||||
|
||||
def test_visualize_prediction():
|
||||
"""Test that the visualization function works."""
|
||||
# Create a dummy image tensor
|
||||
image = torch.rand(3, 400, 600)
|
||||
|
||||
# Create a dummy prediction dictionary
|
||||
prediction = {
|
||||
"boxes": torch.tensor(
|
||||
[[100, 100, 200, 200], [300, 300, 400, 400]], dtype=torch.float32
|
||||
),
|
||||
"scores": torch.tensor([0.9, 0.7], dtype=torch.float32),
|
||||
"labels": torch.tensor([1, 1], dtype=torch.int64),
|
||||
"masks": torch.zeros(2, 1, 400, 600, dtype=torch.float32),
|
||||
}
|
||||
|
||||
# Set some pixels in the mask to 1
|
||||
prediction["masks"][0, 0, 100:200, 100:200] = 1.0
|
||||
prediction["masks"][1, 0, 300:400, 300:400] = 1.0
|
||||
|
||||
# Call the visualization function
|
||||
fig = visualize_prediction(image, prediction, threshold=0.5)
|
||||
|
||||
# Check that a figure was returned
|
||||
assert isinstance(fig, plt.Figure)
|
||||
|
||||
# Check figure properties
|
||||
assert len(fig.axes) == 1
|
||||
|
||||
# Close the figure to avoid memory leaks
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def test_visualize_prediction_threshold():
|
||||
"""Test that the threshold parameter filters predictions correctly."""
|
||||
# Create a dummy image tensor
|
||||
image = torch.rand(3, 400, 600)
|
||||
|
||||
# Create a dummy prediction dictionary with varying scores
|
||||
prediction = {
|
||||
"boxes": torch.tensor(
|
||||
[[100, 100, 200, 200], [300, 300, 400, 400], [500, 100, 550, 150]],
|
||||
dtype=torch.float32,
|
||||
),
|
||||
"scores": torch.tensor([0.9, 0.7, 0.3], dtype=torch.float32),
|
||||
"labels": torch.tensor([1, 1, 1], dtype=torch.int64),
|
||||
"masks": torch.zeros(3, 1, 400, 600, dtype=torch.float32),
|
||||
}
|
||||
|
||||
# Call the visualization function with different thresholds
|
||||
fig_low = visualize_prediction(image, prediction, threshold=0.2)
|
||||
fig_med = visualize_prediction(image, prediction, threshold=0.5)
|
||||
fig_high = visualize_prediction(image, prediction, threshold=0.8)
|
||||
|
||||
# Low threshold should show all 3 boxes
|
||||
assert "Found 3" in fig_low.axes[0].get_xlabel()
|
||||
|
||||
# Medium threshold should show 2 boxes
|
||||
assert "Found 2" in fig_med.axes[0].get_xlabel()
|
||||
|
||||
# High threshold should show 1 box
|
||||
assert "Found 1" in fig_high.axes[0].get_xlabel()
|
||||
|
||||
# Close figures
|
||||
plt.close(fig_low)
|
||||
plt.close(fig_med)
|
||||
plt.close(fig_high)
|
||||
Reference in New Issue
Block a user