Create test script, and refactor logic from train into common file for usage across both scripts

This commit is contained in:
Craig
2025-04-12 11:09:36 +01:00
parent 0f3a96ca81
commit 217cfba9ba
3 changed files with 501 additions and 305 deletions

109
test.py
View File

@@ -0,0 +1,109 @@
import argparse
import logging
import sys
import torch
import torch.utils.data
# Project specific imports
from models.detection import get_maskrcnn_model
from utils.common import (
check_data_path,
load_checkpoint,
load_config,
setup_environment,
)
from utils.data_utils import PennFudanDataset, collate_fn, get_transform
from utils.eval_utils import evaluate
from utils.log_utils import setup_logging
def main(args):
# Load configuration
config = load_config(args.config)
# Setup output directory and get device
output_path, device = setup_environment(config)
# Setup logging
setup_logging(output_path, f"{config['config_name']}_test")
logging.info("--- Testing Script Started ---")
logging.info(f"Loaded configuration from: {args.config}")
logging.info(f"Checkpoint path: {args.checkpoint}")
logging.info(f"Loaded configuration dictionary: {config}")
# Validate data path
data_root = config.get("data_root")
check_data_path(data_root)
try:
# Create the full dataset instance for testing with eval transforms
dataset_test = PennFudanDataset(
root=data_root, transforms=get_transform(train=False)
)
logging.info(f"Test dataset size: {len(dataset_test)}")
# Create test DataLoader
data_loader_test = torch.utils.data.DataLoader(
dataset_test,
batch_size=config.get("batch_size", 2),
shuffle=False, # No need to shuffle test data
num_workers=config.get("num_workers", 4),
collate_fn=collate_fn,
pin_memory=config.get("pin_memory", True),
)
logging.info(
f"Test dataloader configured. Est. batches: {len(data_loader_test)}"
)
except Exception as e:
logging.error(f"Error setting up dataset/dataloader: {e}", exc_info=True)
sys.exit(1)
# Create model
num_classes = config.get("num_classes")
if num_classes is None:
logging.error("'num_classes' not specified in configuration.")
sys.exit(1)
try:
# Create the model with the same architecture as in training
model = get_maskrcnn_model(
num_classes=num_classes,
pretrained=False, # Don't need pretrained weights as we'll load checkpoint
pretrained_backbone=False,
)
# Load checkpoint
load_checkpoint(args.checkpoint, model, device)
model.to(device)
logging.info("Model loaded and moved to device successfully.")
except Exception as e:
logging.error(f"Error setting up model: {e}", exc_info=True)
sys.exit(1)
# Run Evaluation
try:
logging.info("Starting model evaluation...")
eval_metrics = evaluate(model, data_loader_test, device)
# Log detailed metrics
logging.info("--- Evaluation Results ---")
for metric_name, metric_value in eval_metrics.items():
logging.info(f" {metric_name}: {metric_value:.4f}")
logging.info("Evaluation completed successfully")
except Exception as e:
logging.error(f"Error during evaluation: {e}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Test a trained Mask R-CNN model")
parser.add_argument("--config", required=True, help="Path to configuration file")
parser.add_argument(
"--checkpoint", required=True, help="Path to model checkpoint file (.pth)"
)
args = parser.parse_args()
main(args)

512
train.py
View File

@@ -1,112 +1,44 @@
import argparse
import importlib.util
import logging
import os
import random
import sys
import time # Import time for timing
import time
import numpy as np
import torch
import torch.utils.data
# Project specific imports
from models.detection import get_maskrcnn_model
from utils.common import (
check_data_path,
load_checkpoint,
load_config,
setup_environment,
)
from utils.data_utils import PennFudanDataset, collate_fn, get_transform
from utils.eval_utils import evaluate # Import evaluate function
from utils.eval_utils import evaluate
from utils.log_utils import setup_logging
def main(args):
# --- Configuration Loading ---
try:
config_path = os.path.abspath(args.config)
if not os.path.exists(config_path):
print(f"Error: Config file not found at {config_path}")
sys.exit(1)
# Load configuration
config = load_config(args.config)
# Derive module path from file path relative to workspace root
workspace_root = os.path.abspath(
os.getcwd()
) # Assuming script is run from root
relative_path = os.path.relpath(config_path, workspace_root)
if relative_path.startswith(".."):
print(f"Error: Config file {args.config} is outside the project directory.")
sys.exit(1)
module_path_no_ext, _ = os.path.splitext(relative_path)
module_path_str = module_path_no_ext.replace(os.sep, ".")
print(f"Attempting to import config module: {module_path_str}")
config_module = importlib.import_module(module_path_str)
config = config_module.config
print(
f"Loaded configuration from: {config_path} (via module {module_path_str})"
)
except ImportError as e:
print(f"Error importing config module '{module_path_str}': {e}")
print(
"Ensure the config file path is correct and relative imports within it are valid."
)
import traceback
traceback.print_exc()
sys.exit(1)
except AttributeError as e:
print(
f"Error: Could not find 'config' dictionary in module {module_path_str}. {e}"
)
sys.exit(1)
except Exception as e:
print(f"Error loading configuration file {args.config}: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
# --- Output Directory Setup ---
output_dir = config.get("output_dir", "outputs")
config_name = config.get("config_name", "default_run")
output_path = os.path.join(output_dir, config_name)
# Setup output directory and get device
output_path, device = setup_environment(config)
checkpoint_path = os.path.join(output_path, "checkpoints")
os.makedirs(output_path, exist_ok=True)
os.makedirs(checkpoint_path, exist_ok=True)
print(f"Output will be saved to: {output_path}")
# --- Logging Setup (Prompt 9) ---
setup_logging(output_path, config_name)
# Setup logging
setup_logging(output_path, config.get("config_name", "default_run"))
logging.info("--- Training Script Started ---")
logging.info(f"Loaded configuration from: {args.config}")
logging.info(f"Loaded configuration dictionary: {config}")
logging.info(f"Output will be saved to: {output_path}")
# --- Reproducibility ---
seed = config.get("seed", 42)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# Consider adding these for more determinism, but they might impact performance
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
logging.info(f"Set random seed to: {seed}")
# --- Device Setup ---
device_name = config.get("device", "cuda")
if device_name == "cuda" and not torch.cuda.is_available():
logging.warning("CUDA requested but not available, falling back to CPU.")
device_name = "cpu"
device = torch.device(device_name)
logging.info(f"Using device: {device}")
# --- Dataset and DataLoader ---
# Validate data path
data_root = config.get("data_root")
if not data_root or not os.path.isdir(data_root):
logging.error(f"Data root directory not found or not specified: {data_root}")
sys.exit(1)
check_data_path(data_root)
try:
# Create the full training dataset instance first
@@ -183,7 +115,7 @@ def main(args):
logging.error(f"Error setting up dataset/dataloader: {e}", exc_info=True)
sys.exit(1)
# --- Model Instantiation ---
# Create model
num_classes = config.get("num_classes")
if num_classes is None:
logging.error("'num_classes' not specified in configuration.")
@@ -198,245 +130,215 @@ def main(args):
model.to(device)
logging.info("Model loaded successfully.")
except Exception as e:
logging.error(f"Error loading model: {e}", exc_info=True)
logging.error(f"Error creating model: {e}", exc_info=True)
sys.exit(1)
# --- Optimizer ---
# Filter parameters that require gradients
params = [p for p in model.parameters() if p.requires_grad]
try:
optimizer = torch.optim.SGD(
params,
lr=config.get("lr", 0.005),
momentum=config.get("momentum", 0.9),
weight_decay=config.get("weight_decay", 0.0005),
)
logging.info(
f"Optimizer SGD configured with lr={config.get('lr', 0.005)}, momentum={config.get('momentum', 0.9)}, weight_decay={config.get('weight_decay', 0.0005)}"
)
except Exception as e:
logging.error(f"Error creating optimizer: {e}", exc_info=True)
sys.exit(1)
# Create optimizer and learning rate scheduler
optimizer = torch.optim.SGD(
model.parameters(),
lr=config.get("lr", 0.005),
momentum=config.get("momentum", 0.9),
weight_decay=config.get("weight_decay", 0.0005),
)
# --- LR Scheduler (Prompt 10) ---
try:
lr_scheduler = torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=config.get("lr_step_size", 3),
gamma=config.get("lr_gamma", 0.1),
)
logging.info(
f"LR scheduler StepLR configured with step_size={config.get('lr_step_size', 3)}, gamma={config.get('lr_gamma', 0.1)}"
)
except Exception as e:
logging.error(f"Error creating LR scheduler: {e}", exc_info=True)
sys.exit(1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=config.get("lr_step_size", 3),
gamma=config.get("lr_gamma", 0.1),
)
# --- Resume Logic (Prompt 11) ---
# --- Resume from Checkpoint (if specified) ---
start_epoch = 0
latest_checkpoint_path = None
if os.path.isdir(checkpoint_path):
checkpoints = sorted(
[f for f in os.listdir(checkpoint_path) if f.endswith(".pth")]
)
if checkpoints: # Check if list is not empty
latest_checkpoint_file = checkpoints[
-1
] # Get the last one (assuming naming convention like epoch_N.pth)
latest_checkpoint_path = os.path.join(
checkpoint_path, latest_checkpoint_file
)
logging.info(f"Found latest checkpoint: {latest_checkpoint_path}")
else:
logging.info("No checkpoints found in directory. Starting from scratch.")
else:
logging.info("Checkpoint directory not found. Starting from scratch.")
if latest_checkpoint_path:
if args.resume:
try:
logging.info(f"Loading checkpoint '{latest_checkpoint_path}'")
# Ensure loading happens on the correct device
checkpoint = torch.load(latest_checkpoint_path, map_location=device)
# Load model state - handle potential 'module.' prefix if saved with DataParallel
model_state_dict = checkpoint["model_state_dict"]
# Simple check and correction for DataParallel prefix
if all(key.startswith("module.") for key in model_state_dict.keys()):
logging.info("Removing 'module.' prefix from checkpoint keys.")
model_state_dict = {
k.replace("module.", ""): v for k, v in model_state_dict.items()
}
model.load_state_dict(model_state_dict)
# Load optimizer state
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
# Load LR scheduler state
lr_scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
# Load starting epoch (epoch saved is the one *completed*, so start from next)
start_epoch = checkpoint["epoch"]
logging.info(f"Resuming training from epoch {start_epoch + 1}")
# Optionally load and verify config consistency
# loaded_config = checkpoint.get('config')
# if loaded_config:
# # Perform checks if necessary
# pass
except Exception as e:
logging.error(
f"Error loading checkpoint: {e}. Starting training from scratch.",
exc_info=True,
)
start_epoch = 0 # Reset start_epoch if loading fails
# --- Training Loop (Prompt 10, modified for Prompt 11) ---
logging.info("--- Starting Training Loop --- ")
start_time = time.time()
num_epochs = config.get("num_epochs", 10)
# Modify loop to start from start_epoch
for epoch in range(start_epoch, num_epochs):
model.train() # Set model to training mode for each epoch
epoch_start_time = time.time()
logging.info(f"--- Epoch {epoch + 1}/{num_epochs} --- ")
# Variables for tracking epoch progress (optional)
epoch_loss_sum = 0.0
num_batches = len(data_loader_train)
for i, (images, targets) in enumerate(data_loader_train):
batch_start_time = time.time() # Optional: time each batch
try:
# Move data to the device
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
# Perform forward pass
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
loss_value = losses.item()
epoch_loss_sum += loss_value
# Perform backward pass
optimizer.zero_grad()
losses.backward()
optimizer.step()
# Log batch loss periodically
if (i + 1) % config.get("log_freq", 10) == 0:
batch_time = time.time() - batch_start_time
# Include individual losses if desired
loss_dict_items = {
k: f"{v.item():.4f}" for k, v in loss_dict.items()
}
logging.info(
f" Epoch {epoch + 1}, Iter {i + 1}/{num_batches}, Loss: {loss_value:.4f}, Batch Time: {batch_time:.2f}s"
)
logging.debug(
f" Loss Dict: {loss_dict_items}"
) # Log individual losses at DEBUG level
except Exception as e:
logging.error(
f"Error during training epoch {epoch+1}, batch {i+1}: {e}",
exc_info=True,
# Find latest checkpoint
checkpoints = [f for f in os.listdir(checkpoint_path) if f.endswith(".pth")]
if not checkpoints:
logging.warning(
f"No checkpoints found in {checkpoint_path}, starting from scratch."
)
# Decide if you want to stop training or continue to next batch/epoch
logging.warning("Skipping rest of epoch due to error.")
break # Exit the inner loop for this epoch
else:
# Extract epoch numbers from filenames and find the latest
max_epoch = -1
latest_checkpoint = None
for ckpt in checkpoints:
if ckpt.startswith("checkpoint_epoch_"):
try:
epoch_num = int(
ckpt.replace("checkpoint_epoch_", "").replace(
".pth", ""
)
)
if epoch_num > max_epoch:
max_epoch = epoch_num
latest_checkpoint = ckpt
except ValueError:
continue
# --- End of Epoch --- #
# Step the learning rate scheduler
if latest_checkpoint:
checkpoint_file = os.path.join(checkpoint_path, latest_checkpoint)
logging.info(f"Resuming from checkpoint: {checkpoint_file}")
# Load checkpoint
checkpoint, start_epoch = load_checkpoint(
checkpoint_file,
model,
device,
load_optimizer=True,
optimizer=optimizer,
load_scheduler=True,
scheduler=lr_scheduler,
)
logging.info(f"Resuming from epoch {start_epoch}")
else:
logging.warning(
f"No valid checkpoints found in {checkpoint_path}, starting from scratch."
)
except Exception as e:
logging.error(f"Error loading checkpoint: {e}", exc_info=True)
logging.warning("Starting training from scratch.")
start_epoch = 0
# --- Training Loop ---
train_time_start = time.time()
logging.info("--- Starting Training Loop ---")
for epoch in range(start_epoch, config.get("num_epochs", 10)):
# Set model to training mode
model.train()
# Initialize epoch metrics
epoch_loss = 0.0
epoch_loss_classifier = 0.0
epoch_loss_box_reg = 0.0
epoch_loss_mask = 0.0
epoch_loss_objectness = 0.0
epoch_loss_rpn_box_reg = 0.0
logging.info(f"--- Epoch {epoch + 1}/{config.get('num_epochs', 10)} ---")
epoch_start_time = time.time()
# Train loop
for i, (images, targets) in enumerate(data_loader_train):
# Move data to device
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
# Forward pass
loss_dict = model(images, targets)
# Sum loss components
losses = sum(loss for loss in loss_dict.values())
# Backward and optimize
optimizer.zero_grad()
losses.backward()
optimizer.step()
# Log batch results
loss_value = losses.item()
epoch_loss += loss_value
# Accumulate individual loss components
if "loss_classifier" in loss_dict:
epoch_loss_classifier += loss_dict["loss_classifier"].item()
if "loss_box_reg" in loss_dict:
epoch_loss_box_reg += loss_dict["loss_box_reg"].item()
if "loss_mask" in loss_dict:
epoch_loss_mask += loss_dict["loss_mask"].item()
if "loss_objectness" in loss_dict:
epoch_loss_objectness += loss_dict["loss_objectness"].item()
if "loss_rpn_box_reg" in loss_dict:
epoch_loss_rpn_box_reg += loss_dict["loss_rpn_box_reg"].item()
# Periodic logging
if (i + 1) % config.get("log_freq", 10) == 0:
log_str = f"Epoch [{epoch + 1}/{config.get('num_epochs', 10)}], "
log_str += f"Iter [{i + 1}/{len(data_loader_train)}], "
log_str += f"Loss: {loss_value:.4f}"
# Add per-component losses for richer logging
comp_log = []
if "loss_classifier" in loss_dict:
comp_log.append(f"cls: {loss_dict['loss_classifier'].item():.4f}")
if "loss_box_reg" in loss_dict:
comp_log.append(f"box: {loss_dict['loss_box_reg'].item():.4f}")
if "loss_mask" in loss_dict:
comp_log.append(f"mask: {loss_dict['loss_mask'].item():.4f}")
if "loss_objectness" in loss_dict:
comp_log.append(f"obj: {loss_dict['loss_objectness'].item():.4f}")
if "loss_rpn_box_reg" in loss_dict:
comp_log.append(f"rpn: {loss_dict['loss_rpn_box_reg'].item():.4f}")
if comp_log:
log_str += f" [{', '.join(comp_log)}]"
logging.info(log_str)
# Step learning rate scheduler after each epoch
lr_scheduler.step()
# Log epoch summary
epoch_end_time = time.time()
epoch_duration = epoch_end_time - epoch_start_time
avg_epoch_loss = epoch_loss_sum / num_batches if num_batches > 0 else 0
current_lr = optimizer.param_groups[0]["lr"] # Get current learning rate
logging.info(f"--- Epoch {epoch + 1} Summary --- ")
logging.info(f" Average Loss: {avg_epoch_loss:.4f}")
logging.info(f" Learning Rate: {current_lr:.6f}")
logging.info(f" Epoch Duration: {epoch_duration:.2f}s")
# Calculate and log epoch metrics
if len(data_loader_train) > 0:
avg_loss = epoch_loss / len(data_loader_train)
avg_loss_classifier = epoch_loss_classifier / len(data_loader_train)
avg_loss_box_reg = epoch_loss_box_reg / len(data_loader_train)
avg_loss_mask = epoch_loss_mask / len(data_loader_train)
avg_loss_objectness = epoch_loss_objectness / len(data_loader_train)
avg_loss_rpn_box_reg = epoch_loss_rpn_box_reg / len(data_loader_train)
# --- Checkpointing (Prompt 11) --- #
# Save checkpoint periodically or at the end
save_checkpoint = False
if (epoch + 1) % config.get("checkpoint_freq", 1) == 0:
save_checkpoint = True
logging.info(f"Checkpoint frequency met (epoch {epoch + 1})")
elif (epoch + 1) == num_epochs:
save_checkpoint = True
logging.info(f"Final epoch ({epoch + 1}) reached, saving checkpoint.")
logging.info(f"Epoch {epoch + 1} - Avg Loss: {avg_loss:.4f}")
logging.info(f" Classifier Loss: {avg_loss_classifier:.4f}")
logging.info(f" Box Reg Loss: {avg_loss_box_reg:.4f}")
logging.info(f" Mask Loss: {avg_loss_mask:.4f}")
logging.info(f" Objectness Loss: {avg_loss_objectness:.4f}")
logging.info(f" RPN Box Reg Loss: {avg_loss_rpn_box_reg:.4f}")
else:
logging.warning("No training batches were processed in this epoch.")
if save_checkpoint:
checkpoint_filename = f"checkpoint_epoch_{epoch + 1}.pth"
save_path = os.path.join(checkpoint_path, checkpoint_filename)
epoch_duration = time.time() - epoch_start_time
logging.info(f"Epoch duration: {epoch_duration:.2f}s")
# --- Validation ---
logging.info("Running validation...")
val_metrics = evaluate(model, data_loader_val, device)
logging.info(f"Validation Loss: {val_metrics['average_loss']:.4f}")
# --- Checkpoint Saving ---
if (epoch + 1) % config.get("checkpoint_freq", 1) == 0 or epoch == config.get(
"num_epochs", 10
) - 1:
checkpoint_file = os.path.join(
checkpoint_path, f"checkpoint_epoch_{epoch + 1}.pth"
)
checkpoint = {
"epoch": epoch + 1,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": lr_scheduler.state_dict(),
"config": config,
"val_loss": val_metrics["average_loss"],
}
try:
checkpoint_data = {
"epoch": epoch + 1,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": lr_scheduler.state_dict(),
"config": config, # Save config for reference
}
torch.save(checkpoint_data, save_path)
logging.info(f"Checkpoint saved to {save_path}")
torch.save(checkpoint, checkpoint_file)
logging.info(f"Checkpoint saved to {checkpoint_file}")
except Exception as e:
logging.error(
f"Failed to save checkpoint for epoch {epoch + 1} to {save_path}: {e}",
exc_info=True,
)
logging.error(f"Error saving checkpoint: {e}", exc_info=True)
# --- Evaluation (Prompt 12) --- #
if data_loader_val:
logging.info(f"Starting evaluation for epoch {epoch + 1}...")
try:
val_metrics = evaluate(model, data_loader_val, device)
logging.info(f"Epoch {epoch + 1} Validation Metrics: {val_metrics}")
# --- Best Model Checkpoint Logic (Optional Add-on) ---
# Add logic here to track the best metric (e.g., val_metrics['average_loss'])
# and save a separate 'best_model.pth' checkpoint if the current epoch is better.
# Example:
# if 'average_loss' in val_metrics:
# current_val_loss = val_metrics['average_loss']
# if best_val_loss is None or current_val_loss < best_val_loss:
# best_val_loss = current_val_loss
# best_model_path = os.path.join(output_path, 'best_model.pth')
# try:
# # Save only the model state_dict for the best model
# torch.save(model.state_dict(), best_model_path)
# logging.info(f"Saved NEW BEST model checkpoint to {best_model_path} (Val Loss: {best_val_loss:.4f})")
# except Exception as e:
# logging.error(f"Failed to save best model checkpoint: {e}", exc_info=True)
except Exception as e:
logging.error(
f"Error during evaluation for epoch {epoch + 1}: {e}", exc_info=True
)
# Decide if this error should stop the entire training process
# --- End of Training --- #
total_training_time = time.time() - start_time
logging.info("--- Training Finished --- ")
logging.info(
f"Total Training Time: {total_training_time:.2f}s ({total_training_time / 3600:.2f} hours)"
)
# --- Final Metrics and Cleanup ---
total_training_time = time.time() - train_time_start
hours, remainder = divmod(total_training_time, 3600)
minutes, seconds = divmod(remainder, 60)
logging.info(f"Training completed in {int(hours)}h {int(minutes)}m {seconds:.2f}s")
logging.info(f"Final model saved to {checkpoint_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Train Mask R-CNN on Penn-Fudan dataset."
)
parser = argparse.ArgumentParser(description="Train a Mask R-CNN model")
parser.add_argument("--config", required=True, help="Path to configuration file")
parser.add_argument(
"--config",
type=str,
required=True,
help="Path to the Python configuration file (e.g., configs/pennfudan_maskrcnn_config.py)",
"--resume", action="store_true", help="Resume training from latest checkpoint"
)
args = parser.parse_args()
main(args)

185
utils/common.py Normal file
View File

@@ -0,0 +1,185 @@
import importlib.util
import logging
import os
import random
import sys
import numpy as np
import torch
def load_config(config_path):
"""Load configuration from a Python file.
Args:
config_path (str): Path to the configuration file.
Returns:
dict: The loaded configuration dictionary.
"""
try:
config_path = os.path.abspath(config_path)
if not os.path.exists(config_path):
print(f"Error: Config file not found at {config_path}")
sys.exit(1)
# Derive module path from file path relative to workspace root
workspace_root = os.path.abspath(os.getcwd())
relative_path = os.path.relpath(config_path, workspace_root)
if relative_path.startswith(".."):
print(f"Error: Config file {config_path} is outside the project directory.")
sys.exit(1)
module_path_no_ext, _ = os.path.splitext(relative_path)
module_path_str = module_path_no_ext.replace(os.sep, ".")
print(f"Attempting to import config module: {module_path_str}")
config_module = importlib.import_module(module_path_str)
config = config_module.config
print(
f"Loaded configuration from: {config_path} (via module {module_path_str})"
)
return config
except ImportError as e:
print(f"Error importing config module '{module_path_str}': {e}")
print(
"Ensure the config file path is correct and relative imports within it are valid."
)
import traceback
traceback.print_exc()
sys.exit(1)
except AttributeError as e:
print(
f"Error: Could not find 'config' dictionary in module {module_path_str}. {e}"
)
sys.exit(1)
except Exception as e:
print(f"Error loading configuration file {config_path}: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
def setup_environment(config):
"""Set up the environment based on configuration.
Args:
config (dict): Configuration dictionary.
Returns:
tuple: (output_path, device) - the output directory path and torch device.
"""
# Setup output directory
output_dir = config.get("output_dir", "outputs")
config_name = config.get("config_name", "default_run")
output_path = os.path.join(output_dir, config_name)
os.makedirs(output_path, exist_ok=True)
# Set random seeds
seed = config.get("seed", 42)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
logging.info(f"Set random seed to: {seed}")
# Setup device
device_name = config.get("device", "cuda")
if device_name == "cuda" and not torch.cuda.is_available():
logging.warning("CUDA requested but not available, falling back to CPU.")
device_name = "cpu"
device = torch.device(device_name)
logging.info(f"Using device: {device}")
return output_path, device
def load_checkpoint(
checkpoint_path,
model,
device,
load_optimizer=False,
optimizer=None,
load_scheduler=False,
scheduler=None,
):
"""Load a checkpoint into the model and optionally optimizer and scheduler.
Args:
checkpoint_path (str): Path to the checkpoint file.
model (torch.nn.Module): The model to load the weights into.
device (torch.device): The device to load the checkpoint on.
load_optimizer (bool): Whether to load optimizer state.
optimizer (torch.optim.Optimizer, optional): The optimizer to load state into.
load_scheduler (bool): Whether to load scheduler state.
scheduler (torch.optim.lr_scheduler._LRScheduler, optional): The scheduler to load state into.
Returns:
dict: The loaded checkpoint.
int: The starting epoch (checkpoint epoch + 1).
"""
try:
logging.info(f"Loading checkpoint from: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=device)
# Handle potential DataParallel prefix
state_dict = checkpoint.get("model_state_dict", checkpoint)
if isinstance(state_dict, dict):
# Handle case where model was trained with DataParallel
if all(k.startswith("module.") for k in state_dict.keys()):
logging.info(
"Detected DataParallel checkpoint, removing 'module.' prefix"
)
state_dict = {
k.replace("module.", ""): v for k, v in state_dict.items()
}
model.load_state_dict(state_dict)
logging.info("Model state loaded successfully")
# Load optimizer state if requested
if (
load_optimizer
and optimizer is not None
and "optimizer_state_dict" in checkpoint
):
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
logging.info("Optimizer state loaded successfully")
# Load scheduler state if requested
if (
load_scheduler
and scheduler is not None
and "scheduler_state_dict" in checkpoint
):
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
logging.info("Scheduler state loaded successfully")
# Get the epoch number
start_epoch = checkpoint.get("epoch", 0) + 1 if load_optimizer else 0
if "epoch" in checkpoint:
logging.info(f"Loaded checkpoint from epoch: {checkpoint['epoch']}")
return checkpoint, start_epoch
else:
logging.error("Checkpoint does not contain a valid state dictionary.")
sys.exit(1)
except Exception as e:
logging.error(f"Error loading checkpoint: {e}", exc_info=True)
sys.exit(1)
def check_data_path(data_root):
"""Check if the data path exists and is valid.
Args:
data_root (str): Path to the data directory.
"""
if not data_root or not os.path.isdir(data_root):
logging.error(f"Data root directory not found or not specified: {data_root}")
sys.exit(1)