Files
torchvision-vibecoding-project/train.py
2025-04-12 10:55:10 +01:00

443 lines
18 KiB
Python

import argparse
import importlib.util
import logging
import os
import random
import sys
import time # Import time for timing
import numpy as np
import torch
import torch.utils.data
# Project specific imports
from models.detection import get_maskrcnn_model
from utils.data_utils import PennFudanDataset, collate_fn, get_transform
from utils.eval_utils import evaluate # Import evaluate function
from utils.log_utils import setup_logging
def main(args):
# --- Configuration Loading ---
try:
config_path = os.path.abspath(args.config)
if not os.path.exists(config_path):
print(f"Error: Config file not found at {config_path}")
sys.exit(1)
# Derive module path from file path relative to workspace root
workspace_root = os.path.abspath(
os.getcwd()
) # Assuming script is run from root
relative_path = os.path.relpath(config_path, workspace_root)
if relative_path.startswith(".."):
print(f"Error: Config file {args.config} is outside the project directory.")
sys.exit(1)
module_path_no_ext, _ = os.path.splitext(relative_path)
module_path_str = module_path_no_ext.replace(os.sep, ".")
print(f"Attempting to import config module: {module_path_str}")
config_module = importlib.import_module(module_path_str)
config = config_module.config
print(
f"Loaded configuration from: {config_path} (via module {module_path_str})"
)
except ImportError as e:
print(f"Error importing config module '{module_path_str}': {e}")
print(
"Ensure the config file path is correct and relative imports within it are valid."
)
import traceback
traceback.print_exc()
sys.exit(1)
except AttributeError as e:
print(
f"Error: Could not find 'config' dictionary in module {module_path_str}. {e}"
)
sys.exit(1)
except Exception as e:
print(f"Error loading configuration file {args.config}: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
# --- Output Directory Setup ---
output_dir = config.get("output_dir", "outputs")
config_name = config.get("config_name", "default_run")
output_path = os.path.join(output_dir, config_name)
checkpoint_path = os.path.join(output_path, "checkpoints")
os.makedirs(output_path, exist_ok=True)
os.makedirs(checkpoint_path, exist_ok=True)
print(f"Output will be saved to: {output_path}")
# --- Logging Setup (Prompt 9) ---
setup_logging(output_path, config_name)
logging.info("--- Training Script Started ---")
logging.info(f"Loaded configuration from: {args.config}")
logging.info(f"Loaded configuration dictionary: {config}")
logging.info(f"Output will be saved to: {output_path}")
# --- Reproducibility ---
seed = config.get("seed", 42)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# Consider adding these for more determinism, but they might impact performance
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
logging.info(f"Set random seed to: {seed}")
# --- Device Setup ---
device_name = config.get("device", "cuda")
if device_name == "cuda" and not torch.cuda.is_available():
logging.warning("CUDA requested but not available, falling back to CPU.")
device_name = "cpu"
device = torch.device(device_name)
logging.info(f"Using device: {device}")
# --- Dataset and DataLoader ---
data_root = config.get("data_root")
if not data_root or not os.path.isdir(data_root):
logging.error(f"Data root directory not found or not specified: {data_root}")
sys.exit(1)
try:
# Create the full training dataset instance first
dataset_full = PennFudanDataset(
root=data_root, transforms=get_transform(train=True)
)
logging.info(f"Full dataset size: {len(dataset_full)}")
# Create validation dataset instance with eval transforms
dataset_val_instance = PennFudanDataset(
root=data_root, transforms=get_transform(train=False)
)
# Split the dataset indices
torch.manual_seed(
config.get("seed", 42)
) # Use the same seed for consistent splits
indices = torch.randperm(len(dataset_full)).tolist()
val_split_ratio = config.get(
"val_split_ratio", 0.1
) # Default to 10% validation
val_split_count = int(val_split_ratio * len(dataset_full))
if val_split_count == 0 and len(dataset_full) > 0:
logging.warning(
f"Validation split resulted in 0 samples (ratio={val_split_ratio}, total={len(dataset_full)}). Using 1 sample for validation."
)
val_split_count = 1
elif val_split_count >= len(dataset_full):
logging.error(
f"Validation split ratio ({val_split_ratio}) too high, results in no training samples."
)
sys.exit(1)
train_indices = indices[:-val_split_count]
val_indices = indices[-val_split_count:]
# Create Subset datasets
dataset_train = torch.utils.data.Subset(dataset_full, train_indices)
dataset_val = torch.utils.data.Subset(dataset_val_instance, val_indices)
logging.info(
f"Using {len(train_indices)} samples for training and {len(val_indices)} for validation."
)
# Create DataLoaders
data_loader_train = torch.utils.data.DataLoader(
dataset_train,
batch_size=config.get("batch_size", 2),
# Shuffle should be true for the training subset loader
shuffle=True,
num_workers=config.get("num_workers", 4),
collate_fn=collate_fn,
pin_memory=config.get("pin_memory", True),
)
data_loader_val = torch.utils.data.DataLoader(
dataset_val,
batch_size=config.get(
"batch_size", 2
), # Often use same or larger batch size for validation
shuffle=False, # No need to shuffle validation data
num_workers=config.get("num_workers", 4),
collate_fn=collate_fn,
pin_memory=config.get("pin_memory", True),
)
logging.info(
f"Training dataloader configured. Est. batches: {len(data_loader_train)}"
)
logging.info(
f"Validation dataloader configured. Est. batches: {len(data_loader_val)}"
)
except Exception as e:
logging.error(f"Error setting up dataset/dataloader: {e}", exc_info=True)
sys.exit(1)
# --- Model Instantiation ---
num_classes = config.get("num_classes")
if num_classes is None:
logging.error("'num_classes' not specified in configuration.")
sys.exit(1)
try:
model = get_maskrcnn_model(
num_classes=num_classes,
pretrained=config.get("pretrained", True),
pretrained_backbone=config.get("pretrained_backbone", True),
)
model.to(device)
logging.info("Model loaded successfully.")
except Exception as e:
logging.error(f"Error loading model: {e}", exc_info=True)
sys.exit(1)
# --- Optimizer ---
# Filter parameters that require gradients
params = [p for p in model.parameters() if p.requires_grad]
try:
optimizer = torch.optim.SGD(
params,
lr=config.get("lr", 0.005),
momentum=config.get("momentum", 0.9),
weight_decay=config.get("weight_decay", 0.0005),
)
logging.info(
f"Optimizer SGD configured with lr={config.get('lr', 0.005)}, momentum={config.get('momentum', 0.9)}, weight_decay={config.get('weight_decay', 0.0005)}"
)
except Exception as e:
logging.error(f"Error creating optimizer: {e}", exc_info=True)
sys.exit(1)
# --- LR Scheduler (Prompt 10) ---
try:
lr_scheduler = torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=config.get("lr_step_size", 3),
gamma=config.get("lr_gamma", 0.1),
)
logging.info(
f"LR scheduler StepLR configured with step_size={config.get('lr_step_size', 3)}, gamma={config.get('lr_gamma', 0.1)}"
)
except Exception as e:
logging.error(f"Error creating LR scheduler: {e}", exc_info=True)
sys.exit(1)
# --- Resume Logic (Prompt 11) ---
start_epoch = 0
latest_checkpoint_path = None
if os.path.isdir(checkpoint_path):
checkpoints = sorted(
[f for f in os.listdir(checkpoint_path) if f.endswith(".pth")]
)
if checkpoints: # Check if list is not empty
latest_checkpoint_file = checkpoints[
-1
] # Get the last one (assuming naming convention like epoch_N.pth)
latest_checkpoint_path = os.path.join(
checkpoint_path, latest_checkpoint_file
)
logging.info(f"Found latest checkpoint: {latest_checkpoint_path}")
else:
logging.info("No checkpoints found in directory. Starting from scratch.")
else:
logging.info("Checkpoint directory not found. Starting from scratch.")
if latest_checkpoint_path:
try:
logging.info(f"Loading checkpoint '{latest_checkpoint_path}'")
# Ensure loading happens on the correct device
checkpoint = torch.load(latest_checkpoint_path, map_location=device)
# Load model state - handle potential 'module.' prefix if saved with DataParallel
model_state_dict = checkpoint["model_state_dict"]
# Simple check and correction for DataParallel prefix
if all(key.startswith("module.") for key in model_state_dict.keys()):
logging.info("Removing 'module.' prefix from checkpoint keys.")
model_state_dict = {
k.replace("module.", ""): v for k, v in model_state_dict.items()
}
model.load_state_dict(model_state_dict)
# Load optimizer state
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
# Load LR scheduler state
lr_scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
# Load starting epoch (epoch saved is the one *completed*, so start from next)
start_epoch = checkpoint["epoch"]
logging.info(f"Resuming training from epoch {start_epoch + 1}")
# Optionally load and verify config consistency
# loaded_config = checkpoint.get('config')
# if loaded_config:
# # Perform checks if necessary
# pass
except Exception as e:
logging.error(
f"Error loading checkpoint: {e}. Starting training from scratch.",
exc_info=True,
)
start_epoch = 0 # Reset start_epoch if loading fails
# --- Training Loop (Prompt 10, modified for Prompt 11) ---
logging.info("--- Starting Training Loop --- ")
start_time = time.time()
num_epochs = config.get("num_epochs", 10)
# Modify loop to start from start_epoch
for epoch in range(start_epoch, num_epochs):
model.train() # Set model to training mode for each epoch
epoch_start_time = time.time()
logging.info(f"--- Epoch {epoch + 1}/{num_epochs} --- ")
# Variables for tracking epoch progress (optional)
epoch_loss_sum = 0.0
num_batches = len(data_loader_train)
for i, (images, targets) in enumerate(data_loader_train):
batch_start_time = time.time() # Optional: time each batch
try:
# Move data to the device
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
# Perform forward pass
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
loss_value = losses.item()
epoch_loss_sum += loss_value
# Perform backward pass
optimizer.zero_grad()
losses.backward()
optimizer.step()
# Log batch loss periodically
if (i + 1) % config.get("log_freq", 10) == 0:
batch_time = time.time() - batch_start_time
# Include individual losses if desired
loss_dict_items = {
k: f"{v.item():.4f}" for k, v in loss_dict.items()
}
logging.info(
f" Epoch {epoch + 1}, Iter {i + 1}/{num_batches}, Loss: {loss_value:.4f}, Batch Time: {batch_time:.2f}s"
)
logging.debug(
f" Loss Dict: {loss_dict_items}"
) # Log individual losses at DEBUG level
except Exception as e:
logging.error(
f"Error during training epoch {epoch+1}, batch {i+1}: {e}",
exc_info=True,
)
# Decide if you want to stop training or continue to next batch/epoch
logging.warning("Skipping rest of epoch due to error.")
break # Exit the inner loop for this epoch
# --- End of Epoch --- #
# Step the learning rate scheduler
lr_scheduler.step()
# Log epoch summary
epoch_end_time = time.time()
epoch_duration = epoch_end_time - epoch_start_time
avg_epoch_loss = epoch_loss_sum / num_batches if num_batches > 0 else 0
current_lr = optimizer.param_groups[0]["lr"] # Get current learning rate
logging.info(f"--- Epoch {epoch + 1} Summary --- ")
logging.info(f" Average Loss: {avg_epoch_loss:.4f}")
logging.info(f" Learning Rate: {current_lr:.6f}")
logging.info(f" Epoch Duration: {epoch_duration:.2f}s")
# --- Checkpointing (Prompt 11) --- #
# Save checkpoint periodically or at the end
save_checkpoint = False
if (epoch + 1) % config.get("checkpoint_freq", 1) == 0:
save_checkpoint = True
logging.info(f"Checkpoint frequency met (epoch {epoch + 1})")
elif (epoch + 1) == num_epochs:
save_checkpoint = True
logging.info(f"Final epoch ({epoch + 1}) reached, saving checkpoint.")
if save_checkpoint:
checkpoint_filename = f"checkpoint_epoch_{epoch + 1}.pth"
save_path = os.path.join(checkpoint_path, checkpoint_filename)
try:
checkpoint_data = {
"epoch": epoch + 1,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": lr_scheduler.state_dict(),
"config": config, # Save config for reference
}
torch.save(checkpoint_data, save_path)
logging.info(f"Checkpoint saved to {save_path}")
except Exception as e:
logging.error(
f"Failed to save checkpoint for epoch {epoch + 1} to {save_path}: {e}",
exc_info=True,
)
# --- Evaluation (Prompt 12) --- #
if data_loader_val:
logging.info(f"Starting evaluation for epoch {epoch + 1}...")
try:
val_metrics = evaluate(model, data_loader_val, device)
logging.info(f"Epoch {epoch + 1} Validation Metrics: {val_metrics}")
# --- Best Model Checkpoint Logic (Optional Add-on) ---
# Add logic here to track the best metric (e.g., val_metrics['average_loss'])
# and save a separate 'best_model.pth' checkpoint if the current epoch is better.
# Example:
# if 'average_loss' in val_metrics:
# current_val_loss = val_metrics['average_loss']
# if best_val_loss is None or current_val_loss < best_val_loss:
# best_val_loss = current_val_loss
# best_model_path = os.path.join(output_path, 'best_model.pth')
# try:
# # Save only the model state_dict for the best model
# torch.save(model.state_dict(), best_model_path)
# logging.info(f"Saved NEW BEST model checkpoint to {best_model_path} (Val Loss: {best_val_loss:.4f})")
# except Exception as e:
# logging.error(f"Failed to save best model checkpoint: {e}", exc_info=True)
except Exception as e:
logging.error(
f"Error during evaluation for epoch {epoch + 1}: {e}", exc_info=True
)
# Decide if this error should stop the entire training process
# --- End of Training --- #
total_training_time = time.time() - start_time
logging.info("--- Training Finished --- ")
logging.info(
f"Total Training Time: {total_training_time:.2f}s ({total_training_time / 3600:.2f} hours)"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Train Mask R-CNN on Penn-Fudan dataset."
)
parser.add_argument(
"--config",
type=str,
required=True,
help="Path to the Python configuration file (e.g., configs/pennfudan_maskrcnn_config.py)",
)
args = parser.parse_args()
main(args)