import argparse import importlib.util import logging import os import random import sys import numpy as np import torch import torch.utils.data # Project specific imports from models.detection import get_maskrcnn_model from utils.data_utils import PennFudanDataset, collate_fn, get_transform from utils.log_utils import setup_logging def main(args): # --- Configuration Loading --- try: config_path = os.path.abspath(args.config) if not os.path.exists(config_path): print(f"Error: Config file not found at {config_path}") sys.exit(1) # Derive module path from file path relative to workspace root workspace_root = os.path.abspath( os.getcwd() ) # Assuming script is run from root relative_path = os.path.relpath(config_path, workspace_root) if relative_path.startswith(".."): print(f"Error: Config file {args.config} is outside the project directory.") sys.exit(1) module_path_no_ext, _ = os.path.splitext(relative_path) module_path_str = module_path_no_ext.replace(os.sep, ".") print(f"Attempting to import config module: {module_path_str}") config_module = importlib.import_module(module_path_str) config = config_module.config print( f"Loaded configuration from: {config_path} (via module {module_path_str})" ) except ImportError as e: print(f"Error importing config module '{module_path_str}': {e}") print( "Ensure the config file path is correct and relative imports within it are valid." ) import traceback traceback.print_exc() sys.exit(1) except AttributeError as e: print( f"Error: Could not find 'config' dictionary in module {module_path_str}. {e}" ) sys.exit(1) except Exception as e: print(f"Error loading configuration file {args.config}: {e}") import traceback traceback.print_exc() sys.exit(1) # --- Output Directory Setup --- output_dir = config.get("output_dir", "outputs") config_name = config.get("config_name", "default_run") output_path = os.path.join(output_dir, config_name) checkpoint_path = os.path.join(output_path, "checkpoints") os.makedirs(output_path, exist_ok=True) os.makedirs(checkpoint_path, exist_ok=True) print(f"Output will be saved to: {output_path}") # --- Logging Setup (Prompt 9) --- setup_logging(output_path, config_name) logging.info("--- Training Script Started ---") logging.info(f"Loaded configuration from: {args.config}") logging.info(f"Loaded configuration dictionary: {config}") logging.info(f"Output will be saved to: {output_path}") # --- Reproducibility --- seed = config.get("seed", 42) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) # Consider adding these for more determinism, but they might impact performance # torch.backends.cudnn.deterministic = True # torch.backends.cudnn.benchmark = False logging.info(f"Set random seed to: {seed}") # --- Device Setup --- device_name = config.get("device", "cuda") if device_name == "cuda" and not torch.cuda.is_available(): logging.warning("CUDA requested but not available, falling back to CPU.") device_name = "cpu" device = torch.device(device_name) logging.info(f"Using device: {device}") # --- Dataset and DataLoader --- data_root = config.get("data_root") if not data_root or not os.path.isdir(data_root): logging.error(f"Data root directory not found or not specified: {data_root}") sys.exit(1) try: dataset_train = PennFudanDataset( root=data_root, transforms=get_transform(train=True) ) # Note: Validation split will be handled later (Prompt 12) # dataset_val = PennFudanDataset(root=data_root, transforms=get_transform(train=False)) # TODO: Implement data splitting (e.g., using torch.utils.data.Subset) data_loader_train = torch.utils.data.DataLoader( dataset_train, batch_size=config.get("batch_size", 2), shuffle=True, num_workers=config.get("num_workers", 4), collate_fn=collate_fn, pin_memory=config.get( "pin_memory", True ), # Often improves GPU transfer speed ) logging.info(f"Training dataset size: {len(dataset_train)}") logging.info( f"Training dataloader configured with batch size {config.get('batch_size', 2)}" ) # Placeholder for validation loader # data_loader_val = torch.utils.data.DataLoader(...) except Exception as e: logging.error(f"Error setting up dataset/dataloader: {e}", exc_info=True) sys.exit(1) # --- Model Instantiation --- num_classes = config.get("num_classes") if num_classes is None: logging.error("'num_classes' not specified in configuration.") sys.exit(1) try: model = get_maskrcnn_model( num_classes=num_classes, pretrained=config.get("pretrained", True), pretrained_backbone=config.get("pretrained_backbone", True), ) model.to(device) logging.info("Model loaded successfully.") except Exception as e: logging.error(f"Error loading model: {e}", exc_info=True) sys.exit(1) # --- Optimizer --- # Filter parameters that require gradients params = [p for p in model.parameters() if p.requires_grad] try: optimizer = torch.optim.SGD( params, lr=config.get("lr", 0.005), momentum=config.get("momentum", 0.9), weight_decay=config.get("weight_decay", 0.0005), ) logging.info( f"Optimizer SGD configured with lr={config.get('lr', 0.005)}, momentum={config.get('momentum', 0.9)}, weight_decay={config.get('weight_decay', 0.0005)}" ) except Exception as e: logging.error(f"Error creating optimizer: {e}", exc_info=True) sys.exit(1) # --- LR Scheduler (Placeholder for Prompt 10) --- # lr_scheduler = torch.optim.lr_scheduler.StepLR( # optimizer, # step_size=config.get('lr_step_size', 3), # gamma=config.get('lr_gamma', 0.1) # ) # --- Minimal Training Step (Prompt 8 / Updated for Prompt 9) --- logging.info("--- Starting Minimal Training Step --- ") model.train() # Set model to training mode try: # Fetch one batch images, targets = next(iter(data_loader_train)) # Move data to the device images = list(image.to(device) for image in images) # Targets is a list of dicts. Move each tensor in the dict to the device. targets = [{k: v.to(device) for k, v in t.items()} for t in targets] # Perform forward pass (model returns loss dict in train mode) loss_dict = model(images, targets) # Calculate total loss losses = sum(loss for loss in loss_dict.values()) loss_value = losses.item() # Get scalar value # Perform backward pass optimizer.zero_grad() # Clear previous gradients losses.backward() # Compute gradients optimizer.step() # Update weights # Convert loss_dict tensors to scalar values for logging loss_dict_log = {k: v.item() for k, v in loss_dict.items()} logging.info(f"Single step loss dict: {loss_dict_log}") logging.info(f"Single step total loss: {loss_value:.4f}") logging.info("--- Minimal Training Step Completed Successfully --- ") except Exception as e: logging.error(f"Error during minimal training step: {e}", exc_info=True) import traceback # traceback.print_exc() # Already logged with exc_info=True sys.exit(1) # Temporarily exit after the single step (as per Prompt 8) logging.info("Exiting after single training step.") sys.exit(0) # --- Full Training Loop (Placeholder for Prompt 10) --- # print("Basic setup complete. Full training loop implementation pending.") # ... loop implementation ... if __name__ == "__main__": parser = argparse.ArgumentParser( description="Train Mask R-CNN on Penn-Fudan dataset." ) parser.add_argument( "--config", type=str, required=True, help="Path to the Python configuration file (e.g., configs/pennfudan_maskrcnn_config.py)", ) args = parser.parse_args() main(args)