103 lines
3.6 KiB
Python
103 lines
3.6 KiB
Python
import torch
|
|
import torchvision
|
|
|
|
from utils.eval_utils import evaluate
|
|
|
|
|
|
def test_model_creation(small_model):
|
|
"""Test that the model is created correctly."""
|
|
assert isinstance(small_model, torchvision.models.detection.MaskRCNN)
|
|
assert small_model.roi_heads.box_predictor.cls_score.out_features == 2
|
|
assert small_model.roi_heads.mask_predictor.mask_fcn_logits.out_channels == 2
|
|
|
|
|
|
def test_model_forward_train_mode(small_model, sample_dataset, device):
|
|
"""Test model forward pass in training mode."""
|
|
if len(sample_dataset) == 0:
|
|
return # Skip if no data
|
|
|
|
# Set model to training mode
|
|
small_model.train()
|
|
|
|
# Get a batch
|
|
img, target = sample_dataset[0]
|
|
img = img.to(device)
|
|
target = {k: v.to(device) for k, v in target.items()}
|
|
|
|
# Forward pass with targets should return loss dict in training mode
|
|
loss_dict = small_model([img], [target])
|
|
|
|
# Verify loss dict structure
|
|
assert isinstance(loss_dict, dict), "Loss should be a dictionary"
|
|
assert "loss_classifier" in loss_dict, "Should have classifier loss"
|
|
assert "loss_box_reg" in loss_dict, "Should have box regression loss"
|
|
assert "loss_mask" in loss_dict, "Should have mask loss"
|
|
assert "loss_objectness" in loss_dict, "Should have objectness loss"
|
|
assert "loss_rpn_box_reg" in loss_dict, "Should have RPN box regression loss"
|
|
|
|
# Verify loss values
|
|
for loss_name, loss_value in loss_dict.items():
|
|
assert isinstance(loss_value, torch.Tensor), f"{loss_name} should be a tensor"
|
|
assert loss_value.dim() == 0, f"{loss_name} should be a scalar tensor"
|
|
assert not torch.isnan(loss_value), f"{loss_name} should not be NaN"
|
|
assert not torch.isinf(loss_value), f"{loss_name} should not be infinite"
|
|
|
|
|
|
def test_model_forward_eval_mode(small_model, sample_dataset, device):
|
|
"""Test model forward pass in evaluation mode."""
|
|
if len(sample_dataset) == 0:
|
|
return # Skip if no data
|
|
|
|
# Set model to evaluation mode
|
|
small_model.eval()
|
|
|
|
# Get a batch
|
|
img, target = sample_dataset[0]
|
|
img = img.to(device)
|
|
|
|
# Forward pass without targets should return predictions in eval mode
|
|
with torch.no_grad():
|
|
predictions = small_model([img])
|
|
|
|
# Verify predictions structure
|
|
assert isinstance(predictions, list), "Predictions should be a list"
|
|
assert len(predictions) == 1, "Should have predictions for 1 image"
|
|
|
|
pred = predictions[0]
|
|
assert "boxes" in pred, "Predictions should contain 'boxes'"
|
|
assert "scores" in pred, "Predictions should contain 'scores'"
|
|
assert "labels" in pred, "Predictions should contain 'labels'"
|
|
assert "masks" in pred, "Predictions should contain 'masks'"
|
|
|
|
|
|
def test_evaluate_function(small_model, sample_dataset, device):
|
|
"""Test the evaluate function."""
|
|
if len(sample_dataset) == 0:
|
|
return # Skip if no data
|
|
|
|
# Create a tiny dataloader for testing
|
|
from torch.utils.data import DataLoader
|
|
|
|
from utils.data_utils import collate_fn
|
|
|
|
# Use only 2 samples for quick testing
|
|
small_ds = torch.utils.data.Subset(
|
|
sample_dataset, range(min(2, len(sample_dataset)))
|
|
)
|
|
dataloader = DataLoader(
|
|
small_ds, batch_size=1, shuffle=False, collate_fn=collate_fn
|
|
)
|
|
|
|
# Set model to eval mode
|
|
small_model.eval()
|
|
|
|
# Import evaluate function
|
|
|
|
# Run evaluation
|
|
metrics = evaluate(small_model, dataloader, device)
|
|
|
|
# Check results
|
|
assert isinstance(metrics, dict), "Metrics should be a dictionary"
|
|
assert "average_loss" in metrics, "Metrics should contain 'average_loss'"
|
|
assert metrics["average_loss"] >= 0, "Loss should be non-negative"
|