import logging import time import torch def evaluate(model, data_loader, device): """Performs evaluation on the dataset for one epoch. Args: model (torch.nn.Module): The model to evaluate. data_loader (torch.utils.data.DataLoader): DataLoader for the evaluation data. device (torch.device): The device to run evaluation on. Returns: dict: A dictionary containing evaluation metrics (e.g., average loss). """ model.eval() # Set model to evaluation mode total_loss = 0.0 num_batches = len(data_loader) eval_start_time = time.time() status_interval = max(1, num_batches // 10) # Log status roughly 10 times logging.info("--- Starting Evaluation --- ") with torch.no_grad(): # Disable gradient calculations for i, (images, targets) in enumerate(data_loader): images = list(image.to(device) for image in images) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] # In eval mode with targets, Mask R-CNN should still return losses # If it returned predictions, logic here would change to process predictions loss_dict = model(images, targets) losses = sum(loss for loss in loss_dict.values()) loss_value = losses.item() total_loss += loss_value if (i + 1) % status_interval == 0: logging.info(f" Evaluated batch {i + 1}/{num_batches}") avg_loss = total_loss / num_batches if num_batches > 0 else 0 eval_duration = time.time() - eval_start_time logging.info("--- Evaluation Finished ---") logging.info(f" Average Evaluation Loss: {avg_loss:.4f}") logging.info(f" Evaluation Duration: {eval_duration:.2f}s") # Return metrics (currently just average loss) metrics = {"average_loss": avg_loss} return metrics # Example usage (can be removed or kept for testing): if __name__ == "__main__": # This is a dummy test and requires a model, dataloader, device print( "This script contains the evaluate function and cannot be run directly for testing without setup." ) # Example: # device = torch.device('cpu') # # Create dummy model and dataloader # model = ... # data_loader = ... # metrics = evaluate(model, data_loader, device) # print(f"Dummy evaluation metrics: {metrics}")