Files

186 lines
6.2 KiB
Python

import importlib.util
import logging
import os
import random
import sys
import numpy as np
import torch
def load_config(config_path):
"""Load configuration from a Python file.
Args:
config_path (str): Path to the configuration file.
Returns:
dict: The loaded configuration dictionary.
"""
try:
config_path = os.path.abspath(config_path)
if not os.path.exists(config_path):
print(f"Error: Config file not found at {config_path}")
sys.exit(1)
# Derive module path from file path relative to workspace root
workspace_root = os.path.abspath(os.getcwd())
relative_path = os.path.relpath(config_path, workspace_root)
if relative_path.startswith(".."):
print(f"Error: Config file {config_path} is outside the project directory.")
sys.exit(1)
module_path_no_ext, _ = os.path.splitext(relative_path)
module_path_str = module_path_no_ext.replace(os.sep, ".")
print(f"Attempting to import config module: {module_path_str}")
config_module = importlib.import_module(module_path_str)
config = config_module.config
print(
f"Loaded configuration from: {config_path} (via module {module_path_str})"
)
return config
except ImportError as e:
print(f"Error importing config module '{module_path_str}': {e}")
print(
"Ensure the config file path is correct and relative imports within it are valid."
)
import traceback
traceback.print_exc()
sys.exit(1)
except AttributeError as e:
print(
f"Error: Could not find 'config' dictionary in module {module_path_str}. {e}"
)
sys.exit(1)
except Exception as e:
print(f"Error loading configuration file {config_path}: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
def setup_environment(config):
"""Set up the environment based on configuration.
Args:
config (dict): Configuration dictionary.
Returns:
tuple: (output_path, device) - the output directory path and torch device.
"""
# Setup output directory
output_dir = config.get("output_dir", "outputs")
config_name = config.get("config_name", "default_run")
output_path = os.path.join(output_dir, config_name)
os.makedirs(output_path, exist_ok=True)
# Set random seeds
seed = config.get("seed", 42)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
logging.info(f"Set random seed to: {seed}")
# Setup device
device_name = config.get("device", "cuda")
if device_name == "cuda" and not torch.cuda.is_available():
logging.warning("CUDA requested but not available, falling back to CPU.")
device_name = "cpu"
device = torch.device(device_name)
logging.info(f"Using device: {device}")
return output_path, device
def load_checkpoint(
checkpoint_path,
model,
device,
load_optimizer=False,
optimizer=None,
load_scheduler=False,
scheduler=None,
):
"""Load a checkpoint into the model and optionally optimizer and scheduler.
Args:
checkpoint_path (str): Path to the checkpoint file.
model (torch.nn.Module): The model to load the weights into.
device (torch.device): The device to load the checkpoint on.
load_optimizer (bool): Whether to load optimizer state.
optimizer (torch.optim.Optimizer, optional): The optimizer to load state into.
load_scheduler (bool): Whether to load scheduler state.
scheduler (torch.optim.lr_scheduler._LRScheduler, optional): The scheduler to load state into.
Returns:
dict: The loaded checkpoint.
int: The starting epoch (checkpoint epoch + 1).
"""
try:
logging.info(f"Loading checkpoint from: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=device)
# Handle potential DataParallel prefix
state_dict = checkpoint.get("model_state_dict", checkpoint)
if isinstance(state_dict, dict):
# Handle case where model was trained with DataParallel
if all(k.startswith("module.") for k in state_dict.keys()):
logging.info(
"Detected DataParallel checkpoint, removing 'module.' prefix"
)
state_dict = {
k.replace("module.", ""): v for k, v in state_dict.items()
}
model.load_state_dict(state_dict)
logging.info("Model state loaded successfully")
# Load optimizer state if requested
if (
load_optimizer
and optimizer is not None
and "optimizer_state_dict" in checkpoint
):
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
logging.info("Optimizer state loaded successfully")
# Load scheduler state if requested
if (
load_scheduler
and scheduler is not None
and "scheduler_state_dict" in checkpoint
):
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
logging.info("Scheduler state loaded successfully")
# Get the epoch number
start_epoch = checkpoint.get("epoch", 0) + 1 if load_optimizer else 0
if "epoch" in checkpoint:
logging.info(f"Loaded checkpoint from epoch: {checkpoint['epoch']}")
return checkpoint, start_epoch
else:
logging.error("Checkpoint does not contain a valid state dictionary.")
sys.exit(1)
except Exception as e:
logging.error(f"Error loading checkpoint: {e}", exc_info=True)
sys.exit(1)
def check_data_path(data_root):
"""Check if the data path exists and is valid.
Args:
data_root (str): Path to the data directory.
"""
if not data_root or not os.path.isdir(data_root):
logging.error(f"Data root directory not found or not specified: {data_root}")
sys.exit(1)