import argparse import importlib.util import logging import os import random import sys import time # Import time for timing import numpy as np import torch import torch.utils.data # Project specific imports from models.detection import get_maskrcnn_model from utils.data_utils import PennFudanDataset, collate_fn, get_transform from utils.eval_utils import evaluate # Import evaluate function from utils.log_utils import setup_logging def main(args): # --- Configuration Loading --- try: config_path = os.path.abspath(args.config) if not os.path.exists(config_path): print(f"Error: Config file not found at {config_path}") sys.exit(1) # Derive module path from file path relative to workspace root workspace_root = os.path.abspath( os.getcwd() ) # Assuming script is run from root relative_path = os.path.relpath(config_path, workspace_root) if relative_path.startswith(".."): print(f"Error: Config file {args.config} is outside the project directory.") sys.exit(1) module_path_no_ext, _ = os.path.splitext(relative_path) module_path_str = module_path_no_ext.replace(os.sep, ".") print(f"Attempting to import config module: {module_path_str}") config_module = importlib.import_module(module_path_str) config = config_module.config print( f"Loaded configuration from: {config_path} (via module {module_path_str})" ) except ImportError as e: print(f"Error importing config module '{module_path_str}': {e}") print( "Ensure the config file path is correct and relative imports within it are valid." ) import traceback traceback.print_exc() sys.exit(1) except AttributeError as e: print( f"Error: Could not find 'config' dictionary in module {module_path_str}. {e}" ) sys.exit(1) except Exception as e: print(f"Error loading configuration file {args.config}: {e}") import traceback traceback.print_exc() sys.exit(1) # --- Output Directory Setup --- output_dir = config.get("output_dir", "outputs") config_name = config.get("config_name", "default_run") output_path = os.path.join(output_dir, config_name) checkpoint_path = os.path.join(output_path, "checkpoints") os.makedirs(output_path, exist_ok=True) os.makedirs(checkpoint_path, exist_ok=True) print(f"Output will be saved to: {output_path}") # --- Logging Setup (Prompt 9) --- setup_logging(output_path, config_name) logging.info("--- Training Script Started ---") logging.info(f"Loaded configuration from: {args.config}") logging.info(f"Loaded configuration dictionary: {config}") logging.info(f"Output will be saved to: {output_path}") # --- Reproducibility --- seed = config.get("seed", 42) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) # Consider adding these for more determinism, but they might impact performance # torch.backends.cudnn.deterministic = True # torch.backends.cudnn.benchmark = False logging.info(f"Set random seed to: {seed}") # --- Device Setup --- device_name = config.get("device", "cuda") if device_name == "cuda" and not torch.cuda.is_available(): logging.warning("CUDA requested but not available, falling back to CPU.") device_name = "cpu" device = torch.device(device_name) logging.info(f"Using device: {device}") # --- Dataset and DataLoader --- data_root = config.get("data_root") if not data_root or not os.path.isdir(data_root): logging.error(f"Data root directory not found or not specified: {data_root}") sys.exit(1) try: # Create the full training dataset instance first dataset_full = PennFudanDataset( root=data_root, transforms=get_transform(train=True) ) logging.info(f"Full dataset size: {len(dataset_full)}") # Create validation dataset instance with eval transforms dataset_val_instance = PennFudanDataset( root=data_root, transforms=get_transform(train=False) ) # Split the dataset indices torch.manual_seed( config.get("seed", 42) ) # Use the same seed for consistent splits indices = torch.randperm(len(dataset_full)).tolist() val_split_ratio = config.get( "val_split_ratio", 0.1 ) # Default to 10% validation val_split_count = int(val_split_ratio * len(dataset_full)) if val_split_count == 0 and len(dataset_full) > 0: logging.warning( f"Validation split resulted in 0 samples (ratio={val_split_ratio}, total={len(dataset_full)}). Using 1 sample for validation." ) val_split_count = 1 elif val_split_count >= len(dataset_full): logging.error( f"Validation split ratio ({val_split_ratio}) too high, results in no training samples." ) sys.exit(1) train_indices = indices[:-val_split_count] val_indices = indices[-val_split_count:] # Create Subset datasets dataset_train = torch.utils.data.Subset(dataset_full, train_indices) dataset_val = torch.utils.data.Subset(dataset_val_instance, val_indices) logging.info( f"Using {len(train_indices)} samples for training and {len(val_indices)} for validation." ) # Create DataLoaders data_loader_train = torch.utils.data.DataLoader( dataset_train, batch_size=config.get("batch_size", 2), # Shuffle should be true for the training subset loader shuffle=True, num_workers=config.get("num_workers", 4), collate_fn=collate_fn, pin_memory=config.get("pin_memory", True), ) data_loader_val = torch.utils.data.DataLoader( dataset_val, batch_size=config.get( "batch_size", 2 ), # Often use same or larger batch size for validation shuffle=False, # No need to shuffle validation data num_workers=config.get("num_workers", 4), collate_fn=collate_fn, pin_memory=config.get("pin_memory", True), ) logging.info( f"Training dataloader configured. Est. batches: {len(data_loader_train)}" ) logging.info( f"Validation dataloader configured. Est. batches: {len(data_loader_val)}" ) except Exception as e: logging.error(f"Error setting up dataset/dataloader: {e}", exc_info=True) sys.exit(1) # --- Model Instantiation --- num_classes = config.get("num_classes") if num_classes is None: logging.error("'num_classes' not specified in configuration.") sys.exit(1) try: model = get_maskrcnn_model( num_classes=num_classes, pretrained=config.get("pretrained", True), pretrained_backbone=config.get("pretrained_backbone", True), ) model.to(device) logging.info("Model loaded successfully.") except Exception as e: logging.error(f"Error loading model: {e}", exc_info=True) sys.exit(1) # --- Optimizer --- # Filter parameters that require gradients params = [p for p in model.parameters() if p.requires_grad] try: optimizer = torch.optim.SGD( params, lr=config.get("lr", 0.005), momentum=config.get("momentum", 0.9), weight_decay=config.get("weight_decay", 0.0005), ) logging.info( f"Optimizer SGD configured with lr={config.get('lr', 0.005)}, momentum={config.get('momentum', 0.9)}, weight_decay={config.get('weight_decay', 0.0005)}" ) except Exception as e: logging.error(f"Error creating optimizer: {e}", exc_info=True) sys.exit(1) # --- LR Scheduler (Prompt 10) --- try: lr_scheduler = torch.optim.lr_scheduler.StepLR( optimizer, step_size=config.get("lr_step_size", 3), gamma=config.get("lr_gamma", 0.1), ) logging.info( f"LR scheduler StepLR configured with step_size={config.get('lr_step_size', 3)}, gamma={config.get('lr_gamma', 0.1)}" ) except Exception as e: logging.error(f"Error creating LR scheduler: {e}", exc_info=True) sys.exit(1) # --- Resume Logic (Prompt 11) --- start_epoch = 0 latest_checkpoint_path = None if os.path.isdir(checkpoint_path): checkpoints = sorted( [f for f in os.listdir(checkpoint_path) if f.endswith(".pth")] ) if checkpoints: # Check if list is not empty latest_checkpoint_file = checkpoints[ -1 ] # Get the last one (assuming naming convention like epoch_N.pth) latest_checkpoint_path = os.path.join( checkpoint_path, latest_checkpoint_file ) logging.info(f"Found latest checkpoint: {latest_checkpoint_path}") else: logging.info("No checkpoints found in directory. Starting from scratch.") else: logging.info("Checkpoint directory not found. Starting from scratch.") if latest_checkpoint_path: try: logging.info(f"Loading checkpoint '{latest_checkpoint_path}'") # Ensure loading happens on the correct device checkpoint = torch.load(latest_checkpoint_path, map_location=device) # Load model state - handle potential 'module.' prefix if saved with DataParallel model_state_dict = checkpoint["model_state_dict"] # Simple check and correction for DataParallel prefix if all(key.startswith("module.") for key in model_state_dict.keys()): logging.info("Removing 'module.' prefix from checkpoint keys.") model_state_dict = { k.replace("module.", ""): v for k, v in model_state_dict.items() } model.load_state_dict(model_state_dict) # Load optimizer state optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) # Load LR scheduler state lr_scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) # Load starting epoch (epoch saved is the one *completed*, so start from next) start_epoch = checkpoint["epoch"] logging.info(f"Resuming training from epoch {start_epoch + 1}") # Optionally load and verify config consistency # loaded_config = checkpoint.get('config') # if loaded_config: # # Perform checks if necessary # pass except Exception as e: logging.error( f"Error loading checkpoint: {e}. Starting training from scratch.", exc_info=True, ) start_epoch = 0 # Reset start_epoch if loading fails # --- Training Loop (Prompt 10, modified for Prompt 11) --- logging.info("--- Starting Training Loop --- ") start_time = time.time() num_epochs = config.get("num_epochs", 10) # Modify loop to start from start_epoch for epoch in range(start_epoch, num_epochs): model.train() # Set model to training mode for each epoch epoch_start_time = time.time() logging.info(f"--- Epoch {epoch + 1}/{num_epochs} --- ") # Variables for tracking epoch progress (optional) epoch_loss_sum = 0.0 num_batches = len(data_loader_train) for i, (images, targets) in enumerate(data_loader_train): batch_start_time = time.time() # Optional: time each batch try: # Move data to the device images = list(image.to(device) for image in images) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] # Perform forward pass loss_dict = model(images, targets) losses = sum(loss for loss in loss_dict.values()) loss_value = losses.item() epoch_loss_sum += loss_value # Perform backward pass optimizer.zero_grad() losses.backward() optimizer.step() # Log batch loss periodically if (i + 1) % config.get("log_freq", 10) == 0: batch_time = time.time() - batch_start_time # Include individual losses if desired loss_dict_items = { k: f"{v.item():.4f}" for k, v in loss_dict.items() } logging.info( f" Epoch {epoch + 1}, Iter {i + 1}/{num_batches}, Loss: {loss_value:.4f}, Batch Time: {batch_time:.2f}s" ) logging.debug( f" Loss Dict: {loss_dict_items}" ) # Log individual losses at DEBUG level except Exception as e: logging.error( f"Error during training epoch {epoch+1}, batch {i+1}: {e}", exc_info=True, ) # Decide if you want to stop training or continue to next batch/epoch logging.warning("Skipping rest of epoch due to error.") break # Exit the inner loop for this epoch # --- End of Epoch --- # # Step the learning rate scheduler lr_scheduler.step() # Log epoch summary epoch_end_time = time.time() epoch_duration = epoch_end_time - epoch_start_time avg_epoch_loss = epoch_loss_sum / num_batches if num_batches > 0 else 0 current_lr = optimizer.param_groups[0]["lr"] # Get current learning rate logging.info(f"--- Epoch {epoch + 1} Summary --- ") logging.info(f" Average Loss: {avg_epoch_loss:.4f}") logging.info(f" Learning Rate: {current_lr:.6f}") logging.info(f" Epoch Duration: {epoch_duration:.2f}s") # --- Checkpointing (Prompt 11) --- # # Save checkpoint periodically or at the end save_checkpoint = False if (epoch + 1) % config.get("checkpoint_freq", 1) == 0: save_checkpoint = True logging.info(f"Checkpoint frequency met (epoch {epoch + 1})") elif (epoch + 1) == num_epochs: save_checkpoint = True logging.info(f"Final epoch ({epoch + 1}) reached, saving checkpoint.") if save_checkpoint: checkpoint_filename = f"checkpoint_epoch_{epoch + 1}.pth" save_path = os.path.join(checkpoint_path, checkpoint_filename) try: checkpoint_data = { "epoch": epoch + 1, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": lr_scheduler.state_dict(), "config": config, # Save config for reference } torch.save(checkpoint_data, save_path) logging.info(f"Checkpoint saved to {save_path}") except Exception as e: logging.error( f"Failed to save checkpoint for epoch {epoch + 1} to {save_path}: {e}", exc_info=True, ) # --- Evaluation (Prompt 12) --- # if data_loader_val: logging.info(f"Starting evaluation for epoch {epoch + 1}...") try: val_metrics = evaluate(model, data_loader_val, device) logging.info(f"Epoch {epoch + 1} Validation Metrics: {val_metrics}") # --- Best Model Checkpoint Logic (Optional Add-on) --- # Add logic here to track the best metric (e.g., val_metrics['average_loss']) # and save a separate 'best_model.pth' checkpoint if the current epoch is better. # Example: # if 'average_loss' in val_metrics: # current_val_loss = val_metrics['average_loss'] # if best_val_loss is None or current_val_loss < best_val_loss: # best_val_loss = current_val_loss # best_model_path = os.path.join(output_path, 'best_model.pth') # try: # # Save only the model state_dict for the best model # torch.save(model.state_dict(), best_model_path) # logging.info(f"Saved NEW BEST model checkpoint to {best_model_path} (Val Loss: {best_val_loss:.4f})") # except Exception as e: # logging.error(f"Failed to save best model checkpoint: {e}", exc_info=True) except Exception as e: logging.error( f"Error during evaluation for epoch {epoch + 1}: {e}", exc_info=True ) # Decide if this error should stop the entire training process # --- End of Training --- # total_training_time = time.time() - start_time logging.info("--- Training Finished --- ") logging.info( f"Total Training Time: {total_training_time:.2f}s ({total_training_time / 3600:.2f} hours)" ) if __name__ == "__main__": parser = argparse.ArgumentParser( description="Train Mask R-CNN on Penn-Fudan dataset." ) parser.add_argument( "--config", type=str, required=True, help="Path to the Python configuration file (e.g., configs/pennfudan_maskrcnn_config.py)", ) args = parser.parse_args() main(args)