Create eval loop and use full train dataset
This commit is contained in:
65
utils/eval_utils.py
Normal file
65
utils/eval_utils.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def evaluate(model, data_loader, device):
|
||||
"""Performs evaluation on the dataset for one epoch.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model to evaluate.
|
||||
data_loader (torch.utils.data.DataLoader): DataLoader for the evaluation data.
|
||||
device (torch.device): The device to run evaluation on.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing evaluation metrics (e.g., average loss).
|
||||
"""
|
||||
model.eval() # Set model to evaluation mode
|
||||
total_loss = 0.0
|
||||
num_batches = len(data_loader)
|
||||
eval_start_time = time.time()
|
||||
status_interval = max(1, num_batches // 10) # Log status roughly 10 times
|
||||
|
||||
logging.info("--- Starting Evaluation --- ")
|
||||
|
||||
with torch.no_grad(): # Disable gradient calculations
|
||||
for i, (images, targets) in enumerate(data_loader):
|
||||
images = list(image.to(device) for image in images)
|
||||
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
||||
|
||||
# In eval mode with targets, Mask R-CNN should still return losses
|
||||
# If it returned predictions, logic here would change to process predictions
|
||||
loss_dict = model(images, targets)
|
||||
losses = sum(loss for loss in loss_dict.values())
|
||||
loss_value = losses.item()
|
||||
total_loss += loss_value
|
||||
|
||||
if (i + 1) % status_interval == 0:
|
||||
logging.info(f" Evaluated batch {i + 1}/{num_batches}")
|
||||
|
||||
avg_loss = total_loss / num_batches if num_batches > 0 else 0
|
||||
eval_duration = time.time() - eval_start_time
|
||||
|
||||
logging.info("--- Evaluation Finished ---")
|
||||
logging.info(f" Average Evaluation Loss: {avg_loss:.4f}")
|
||||
logging.info(f" Evaluation Duration: {eval_duration:.2f}s")
|
||||
|
||||
# Return metrics (currently just average loss)
|
||||
metrics = {"average_loss": avg_loss}
|
||||
return metrics
|
||||
|
||||
|
||||
# Example usage (can be removed or kept for testing):
|
||||
if __name__ == "__main__":
|
||||
# This is a dummy test and requires a model, dataloader, device
|
||||
print(
|
||||
"This script contains the evaluate function and cannot be run directly for testing without setup."
|
||||
)
|
||||
# Example:
|
||||
# device = torch.device('cpu')
|
||||
# # Create dummy model and dataloader
|
||||
# model = ...
|
||||
# data_loader = ...
|
||||
# metrics = evaluate(model, data_loader, device)
|
||||
# print(f"Dummy evaluation metrics: {metrics}")
|
||||
Reference in New Issue
Block a user