diff --git a/test.py b/test.py index e69de29..e461190 100644 --- a/test.py +++ b/test.py @@ -0,0 +1,109 @@ +import argparse +import logging +import sys + +import torch +import torch.utils.data + +# Project specific imports +from models.detection import get_maskrcnn_model +from utils.common import ( + check_data_path, + load_checkpoint, + load_config, + setup_environment, +) +from utils.data_utils import PennFudanDataset, collate_fn, get_transform +from utils.eval_utils import evaluate +from utils.log_utils import setup_logging + + +def main(args): + # Load configuration + config = load_config(args.config) + + # Setup output directory and get device + output_path, device = setup_environment(config) + + # Setup logging + setup_logging(output_path, f"{config['config_name']}_test") + logging.info("--- Testing Script Started ---") + logging.info(f"Loaded configuration from: {args.config}") + logging.info(f"Checkpoint path: {args.checkpoint}") + logging.info(f"Loaded configuration dictionary: {config}") + + # Validate data path + data_root = config.get("data_root") + check_data_path(data_root) + + try: + # Create the full dataset instance for testing with eval transforms + dataset_test = PennFudanDataset( + root=data_root, transforms=get_transform(train=False) + ) + logging.info(f"Test dataset size: {len(dataset_test)}") + + # Create test DataLoader + data_loader_test = torch.utils.data.DataLoader( + dataset_test, + batch_size=config.get("batch_size", 2), + shuffle=False, # No need to shuffle test data + num_workers=config.get("num_workers", 4), + collate_fn=collate_fn, + pin_memory=config.get("pin_memory", True), + ) + + logging.info( + f"Test dataloader configured. Est. batches: {len(data_loader_test)}" + ) + + except Exception as e: + logging.error(f"Error setting up dataset/dataloader: {e}", exc_info=True) + sys.exit(1) + + # Create model + num_classes = config.get("num_classes") + if num_classes is None: + logging.error("'num_classes' not specified in configuration.") + sys.exit(1) + + try: + # Create the model with the same architecture as in training + model = get_maskrcnn_model( + num_classes=num_classes, + pretrained=False, # Don't need pretrained weights as we'll load checkpoint + pretrained_backbone=False, + ) + + # Load checkpoint + load_checkpoint(args.checkpoint, model, device) + model.to(device) + logging.info("Model loaded and moved to device successfully.") + except Exception as e: + logging.error(f"Error setting up model: {e}", exc_info=True) + sys.exit(1) + + # Run Evaluation + try: + logging.info("Starting model evaluation...") + eval_metrics = evaluate(model, data_loader_test, device) + + # Log detailed metrics + logging.info("--- Evaluation Results ---") + for metric_name, metric_value in eval_metrics.items(): + logging.info(f" {metric_name}: {metric_value:.4f}") + + logging.info("Evaluation completed successfully") + except Exception as e: + logging.error(f"Error during evaluation: {e}", exc_info=True) + sys.exit(1) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Test a trained Mask R-CNN model") + parser.add_argument("--config", required=True, help="Path to configuration file") + parser.add_argument( + "--checkpoint", required=True, help="Path to model checkpoint file (.pth)" + ) + args = parser.parse_args() + main(args) diff --git a/train.py b/train.py index 9bad324..f17a17c 100644 --- a/train.py +++ b/train.py @@ -1,112 +1,44 @@ import argparse -import importlib.util import logging import os -import random import sys -import time # Import time for timing +import time -import numpy as np import torch import torch.utils.data # Project specific imports from models.detection import get_maskrcnn_model +from utils.common import ( + check_data_path, + load_checkpoint, + load_config, + setup_environment, +) from utils.data_utils import PennFudanDataset, collate_fn, get_transform -from utils.eval_utils import evaluate # Import evaluate function +from utils.eval_utils import evaluate from utils.log_utils import setup_logging def main(args): - # --- Configuration Loading --- - try: - config_path = os.path.abspath(args.config) - if not os.path.exists(config_path): - print(f"Error: Config file not found at {config_path}") - sys.exit(1) + # Load configuration + config = load_config(args.config) - # Derive module path from file path relative to workspace root - workspace_root = os.path.abspath( - os.getcwd() - ) # Assuming script is run from root - relative_path = os.path.relpath(config_path, workspace_root) - if relative_path.startswith(".."): - print(f"Error: Config file {args.config} is outside the project directory.") - sys.exit(1) - - module_path_no_ext, _ = os.path.splitext(relative_path) - module_path_str = module_path_no_ext.replace(os.sep, ".") - - print(f"Attempting to import config module: {module_path_str}") - config_module = importlib.import_module(module_path_str) - config = config_module.config - - print( - f"Loaded configuration from: {config_path} (via module {module_path_str})" - ) - - except ImportError as e: - print(f"Error importing config module '{module_path_str}': {e}") - print( - "Ensure the config file path is correct and relative imports within it are valid." - ) - import traceback - - traceback.print_exc() - sys.exit(1) - except AttributeError as e: - print( - f"Error: Could not find 'config' dictionary in module {module_path_str}. {e}" - ) - sys.exit(1) - except Exception as e: - print(f"Error loading configuration file {args.config}: {e}") - import traceback - - traceback.print_exc() - sys.exit(1) - - # --- Output Directory Setup --- - output_dir = config.get("output_dir", "outputs") - config_name = config.get("config_name", "default_run") - output_path = os.path.join(output_dir, config_name) + # Setup output directory and get device + output_path, device = setup_environment(config) checkpoint_path = os.path.join(output_path, "checkpoints") - os.makedirs(output_path, exist_ok=True) os.makedirs(checkpoint_path, exist_ok=True) - print(f"Output will be saved to: {output_path}") - # --- Logging Setup (Prompt 9) --- - setup_logging(output_path, config_name) + # Setup logging + setup_logging(output_path, config.get("config_name", "default_run")) logging.info("--- Training Script Started ---") logging.info(f"Loaded configuration from: {args.config}") logging.info(f"Loaded configuration dictionary: {config}") logging.info(f"Output will be saved to: {output_path}") - # --- Reproducibility --- - seed = config.get("seed", 42) - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) - # Consider adding these for more determinism, but they might impact performance - # torch.backends.cudnn.deterministic = True - # torch.backends.cudnn.benchmark = False - logging.info(f"Set random seed to: {seed}") - - # --- Device Setup --- - device_name = config.get("device", "cuda") - if device_name == "cuda" and not torch.cuda.is_available(): - logging.warning("CUDA requested but not available, falling back to CPU.") - device_name = "cpu" - device = torch.device(device_name) - logging.info(f"Using device: {device}") - - # --- Dataset and DataLoader --- + # Validate data path data_root = config.get("data_root") - if not data_root or not os.path.isdir(data_root): - logging.error(f"Data root directory not found or not specified: {data_root}") - sys.exit(1) + check_data_path(data_root) try: # Create the full training dataset instance first @@ -183,7 +115,7 @@ def main(args): logging.error(f"Error setting up dataset/dataloader: {e}", exc_info=True) sys.exit(1) - # --- Model Instantiation --- + # Create model num_classes = config.get("num_classes") if num_classes is None: logging.error("'num_classes' not specified in configuration.") @@ -198,245 +130,215 @@ def main(args): model.to(device) logging.info("Model loaded successfully.") except Exception as e: - logging.error(f"Error loading model: {e}", exc_info=True) + logging.error(f"Error creating model: {e}", exc_info=True) sys.exit(1) - # --- Optimizer --- - # Filter parameters that require gradients - params = [p for p in model.parameters() if p.requires_grad] - try: - optimizer = torch.optim.SGD( - params, - lr=config.get("lr", 0.005), - momentum=config.get("momentum", 0.9), - weight_decay=config.get("weight_decay", 0.0005), - ) - logging.info( - f"Optimizer SGD configured with lr={config.get('lr', 0.005)}, momentum={config.get('momentum', 0.9)}, weight_decay={config.get('weight_decay', 0.0005)}" - ) - except Exception as e: - logging.error(f"Error creating optimizer: {e}", exc_info=True) - sys.exit(1) + # Create optimizer and learning rate scheduler + optimizer = torch.optim.SGD( + model.parameters(), + lr=config.get("lr", 0.005), + momentum=config.get("momentum", 0.9), + weight_decay=config.get("weight_decay", 0.0005), + ) - # --- LR Scheduler (Prompt 10) --- - try: - lr_scheduler = torch.optim.lr_scheduler.StepLR( - optimizer, - step_size=config.get("lr_step_size", 3), - gamma=config.get("lr_gamma", 0.1), - ) - logging.info( - f"LR scheduler StepLR configured with step_size={config.get('lr_step_size', 3)}, gamma={config.get('lr_gamma', 0.1)}" - ) - except Exception as e: - logging.error(f"Error creating LR scheduler: {e}", exc_info=True) - sys.exit(1) + lr_scheduler = torch.optim.lr_scheduler.StepLR( + optimizer, + step_size=config.get("lr_step_size", 3), + gamma=config.get("lr_gamma", 0.1), + ) - # --- Resume Logic (Prompt 11) --- + # --- Resume from Checkpoint (if specified) --- start_epoch = 0 - latest_checkpoint_path = None - if os.path.isdir(checkpoint_path): - checkpoints = sorted( - [f for f in os.listdir(checkpoint_path) if f.endswith(".pth")] - ) - if checkpoints: # Check if list is not empty - latest_checkpoint_file = checkpoints[ - -1 - ] # Get the last one (assuming naming convention like epoch_N.pth) - latest_checkpoint_path = os.path.join( - checkpoint_path, latest_checkpoint_file - ) - logging.info(f"Found latest checkpoint: {latest_checkpoint_path}") - else: - logging.info("No checkpoints found in directory. Starting from scratch.") - else: - logging.info("Checkpoint directory not found. Starting from scratch.") - - if latest_checkpoint_path: + if args.resume: try: - logging.info(f"Loading checkpoint '{latest_checkpoint_path}'") - # Ensure loading happens on the correct device - checkpoint = torch.load(latest_checkpoint_path, map_location=device) - - # Load model state - handle potential 'module.' prefix if saved with DataParallel - model_state_dict = checkpoint["model_state_dict"] - # Simple check and correction for DataParallel prefix - if all(key.startswith("module.") for key in model_state_dict.keys()): - logging.info("Removing 'module.' prefix from checkpoint keys.") - model_state_dict = { - k.replace("module.", ""): v for k, v in model_state_dict.items() - } - model.load_state_dict(model_state_dict) - - # Load optimizer state - optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) - - # Load LR scheduler state - lr_scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) - - # Load starting epoch (epoch saved is the one *completed*, so start from next) - start_epoch = checkpoint["epoch"] - logging.info(f"Resuming training from epoch {start_epoch + 1}") - - # Optionally load and verify config consistency - # loaded_config = checkpoint.get('config') - # if loaded_config: - # # Perform checks if necessary - # pass - - except Exception as e: - logging.error( - f"Error loading checkpoint: {e}. Starting training from scratch.", - exc_info=True, - ) - start_epoch = 0 # Reset start_epoch if loading fails - - # --- Training Loop (Prompt 10, modified for Prompt 11) --- - logging.info("--- Starting Training Loop --- ") - start_time = time.time() - num_epochs = config.get("num_epochs", 10) - - # Modify loop to start from start_epoch - for epoch in range(start_epoch, num_epochs): - model.train() # Set model to training mode for each epoch - epoch_start_time = time.time() - logging.info(f"--- Epoch {epoch + 1}/{num_epochs} --- ") - - # Variables for tracking epoch progress (optional) - epoch_loss_sum = 0.0 - num_batches = len(data_loader_train) - - for i, (images, targets) in enumerate(data_loader_train): - batch_start_time = time.time() # Optional: time each batch - try: - # Move data to the device - images = list(image.to(device) for image in images) - targets = [{k: v.to(device) for k, v in t.items()} for t in targets] - - # Perform forward pass - loss_dict = model(images, targets) - losses = sum(loss for loss in loss_dict.values()) - loss_value = losses.item() - epoch_loss_sum += loss_value - - # Perform backward pass - optimizer.zero_grad() - losses.backward() - optimizer.step() - - # Log batch loss periodically - if (i + 1) % config.get("log_freq", 10) == 0: - batch_time = time.time() - batch_start_time - # Include individual losses if desired - loss_dict_items = { - k: f"{v.item():.4f}" for k, v in loss_dict.items() - } - logging.info( - f" Epoch {epoch + 1}, Iter {i + 1}/{num_batches}, Loss: {loss_value:.4f}, Batch Time: {batch_time:.2f}s" - ) - logging.debug( - f" Loss Dict: {loss_dict_items}" - ) # Log individual losses at DEBUG level - - except Exception as e: - logging.error( - f"Error during training epoch {epoch+1}, batch {i+1}: {e}", - exc_info=True, + # Find latest checkpoint + checkpoints = [f for f in os.listdir(checkpoint_path) if f.endswith(".pth")] + if not checkpoints: + logging.warning( + f"No checkpoints found in {checkpoint_path}, starting from scratch." ) - # Decide if you want to stop training or continue to next batch/epoch - logging.warning("Skipping rest of epoch due to error.") - break # Exit the inner loop for this epoch + else: + # Extract epoch numbers from filenames and find the latest + max_epoch = -1 + latest_checkpoint = None + for ckpt in checkpoints: + if ckpt.startswith("checkpoint_epoch_"): + try: + epoch_num = int( + ckpt.replace("checkpoint_epoch_", "").replace( + ".pth", "" + ) + ) + if epoch_num > max_epoch: + max_epoch = epoch_num + latest_checkpoint = ckpt + except ValueError: + continue - # --- End of Epoch --- # - # Step the learning rate scheduler + if latest_checkpoint: + checkpoint_file = os.path.join(checkpoint_path, latest_checkpoint) + logging.info(f"Resuming from checkpoint: {checkpoint_file}") + + # Load checkpoint + checkpoint, start_epoch = load_checkpoint( + checkpoint_file, + model, + device, + load_optimizer=True, + optimizer=optimizer, + load_scheduler=True, + scheduler=lr_scheduler, + ) + + logging.info(f"Resuming from epoch {start_epoch}") + else: + logging.warning( + f"No valid checkpoints found in {checkpoint_path}, starting from scratch." + ) + except Exception as e: + logging.error(f"Error loading checkpoint: {e}", exc_info=True) + logging.warning("Starting training from scratch.") + start_epoch = 0 + + # --- Training Loop --- + train_time_start = time.time() + logging.info("--- Starting Training Loop ---") + + for epoch in range(start_epoch, config.get("num_epochs", 10)): + # Set model to training mode + model.train() + + # Initialize epoch metrics + epoch_loss = 0.0 + epoch_loss_classifier = 0.0 + epoch_loss_box_reg = 0.0 + epoch_loss_mask = 0.0 + epoch_loss_objectness = 0.0 + epoch_loss_rpn_box_reg = 0.0 + + logging.info(f"--- Epoch {epoch + 1}/{config.get('num_epochs', 10)} ---") + epoch_start_time = time.time() + + # Train loop + for i, (images, targets) in enumerate(data_loader_train): + # Move data to device + images = list(image.to(device) for image in images) + targets = [{k: v.to(device) for k, v in t.items()} for t in targets] + + # Forward pass + loss_dict = model(images, targets) + + # Sum loss components + losses = sum(loss for loss in loss_dict.values()) + + # Backward and optimize + optimizer.zero_grad() + losses.backward() + optimizer.step() + + # Log batch results + loss_value = losses.item() + epoch_loss += loss_value + + # Accumulate individual loss components + if "loss_classifier" in loss_dict: + epoch_loss_classifier += loss_dict["loss_classifier"].item() + if "loss_box_reg" in loss_dict: + epoch_loss_box_reg += loss_dict["loss_box_reg"].item() + if "loss_mask" in loss_dict: + epoch_loss_mask += loss_dict["loss_mask"].item() + if "loss_objectness" in loss_dict: + epoch_loss_objectness += loss_dict["loss_objectness"].item() + if "loss_rpn_box_reg" in loss_dict: + epoch_loss_rpn_box_reg += loss_dict["loss_rpn_box_reg"].item() + + # Periodic logging + if (i + 1) % config.get("log_freq", 10) == 0: + log_str = f"Epoch [{epoch + 1}/{config.get('num_epochs', 10)}], " + log_str += f"Iter [{i + 1}/{len(data_loader_train)}], " + log_str += f"Loss: {loss_value:.4f}" + + # Add per-component losses for richer logging + comp_log = [] + if "loss_classifier" in loss_dict: + comp_log.append(f"cls: {loss_dict['loss_classifier'].item():.4f}") + if "loss_box_reg" in loss_dict: + comp_log.append(f"box: {loss_dict['loss_box_reg'].item():.4f}") + if "loss_mask" in loss_dict: + comp_log.append(f"mask: {loss_dict['loss_mask'].item():.4f}") + if "loss_objectness" in loss_dict: + comp_log.append(f"obj: {loss_dict['loss_objectness'].item():.4f}") + if "loss_rpn_box_reg" in loss_dict: + comp_log.append(f"rpn: {loss_dict['loss_rpn_box_reg'].item():.4f}") + + if comp_log: + log_str += f" [{', '.join(comp_log)}]" + + logging.info(log_str) + + # Step learning rate scheduler after each epoch lr_scheduler.step() - # Log epoch summary - epoch_end_time = time.time() - epoch_duration = epoch_end_time - epoch_start_time - avg_epoch_loss = epoch_loss_sum / num_batches if num_batches > 0 else 0 - current_lr = optimizer.param_groups[0]["lr"] # Get current learning rate - logging.info(f"--- Epoch {epoch + 1} Summary --- ") - logging.info(f" Average Loss: {avg_epoch_loss:.4f}") - logging.info(f" Learning Rate: {current_lr:.6f}") - logging.info(f" Epoch Duration: {epoch_duration:.2f}s") + # Calculate and log epoch metrics + if len(data_loader_train) > 0: + avg_loss = epoch_loss / len(data_loader_train) + avg_loss_classifier = epoch_loss_classifier / len(data_loader_train) + avg_loss_box_reg = epoch_loss_box_reg / len(data_loader_train) + avg_loss_mask = epoch_loss_mask / len(data_loader_train) + avg_loss_objectness = epoch_loss_objectness / len(data_loader_train) + avg_loss_rpn_box_reg = epoch_loss_rpn_box_reg / len(data_loader_train) - # --- Checkpointing (Prompt 11) --- # - # Save checkpoint periodically or at the end - save_checkpoint = False - if (epoch + 1) % config.get("checkpoint_freq", 1) == 0: - save_checkpoint = True - logging.info(f"Checkpoint frequency met (epoch {epoch + 1})") - elif (epoch + 1) == num_epochs: - save_checkpoint = True - logging.info(f"Final epoch ({epoch + 1}) reached, saving checkpoint.") + logging.info(f"Epoch {epoch + 1} - Avg Loss: {avg_loss:.4f}") + logging.info(f" Classifier Loss: {avg_loss_classifier:.4f}") + logging.info(f" Box Reg Loss: {avg_loss_box_reg:.4f}") + logging.info(f" Mask Loss: {avg_loss_mask:.4f}") + logging.info(f" Objectness Loss: {avg_loss_objectness:.4f}") + logging.info(f" RPN Box Reg Loss: {avg_loss_rpn_box_reg:.4f}") + else: + logging.warning("No training batches were processed in this epoch.") - if save_checkpoint: - checkpoint_filename = f"checkpoint_epoch_{epoch + 1}.pth" - save_path = os.path.join(checkpoint_path, checkpoint_filename) + epoch_duration = time.time() - epoch_start_time + logging.info(f"Epoch duration: {epoch_duration:.2f}s") + + # --- Validation --- + logging.info("Running validation...") + val_metrics = evaluate(model, data_loader_val, device) + logging.info(f"Validation Loss: {val_metrics['average_loss']:.4f}") + + # --- Checkpoint Saving --- + if (epoch + 1) % config.get("checkpoint_freq", 1) == 0 or epoch == config.get( + "num_epochs", 10 + ) - 1: + checkpoint_file = os.path.join( + checkpoint_path, f"checkpoint_epoch_{epoch + 1}.pth" + ) + checkpoint = { + "epoch": epoch + 1, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "scheduler_state_dict": lr_scheduler.state_dict(), + "config": config, + "val_loss": val_metrics["average_loss"], + } try: - checkpoint_data = { - "epoch": epoch + 1, - "model_state_dict": model.state_dict(), - "optimizer_state_dict": optimizer.state_dict(), - "scheduler_state_dict": lr_scheduler.state_dict(), - "config": config, # Save config for reference - } - torch.save(checkpoint_data, save_path) - logging.info(f"Checkpoint saved to {save_path}") + torch.save(checkpoint, checkpoint_file) + logging.info(f"Checkpoint saved to {checkpoint_file}") except Exception as e: - logging.error( - f"Failed to save checkpoint for epoch {epoch + 1} to {save_path}: {e}", - exc_info=True, - ) + logging.error(f"Error saving checkpoint: {e}", exc_info=True) - # --- Evaluation (Prompt 12) --- # - if data_loader_val: - logging.info(f"Starting evaluation for epoch {epoch + 1}...") - try: - val_metrics = evaluate(model, data_loader_val, device) - logging.info(f"Epoch {epoch + 1} Validation Metrics: {val_metrics}") - - # --- Best Model Checkpoint Logic (Optional Add-on) --- - # Add logic here to track the best metric (e.g., val_metrics['average_loss']) - # and save a separate 'best_model.pth' checkpoint if the current epoch is better. - # Example: - # if 'average_loss' in val_metrics: - # current_val_loss = val_metrics['average_loss'] - # if best_val_loss is None or current_val_loss < best_val_loss: - # best_val_loss = current_val_loss - # best_model_path = os.path.join(output_path, 'best_model.pth') - # try: - # # Save only the model state_dict for the best model - # torch.save(model.state_dict(), best_model_path) - # logging.info(f"Saved NEW BEST model checkpoint to {best_model_path} (Val Loss: {best_val_loss:.4f})") - # except Exception as e: - # logging.error(f"Failed to save best model checkpoint: {e}", exc_info=True) - - except Exception as e: - logging.error( - f"Error during evaluation for epoch {epoch + 1}: {e}", exc_info=True - ) - # Decide if this error should stop the entire training process - - # --- End of Training --- # - total_training_time = time.time() - start_time - logging.info("--- Training Finished --- ") - logging.info( - f"Total Training Time: {total_training_time:.2f}s ({total_training_time / 3600:.2f} hours)" - ) + # --- Final Metrics and Cleanup --- + total_training_time = time.time() - train_time_start + hours, remainder = divmod(total_training_time, 3600) + minutes, seconds = divmod(remainder, 60) + logging.info(f"Training completed in {int(hours)}h {int(minutes)}m {seconds:.2f}s") + logging.info(f"Final model saved to {checkpoint_path}") if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Train Mask R-CNN on Penn-Fudan dataset." - ) + parser = argparse.ArgumentParser(description="Train a Mask R-CNN model") + parser.add_argument("--config", required=True, help="Path to configuration file") parser.add_argument( - "--config", - type=str, - required=True, - help="Path to the Python configuration file (e.g., configs/pennfudan_maskrcnn_config.py)", + "--resume", action="store_true", help="Resume training from latest checkpoint" ) - args = parser.parse_args() main(args) diff --git a/utils/common.py b/utils/common.py new file mode 100644 index 0000000..b48a7c7 --- /dev/null +++ b/utils/common.py @@ -0,0 +1,185 @@ +import importlib.util +import logging +import os +import random +import sys + +import numpy as np +import torch + + +def load_config(config_path): + """Load configuration from a Python file. + + Args: + config_path (str): Path to the configuration file. + + Returns: + dict: The loaded configuration dictionary. + """ + try: + config_path = os.path.abspath(config_path) + if not os.path.exists(config_path): + print(f"Error: Config file not found at {config_path}") + sys.exit(1) + + # Derive module path from file path relative to workspace root + workspace_root = os.path.abspath(os.getcwd()) + relative_path = os.path.relpath(config_path, workspace_root) + if relative_path.startswith(".."): + print(f"Error: Config file {config_path} is outside the project directory.") + sys.exit(1) + + module_path_no_ext, _ = os.path.splitext(relative_path) + module_path_str = module_path_no_ext.replace(os.sep, ".") + + print(f"Attempting to import config module: {module_path_str}") + config_module = importlib.import_module(module_path_str) + config = config_module.config + + print( + f"Loaded configuration from: {config_path} (via module {module_path_str})" + ) + return config + + except ImportError as e: + print(f"Error importing config module '{module_path_str}': {e}") + print( + "Ensure the config file path is correct and relative imports within it are valid." + ) + import traceback + + traceback.print_exc() + sys.exit(1) + except AttributeError as e: + print( + f"Error: Could not find 'config' dictionary in module {module_path_str}. {e}" + ) + sys.exit(1) + except Exception as e: + print(f"Error loading configuration file {config_path}: {e}") + import traceback + + traceback.print_exc() + sys.exit(1) + + +def setup_environment(config): + """Set up the environment based on configuration. + + Args: + config (dict): Configuration dictionary. + + Returns: + tuple: (output_path, device) - the output directory path and torch device. + """ + # Setup output directory + output_dir = config.get("output_dir", "outputs") + config_name = config.get("config_name", "default_run") + output_path = os.path.join(output_dir, config_name) + os.makedirs(output_path, exist_ok=True) + + # Set random seeds + seed = config.get("seed", 42) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + logging.info(f"Set random seed to: {seed}") + + # Setup device + device_name = config.get("device", "cuda") + if device_name == "cuda" and not torch.cuda.is_available(): + logging.warning("CUDA requested but not available, falling back to CPU.") + device_name = "cpu" + device = torch.device(device_name) + logging.info(f"Using device: {device}") + + return output_path, device + + +def load_checkpoint( + checkpoint_path, + model, + device, + load_optimizer=False, + optimizer=None, + load_scheduler=False, + scheduler=None, +): + """Load a checkpoint into the model and optionally optimizer and scheduler. + + Args: + checkpoint_path (str): Path to the checkpoint file. + model (torch.nn.Module): The model to load the weights into. + device (torch.device): The device to load the checkpoint on. + load_optimizer (bool): Whether to load optimizer state. + optimizer (torch.optim.Optimizer, optional): The optimizer to load state into. + load_scheduler (bool): Whether to load scheduler state. + scheduler (torch.optim.lr_scheduler._LRScheduler, optional): The scheduler to load state into. + + Returns: + dict: The loaded checkpoint. + int: The starting epoch (checkpoint epoch + 1). + """ + try: + logging.info(f"Loading checkpoint from: {checkpoint_path}") + checkpoint = torch.load(checkpoint_path, map_location=device) + + # Handle potential DataParallel prefix + state_dict = checkpoint.get("model_state_dict", checkpoint) + if isinstance(state_dict, dict): + # Handle case where model was trained with DataParallel + if all(k.startswith("module.") for k in state_dict.keys()): + logging.info( + "Detected DataParallel checkpoint, removing 'module.' prefix" + ) + state_dict = { + k.replace("module.", ""): v for k, v in state_dict.items() + } + + model.load_state_dict(state_dict) + logging.info("Model state loaded successfully") + + # Load optimizer state if requested + if ( + load_optimizer + and optimizer is not None + and "optimizer_state_dict" in checkpoint + ): + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + logging.info("Optimizer state loaded successfully") + + # Load scheduler state if requested + if ( + load_scheduler + and scheduler is not None + and "scheduler_state_dict" in checkpoint + ): + scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) + logging.info("Scheduler state loaded successfully") + + # Get the epoch number + start_epoch = checkpoint.get("epoch", 0) + 1 if load_optimizer else 0 + if "epoch" in checkpoint: + logging.info(f"Loaded checkpoint from epoch: {checkpoint['epoch']}") + + return checkpoint, start_epoch + else: + logging.error("Checkpoint does not contain a valid state dictionary.") + sys.exit(1) + except Exception as e: + logging.error(f"Error loading checkpoint: {e}", exc_info=True) + sys.exit(1) + + +def check_data_path(data_root): + """Check if the data path exists and is valid. + + Args: + data_root (str): Path to the data directory. + """ + if not data_root or not os.path.isdir(data_root): + logging.error(f"Data root directory not found or not specified: {data_root}") + sys.exit(1)