From 0f3a96ca819fb00bc15101c5d964c96f35478767 Mon Sep 17 00:00:00 2001 From: Craig Date: Sat, 12 Apr 2025 10:55:10 +0100 Subject: [PATCH] Create eval loop and use full train dataset --- todo.md | 52 ++++++------ train.py | 199 ++++++++++++++++++++++++++++++++++++++------ utils/eval_utils.py | 65 +++++++++++++++ 3 files changed, 266 insertions(+), 50 deletions(-) create mode 100644 utils/eval_utils.py diff --git a/todo.md b/todo.md index c69816f..f5ea8bc 100644 --- a/todo.md +++ b/todo.md @@ -60,36 +60,36 @@ This list outlines the steps required to complete the Torchvision Finetuning pro - [x] Call `setup_logging`. - [x] Replace `print` with `logging.info`. - [x] Log config, device, and training progress/losses. -- [ ] Implement full training loop in `train.py`. - - [ ] Remove single-step exit. - - [ ] Add LR scheduler (`torch.optim.lr_scheduler.StepLR`). - - [ ] Add epoch loop. - - [ ] Add batch loop, integrating the single training step logic. - - [ ] Log loss periodically within the batch loop. - - [ ] Step the LR scheduler at the end of each epoch. - - [ ] Log total training time. -- [ ] Implement checkpointing in `train.py`. - - [ ] Define checkpoint directory. - - [ ] Implement logic to find and load the latest checkpoint (resume training). - - [ ] Save checkpoints periodically (based on frequency or final epoch). - - [ ] Include epoch, model state, optimizer state, scheduler state, config. - - [ ] Log checkpoint loading/saving. +- [x] Implement full training loop in `train.py`. + - [x] Remove single-step exit. + - [x] Add LR scheduler (`torch.optim.lr_scheduler.StepLR`). + - [x] Add epoch loop. + - [x] Add batch loop, integrating the single training step logic. + - [x] Log loss periodically within the batch loop. + - [x] Step the LR scheduler at the end of each epoch. + - [x] Log total training time. +- [x] Implement checkpointing in `train.py`. + - [x] Define checkpoint directory. + - [x] Implement logic to find and load the latest checkpoint (resume training). + - [x] Save checkpoints periodically (based on frequency or final epoch). + - [x] Include epoch, model state, optimizer state, scheduler state, config. + - [x] Log checkpoint loading/saving. ## Phase 4: Evaluation & Testing - [ ] Add evaluation dependencies (`pycocotools` - optional initially). -- [ ] Create `utils/eval_utils.py` and implement `evaluate` function. - - [ ] Set `model.eval()`. - - [ ] Use `torch.no_grad()`. - - [ ] Loop through validation/test dataloader. - - [ ] Perform forward pass. - - [ ] Calculate/aggregate metrics (start with average loss, potentially add mAP later). - - [ ] Log evaluation metrics and time. - - [ ] Return metrics. -- [ ] Integrate evaluation into `train.py`. - - [ ] Create validation `Dataset` and `DataLoader` (using `torch.utils.data.Subset`). - - [ ] Call `evaluate` at the end of each epoch. - - [ ] Log validation metrics. +- [x] Create `utils/eval_utils.py` and implement `evaluate` function. + - [x] Set `model.eval()`. + - [x] Use `torch.no_grad()`. + - [x] Loop through validation/test dataloader. + - [x] Perform forward pass. + - [x] Calculate/aggregate metrics (start with average loss, potentially add mAP later). + - [x] Log evaluation metrics and time. + - [x] Return metrics. +- [x] Integrate evaluation into `train.py`. + - [x] Create validation `Dataset` and `DataLoader` (using `torch.utils.data.Subset`). + - [x] Call `evaluate` at the end of each epoch. + - [x] Log validation metrics. - [ ] (Later) Implement logic to save the *best* model based on validation metric. - [ ] Implement `test.py` script. - [ ] Reuse argument parsing, config loading, device setup, dataset/dataloader (test split), model creation from `train.py`. diff --git a/train.py b/train.py index e433b5c..9bad324 100644 --- a/train.py +++ b/train.py @@ -13,6 +13,7 @@ import torch.utils.data # Project specific imports from models.detection import get_maskrcnn_model from utils.data_utils import PennFudanDataset, collate_fn, get_transform +from utils.eval_utils import evaluate # Import evaluate function from utils.log_utils import setup_logging @@ -108,31 +109,75 @@ def main(args): sys.exit(1) try: - dataset_train = PennFudanDataset( + # Create the full training dataset instance first + dataset_full = PennFudanDataset( root=data_root, transforms=get_transform(train=True) ) - # Note: Validation split will be handled later (Prompt 12) - # dataset_val = PennFudanDataset(root=data_root, transforms=get_transform(train=False)) + logging.info(f"Full dataset size: {len(dataset_full)}") - # TODO: Implement data splitting (e.g., using torch.utils.data.Subset) + # Create validation dataset instance with eval transforms + dataset_val_instance = PennFudanDataset( + root=data_root, transforms=get_transform(train=False) + ) + # Split the dataset indices + torch.manual_seed( + config.get("seed", 42) + ) # Use the same seed for consistent splits + indices = torch.randperm(len(dataset_full)).tolist() + val_split_ratio = config.get( + "val_split_ratio", 0.1 + ) # Default to 10% validation + val_split_count = int(val_split_ratio * len(dataset_full)) + if val_split_count == 0 and len(dataset_full) > 0: + logging.warning( + f"Validation split resulted in 0 samples (ratio={val_split_ratio}, total={len(dataset_full)}). Using 1 sample for validation." + ) + val_split_count = 1 + elif val_split_count >= len(dataset_full): + logging.error( + f"Validation split ratio ({val_split_ratio}) too high, results in no training samples." + ) + sys.exit(1) + + train_indices = indices[:-val_split_count] + val_indices = indices[-val_split_count:] + + # Create Subset datasets + dataset_train = torch.utils.data.Subset(dataset_full, train_indices) + dataset_val = torch.utils.data.Subset(dataset_val_instance, val_indices) + + logging.info( + f"Using {len(train_indices)} samples for training and {len(val_indices)} for validation." + ) + + # Create DataLoaders data_loader_train = torch.utils.data.DataLoader( dataset_train, batch_size=config.get("batch_size", 2), + # Shuffle should be true for the training subset loader shuffle=True, num_workers=config.get("num_workers", 4), collate_fn=collate_fn, - pin_memory=config.get( - "pin_memory", True - ), # Often improves GPU transfer speed + pin_memory=config.get("pin_memory", True), ) - logging.info(f"Training dataset size: {len(dataset_train)}") - logging.info( - f"Training dataloader configured with batch size {config.get('batch_size', 2)}" + data_loader_val = torch.utils.data.DataLoader( + dataset_val, + batch_size=config.get( + "batch_size", 2 + ), # Often use same or larger batch size for validation + shuffle=False, # No need to shuffle validation data + num_workers=config.get("num_workers", 4), + collate_fn=collate_fn, + pin_memory=config.get("pin_memory", True), ) - # Placeholder for validation loader - # data_loader_val = torch.utils.data.DataLoader(...) + logging.info( + f"Training dataloader configured. Est. batches: {len(data_loader_train)}" + ) + logging.info( + f"Validation dataloader configured. Est. batches: {len(data_loader_val)}" + ) except Exception as e: logging.error(f"Error setting up dataset/dataloader: {e}", exc_info=True) @@ -187,12 +232,72 @@ def main(args): logging.error(f"Error creating LR scheduler: {e}", exc_info=True) sys.exit(1) - # --- Training Loop (Prompt 10) --- + # --- Resume Logic (Prompt 11) --- + start_epoch = 0 + latest_checkpoint_path = None + if os.path.isdir(checkpoint_path): + checkpoints = sorted( + [f for f in os.listdir(checkpoint_path) if f.endswith(".pth")] + ) + if checkpoints: # Check if list is not empty + latest_checkpoint_file = checkpoints[ + -1 + ] # Get the last one (assuming naming convention like epoch_N.pth) + latest_checkpoint_path = os.path.join( + checkpoint_path, latest_checkpoint_file + ) + logging.info(f"Found latest checkpoint: {latest_checkpoint_path}") + else: + logging.info("No checkpoints found in directory. Starting from scratch.") + else: + logging.info("Checkpoint directory not found. Starting from scratch.") + + if latest_checkpoint_path: + try: + logging.info(f"Loading checkpoint '{latest_checkpoint_path}'") + # Ensure loading happens on the correct device + checkpoint = torch.load(latest_checkpoint_path, map_location=device) + + # Load model state - handle potential 'module.' prefix if saved with DataParallel + model_state_dict = checkpoint["model_state_dict"] + # Simple check and correction for DataParallel prefix + if all(key.startswith("module.") for key in model_state_dict.keys()): + logging.info("Removing 'module.' prefix from checkpoint keys.") + model_state_dict = { + k.replace("module.", ""): v for k, v in model_state_dict.items() + } + model.load_state_dict(model_state_dict) + + # Load optimizer state + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + + # Load LR scheduler state + lr_scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) + + # Load starting epoch (epoch saved is the one *completed*, so start from next) + start_epoch = checkpoint["epoch"] + logging.info(f"Resuming training from epoch {start_epoch + 1}") + + # Optionally load and verify config consistency + # loaded_config = checkpoint.get('config') + # if loaded_config: + # # Perform checks if necessary + # pass + + except Exception as e: + logging.error( + f"Error loading checkpoint: {e}. Starting training from scratch.", + exc_info=True, + ) + start_epoch = 0 # Reset start_epoch if loading fails + + # --- Training Loop (Prompt 10, modified for Prompt 11) --- logging.info("--- Starting Training Loop --- ") start_time = time.time() num_epochs = config.get("num_epochs", 10) - for epoch in range(num_epochs): + # Modify loop to start from start_epoch + for epoch in range(start_epoch, num_epochs): model.train() # Set model to training mode for each epoch epoch_start_time = time.time() logging.info(f"--- Epoch {epoch + 1}/{num_epochs} --- ") @@ -256,17 +361,63 @@ def main(args): logging.info(f" Learning Rate: {current_lr:.6f}") logging.info(f" Epoch Duration: {epoch_duration:.2f}s") - # --- Checkpointing (Placeholder for Prompt 11) --- # - # Add checkpoint saving logic here, e.g.: - # if (epoch + 1) % config.get('checkpoint_freq', 1) == 0 or (epoch + 1) == num_epochs: - # # ... save checkpoint ... - # logging.info(f"Saved checkpoint for epoch {epoch + 1}") + # --- Checkpointing (Prompt 11) --- # + # Save checkpoint periodically or at the end + save_checkpoint = False + if (epoch + 1) % config.get("checkpoint_freq", 1) == 0: + save_checkpoint = True + logging.info(f"Checkpoint frequency met (epoch {epoch + 1})") + elif (epoch + 1) == num_epochs: + save_checkpoint = True + logging.info(f"Final epoch ({epoch + 1}) reached, saving checkpoint.") - # --- Evaluation (Placeholder for Prompt 12) --- # - # Add evaluation logic here, e.g.: - # if data_loader_val: - # evaluate(model, data_loader_val, device) - # logging.info(f"Ran evaluation for epoch {epoch + 1}") + if save_checkpoint: + checkpoint_filename = f"checkpoint_epoch_{epoch + 1}.pth" + save_path = os.path.join(checkpoint_path, checkpoint_filename) + try: + checkpoint_data = { + "epoch": epoch + 1, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "scheduler_state_dict": lr_scheduler.state_dict(), + "config": config, # Save config for reference + } + torch.save(checkpoint_data, save_path) + logging.info(f"Checkpoint saved to {save_path}") + except Exception as e: + logging.error( + f"Failed to save checkpoint for epoch {epoch + 1} to {save_path}: {e}", + exc_info=True, + ) + + # --- Evaluation (Prompt 12) --- # + if data_loader_val: + logging.info(f"Starting evaluation for epoch {epoch + 1}...") + try: + val_metrics = evaluate(model, data_loader_val, device) + logging.info(f"Epoch {epoch + 1} Validation Metrics: {val_metrics}") + + # --- Best Model Checkpoint Logic (Optional Add-on) --- + # Add logic here to track the best metric (e.g., val_metrics['average_loss']) + # and save a separate 'best_model.pth' checkpoint if the current epoch is better. + # Example: + # if 'average_loss' in val_metrics: + # current_val_loss = val_metrics['average_loss'] + # if best_val_loss is None or current_val_loss < best_val_loss: + # best_val_loss = current_val_loss + # best_model_path = os.path.join(output_path, 'best_model.pth') + # try: + # # Save only the model state_dict for the best model + # torch.save(model.state_dict(), best_model_path) + # logging.info(f"Saved NEW BEST model checkpoint to {best_model_path} (Val Loss: {best_val_loss:.4f})") + # except Exception as e: + # logging.error(f"Failed to save best model checkpoint: {e}", exc_info=True) + + except Exception as e: + logging.error( + f"Error during evaluation for epoch {epoch + 1}: {e}", exc_info=True + ) + # Decide if this error should stop the entire training process # --- End of Training --- # total_training_time = time.time() - start_time diff --git a/utils/eval_utils.py b/utils/eval_utils.py new file mode 100644 index 0000000..c882cd5 --- /dev/null +++ b/utils/eval_utils.py @@ -0,0 +1,65 @@ +import logging +import time + +import torch + + +def evaluate(model, data_loader, device): + """Performs evaluation on the dataset for one epoch. + + Args: + model (torch.nn.Module): The model to evaluate. + data_loader (torch.utils.data.DataLoader): DataLoader for the evaluation data. + device (torch.device): The device to run evaluation on. + + Returns: + dict: A dictionary containing evaluation metrics (e.g., average loss). + """ + model.eval() # Set model to evaluation mode + total_loss = 0.0 + num_batches = len(data_loader) + eval_start_time = time.time() + status_interval = max(1, num_batches // 10) # Log status roughly 10 times + + logging.info("--- Starting Evaluation --- ") + + with torch.no_grad(): # Disable gradient calculations + for i, (images, targets) in enumerate(data_loader): + images = list(image.to(device) for image in images) + targets = [{k: v.to(device) for k, v in t.items()} for t in targets] + + # In eval mode with targets, Mask R-CNN should still return losses + # If it returned predictions, logic here would change to process predictions + loss_dict = model(images, targets) + losses = sum(loss for loss in loss_dict.values()) + loss_value = losses.item() + total_loss += loss_value + + if (i + 1) % status_interval == 0: + logging.info(f" Evaluated batch {i + 1}/{num_batches}") + + avg_loss = total_loss / num_batches if num_batches > 0 else 0 + eval_duration = time.time() - eval_start_time + + logging.info("--- Evaluation Finished ---") + logging.info(f" Average Evaluation Loss: {avg_loss:.4f}") + logging.info(f" Evaluation Duration: {eval_duration:.2f}s") + + # Return metrics (currently just average loss) + metrics = {"average_loss": avg_loss} + return metrics + + +# Example usage (can be removed or kept for testing): +if __name__ == "__main__": + # This is a dummy test and requires a model, dataloader, device + print( + "This script contains the evaluate function and cannot be run directly for testing without setup." + ) + # Example: + # device = torch.device('cpu') + # # Create dummy model and dataloader + # model = ... + # data_loader = ... + # metrics = evaluate(model, data_loader, device) + # print(f"Dummy evaluation metrics: {metrics}")