import argparse import logging import os import sys import time import torch import torch.utils.data # Project specific imports from models.detection import get_maskrcnn_model from utils.common import ( check_data_path, load_checkpoint, load_config, setup_environment, ) from utils.data_utils import PennFudanDataset, collate_fn, get_transform from utils.eval_utils import evaluate from utils.log_utils import setup_logging def main(args): # Load configuration config = load_config(args.config) # Setup output directory and get device output_path, device = setup_environment(config) checkpoint_path = os.path.join(output_path, "checkpoints") os.makedirs(checkpoint_path, exist_ok=True) # Setup logging setup_logging(output_path, config.get("config_name", "default_run")) logging.info("--- Training Script Started ---") logging.info(f"Loaded configuration from: {args.config}") logging.info(f"Loaded configuration dictionary: {config}") logging.info(f"Output will be saved to: {output_path}") # Validate data path data_root = config.get("data_root") check_data_path(data_root) try: # Create the full training dataset instance first dataset_full = PennFudanDataset( root=data_root, transforms=get_transform(train=True) ) logging.info(f"Full dataset size: {len(dataset_full)}") # Create validation dataset instance with eval transforms dataset_val_instance = PennFudanDataset( root=data_root, transforms=get_transform(train=False) ) # Split the dataset indices torch.manual_seed( config.get("seed", 42) ) # Use the same seed for consistent splits indices = torch.randperm(len(dataset_full)).tolist() val_split_ratio = config.get( "val_split_ratio", 0.1 ) # Default to 10% validation val_split_count = int(val_split_ratio * len(dataset_full)) if val_split_count == 0 and len(dataset_full) > 0: logging.warning( f"Validation split resulted in 0 samples (ratio={val_split_ratio}, total={len(dataset_full)}). Using 1 sample for validation." ) val_split_count = 1 elif val_split_count >= len(dataset_full): logging.error( f"Validation split ratio ({val_split_ratio}) too high, results in no training samples." ) sys.exit(1) train_indices = indices[:-val_split_count] val_indices = indices[-val_split_count:] # Create Subset datasets dataset_train = torch.utils.data.Subset(dataset_full, train_indices) dataset_val = torch.utils.data.Subset(dataset_val_instance, val_indices) logging.info( f"Using {len(train_indices)} samples for training and {len(val_indices)} for validation." ) # Create DataLoaders data_loader_train = torch.utils.data.DataLoader( dataset_train, batch_size=config.get("batch_size", 2), # Shuffle should be true for the training subset loader shuffle=True, num_workers=config.get("num_workers", 4), collate_fn=collate_fn, pin_memory=config.get("pin_memory", True), ) data_loader_val = torch.utils.data.DataLoader( dataset_val, batch_size=config.get( "batch_size", 2 ), # Often use same or larger batch size for validation shuffle=False, # No need to shuffle validation data num_workers=config.get("num_workers", 4), collate_fn=collate_fn, pin_memory=config.get("pin_memory", True), ) logging.info( f"Training dataloader configured. Est. batches: {len(data_loader_train)}" ) logging.info( f"Validation dataloader configured. Est. batches: {len(data_loader_val)}" ) except Exception as e: logging.error(f"Error setting up dataset/dataloader: {e}", exc_info=True) sys.exit(1) # Create model num_classes = config.get("num_classes") if num_classes is None: logging.error("'num_classes' not specified in configuration.") sys.exit(1) try: model = get_maskrcnn_model( num_classes=num_classes, pretrained=config.get("pretrained", True), pretrained_backbone=config.get("pretrained_backbone", True), ) model.to(device) logging.info("Model loaded successfully.") except Exception as e: logging.error(f"Error creating model: {e}", exc_info=True) sys.exit(1) # Create optimizer and learning rate scheduler optimizer = torch.optim.SGD( model.parameters(), lr=config.get("lr", 0.005), momentum=config.get("momentum", 0.9), weight_decay=config.get("weight_decay", 0.0005), ) lr_scheduler = torch.optim.lr_scheduler.StepLR( optimizer, step_size=config.get("lr_step_size", 3), gamma=config.get("lr_gamma", 0.1), ) # --- Resume from Checkpoint (if specified) --- start_epoch = 0 if args.resume: try: # Find latest checkpoint checkpoints = [f for f in os.listdir(checkpoint_path) if f.endswith(".pth")] if not checkpoints: logging.warning( f"No checkpoints found in {checkpoint_path}, starting from scratch." ) else: # Extract epoch numbers from filenames and find the latest max_epoch = -1 latest_checkpoint = None for ckpt in checkpoints: if ckpt.startswith("checkpoint_epoch_"): try: epoch_num = int( ckpt.replace("checkpoint_epoch_", "").replace( ".pth", "" ) ) if epoch_num > max_epoch: max_epoch = epoch_num latest_checkpoint = ckpt except ValueError: continue if latest_checkpoint: checkpoint_file = os.path.join(checkpoint_path, latest_checkpoint) logging.info(f"Resuming from checkpoint: {checkpoint_file}") # Load checkpoint checkpoint, start_epoch = load_checkpoint( checkpoint_file, model, device, load_optimizer=True, optimizer=optimizer, load_scheduler=True, scheduler=lr_scheduler, ) logging.info(f"Resuming from epoch {start_epoch}") else: logging.warning( f"No valid checkpoints found in {checkpoint_path}, starting from scratch." ) except Exception as e: logging.error(f"Error loading checkpoint: {e}", exc_info=True) logging.warning("Starting training from scratch.") start_epoch = 0 # --- Training Loop --- train_time_start = time.time() logging.info("--- Starting Training Loop ---") for epoch in range(start_epoch, config.get("num_epochs", 10)): # Set model to training mode model.train() # Initialize epoch metrics epoch_loss = 0.0 epoch_loss_classifier = 0.0 epoch_loss_box_reg = 0.0 epoch_loss_mask = 0.0 epoch_loss_objectness = 0.0 epoch_loss_rpn_box_reg = 0.0 logging.info(f"--- Epoch {epoch + 1}/{config.get('num_epochs', 10)} ---") epoch_start_time = time.time() # Train loop for i, (images, targets) in enumerate(data_loader_train): # Move data to device images = list(image.to(device) for image in images) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] # Forward pass loss_dict = model(images, targets) # Sum loss components losses = sum(loss for loss in loss_dict.values()) # Backward and optimize optimizer.zero_grad() losses.backward() optimizer.step() # Log batch results loss_value = losses.item() epoch_loss += loss_value # Accumulate individual loss components if "loss_classifier" in loss_dict: epoch_loss_classifier += loss_dict["loss_classifier"].item() if "loss_box_reg" in loss_dict: epoch_loss_box_reg += loss_dict["loss_box_reg"].item() if "loss_mask" in loss_dict: epoch_loss_mask += loss_dict["loss_mask"].item() if "loss_objectness" in loss_dict: epoch_loss_objectness += loss_dict["loss_objectness"].item() if "loss_rpn_box_reg" in loss_dict: epoch_loss_rpn_box_reg += loss_dict["loss_rpn_box_reg"].item() # Periodic logging if (i + 1) % config.get("log_freq", 10) == 0: log_str = f"Epoch [{epoch + 1}/{config.get('num_epochs', 10)}], " log_str += f"Iter [{i + 1}/{len(data_loader_train)}], " log_str += f"Loss: {loss_value:.4f}" # Add per-component losses for richer logging comp_log = [] if "loss_classifier" in loss_dict: comp_log.append(f"cls: {loss_dict['loss_classifier'].item():.4f}") if "loss_box_reg" in loss_dict: comp_log.append(f"box: {loss_dict['loss_box_reg'].item():.4f}") if "loss_mask" in loss_dict: comp_log.append(f"mask: {loss_dict['loss_mask'].item():.4f}") if "loss_objectness" in loss_dict: comp_log.append(f"obj: {loss_dict['loss_objectness'].item():.4f}") if "loss_rpn_box_reg" in loss_dict: comp_log.append(f"rpn: {loss_dict['loss_rpn_box_reg'].item():.4f}") if comp_log: log_str += f" [{', '.join(comp_log)}]" logging.info(log_str) # Step learning rate scheduler after each epoch lr_scheduler.step() # Calculate and log epoch metrics if len(data_loader_train) > 0: avg_loss = epoch_loss / len(data_loader_train) avg_loss_classifier = epoch_loss_classifier / len(data_loader_train) avg_loss_box_reg = epoch_loss_box_reg / len(data_loader_train) avg_loss_mask = epoch_loss_mask / len(data_loader_train) avg_loss_objectness = epoch_loss_objectness / len(data_loader_train) avg_loss_rpn_box_reg = epoch_loss_rpn_box_reg / len(data_loader_train) logging.info(f"Epoch {epoch + 1} - Avg Loss: {avg_loss:.4f}") logging.info(f" Classifier Loss: {avg_loss_classifier:.4f}") logging.info(f" Box Reg Loss: {avg_loss_box_reg:.4f}") logging.info(f" Mask Loss: {avg_loss_mask:.4f}") logging.info(f" Objectness Loss: {avg_loss_objectness:.4f}") logging.info(f" RPN Box Reg Loss: {avg_loss_rpn_box_reg:.4f}") else: logging.warning("No training batches were processed in this epoch.") epoch_duration = time.time() - epoch_start_time logging.info(f"Epoch duration: {epoch_duration:.2f}s") # --- Validation --- logging.info("Running validation...") val_metrics = evaluate(model, data_loader_val, device) logging.info(f"Validation Loss: {val_metrics['average_loss']:.4f}") # --- Checkpoint Saving --- if (epoch + 1) % config.get("checkpoint_freq", 1) == 0 or epoch == config.get( "num_epochs", 10 ) - 1: checkpoint_file = os.path.join( checkpoint_path, f"checkpoint_epoch_{epoch+1}.pth" ) checkpoint = { "epoch": epoch + 1, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": lr_scheduler.state_dict(), "config": config, "val_loss": val_metrics["average_loss"], } try: torch.save(checkpoint, checkpoint_file) logging.info(f"Checkpoint saved to {checkpoint_file}") except Exception as e: logging.error(f"Error saving checkpoint: {e}", exc_info=True) # --- Final Metrics and Cleanup --- total_training_time = time.time() - train_time_start hours, remainder = divmod(total_training_time, 3600) minutes, seconds = divmod(remainder, 60) logging.info(f"Training completed in {int(hours)}h {int(minutes)}m {seconds:.2f}s") logging.info(f"Final model saved to {checkpoint_path}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Train a Mask R-CNN model") parser.add_argument("--config", required=True, help="Path to configuration file") parser.add_argument( "--resume", action="store_true", help="Resume training from latest checkpoint" ) args = parser.parse_args() main(args)