Files
torchvision-vibecoding-project/utils/eval_utils.py

70 lines
2.4 KiB
Python

import logging
import time
import torch
def evaluate(model, data_loader, device):
"""Performs evaluation on the dataset for one epoch.
Args:
model (torch.nn.Module): The model to evaluate.
data_loader (torch.utils.data.DataLoader): DataLoader for the evaluation data.
device (torch.device): The device to run evaluation on.
Returns:
dict: A dictionary containing evaluation metrics (e.g., average loss).
"""
model.eval() # Set model to evaluation mode
total_loss = 0.0
num_batches = len(data_loader)
eval_start_time = time.time()
status_interval = max(1, num_batches // 10) # Log status roughly 10 times
logging.info("--- Starting Evaluation --- ")
with torch.no_grad(): # Disable gradient calculations
for i, (images, targets) in enumerate(data_loader):
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
# To handle the different behavior of Mask R-CNN in eval mode,
# we explicitly reset the model to training mode to compute losses,
# then switch back to eval mode for the rest of the evaluation
model.train()
loss_dict = model(images, targets)
model.eval()
losses = sum(loss for loss in loss_dict.values())
loss_value = losses.item()
total_loss += loss_value
if (i + 1) % status_interval == 0:
logging.info(f" Evaluated batch {i + 1}/{num_batches}")
avg_loss = total_loss / num_batches if num_batches > 0 else 0
eval_duration = time.time() - eval_start_time
logging.info("--- Evaluation Finished ---")
logging.info(f" Average Evaluation Loss: {avg_loss:.4f}")
logging.info(f" Evaluation Duration: {eval_duration:.2f}s")
# Return metrics (currently just average loss)
metrics = {"average_loss": avg_loss}
return metrics
# Example usage (can be removed or kept for testing):
if __name__ == "__main__":
# This is a dummy test and requires a model, dataloader, device
print(
"This script contains the evaluate function and cannot be run directly for testing without setup."
)
# Example:
# device = torch.device('cpu')
# # Create dummy model and dataloader
# model = ...
# data_loader = ...
# metrics = evaluate(model, data_loader, device)
# print(f"Dummy evaluation metrics: {metrics}")