Files
torchvision-vibecoding-project/tests/test_model.py

103 lines
3.6 KiB
Python

import torch
import torchvision
from utils.eval_utils import evaluate
def test_model_creation(small_model):
"""Test that the model is created correctly."""
assert isinstance(small_model, torchvision.models.detection.MaskRCNN)
assert small_model.roi_heads.box_predictor.cls_score.out_features == 2
assert small_model.roi_heads.mask_predictor.mask_fcn_logits.out_channels == 2
def test_model_forward_train_mode(small_model, sample_dataset, device):
"""Test model forward pass in training mode."""
if len(sample_dataset) == 0:
return # Skip if no data
# Set model to training mode
small_model.train()
# Get a batch
img, target = sample_dataset[0]
img = img.to(device)
target = {k: v.to(device) for k, v in target.items()}
# Forward pass with targets should return loss dict in training mode
loss_dict = small_model([img], [target])
# Verify loss dict structure
assert isinstance(loss_dict, dict), "Loss should be a dictionary"
assert "loss_classifier" in loss_dict, "Should have classifier loss"
assert "loss_box_reg" in loss_dict, "Should have box regression loss"
assert "loss_mask" in loss_dict, "Should have mask loss"
assert "loss_objectness" in loss_dict, "Should have objectness loss"
assert "loss_rpn_box_reg" in loss_dict, "Should have RPN box regression loss"
# Verify loss values
for loss_name, loss_value in loss_dict.items():
assert isinstance(loss_value, torch.Tensor), f"{loss_name} should be a tensor"
assert loss_value.dim() == 0, f"{loss_name} should be a scalar tensor"
assert not torch.isnan(loss_value), f"{loss_name} should not be NaN"
assert not torch.isinf(loss_value), f"{loss_name} should not be infinite"
def test_model_forward_eval_mode(small_model, sample_dataset, device):
"""Test model forward pass in evaluation mode."""
if len(sample_dataset) == 0:
return # Skip if no data
# Set model to evaluation mode
small_model.eval()
# Get a batch
img, target = sample_dataset[0]
img = img.to(device)
# Forward pass without targets should return predictions in eval mode
with torch.no_grad():
predictions = small_model([img])
# Verify predictions structure
assert isinstance(predictions, list), "Predictions should be a list"
assert len(predictions) == 1, "Should have predictions for 1 image"
pred = predictions[0]
assert "boxes" in pred, "Predictions should contain 'boxes'"
assert "scores" in pred, "Predictions should contain 'scores'"
assert "labels" in pred, "Predictions should contain 'labels'"
assert "masks" in pred, "Predictions should contain 'masks'"
def test_evaluate_function(small_model, sample_dataset, device):
"""Test the evaluate function."""
if len(sample_dataset) == 0:
return # Skip if no data
# Create a tiny dataloader for testing
from torch.utils.data import DataLoader
from utils.data_utils import collate_fn
# Use only 2 samples for quick testing
small_ds = torch.utils.data.Subset(
sample_dataset, range(min(2, len(sample_dataset)))
)
dataloader = DataLoader(
small_ds, batch_size=1, shuffle=False, collate_fn=collate_fn
)
# Set model to eval mode
small_model.eval()
# Import evaluate function
# Run evaluation
metrics = evaluate(small_model, dataloader, device)
# Check results
assert isinstance(metrics, dict), "Metrics should be a dictionary"
assert "average_loss" in metrics, "Metrics should contain 'average_loss'"
assert metrics["average_loss"] >= 0, "Loss should be non-negative"