Implement full training loop
This commit is contained in:
144
train.py
144
train.py
@@ -4,6 +4,7 @@ import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import time # Import time for timing
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -172,58 +173,107 @@ def main(args):
|
||||
logging.error(f"Error creating optimizer: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
# --- LR Scheduler (Placeholder for Prompt 10) ---
|
||||
# lr_scheduler = torch.optim.lr_scheduler.StepLR(
|
||||
# optimizer,
|
||||
# step_size=config.get('lr_step_size', 3),
|
||||
# gamma=config.get('lr_gamma', 0.1)
|
||||
# )
|
||||
|
||||
# --- Minimal Training Step (Prompt 8 / Updated for Prompt 9) ---
|
||||
logging.info("--- Starting Minimal Training Step --- ")
|
||||
model.train() # Set model to training mode
|
||||
|
||||
# --- LR Scheduler (Prompt 10) ---
|
||||
try:
|
||||
# Fetch one batch
|
||||
images, targets = next(iter(data_loader_train))
|
||||
|
||||
# Move data to the device
|
||||
images = list(image.to(device) for image in images)
|
||||
# Targets is a list of dicts. Move each tensor in the dict to the device.
|
||||
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
||||
|
||||
# Perform forward pass (model returns loss dict in train mode)
|
||||
loss_dict = model(images, targets)
|
||||
|
||||
# Calculate total loss
|
||||
losses = sum(loss for loss in loss_dict.values())
|
||||
loss_value = losses.item() # Get scalar value
|
||||
|
||||
# Perform backward pass
|
||||
optimizer.zero_grad() # Clear previous gradients
|
||||
losses.backward() # Compute gradients
|
||||
optimizer.step() # Update weights
|
||||
|
||||
# Convert loss_dict tensors to scalar values for logging
|
||||
loss_dict_log = {k: v.item() for k, v in loss_dict.items()}
|
||||
logging.info(f"Single step loss dict: {loss_dict_log}")
|
||||
logging.info(f"Single step total loss: {loss_value:.4f}")
|
||||
logging.info("--- Minimal Training Step Completed Successfully --- ")
|
||||
|
||||
lr_scheduler = torch.optim.lr_scheduler.StepLR(
|
||||
optimizer,
|
||||
step_size=config.get("lr_step_size", 3),
|
||||
gamma=config.get("lr_gamma", 0.1),
|
||||
)
|
||||
logging.info(
|
||||
f"LR scheduler StepLR configured with step_size={config.get('lr_step_size', 3)}, gamma={config.get('lr_gamma', 0.1)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error during minimal training step: {e}", exc_info=True)
|
||||
import traceback
|
||||
|
||||
# traceback.print_exc() # Already logged with exc_info=True
|
||||
logging.error(f"Error creating LR scheduler: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
# Temporarily exit after the single step (as per Prompt 8)
|
||||
logging.info("Exiting after single training step.")
|
||||
sys.exit(0)
|
||||
# --- Training Loop (Prompt 10) ---
|
||||
logging.info("--- Starting Training Loop --- ")
|
||||
start_time = time.time()
|
||||
num_epochs = config.get("num_epochs", 10)
|
||||
|
||||
# --- Full Training Loop (Placeholder for Prompt 10) ---
|
||||
# print("Basic setup complete. Full training loop implementation pending.")
|
||||
# ... loop implementation ...
|
||||
for epoch in range(num_epochs):
|
||||
model.train() # Set model to training mode for each epoch
|
||||
epoch_start_time = time.time()
|
||||
logging.info(f"--- Epoch {epoch + 1}/{num_epochs} --- ")
|
||||
|
||||
# Variables for tracking epoch progress (optional)
|
||||
epoch_loss_sum = 0.0
|
||||
num_batches = len(data_loader_train)
|
||||
|
||||
for i, (images, targets) in enumerate(data_loader_train):
|
||||
batch_start_time = time.time() # Optional: time each batch
|
||||
try:
|
||||
# Move data to the device
|
||||
images = list(image.to(device) for image in images)
|
||||
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
||||
|
||||
# Perform forward pass
|
||||
loss_dict = model(images, targets)
|
||||
losses = sum(loss for loss in loss_dict.values())
|
||||
loss_value = losses.item()
|
||||
epoch_loss_sum += loss_value
|
||||
|
||||
# Perform backward pass
|
||||
optimizer.zero_grad()
|
||||
losses.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Log batch loss periodically
|
||||
if (i + 1) % config.get("log_freq", 10) == 0:
|
||||
batch_time = time.time() - batch_start_time
|
||||
# Include individual losses if desired
|
||||
loss_dict_items = {
|
||||
k: f"{v.item():.4f}" for k, v in loss_dict.items()
|
||||
}
|
||||
logging.info(
|
||||
f" Epoch {epoch + 1}, Iter {i + 1}/{num_batches}, Loss: {loss_value:.4f}, Batch Time: {batch_time:.2f}s"
|
||||
)
|
||||
logging.debug(
|
||||
f" Loss Dict: {loss_dict_items}"
|
||||
) # Log individual losses at DEBUG level
|
||||
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error during training epoch {epoch+1}, batch {i+1}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Decide if you want to stop training or continue to next batch/epoch
|
||||
logging.warning("Skipping rest of epoch due to error.")
|
||||
break # Exit the inner loop for this epoch
|
||||
|
||||
# --- End of Epoch --- #
|
||||
# Step the learning rate scheduler
|
||||
lr_scheduler.step()
|
||||
|
||||
# Log epoch summary
|
||||
epoch_end_time = time.time()
|
||||
epoch_duration = epoch_end_time - epoch_start_time
|
||||
avg_epoch_loss = epoch_loss_sum / num_batches if num_batches > 0 else 0
|
||||
current_lr = optimizer.param_groups[0]["lr"] # Get current learning rate
|
||||
logging.info(f"--- Epoch {epoch + 1} Summary --- ")
|
||||
logging.info(f" Average Loss: {avg_epoch_loss:.4f}")
|
||||
logging.info(f" Learning Rate: {current_lr:.6f}")
|
||||
logging.info(f" Epoch Duration: {epoch_duration:.2f}s")
|
||||
|
||||
# --- Checkpointing (Placeholder for Prompt 11) --- #
|
||||
# Add checkpoint saving logic here, e.g.:
|
||||
# if (epoch + 1) % config.get('checkpoint_freq', 1) == 0 or (epoch + 1) == num_epochs:
|
||||
# # ... save checkpoint ...
|
||||
# logging.info(f"Saved checkpoint for epoch {epoch + 1}")
|
||||
|
||||
# --- Evaluation (Placeholder for Prompt 12) --- #
|
||||
# Add evaluation logic here, e.g.:
|
||||
# if data_loader_val:
|
||||
# evaluate(model, data_loader_val, device)
|
||||
# logging.info(f"Ran evaluation for epoch {epoch + 1}")
|
||||
|
||||
# --- End of Training --- #
|
||||
total_training_time = time.time() - start_time
|
||||
logging.info("--- Training Finished --- ")
|
||||
logging.info(
|
||||
f"Total Training Time: {total_training_time:.2f}s ({total_training_time / 3600:.2f} hours)"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user