From 620c34bf13e75a405b88c624aaeded8c081dd485 Mon Sep 17 00:00:00 2001 From: Craig Date: Sat, 12 Apr 2025 10:44:52 +0100 Subject: [PATCH] Formatting --- models/detection.py | 55 ++++++++++ todo.md | 50 ++++----- train.py | 241 ++++++++++++++++++++++++++++++++++++++++++++ utils/log_utils.py | 44 ++++++++ 4 files changed, 365 insertions(+), 25 deletions(-) create mode 100644 utils/log_utils.py diff --git a/models/detection.py b/models/detection.py index e69de29..eb9f902 100644 --- a/models/detection.py +++ b/models/detection.py @@ -0,0 +1,55 @@ +import torchvision +from torchvision.models import ResNet50_Weights + +# Import weights enums for clarity +from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights +from torchvision.models.detection.faster_rcnn import FastRCNNPredictor +from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor + + +def get_maskrcnn_model(num_classes, pretrained=True, pretrained_backbone=True): + """Loads a Mask R-CNN model with a ResNet-50-FPN backbone. + + Args: + num_classes (int): Number of output classes (including background). + pretrained (bool): If True, loads weights pre-trained on COCO. + pretrained_backbone (bool): If True (and pretrained=False), loads backbone + weights pre-trained on ImageNet. + + Returns: + torchvision.models.detection.MaskRCNN: The modified Mask R-CNN model. + """ + + # Determine weights based on arguments + if pretrained: + weights = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT + weights_backbone = None # Backbone weights are included in MaskRCNN weights + elif pretrained_backbone: + weights = None + weights_backbone = ResNet50_Weights.DEFAULT + else: + weights = None + weights_backbone = None + + # Load the model structure with specified weights + # Use maskrcnn_resnet50_fpn_v2 for compatibility with V2 weights + model = torchvision.models.detection.maskrcnn_resnet50_fpn_v2( + weights=weights, weights_backbone=weights_backbone + ) + + # 1. Replace the box predictor + # Get number of input features for the classifier + in_features_box = model.roi_heads.box_predictor.cls_score.in_features + # Replace the pre-trained head with a new one + model.roi_heads.box_predictor = FastRCNNPredictor(in_features_box, num_classes) + + # 2. Replace the mask predictor + # Get number of input features for the mask classifier + in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels + hidden_layer = 256 # Default value + # Replace the mask predictor with a new one + model.roi_heads.mask_predictor = MaskRCNNPredictor( + in_features_mask, hidden_layer, num_classes + ) + + return model diff --git a/todo.md b/todo.md index 5dc937a..70b705f 100644 --- a/todo.md +++ b/todo.md @@ -26,34 +26,34 @@ This list outlines the steps required to complete the Torchvision Finetuning pro - [x] `__len__`: Return dataset size. - [x] Implement `get_transform(train)` function in `utils/data_utils.py` (using `torchvision.transforms.v2`). - [x] Implement `collate_fn(batch)` function in `utils/data_utils.py`. -- [ ] Implement `get_maskrcnn_model(num_classes, ...)` function in `models/detection.py`. - - [ ] Load pre-trained Mask R-CNN (`maskrcnn_resnet50_fpn_v2`). - - [ ] Replace box predictor head (`FastRCNNPredictor`). - - [ ] Replace mask predictor head (`MaskRCNNPredictor`). +- [x] Implement `get_maskrcnn_model(num_classes, ...)` function in `models/detection.py`. + - [x] Load pre-trained Mask R-CNN (`maskrcnn_resnet50_fpn_v2`). + - [x] Replace box predictor head (`FastRCNNPredictor`). + - [x] Replace mask predictor head (`MaskRCNNPredictor`). ## Phase 3: Training Script & Core Logic -- [ ] Set up basic `train.py` structure. - - [ ] Add imports. - - [ ] Implement `argparse` for `--config` argument. - - [ ] Implement dynamic config loading (`importlib`). - - [ ] Set random seeds. - - [ ] Determine compute device (`cuda` or `cpu`). - - [ ] Create output directory structure (`outputs//checkpoints`). - - [ ] Instantiate `PennFudanDataset` (train). - - [ ] Instantiate `DataLoader` (train) using `collate_fn`. - - [ ] Instantiate model using `get_maskrcnn_model`. - - [ ] Move model to device. - - [ ] Add `if __name__ == "__main__":` guard. -- [ ] Implement minimal training step in `train.py`. - - [ ] Instantiate optimizer (`torch.optim.SGD`). - - [ ] Set `model.train()`. - - [ ] Fetch one batch. - - [ ] Move data to device. - - [ ] Perform forward pass (`loss_dict = model(...)`). - - [ ] Calculate total loss (`sum(...)`). - - [ ] Perform backward pass (`optimizer.zero_grad()`, `loss.backward()`, `optimizer.step()`). - - [ ] Print/log loss for the single step (and temporarily exit). +- [x] Set up basic `train.py` structure. + - [x] Add imports. + - [x] Implement `argparse` for `--config` argument. + - [x] Implement dynamic config loading (`importlib`). + - [x] Set random seeds. + - [x] Determine compute device (`cuda` or `cpu`). + - [x] Create output directory structure (`outputs//checkpoints`). + - [x] Instantiate `PennFudanDataset` (train). + - [x] Instantiate `DataLoader` (train) using `collate_fn`. + - [x] Instantiate model using `get_maskrcnn_model`. + - [x] Move model to device. + - [x] Add `if __name__ == "__main__":` guard. +- [x] Implement minimal training step in `train.py`. + - [x] Instantiate optimizer (`torch.optim.SGD`). + - [x] Set `model.train()`. + - [x] Fetch one batch. + - [x] Move data to device. + - [x] Perform forward pass (`loss_dict = model(...)`). + - [x] Calculate total loss (`sum(...)`). + - [x] Perform backward pass (`optimizer.zero_grad()`, `loss.backward()`, `optimizer.step()`) + - [x] Print/log loss for the single step (and temporarily exit). - [ ] Implement logging setup in `utils/log_utils.py` (`setup_logging` function). - [ ] Configure `logging.basicConfig` for file and console output. - [ ] Integrate logging into `train.py`. diff --git a/train.py b/train.py index e69de29..690d29d 100644 --- a/train.py +++ b/train.py @@ -0,0 +1,241 @@ +import argparse +import importlib.util +import logging +import os +import random +import sys + +import numpy as np +import torch +import torch.utils.data + +# Project specific imports +from models.detection import get_maskrcnn_model +from utils.data_utils import PennFudanDataset, collate_fn, get_transform +from utils.log_utils import setup_logging + + +def main(args): + # --- Configuration Loading --- + try: + config_path = os.path.abspath(args.config) + if not os.path.exists(config_path): + print(f"Error: Config file not found at {config_path}") + sys.exit(1) + + # Derive module path from file path relative to workspace root + workspace_root = os.path.abspath( + os.getcwd() + ) # Assuming script is run from root + relative_path = os.path.relpath(config_path, workspace_root) + if relative_path.startswith(".."): + print(f"Error: Config file {args.config} is outside the project directory.") + sys.exit(1) + + module_path_no_ext, _ = os.path.splitext(relative_path) + module_path_str = module_path_no_ext.replace(os.sep, ".") + + print(f"Attempting to import config module: {module_path_str}") + config_module = importlib.import_module(module_path_str) + config = config_module.config + + print( + f"Loaded configuration from: {config_path} (via module {module_path_str})" + ) + + except ImportError as e: + print(f"Error importing config module '{module_path_str}': {e}") + print( + "Ensure the config file path is correct and relative imports within it are valid." + ) + import traceback + + traceback.print_exc() + sys.exit(1) + except AttributeError as e: + print( + f"Error: Could not find 'config' dictionary in module {module_path_str}. {e}" + ) + sys.exit(1) + except Exception as e: + print(f"Error loading configuration file {args.config}: {e}") + import traceback + + traceback.print_exc() + sys.exit(1) + + # --- Output Directory Setup --- + output_dir = config.get("output_dir", "outputs") + config_name = config.get("config_name", "default_run") + output_path = os.path.join(output_dir, config_name) + checkpoint_path = os.path.join(output_path, "checkpoints") + os.makedirs(output_path, exist_ok=True) + os.makedirs(checkpoint_path, exist_ok=True) + print(f"Output will be saved to: {output_path}") + + # --- Logging Setup (Prompt 9) --- + setup_logging(output_path, config_name) + logging.info("--- Training Script Started ---") + logging.info(f"Loaded configuration from: {args.config}") + logging.info(f"Loaded configuration dictionary: {config}") + logging.info(f"Output will be saved to: {output_path}") + + # --- Reproducibility --- + seed = config.get("seed", 42) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + # Consider adding these for more determinism, but they might impact performance + # torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.benchmark = False + logging.info(f"Set random seed to: {seed}") + + # --- Device Setup --- + device_name = config.get("device", "cuda") + if device_name == "cuda" and not torch.cuda.is_available(): + logging.warning("CUDA requested but not available, falling back to CPU.") + device_name = "cpu" + device = torch.device(device_name) + logging.info(f"Using device: {device}") + + # --- Dataset and DataLoader --- + data_root = config.get("data_root") + if not data_root or not os.path.isdir(data_root): + logging.error(f"Data root directory not found or not specified: {data_root}") + sys.exit(1) + + try: + dataset_train = PennFudanDataset( + root=data_root, transforms=get_transform(train=True) + ) + # Note: Validation split will be handled later (Prompt 12) + # dataset_val = PennFudanDataset(root=data_root, transforms=get_transform(train=False)) + + # TODO: Implement data splitting (e.g., using torch.utils.data.Subset) + + data_loader_train = torch.utils.data.DataLoader( + dataset_train, + batch_size=config.get("batch_size", 2), + shuffle=True, + num_workers=config.get("num_workers", 4), + collate_fn=collate_fn, + pin_memory=config.get( + "pin_memory", True + ), # Often improves GPU transfer speed + ) + logging.info(f"Training dataset size: {len(dataset_train)}") + logging.info( + f"Training dataloader configured with batch size {config.get('batch_size', 2)}" + ) + + # Placeholder for validation loader + # data_loader_val = torch.utils.data.DataLoader(...) + + except Exception as e: + logging.error(f"Error setting up dataset/dataloader: {e}", exc_info=True) + sys.exit(1) + + # --- Model Instantiation --- + num_classes = config.get("num_classes") + if num_classes is None: + logging.error("'num_classes' not specified in configuration.") + sys.exit(1) + + try: + model = get_maskrcnn_model( + num_classes=num_classes, + pretrained=config.get("pretrained", True), + pretrained_backbone=config.get("pretrained_backbone", True), + ) + model.to(device) + logging.info("Model loaded successfully.") + except Exception as e: + logging.error(f"Error loading model: {e}", exc_info=True) + sys.exit(1) + + # --- Optimizer --- + # Filter parameters that require gradients + params = [p for p in model.parameters() if p.requires_grad] + try: + optimizer = torch.optim.SGD( + params, + lr=config.get("lr", 0.005), + momentum=config.get("momentum", 0.9), + weight_decay=config.get("weight_decay", 0.0005), + ) + logging.info( + f"Optimizer SGD configured with lr={config.get('lr', 0.005)}, momentum={config.get('momentum', 0.9)}, weight_decay={config.get('weight_decay', 0.0005)}" + ) + except Exception as e: + logging.error(f"Error creating optimizer: {e}", exc_info=True) + sys.exit(1) + + # --- LR Scheduler (Placeholder for Prompt 10) --- + # lr_scheduler = torch.optim.lr_scheduler.StepLR( + # optimizer, + # step_size=config.get('lr_step_size', 3), + # gamma=config.get('lr_gamma', 0.1) + # ) + + # --- Minimal Training Step (Prompt 8 / Updated for Prompt 9) --- + logging.info("--- Starting Minimal Training Step --- ") + model.train() # Set model to training mode + + try: + # Fetch one batch + images, targets = next(iter(data_loader_train)) + + # Move data to the device + images = list(image.to(device) for image in images) + # Targets is a list of dicts. Move each tensor in the dict to the device. + targets = [{k: v.to(device) for k, v in t.items()} for t in targets] + + # Perform forward pass (model returns loss dict in train mode) + loss_dict = model(images, targets) + + # Calculate total loss + losses = sum(loss for loss in loss_dict.values()) + loss_value = losses.item() # Get scalar value + + # Perform backward pass + optimizer.zero_grad() # Clear previous gradients + losses.backward() # Compute gradients + optimizer.step() # Update weights + + # Convert loss_dict tensors to scalar values for logging + loss_dict_log = {k: v.item() for k, v in loss_dict.items()} + logging.info(f"Single step loss dict: {loss_dict_log}") + logging.info(f"Single step total loss: {loss_value:.4f}") + logging.info("--- Minimal Training Step Completed Successfully --- ") + + except Exception as e: + logging.error(f"Error during minimal training step: {e}", exc_info=True) + import traceback + + # traceback.print_exc() # Already logged with exc_info=True + sys.exit(1) + + # Temporarily exit after the single step (as per Prompt 8) + logging.info("Exiting after single training step.") + sys.exit(0) + + # --- Full Training Loop (Placeholder for Prompt 10) --- + # print("Basic setup complete. Full training loop implementation pending.") + # ... loop implementation ... + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Train Mask R-CNN on Penn-Fudan dataset." + ) + parser.add_argument( + "--config", + type=str, + required=True, + help="Path to the Python configuration file (e.g., configs/pennfudan_maskrcnn_config.py)", + ) + + args = parser.parse_args() + main(args) diff --git a/utils/log_utils.py b/utils/log_utils.py new file mode 100644 index 0000000..b70ee6e --- /dev/null +++ b/utils/log_utils.py @@ -0,0 +1,44 @@ +import logging +import os +import sys + + +def setup_logging(log_dir, config_name): + """Configures logging to output to both file and console. + + Args: + log_dir (str): The directory where the log file should be saved. + config_name (str): The name of the configuration run, used for the log filename. + """ + # Ensure log directory exists + os.makedirs(log_dir, exist_ok=True) + + log_filename = f"{config_name}_train.log" + log_filepath = os.path.join(log_dir, log_filename) + + # Configure the root logger + logging.basicConfig( + level=logging.INFO, # Log INFO level and above (INFO, WARNING, ERROR, CRITICAL) + format="%(asctime)s [%(levelname)s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + handlers=[ + logging.FileHandler(log_filepath), # Log to a file + logging.StreamHandler(sys.stdout), # Log to the console (stdout) + ], + # Force=True ensures that if basicConfig was called before (e.g., by a library), + # this configuration will overwrite it. Use with caution if libraries might + # configure logging themselves in complex ways. + force=True, + ) + + logging.info(f"Logging configured. Log file: {log_filepath}") + + +# Example usage (can be removed or kept for testing): +if __name__ == "__main__": + print("Testing logging setup...") + setup_logging("temp_logs", "test_config") + logging.info("This is an info message.") + logging.warning("This is a warning message.") + logging.error("This is an error message.") + print("Check 'temp_logs/test_config_train.log' and console output.")