Pausing for now, train loop should now work and adding some tests
This commit is contained in:
102
tests/test_model.py
Normal file
102
tests/test_model.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
from utils.eval_utils import evaluate
|
||||
|
||||
|
||||
def test_model_creation(small_model):
|
||||
"""Test that the model is created correctly."""
|
||||
assert isinstance(small_model, torchvision.models.detection.MaskRCNN)
|
||||
assert small_model.roi_heads.box_predictor.cls_score.out_features == 2
|
||||
assert small_model.roi_heads.mask_predictor.mask_fcn_logits.out_channels == 2
|
||||
|
||||
|
||||
def test_model_forward_train_mode(small_model, sample_dataset, device):
|
||||
"""Test model forward pass in training mode."""
|
||||
if len(sample_dataset) == 0:
|
||||
return # Skip if no data
|
||||
|
||||
# Set model to training mode
|
||||
small_model.train()
|
||||
|
||||
# Get a batch
|
||||
img, target = sample_dataset[0]
|
||||
img = img.to(device)
|
||||
target = {k: v.to(device) for k, v in target.items()}
|
||||
|
||||
# Forward pass with targets should return loss dict in training mode
|
||||
loss_dict = small_model([img], [target])
|
||||
|
||||
# Verify loss dict structure
|
||||
assert isinstance(loss_dict, dict), "Loss should be a dictionary"
|
||||
assert "loss_classifier" in loss_dict, "Should have classifier loss"
|
||||
assert "loss_box_reg" in loss_dict, "Should have box regression loss"
|
||||
assert "loss_mask" in loss_dict, "Should have mask loss"
|
||||
assert "loss_objectness" in loss_dict, "Should have objectness loss"
|
||||
assert "loss_rpn_box_reg" in loss_dict, "Should have RPN box regression loss"
|
||||
|
||||
# Verify loss values
|
||||
for loss_name, loss_value in loss_dict.items():
|
||||
assert isinstance(loss_value, torch.Tensor), f"{loss_name} should be a tensor"
|
||||
assert loss_value.dim() == 0, f"{loss_name} should be a scalar tensor"
|
||||
assert not torch.isnan(loss_value), f"{loss_name} should not be NaN"
|
||||
assert not torch.isinf(loss_value), f"{loss_name} should not be infinite"
|
||||
|
||||
|
||||
def test_model_forward_eval_mode(small_model, sample_dataset, device):
|
||||
"""Test model forward pass in evaluation mode."""
|
||||
if len(sample_dataset) == 0:
|
||||
return # Skip if no data
|
||||
|
||||
# Set model to evaluation mode
|
||||
small_model.eval()
|
||||
|
||||
# Get a batch
|
||||
img, target = sample_dataset[0]
|
||||
img = img.to(device)
|
||||
|
||||
# Forward pass without targets should return predictions in eval mode
|
||||
with torch.no_grad():
|
||||
predictions = small_model([img])
|
||||
|
||||
# Verify predictions structure
|
||||
assert isinstance(predictions, list), "Predictions should be a list"
|
||||
assert len(predictions) == 1, "Should have predictions for 1 image"
|
||||
|
||||
pred = predictions[0]
|
||||
assert "boxes" in pred, "Predictions should contain 'boxes'"
|
||||
assert "scores" in pred, "Predictions should contain 'scores'"
|
||||
assert "labels" in pred, "Predictions should contain 'labels'"
|
||||
assert "masks" in pred, "Predictions should contain 'masks'"
|
||||
|
||||
|
||||
def test_evaluate_function(small_model, sample_dataset, device):
|
||||
"""Test the evaluate function."""
|
||||
if len(sample_dataset) == 0:
|
||||
return # Skip if no data
|
||||
|
||||
# Create a tiny dataloader for testing
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from utils.data_utils import collate_fn
|
||||
|
||||
# Use only 2 samples for quick testing
|
||||
small_ds = torch.utils.data.Subset(
|
||||
sample_dataset, range(min(2, len(sample_dataset)))
|
||||
)
|
||||
dataloader = DataLoader(
|
||||
small_ds, batch_size=1, shuffle=False, collate_fn=collate_fn
|
||||
)
|
||||
|
||||
# Set model to eval mode
|
||||
small_model.eval()
|
||||
|
||||
# Import evaluate function
|
||||
|
||||
# Run evaluation
|
||||
metrics = evaluate(small_model, dataloader, device)
|
||||
|
||||
# Check results
|
||||
assert isinstance(metrics, dict), "Metrics should be a dictionary"
|
||||
assert "average_loss" in metrics, "Metrics should contain 'average_loss'"
|
||||
assert metrics["average_loss"] >= 0, "Loss should be non-negative"
|
||||
Reference in New Issue
Block a user