From e9b97ac2b5daec864abe5d16d0bebb0c7c6f6de5 Mon Sep 17 00:00:00 2001 From: Craig Date: Sat, 12 Apr 2025 10:52:01 +0100 Subject: [PATCH] Implement full training loop --- train.py | 144 +++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 97 insertions(+), 47 deletions(-) diff --git a/train.py b/train.py index 690d29d..e433b5c 100644 --- a/train.py +++ b/train.py @@ -4,6 +4,7 @@ import logging import os import random import sys +import time # Import time for timing import numpy as np import torch @@ -172,58 +173,107 @@ def main(args): logging.error(f"Error creating optimizer: {e}", exc_info=True) sys.exit(1) - # --- LR Scheduler (Placeholder for Prompt 10) --- - # lr_scheduler = torch.optim.lr_scheduler.StepLR( - # optimizer, - # step_size=config.get('lr_step_size', 3), - # gamma=config.get('lr_gamma', 0.1) - # ) - - # --- Minimal Training Step (Prompt 8 / Updated for Prompt 9) --- - logging.info("--- Starting Minimal Training Step --- ") - model.train() # Set model to training mode - + # --- LR Scheduler (Prompt 10) --- try: - # Fetch one batch - images, targets = next(iter(data_loader_train)) - - # Move data to the device - images = list(image.to(device) for image in images) - # Targets is a list of dicts. Move each tensor in the dict to the device. - targets = [{k: v.to(device) for k, v in t.items()} for t in targets] - - # Perform forward pass (model returns loss dict in train mode) - loss_dict = model(images, targets) - - # Calculate total loss - losses = sum(loss for loss in loss_dict.values()) - loss_value = losses.item() # Get scalar value - - # Perform backward pass - optimizer.zero_grad() # Clear previous gradients - losses.backward() # Compute gradients - optimizer.step() # Update weights - - # Convert loss_dict tensors to scalar values for logging - loss_dict_log = {k: v.item() for k, v in loss_dict.items()} - logging.info(f"Single step loss dict: {loss_dict_log}") - logging.info(f"Single step total loss: {loss_value:.4f}") - logging.info("--- Minimal Training Step Completed Successfully --- ") - + lr_scheduler = torch.optim.lr_scheduler.StepLR( + optimizer, + step_size=config.get("lr_step_size", 3), + gamma=config.get("lr_gamma", 0.1), + ) + logging.info( + f"LR scheduler StepLR configured with step_size={config.get('lr_step_size', 3)}, gamma={config.get('lr_gamma', 0.1)}" + ) except Exception as e: - logging.error(f"Error during minimal training step: {e}", exc_info=True) - import traceback - - # traceback.print_exc() # Already logged with exc_info=True + logging.error(f"Error creating LR scheduler: {e}", exc_info=True) sys.exit(1) - # Temporarily exit after the single step (as per Prompt 8) - logging.info("Exiting after single training step.") - sys.exit(0) + # --- Training Loop (Prompt 10) --- + logging.info("--- Starting Training Loop --- ") + start_time = time.time() + num_epochs = config.get("num_epochs", 10) - # --- Full Training Loop (Placeholder for Prompt 10) --- - # print("Basic setup complete. Full training loop implementation pending.") - # ... loop implementation ... + for epoch in range(num_epochs): + model.train() # Set model to training mode for each epoch + epoch_start_time = time.time() + logging.info(f"--- Epoch {epoch + 1}/{num_epochs} --- ") + + # Variables for tracking epoch progress (optional) + epoch_loss_sum = 0.0 + num_batches = len(data_loader_train) + + for i, (images, targets) in enumerate(data_loader_train): + batch_start_time = time.time() # Optional: time each batch + try: + # Move data to the device + images = list(image.to(device) for image in images) + targets = [{k: v.to(device) for k, v in t.items()} for t in targets] + + # Perform forward pass + loss_dict = model(images, targets) + losses = sum(loss for loss in loss_dict.values()) + loss_value = losses.item() + epoch_loss_sum += loss_value + + # Perform backward pass + optimizer.zero_grad() + losses.backward() + optimizer.step() + + # Log batch loss periodically + if (i + 1) % config.get("log_freq", 10) == 0: + batch_time = time.time() - batch_start_time + # Include individual losses if desired + loss_dict_items = { + k: f"{v.item():.4f}" for k, v in loss_dict.items() + } + logging.info( + f" Epoch {epoch + 1}, Iter {i + 1}/{num_batches}, Loss: {loss_value:.4f}, Batch Time: {batch_time:.2f}s" + ) + logging.debug( + f" Loss Dict: {loss_dict_items}" + ) # Log individual losses at DEBUG level + + except Exception as e: + logging.error( + f"Error during training epoch {epoch+1}, batch {i+1}: {e}", + exc_info=True, + ) + # Decide if you want to stop training or continue to next batch/epoch + logging.warning("Skipping rest of epoch due to error.") + break # Exit the inner loop for this epoch + + # --- End of Epoch --- # + # Step the learning rate scheduler + lr_scheduler.step() + + # Log epoch summary + epoch_end_time = time.time() + epoch_duration = epoch_end_time - epoch_start_time + avg_epoch_loss = epoch_loss_sum / num_batches if num_batches > 0 else 0 + current_lr = optimizer.param_groups[0]["lr"] # Get current learning rate + logging.info(f"--- Epoch {epoch + 1} Summary --- ") + logging.info(f" Average Loss: {avg_epoch_loss:.4f}") + logging.info(f" Learning Rate: {current_lr:.6f}") + logging.info(f" Epoch Duration: {epoch_duration:.2f}s") + + # --- Checkpointing (Placeholder for Prompt 11) --- # + # Add checkpoint saving logic here, e.g.: + # if (epoch + 1) % config.get('checkpoint_freq', 1) == 0 or (epoch + 1) == num_epochs: + # # ... save checkpoint ... + # logging.info(f"Saved checkpoint for epoch {epoch + 1}") + + # --- Evaluation (Placeholder for Prompt 12) --- # + # Add evaluation logic here, e.g.: + # if data_loader_val: + # evaluate(model, data_loader_val, device) + # logging.info(f"Ran evaluation for epoch {epoch + 1}") + + # --- End of Training --- # + total_training_time = time.time() - start_time + logging.info("--- Training Finished --- ") + logging.info( + f"Total Training Time: {total_training_time:.2f}s ({total_training_time / 3600:.2f} hours)" + ) if __name__ == "__main__":