Files
torchvision-vibecoding-project/train.py
2025-04-12 10:44:52 +01:00

242 lines
8.5 KiB
Python

import argparse
import importlib.util
import logging
import os
import random
import sys
import numpy as np
import torch
import torch.utils.data
# Project specific imports
from models.detection import get_maskrcnn_model
from utils.data_utils import PennFudanDataset, collate_fn, get_transform
from utils.log_utils import setup_logging
def main(args):
# --- Configuration Loading ---
try:
config_path = os.path.abspath(args.config)
if not os.path.exists(config_path):
print(f"Error: Config file not found at {config_path}")
sys.exit(1)
# Derive module path from file path relative to workspace root
workspace_root = os.path.abspath(
os.getcwd()
) # Assuming script is run from root
relative_path = os.path.relpath(config_path, workspace_root)
if relative_path.startswith(".."):
print(f"Error: Config file {args.config} is outside the project directory.")
sys.exit(1)
module_path_no_ext, _ = os.path.splitext(relative_path)
module_path_str = module_path_no_ext.replace(os.sep, ".")
print(f"Attempting to import config module: {module_path_str}")
config_module = importlib.import_module(module_path_str)
config = config_module.config
print(
f"Loaded configuration from: {config_path} (via module {module_path_str})"
)
except ImportError as e:
print(f"Error importing config module '{module_path_str}': {e}")
print(
"Ensure the config file path is correct and relative imports within it are valid."
)
import traceback
traceback.print_exc()
sys.exit(1)
except AttributeError as e:
print(
f"Error: Could not find 'config' dictionary in module {module_path_str}. {e}"
)
sys.exit(1)
except Exception as e:
print(f"Error loading configuration file {args.config}: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
# --- Output Directory Setup ---
output_dir = config.get("output_dir", "outputs")
config_name = config.get("config_name", "default_run")
output_path = os.path.join(output_dir, config_name)
checkpoint_path = os.path.join(output_path, "checkpoints")
os.makedirs(output_path, exist_ok=True)
os.makedirs(checkpoint_path, exist_ok=True)
print(f"Output will be saved to: {output_path}")
# --- Logging Setup (Prompt 9) ---
setup_logging(output_path, config_name)
logging.info("--- Training Script Started ---")
logging.info(f"Loaded configuration from: {args.config}")
logging.info(f"Loaded configuration dictionary: {config}")
logging.info(f"Output will be saved to: {output_path}")
# --- Reproducibility ---
seed = config.get("seed", 42)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# Consider adding these for more determinism, but they might impact performance
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
logging.info(f"Set random seed to: {seed}")
# --- Device Setup ---
device_name = config.get("device", "cuda")
if device_name == "cuda" and not torch.cuda.is_available():
logging.warning("CUDA requested but not available, falling back to CPU.")
device_name = "cpu"
device = torch.device(device_name)
logging.info(f"Using device: {device}")
# --- Dataset and DataLoader ---
data_root = config.get("data_root")
if not data_root or not os.path.isdir(data_root):
logging.error(f"Data root directory not found or not specified: {data_root}")
sys.exit(1)
try:
dataset_train = PennFudanDataset(
root=data_root, transforms=get_transform(train=True)
)
# Note: Validation split will be handled later (Prompt 12)
# dataset_val = PennFudanDataset(root=data_root, transforms=get_transform(train=False))
# TODO: Implement data splitting (e.g., using torch.utils.data.Subset)
data_loader_train = torch.utils.data.DataLoader(
dataset_train,
batch_size=config.get("batch_size", 2),
shuffle=True,
num_workers=config.get("num_workers", 4),
collate_fn=collate_fn,
pin_memory=config.get(
"pin_memory", True
), # Often improves GPU transfer speed
)
logging.info(f"Training dataset size: {len(dataset_train)}")
logging.info(
f"Training dataloader configured with batch size {config.get('batch_size', 2)}"
)
# Placeholder for validation loader
# data_loader_val = torch.utils.data.DataLoader(...)
except Exception as e:
logging.error(f"Error setting up dataset/dataloader: {e}", exc_info=True)
sys.exit(1)
# --- Model Instantiation ---
num_classes = config.get("num_classes")
if num_classes is None:
logging.error("'num_classes' not specified in configuration.")
sys.exit(1)
try:
model = get_maskrcnn_model(
num_classes=num_classes,
pretrained=config.get("pretrained", True),
pretrained_backbone=config.get("pretrained_backbone", True),
)
model.to(device)
logging.info("Model loaded successfully.")
except Exception as e:
logging.error(f"Error loading model: {e}", exc_info=True)
sys.exit(1)
# --- Optimizer ---
# Filter parameters that require gradients
params = [p for p in model.parameters() if p.requires_grad]
try:
optimizer = torch.optim.SGD(
params,
lr=config.get("lr", 0.005),
momentum=config.get("momentum", 0.9),
weight_decay=config.get("weight_decay", 0.0005),
)
logging.info(
f"Optimizer SGD configured with lr={config.get('lr', 0.005)}, momentum={config.get('momentum', 0.9)}, weight_decay={config.get('weight_decay', 0.0005)}"
)
except Exception as e:
logging.error(f"Error creating optimizer: {e}", exc_info=True)
sys.exit(1)
# --- LR Scheduler (Placeholder for Prompt 10) ---
# lr_scheduler = torch.optim.lr_scheduler.StepLR(
# optimizer,
# step_size=config.get('lr_step_size', 3),
# gamma=config.get('lr_gamma', 0.1)
# )
# --- Minimal Training Step (Prompt 8 / Updated for Prompt 9) ---
logging.info("--- Starting Minimal Training Step --- ")
model.train() # Set model to training mode
try:
# Fetch one batch
images, targets = next(iter(data_loader_train))
# Move data to the device
images = list(image.to(device) for image in images)
# Targets is a list of dicts. Move each tensor in the dict to the device.
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
# Perform forward pass (model returns loss dict in train mode)
loss_dict = model(images, targets)
# Calculate total loss
losses = sum(loss for loss in loss_dict.values())
loss_value = losses.item() # Get scalar value
# Perform backward pass
optimizer.zero_grad() # Clear previous gradients
losses.backward() # Compute gradients
optimizer.step() # Update weights
# Convert loss_dict tensors to scalar values for logging
loss_dict_log = {k: v.item() for k, v in loss_dict.items()}
logging.info(f"Single step loss dict: {loss_dict_log}")
logging.info(f"Single step total loss: {loss_value:.4f}")
logging.info("--- Minimal Training Step Completed Successfully --- ")
except Exception as e:
logging.error(f"Error during minimal training step: {e}", exc_info=True)
import traceback
# traceback.print_exc() # Already logged with exc_info=True
sys.exit(1)
# Temporarily exit after the single step (as per Prompt 8)
logging.info("Exiting after single training step.")
sys.exit(0)
# --- Full Training Loop (Placeholder for Prompt 10) ---
# print("Basic setup complete. Full training loop implementation pending.")
# ... loop implementation ...
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Train Mask R-CNN on Penn-Fudan dataset."
)
parser.add_argument(
"--config",
type=str,
required=True,
help="Path to the Python configuration file (e.g., configs/pennfudan_maskrcnn_config.py)",
)
args = parser.parse_args()
main(args)