Implement full training loop

This commit is contained in:
Craig
2025-04-12 10:52:01 +01:00
parent bd6b5170b7
commit e9b97ac2b5

144
train.py
View File

@@ -4,6 +4,7 @@ import logging
import os
import random
import sys
import time # Import time for timing
import numpy as np
import torch
@@ -172,58 +173,107 @@ def main(args):
logging.error(f"Error creating optimizer: {e}", exc_info=True)
sys.exit(1)
# --- LR Scheduler (Placeholder for Prompt 10) ---
# lr_scheduler = torch.optim.lr_scheduler.StepLR(
# optimizer,
# step_size=config.get('lr_step_size', 3),
# gamma=config.get('lr_gamma', 0.1)
# )
# --- Minimal Training Step (Prompt 8 / Updated for Prompt 9) ---
logging.info("--- Starting Minimal Training Step --- ")
model.train() # Set model to training mode
# --- LR Scheduler (Prompt 10) ---
try:
# Fetch one batch
images, targets = next(iter(data_loader_train))
# Move data to the device
images = list(image.to(device) for image in images)
# Targets is a list of dicts. Move each tensor in the dict to the device.
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
# Perform forward pass (model returns loss dict in train mode)
loss_dict = model(images, targets)
# Calculate total loss
losses = sum(loss for loss in loss_dict.values())
loss_value = losses.item() # Get scalar value
# Perform backward pass
optimizer.zero_grad() # Clear previous gradients
losses.backward() # Compute gradients
optimizer.step() # Update weights
# Convert loss_dict tensors to scalar values for logging
loss_dict_log = {k: v.item() for k, v in loss_dict.items()}
logging.info(f"Single step loss dict: {loss_dict_log}")
logging.info(f"Single step total loss: {loss_value:.4f}")
logging.info("--- Minimal Training Step Completed Successfully --- ")
lr_scheduler = torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=config.get("lr_step_size", 3),
gamma=config.get("lr_gamma", 0.1),
)
logging.info(
f"LR scheduler StepLR configured with step_size={config.get('lr_step_size', 3)}, gamma={config.get('lr_gamma', 0.1)}"
)
except Exception as e:
logging.error(f"Error during minimal training step: {e}", exc_info=True)
import traceback
# traceback.print_exc() # Already logged with exc_info=True
logging.error(f"Error creating LR scheduler: {e}", exc_info=True)
sys.exit(1)
# Temporarily exit after the single step (as per Prompt 8)
logging.info("Exiting after single training step.")
sys.exit(0)
# --- Training Loop (Prompt 10) ---
logging.info("--- Starting Training Loop --- ")
start_time = time.time()
num_epochs = config.get("num_epochs", 10)
# --- Full Training Loop (Placeholder for Prompt 10) ---
# print("Basic setup complete. Full training loop implementation pending.")
# ... loop implementation ...
for epoch in range(num_epochs):
model.train() # Set model to training mode for each epoch
epoch_start_time = time.time()
logging.info(f"--- Epoch {epoch + 1}/{num_epochs} --- ")
# Variables for tracking epoch progress (optional)
epoch_loss_sum = 0.0
num_batches = len(data_loader_train)
for i, (images, targets) in enumerate(data_loader_train):
batch_start_time = time.time() # Optional: time each batch
try:
# Move data to the device
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
# Perform forward pass
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
loss_value = losses.item()
epoch_loss_sum += loss_value
# Perform backward pass
optimizer.zero_grad()
losses.backward()
optimizer.step()
# Log batch loss periodically
if (i + 1) % config.get("log_freq", 10) == 0:
batch_time = time.time() - batch_start_time
# Include individual losses if desired
loss_dict_items = {
k: f"{v.item():.4f}" for k, v in loss_dict.items()
}
logging.info(
f" Epoch {epoch + 1}, Iter {i + 1}/{num_batches}, Loss: {loss_value:.4f}, Batch Time: {batch_time:.2f}s"
)
logging.debug(
f" Loss Dict: {loss_dict_items}"
) # Log individual losses at DEBUG level
except Exception as e:
logging.error(
f"Error during training epoch {epoch+1}, batch {i+1}: {e}",
exc_info=True,
)
# Decide if you want to stop training or continue to next batch/epoch
logging.warning("Skipping rest of epoch due to error.")
break # Exit the inner loop for this epoch
# --- End of Epoch --- #
# Step the learning rate scheduler
lr_scheduler.step()
# Log epoch summary
epoch_end_time = time.time()
epoch_duration = epoch_end_time - epoch_start_time
avg_epoch_loss = epoch_loss_sum / num_batches if num_batches > 0 else 0
current_lr = optimizer.param_groups[0]["lr"] # Get current learning rate
logging.info(f"--- Epoch {epoch + 1} Summary --- ")
logging.info(f" Average Loss: {avg_epoch_loss:.4f}")
logging.info(f" Learning Rate: {current_lr:.6f}")
logging.info(f" Epoch Duration: {epoch_duration:.2f}s")
# --- Checkpointing (Placeholder for Prompt 11) --- #
# Add checkpoint saving logic here, e.g.:
# if (epoch + 1) % config.get('checkpoint_freq', 1) == 0 or (epoch + 1) == num_epochs:
# # ... save checkpoint ...
# logging.info(f"Saved checkpoint for epoch {epoch + 1}")
# --- Evaluation (Placeholder for Prompt 12) --- #
# Add evaluation logic here, e.g.:
# if data_loader_val:
# evaluate(model, data_loader_val, device)
# logging.info(f"Ran evaluation for epoch {epoch + 1}")
# --- End of Training --- #
total_training_time = time.time() - start_time
logging.info("--- Training Finished --- ")
logging.info(
f"Total Training Time: {total_training_time:.2f}s ({total_training_time / 3600:.2f} hours)"
)
if __name__ == "__main__":