Implement full training loop
This commit is contained in:
144
train.py
144
train.py
@@ -4,6 +4,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
|
import time # Import time for timing
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -172,58 +173,107 @@ def main(args):
|
|||||||
logging.error(f"Error creating optimizer: {e}", exc_info=True)
|
logging.error(f"Error creating optimizer: {e}", exc_info=True)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# --- LR Scheduler (Placeholder for Prompt 10) ---
|
# --- LR Scheduler (Prompt 10) ---
|
||||||
# lr_scheduler = torch.optim.lr_scheduler.StepLR(
|
|
||||||
# optimizer,
|
|
||||||
# step_size=config.get('lr_step_size', 3),
|
|
||||||
# gamma=config.get('lr_gamma', 0.1)
|
|
||||||
# )
|
|
||||||
|
|
||||||
# --- Minimal Training Step (Prompt 8 / Updated for Prompt 9) ---
|
|
||||||
logging.info("--- Starting Minimal Training Step --- ")
|
|
||||||
model.train() # Set model to training mode
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Fetch one batch
|
lr_scheduler = torch.optim.lr_scheduler.StepLR(
|
||||||
images, targets = next(iter(data_loader_train))
|
optimizer,
|
||||||
|
step_size=config.get("lr_step_size", 3),
|
||||||
# Move data to the device
|
gamma=config.get("lr_gamma", 0.1),
|
||||||
images = list(image.to(device) for image in images)
|
)
|
||||||
# Targets is a list of dicts. Move each tensor in the dict to the device.
|
logging.info(
|
||||||
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
f"LR scheduler StepLR configured with step_size={config.get('lr_step_size', 3)}, gamma={config.get('lr_gamma', 0.1)}"
|
||||||
|
)
|
||||||
# Perform forward pass (model returns loss dict in train mode)
|
|
||||||
loss_dict = model(images, targets)
|
|
||||||
|
|
||||||
# Calculate total loss
|
|
||||||
losses = sum(loss for loss in loss_dict.values())
|
|
||||||
loss_value = losses.item() # Get scalar value
|
|
||||||
|
|
||||||
# Perform backward pass
|
|
||||||
optimizer.zero_grad() # Clear previous gradients
|
|
||||||
losses.backward() # Compute gradients
|
|
||||||
optimizer.step() # Update weights
|
|
||||||
|
|
||||||
# Convert loss_dict tensors to scalar values for logging
|
|
||||||
loss_dict_log = {k: v.item() for k, v in loss_dict.items()}
|
|
||||||
logging.info(f"Single step loss dict: {loss_dict_log}")
|
|
||||||
logging.info(f"Single step total loss: {loss_value:.4f}")
|
|
||||||
logging.info("--- Minimal Training Step Completed Successfully --- ")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error during minimal training step: {e}", exc_info=True)
|
logging.error(f"Error creating LR scheduler: {e}", exc_info=True)
|
||||||
import traceback
|
|
||||||
|
|
||||||
# traceback.print_exc() # Already logged with exc_info=True
|
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# Temporarily exit after the single step (as per Prompt 8)
|
# --- Training Loop (Prompt 10) ---
|
||||||
logging.info("Exiting after single training step.")
|
logging.info("--- Starting Training Loop --- ")
|
||||||
sys.exit(0)
|
start_time = time.time()
|
||||||
|
num_epochs = config.get("num_epochs", 10)
|
||||||
|
|
||||||
# --- Full Training Loop (Placeholder for Prompt 10) ---
|
for epoch in range(num_epochs):
|
||||||
# print("Basic setup complete. Full training loop implementation pending.")
|
model.train() # Set model to training mode for each epoch
|
||||||
# ... loop implementation ...
|
epoch_start_time = time.time()
|
||||||
|
logging.info(f"--- Epoch {epoch + 1}/{num_epochs} --- ")
|
||||||
|
|
||||||
|
# Variables for tracking epoch progress (optional)
|
||||||
|
epoch_loss_sum = 0.0
|
||||||
|
num_batches = len(data_loader_train)
|
||||||
|
|
||||||
|
for i, (images, targets) in enumerate(data_loader_train):
|
||||||
|
batch_start_time = time.time() # Optional: time each batch
|
||||||
|
try:
|
||||||
|
# Move data to the device
|
||||||
|
images = list(image.to(device) for image in images)
|
||||||
|
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
||||||
|
|
||||||
|
# Perform forward pass
|
||||||
|
loss_dict = model(images, targets)
|
||||||
|
losses = sum(loss for loss in loss_dict.values())
|
||||||
|
loss_value = losses.item()
|
||||||
|
epoch_loss_sum += loss_value
|
||||||
|
|
||||||
|
# Perform backward pass
|
||||||
|
optimizer.zero_grad()
|
||||||
|
losses.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# Log batch loss periodically
|
||||||
|
if (i + 1) % config.get("log_freq", 10) == 0:
|
||||||
|
batch_time = time.time() - batch_start_time
|
||||||
|
# Include individual losses if desired
|
||||||
|
loss_dict_items = {
|
||||||
|
k: f"{v.item():.4f}" for k, v in loss_dict.items()
|
||||||
|
}
|
||||||
|
logging.info(
|
||||||
|
f" Epoch {epoch + 1}, Iter {i + 1}/{num_batches}, Loss: {loss_value:.4f}, Batch Time: {batch_time:.2f}s"
|
||||||
|
)
|
||||||
|
logging.debug(
|
||||||
|
f" Loss Dict: {loss_dict_items}"
|
||||||
|
) # Log individual losses at DEBUG level
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(
|
||||||
|
f"Error during training epoch {epoch+1}, batch {i+1}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
# Decide if you want to stop training or continue to next batch/epoch
|
||||||
|
logging.warning("Skipping rest of epoch due to error.")
|
||||||
|
break # Exit the inner loop for this epoch
|
||||||
|
|
||||||
|
# --- End of Epoch --- #
|
||||||
|
# Step the learning rate scheduler
|
||||||
|
lr_scheduler.step()
|
||||||
|
|
||||||
|
# Log epoch summary
|
||||||
|
epoch_end_time = time.time()
|
||||||
|
epoch_duration = epoch_end_time - epoch_start_time
|
||||||
|
avg_epoch_loss = epoch_loss_sum / num_batches if num_batches > 0 else 0
|
||||||
|
current_lr = optimizer.param_groups[0]["lr"] # Get current learning rate
|
||||||
|
logging.info(f"--- Epoch {epoch + 1} Summary --- ")
|
||||||
|
logging.info(f" Average Loss: {avg_epoch_loss:.4f}")
|
||||||
|
logging.info(f" Learning Rate: {current_lr:.6f}")
|
||||||
|
logging.info(f" Epoch Duration: {epoch_duration:.2f}s")
|
||||||
|
|
||||||
|
# --- Checkpointing (Placeholder for Prompt 11) --- #
|
||||||
|
# Add checkpoint saving logic here, e.g.:
|
||||||
|
# if (epoch + 1) % config.get('checkpoint_freq', 1) == 0 or (epoch + 1) == num_epochs:
|
||||||
|
# # ... save checkpoint ...
|
||||||
|
# logging.info(f"Saved checkpoint for epoch {epoch + 1}")
|
||||||
|
|
||||||
|
# --- Evaluation (Placeholder for Prompt 12) --- #
|
||||||
|
# Add evaluation logic here, e.g.:
|
||||||
|
# if data_loader_val:
|
||||||
|
# evaluate(model, data_loader_val, device)
|
||||||
|
# logging.info(f"Ran evaluation for epoch {epoch + 1}")
|
||||||
|
|
||||||
|
# --- End of Training --- #
|
||||||
|
total_training_time = time.time() - start_time
|
||||||
|
logging.info("--- Training Finished --- ")
|
||||||
|
logging.info(
|
||||||
|
f"Total Training Time: {total_training_time:.2f}s ({total_training_time / 3600:.2f} hours)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user