import importlib.util import logging import os import random import sys import numpy as np import torch def load_config(config_path): """Load configuration from a Python file. Args: config_path (str): Path to the configuration file. Returns: dict: The loaded configuration dictionary. """ try: config_path = os.path.abspath(config_path) if not os.path.exists(config_path): print(f"Error: Config file not found at {config_path}") sys.exit(1) # Derive module path from file path relative to workspace root workspace_root = os.path.abspath(os.getcwd()) relative_path = os.path.relpath(config_path, workspace_root) if relative_path.startswith(".."): print(f"Error: Config file {config_path} is outside the project directory.") sys.exit(1) module_path_no_ext, _ = os.path.splitext(relative_path) module_path_str = module_path_no_ext.replace(os.sep, ".") print(f"Attempting to import config module: {module_path_str}") config_module = importlib.import_module(module_path_str) config = config_module.config print( f"Loaded configuration from: {config_path} (via module {module_path_str})" ) return config except ImportError as e: print(f"Error importing config module '{module_path_str}': {e}") print( "Ensure the config file path is correct and relative imports within it are valid." ) import traceback traceback.print_exc() sys.exit(1) except AttributeError as e: print( f"Error: Could not find 'config' dictionary in module {module_path_str}. {e}" ) sys.exit(1) except Exception as e: print(f"Error loading configuration file {config_path}: {e}") import traceback traceback.print_exc() sys.exit(1) def setup_environment(config): """Set up the environment based on configuration. Args: config (dict): Configuration dictionary. Returns: tuple: (output_path, device) - the output directory path and torch device. """ # Setup output directory output_dir = config.get("output_dir", "outputs") config_name = config.get("config_name", "default_run") output_path = os.path.join(output_dir, config_name) os.makedirs(output_path, exist_ok=True) # Set random seeds seed = config.get("seed", 42) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) logging.info(f"Set random seed to: {seed}") # Setup device device_name = config.get("device", "cuda") if device_name == "cuda" and not torch.cuda.is_available(): logging.warning("CUDA requested but not available, falling back to CPU.") device_name = "cpu" device = torch.device(device_name) logging.info(f"Using device: {device}") return output_path, device def load_checkpoint( checkpoint_path, model, device, load_optimizer=False, optimizer=None, load_scheduler=False, scheduler=None, ): """Load a checkpoint into the model and optionally optimizer and scheduler. Args: checkpoint_path (str): Path to the checkpoint file. model (torch.nn.Module): The model to load the weights into. device (torch.device): The device to load the checkpoint on. load_optimizer (bool): Whether to load optimizer state. optimizer (torch.optim.Optimizer, optional): The optimizer to load state into. load_scheduler (bool): Whether to load scheduler state. scheduler (torch.optim.lr_scheduler._LRScheduler, optional): The scheduler to load state into. Returns: dict: The loaded checkpoint. int: The starting epoch (checkpoint epoch + 1). """ try: logging.info(f"Loading checkpoint from: {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location=device) # Handle potential DataParallel prefix state_dict = checkpoint.get("model_state_dict", checkpoint) if isinstance(state_dict, dict): # Handle case where model was trained with DataParallel if all(k.startswith("module.") for k in state_dict.keys()): logging.info( "Detected DataParallel checkpoint, removing 'module.' prefix" ) state_dict = { k.replace("module.", ""): v for k, v in state_dict.items() } model.load_state_dict(state_dict) logging.info("Model state loaded successfully") # Load optimizer state if requested if ( load_optimizer and optimizer is not None and "optimizer_state_dict" in checkpoint ): optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) logging.info("Optimizer state loaded successfully") # Load scheduler state if requested if ( load_scheduler and scheduler is not None and "scheduler_state_dict" in checkpoint ): scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) logging.info("Scheduler state loaded successfully") # Get the epoch number start_epoch = checkpoint.get("epoch", 0) + 1 if load_optimizer else 0 if "epoch" in checkpoint: logging.info(f"Loaded checkpoint from epoch: {checkpoint['epoch']}") return checkpoint, start_epoch else: logging.error("Checkpoint does not contain a valid state dictionary.") sys.exit(1) except Exception as e: logging.error(f"Error loading checkpoint: {e}", exc_info=True) sys.exit(1) def check_data_path(data_root): """Check if the data path exists and is valid. Args: data_root (str): Path to the data directory. """ if not data_root or not os.path.isdir(data_root): logging.error(f"Data root directory not found or not specified: {data_root}") sys.exit(1)