186 lines
6.2 KiB
Python
186 lines
6.2 KiB
Python
import importlib.util
|
|
import logging
|
|
import os
|
|
import random
|
|
import sys
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
|
|
def load_config(config_path):
|
|
"""Load configuration from a Python file.
|
|
|
|
Args:
|
|
config_path (str): Path to the configuration file.
|
|
|
|
Returns:
|
|
dict: The loaded configuration dictionary.
|
|
"""
|
|
try:
|
|
config_path = os.path.abspath(config_path)
|
|
if not os.path.exists(config_path):
|
|
print(f"Error: Config file not found at {config_path}")
|
|
sys.exit(1)
|
|
|
|
# Derive module path from file path relative to workspace root
|
|
workspace_root = os.path.abspath(os.getcwd())
|
|
relative_path = os.path.relpath(config_path, workspace_root)
|
|
if relative_path.startswith(".."):
|
|
print(f"Error: Config file {config_path} is outside the project directory.")
|
|
sys.exit(1)
|
|
|
|
module_path_no_ext, _ = os.path.splitext(relative_path)
|
|
module_path_str = module_path_no_ext.replace(os.sep, ".")
|
|
|
|
print(f"Attempting to import config module: {module_path_str}")
|
|
config_module = importlib.import_module(module_path_str)
|
|
config = config_module.config
|
|
|
|
print(
|
|
f"Loaded configuration from: {config_path} (via module {module_path_str})"
|
|
)
|
|
return config
|
|
|
|
except ImportError as e:
|
|
print(f"Error importing config module '{module_path_str}': {e}")
|
|
print(
|
|
"Ensure the config file path is correct and relative imports within it are valid."
|
|
)
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
sys.exit(1)
|
|
except AttributeError as e:
|
|
print(
|
|
f"Error: Could not find 'config' dictionary in module {module_path_str}. {e}"
|
|
)
|
|
sys.exit(1)
|
|
except Exception as e:
|
|
print(f"Error loading configuration file {config_path}: {e}")
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
sys.exit(1)
|
|
|
|
|
|
def setup_environment(config):
|
|
"""Set up the environment based on configuration.
|
|
|
|
Args:
|
|
config (dict): Configuration dictionary.
|
|
|
|
Returns:
|
|
tuple: (output_path, device) - the output directory path and torch device.
|
|
"""
|
|
# Setup output directory
|
|
output_dir = config.get("output_dir", "outputs")
|
|
config_name = config.get("config_name", "default_run")
|
|
output_path = os.path.join(output_dir, config_name)
|
|
os.makedirs(output_path, exist_ok=True)
|
|
|
|
# Set random seeds
|
|
seed = config.get("seed", 42)
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(seed)
|
|
logging.info(f"Set random seed to: {seed}")
|
|
|
|
# Setup device
|
|
device_name = config.get("device", "cuda")
|
|
if device_name == "cuda" and not torch.cuda.is_available():
|
|
logging.warning("CUDA requested but not available, falling back to CPU.")
|
|
device_name = "cpu"
|
|
device = torch.device(device_name)
|
|
logging.info(f"Using device: {device}")
|
|
|
|
return output_path, device
|
|
|
|
|
|
def load_checkpoint(
|
|
checkpoint_path,
|
|
model,
|
|
device,
|
|
load_optimizer=False,
|
|
optimizer=None,
|
|
load_scheduler=False,
|
|
scheduler=None,
|
|
):
|
|
"""Load a checkpoint into the model and optionally optimizer and scheduler.
|
|
|
|
Args:
|
|
checkpoint_path (str): Path to the checkpoint file.
|
|
model (torch.nn.Module): The model to load the weights into.
|
|
device (torch.device): The device to load the checkpoint on.
|
|
load_optimizer (bool): Whether to load optimizer state.
|
|
optimizer (torch.optim.Optimizer, optional): The optimizer to load state into.
|
|
load_scheduler (bool): Whether to load scheduler state.
|
|
scheduler (torch.optim.lr_scheduler._LRScheduler, optional): The scheduler to load state into.
|
|
|
|
Returns:
|
|
dict: The loaded checkpoint.
|
|
int: The starting epoch (checkpoint epoch + 1).
|
|
"""
|
|
try:
|
|
logging.info(f"Loading checkpoint from: {checkpoint_path}")
|
|
checkpoint = torch.load(checkpoint_path, map_location=device)
|
|
|
|
# Handle potential DataParallel prefix
|
|
state_dict = checkpoint.get("model_state_dict", checkpoint)
|
|
if isinstance(state_dict, dict):
|
|
# Handle case where model was trained with DataParallel
|
|
if all(k.startswith("module.") for k in state_dict.keys()):
|
|
logging.info(
|
|
"Detected DataParallel checkpoint, removing 'module.' prefix"
|
|
)
|
|
state_dict = {
|
|
k.replace("module.", ""): v for k, v in state_dict.items()
|
|
}
|
|
|
|
model.load_state_dict(state_dict)
|
|
logging.info("Model state loaded successfully")
|
|
|
|
# Load optimizer state if requested
|
|
if (
|
|
load_optimizer
|
|
and optimizer is not None
|
|
and "optimizer_state_dict" in checkpoint
|
|
):
|
|
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
|
logging.info("Optimizer state loaded successfully")
|
|
|
|
# Load scheduler state if requested
|
|
if (
|
|
load_scheduler
|
|
and scheduler is not None
|
|
and "scheduler_state_dict" in checkpoint
|
|
):
|
|
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
|
logging.info("Scheduler state loaded successfully")
|
|
|
|
# Get the epoch number
|
|
start_epoch = checkpoint.get("epoch", 0) + 1 if load_optimizer else 0
|
|
if "epoch" in checkpoint:
|
|
logging.info(f"Loaded checkpoint from epoch: {checkpoint['epoch']}")
|
|
|
|
return checkpoint, start_epoch
|
|
else:
|
|
logging.error("Checkpoint does not contain a valid state dictionary.")
|
|
sys.exit(1)
|
|
except Exception as e:
|
|
logging.error(f"Error loading checkpoint: {e}", exc_info=True)
|
|
sys.exit(1)
|
|
|
|
|
|
def check_data_path(data_root):
|
|
"""Check if the data path exists and is valid.
|
|
|
|
Args:
|
|
data_root (str): Path to the data directory.
|
|
"""
|
|
if not data_root or not os.path.isdir(data_root):
|
|
logging.error(f"Data root directory not found or not specified: {data_root}")
|
|
sys.exit(1)
|