66 lines
2.3 KiB
Python
66 lines
2.3 KiB
Python
import logging
|
|
import time
|
|
|
|
import torch
|
|
|
|
|
|
def evaluate(model, data_loader, device):
|
|
"""Performs evaluation on the dataset for one epoch.
|
|
|
|
Args:
|
|
model (torch.nn.Module): The model to evaluate.
|
|
data_loader (torch.utils.data.DataLoader): DataLoader for the evaluation data.
|
|
device (torch.device): The device to run evaluation on.
|
|
|
|
Returns:
|
|
dict: A dictionary containing evaluation metrics (e.g., average loss).
|
|
"""
|
|
model.eval() # Set model to evaluation mode
|
|
total_loss = 0.0
|
|
num_batches = len(data_loader)
|
|
eval_start_time = time.time()
|
|
status_interval = max(1, num_batches // 10) # Log status roughly 10 times
|
|
|
|
logging.info("--- Starting Evaluation --- ")
|
|
|
|
with torch.no_grad(): # Disable gradient calculations
|
|
for i, (images, targets) in enumerate(data_loader):
|
|
images = list(image.to(device) for image in images)
|
|
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
|
|
|
# In eval mode with targets, Mask R-CNN should still return losses
|
|
# If it returned predictions, logic here would change to process predictions
|
|
loss_dict = model(images, targets)
|
|
losses = sum(loss for loss in loss_dict.values())
|
|
loss_value = losses.item()
|
|
total_loss += loss_value
|
|
|
|
if (i + 1) % status_interval == 0:
|
|
logging.info(f" Evaluated batch {i + 1}/{num_batches}")
|
|
|
|
avg_loss = total_loss / num_batches if num_batches > 0 else 0
|
|
eval_duration = time.time() - eval_start_time
|
|
|
|
logging.info("--- Evaluation Finished ---")
|
|
logging.info(f" Average Evaluation Loss: {avg_loss:.4f}")
|
|
logging.info(f" Evaluation Duration: {eval_duration:.2f}s")
|
|
|
|
# Return metrics (currently just average loss)
|
|
metrics = {"average_loss": avg_loss}
|
|
return metrics
|
|
|
|
|
|
# Example usage (can be removed or kept for testing):
|
|
if __name__ == "__main__":
|
|
# This is a dummy test and requires a model, dataloader, device
|
|
print(
|
|
"This script contains the evaluate function and cannot be run directly for testing without setup."
|
|
)
|
|
# Example:
|
|
# device = torch.device('cpu')
|
|
# # Create dummy model and dataloader
|
|
# model = ...
|
|
# data_loader = ...
|
|
# metrics = evaluate(model, data_loader, device)
|
|
# print(f"Dummy evaluation metrics: {metrics}")
|