Create test script, and refactor logic from train into common file for usage across both scripts
This commit is contained in:
185
utils/common.py
Normal file
185
utils/common.py
Normal file
@@ -0,0 +1,185 @@
|
||||
import importlib.util
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def load_config(config_path):
|
||||
"""Load configuration from a Python file.
|
||||
|
||||
Args:
|
||||
config_path (str): Path to the configuration file.
|
||||
|
||||
Returns:
|
||||
dict: The loaded configuration dictionary.
|
||||
"""
|
||||
try:
|
||||
config_path = os.path.abspath(config_path)
|
||||
if not os.path.exists(config_path):
|
||||
print(f"Error: Config file not found at {config_path}")
|
||||
sys.exit(1)
|
||||
|
||||
# Derive module path from file path relative to workspace root
|
||||
workspace_root = os.path.abspath(os.getcwd())
|
||||
relative_path = os.path.relpath(config_path, workspace_root)
|
||||
if relative_path.startswith(".."):
|
||||
print(f"Error: Config file {config_path} is outside the project directory.")
|
||||
sys.exit(1)
|
||||
|
||||
module_path_no_ext, _ = os.path.splitext(relative_path)
|
||||
module_path_str = module_path_no_ext.replace(os.sep, ".")
|
||||
|
||||
print(f"Attempting to import config module: {module_path_str}")
|
||||
config_module = importlib.import_module(module_path_str)
|
||||
config = config_module.config
|
||||
|
||||
print(
|
||||
f"Loaded configuration from: {config_path} (via module {module_path_str})"
|
||||
)
|
||||
return config
|
||||
|
||||
except ImportError as e:
|
||||
print(f"Error importing config module '{module_path_str}': {e}")
|
||||
print(
|
||||
"Ensure the config file path is correct and relative imports within it are valid."
|
||||
)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
except AttributeError as e:
|
||||
print(
|
||||
f"Error: Could not find 'config' dictionary in module {module_path_str}. {e}"
|
||||
)
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"Error loading configuration file {config_path}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def setup_environment(config):
|
||||
"""Set up the environment based on configuration.
|
||||
|
||||
Args:
|
||||
config (dict): Configuration dictionary.
|
||||
|
||||
Returns:
|
||||
tuple: (output_path, device) - the output directory path and torch device.
|
||||
"""
|
||||
# Setup output directory
|
||||
output_dir = config.get("output_dir", "outputs")
|
||||
config_name = config.get("config_name", "default_run")
|
||||
output_path = os.path.join(output_dir, config_name)
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
# Set random seeds
|
||||
seed = config.get("seed", 42)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
logging.info(f"Set random seed to: {seed}")
|
||||
|
||||
# Setup device
|
||||
device_name = config.get("device", "cuda")
|
||||
if device_name == "cuda" and not torch.cuda.is_available():
|
||||
logging.warning("CUDA requested but not available, falling back to CPU.")
|
||||
device_name = "cpu"
|
||||
device = torch.device(device_name)
|
||||
logging.info(f"Using device: {device}")
|
||||
|
||||
return output_path, device
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
checkpoint_path,
|
||||
model,
|
||||
device,
|
||||
load_optimizer=False,
|
||||
optimizer=None,
|
||||
load_scheduler=False,
|
||||
scheduler=None,
|
||||
):
|
||||
"""Load a checkpoint into the model and optionally optimizer and scheduler.
|
||||
|
||||
Args:
|
||||
checkpoint_path (str): Path to the checkpoint file.
|
||||
model (torch.nn.Module): The model to load the weights into.
|
||||
device (torch.device): The device to load the checkpoint on.
|
||||
load_optimizer (bool): Whether to load optimizer state.
|
||||
optimizer (torch.optim.Optimizer, optional): The optimizer to load state into.
|
||||
load_scheduler (bool): Whether to load scheduler state.
|
||||
scheduler (torch.optim.lr_scheduler._LRScheduler, optional): The scheduler to load state into.
|
||||
|
||||
Returns:
|
||||
dict: The loaded checkpoint.
|
||||
int: The starting epoch (checkpoint epoch + 1).
|
||||
"""
|
||||
try:
|
||||
logging.info(f"Loading checkpoint from: {checkpoint_path}")
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
|
||||
# Handle potential DataParallel prefix
|
||||
state_dict = checkpoint.get("model_state_dict", checkpoint)
|
||||
if isinstance(state_dict, dict):
|
||||
# Handle case where model was trained with DataParallel
|
||||
if all(k.startswith("module.") for k in state_dict.keys()):
|
||||
logging.info(
|
||||
"Detected DataParallel checkpoint, removing 'module.' prefix"
|
||||
)
|
||||
state_dict = {
|
||||
k.replace("module.", ""): v for k, v in state_dict.items()
|
||||
}
|
||||
|
||||
model.load_state_dict(state_dict)
|
||||
logging.info("Model state loaded successfully")
|
||||
|
||||
# Load optimizer state if requested
|
||||
if (
|
||||
load_optimizer
|
||||
and optimizer is not None
|
||||
and "optimizer_state_dict" in checkpoint
|
||||
):
|
||||
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||
logging.info("Optimizer state loaded successfully")
|
||||
|
||||
# Load scheduler state if requested
|
||||
if (
|
||||
load_scheduler
|
||||
and scheduler is not None
|
||||
and "scheduler_state_dict" in checkpoint
|
||||
):
|
||||
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
||||
logging.info("Scheduler state loaded successfully")
|
||||
|
||||
# Get the epoch number
|
||||
start_epoch = checkpoint.get("epoch", 0) + 1 if load_optimizer else 0
|
||||
if "epoch" in checkpoint:
|
||||
logging.info(f"Loaded checkpoint from epoch: {checkpoint['epoch']}")
|
||||
|
||||
return checkpoint, start_epoch
|
||||
else:
|
||||
logging.error("Checkpoint does not contain a valid state dictionary.")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading checkpoint: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def check_data_path(data_root):
|
||||
"""Check if the data path exists and is valid.
|
||||
|
||||
Args:
|
||||
data_root (str): Path to the data directory.
|
||||
"""
|
||||
if not data_root or not os.path.isdir(data_root):
|
||||
logging.error(f"Data root directory not found or not specified: {data_root}")
|
||||
sys.exit(1)
|
||||
Reference in New Issue
Block a user