Files
torchvision-vibecoding-project/train.py

345 lines
13 KiB
Python

import argparse
import logging
import os
import sys
import time
import torch
import torch.utils.data
# Project specific imports
from models.detection import get_maskrcnn_model
from utils.common import (
check_data_path,
load_checkpoint,
load_config,
setup_environment,
)
from utils.data_utils import PennFudanDataset, collate_fn, get_transform
from utils.eval_utils import evaluate
from utils.log_utils import setup_logging
def main(args):
# Load configuration
config = load_config(args.config)
# Setup output directory and get device
output_path, device = setup_environment(config)
checkpoint_path = os.path.join(output_path, "checkpoints")
os.makedirs(checkpoint_path, exist_ok=True)
# Setup logging
setup_logging(output_path, config.get("config_name", "default_run"))
logging.info("--- Training Script Started ---")
logging.info(f"Loaded configuration from: {args.config}")
logging.info(f"Loaded configuration dictionary: {config}")
logging.info(f"Output will be saved to: {output_path}")
# Validate data path
data_root = config.get("data_root")
check_data_path(data_root)
try:
# Create the full training dataset instance first
dataset_full = PennFudanDataset(
root=data_root, transforms=get_transform(train=True)
)
logging.info(f"Full dataset size: {len(dataset_full)}")
# Create validation dataset instance with eval transforms
dataset_val_instance = PennFudanDataset(
root=data_root, transforms=get_transform(train=False)
)
# Split the dataset indices
torch.manual_seed(
config.get("seed", 42)
) # Use the same seed for consistent splits
indices = torch.randperm(len(dataset_full)).tolist()
val_split_ratio = config.get(
"val_split_ratio", 0.1
) # Default to 10% validation
val_split_count = int(val_split_ratio * len(dataset_full))
if val_split_count == 0 and len(dataset_full) > 0:
logging.warning(
f"Validation split resulted in 0 samples (ratio={val_split_ratio}, total={len(dataset_full)}). Using 1 sample for validation."
)
val_split_count = 1
elif val_split_count >= len(dataset_full):
logging.error(
f"Validation split ratio ({val_split_ratio}) too high, results in no training samples."
)
sys.exit(1)
train_indices = indices[:-val_split_count]
val_indices = indices[-val_split_count:]
# Create Subset datasets
dataset_train = torch.utils.data.Subset(dataset_full, train_indices)
dataset_val = torch.utils.data.Subset(dataset_val_instance, val_indices)
logging.info(
f"Using {len(train_indices)} samples for training and {len(val_indices)} for validation."
)
# Create DataLoaders
data_loader_train = torch.utils.data.DataLoader(
dataset_train,
batch_size=config.get("batch_size", 2),
# Shuffle should be true for the training subset loader
shuffle=True,
num_workers=config.get("num_workers", 4),
collate_fn=collate_fn,
pin_memory=config.get("pin_memory", True),
)
data_loader_val = torch.utils.data.DataLoader(
dataset_val,
batch_size=config.get(
"batch_size", 2
), # Often use same or larger batch size for validation
shuffle=False, # No need to shuffle validation data
num_workers=config.get("num_workers", 4),
collate_fn=collate_fn,
pin_memory=config.get("pin_memory", True),
)
logging.info(
f"Training dataloader configured. Est. batches: {len(data_loader_train)}"
)
logging.info(
f"Validation dataloader configured. Est. batches: {len(data_loader_val)}"
)
except Exception as e:
logging.error(f"Error setting up dataset/dataloader: {e}", exc_info=True)
sys.exit(1)
# Create model
num_classes = config.get("num_classes")
if num_classes is None:
logging.error("'num_classes' not specified in configuration.")
sys.exit(1)
try:
model = get_maskrcnn_model(
num_classes=num_classes,
pretrained=config.get("pretrained", True),
pretrained_backbone=config.get("pretrained_backbone", True),
)
model.to(device)
logging.info("Model loaded successfully.")
except Exception as e:
logging.error(f"Error creating model: {e}", exc_info=True)
sys.exit(1)
# Create optimizer and learning rate scheduler
optimizer = torch.optim.SGD(
model.parameters(),
lr=config.get("lr", 0.005),
momentum=config.get("momentum", 0.9),
weight_decay=config.get("weight_decay", 0.0005),
)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=config.get("lr_step_size", 3),
gamma=config.get("lr_gamma", 0.1),
)
# --- Resume from Checkpoint (if specified) ---
start_epoch = 0
if args.resume:
try:
# Find latest checkpoint
checkpoints = [f for f in os.listdir(checkpoint_path) if f.endswith(".pth")]
if not checkpoints:
logging.warning(
f"No checkpoints found in {checkpoint_path}, starting from scratch."
)
else:
# Extract epoch numbers from filenames and find the latest
max_epoch = -1
latest_checkpoint = None
for ckpt in checkpoints:
if ckpt.startswith("checkpoint_epoch_"):
try:
epoch_num = int(
ckpt.replace("checkpoint_epoch_", "").replace(
".pth", ""
)
)
if epoch_num > max_epoch:
max_epoch = epoch_num
latest_checkpoint = ckpt
except ValueError:
continue
if latest_checkpoint:
checkpoint_file = os.path.join(checkpoint_path, latest_checkpoint)
logging.info(f"Resuming from checkpoint: {checkpoint_file}")
# Load checkpoint
checkpoint, start_epoch = load_checkpoint(
checkpoint_file,
model,
device,
load_optimizer=True,
optimizer=optimizer,
load_scheduler=True,
scheduler=lr_scheduler,
)
logging.info(f"Resuming from epoch {start_epoch}")
else:
logging.warning(
f"No valid checkpoints found in {checkpoint_path}, starting from scratch."
)
except Exception as e:
logging.error(f"Error loading checkpoint: {e}", exc_info=True)
logging.warning("Starting training from scratch.")
start_epoch = 0
# --- Training Loop ---
train_time_start = time.time()
logging.info("--- Starting Training Loop ---")
for epoch in range(start_epoch, config.get("num_epochs", 10)):
# Set model to training mode
model.train()
# Initialize epoch metrics
epoch_loss = 0.0
epoch_loss_classifier = 0.0
epoch_loss_box_reg = 0.0
epoch_loss_mask = 0.0
epoch_loss_objectness = 0.0
epoch_loss_rpn_box_reg = 0.0
logging.info(f"--- Epoch {epoch + 1}/{config.get('num_epochs', 10)} ---")
epoch_start_time = time.time()
# Train loop
for i, (images, targets) in enumerate(data_loader_train):
# Move data to device
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
# Forward pass
loss_dict = model(images, targets)
# Sum loss components
losses = sum(loss for loss in loss_dict.values())
# Backward and optimize
optimizer.zero_grad()
losses.backward()
optimizer.step()
# Log batch results
loss_value = losses.item()
epoch_loss += loss_value
# Accumulate individual loss components
if "loss_classifier" in loss_dict:
epoch_loss_classifier += loss_dict["loss_classifier"].item()
if "loss_box_reg" in loss_dict:
epoch_loss_box_reg += loss_dict["loss_box_reg"].item()
if "loss_mask" in loss_dict:
epoch_loss_mask += loss_dict["loss_mask"].item()
if "loss_objectness" in loss_dict:
epoch_loss_objectness += loss_dict["loss_objectness"].item()
if "loss_rpn_box_reg" in loss_dict:
epoch_loss_rpn_box_reg += loss_dict["loss_rpn_box_reg"].item()
# Periodic logging
if (i + 1) % config.get("log_freq", 10) == 0:
log_str = f"Epoch [{epoch + 1}/{config.get('num_epochs', 10)}], "
log_str += f"Iter [{i + 1}/{len(data_loader_train)}], "
log_str += f"Loss: {loss_value:.4f}"
# Add per-component losses for richer logging
comp_log = []
if "loss_classifier" in loss_dict:
comp_log.append(f"cls: {loss_dict['loss_classifier'].item():.4f}")
if "loss_box_reg" in loss_dict:
comp_log.append(f"box: {loss_dict['loss_box_reg'].item():.4f}")
if "loss_mask" in loss_dict:
comp_log.append(f"mask: {loss_dict['loss_mask'].item():.4f}")
if "loss_objectness" in loss_dict:
comp_log.append(f"obj: {loss_dict['loss_objectness'].item():.4f}")
if "loss_rpn_box_reg" in loss_dict:
comp_log.append(f"rpn: {loss_dict['loss_rpn_box_reg'].item():.4f}")
if comp_log:
log_str += f" [{', '.join(comp_log)}]"
logging.info(log_str)
# Step learning rate scheduler after each epoch
lr_scheduler.step()
# Calculate and log epoch metrics
if len(data_loader_train) > 0:
avg_loss = epoch_loss / len(data_loader_train)
avg_loss_classifier = epoch_loss_classifier / len(data_loader_train)
avg_loss_box_reg = epoch_loss_box_reg / len(data_loader_train)
avg_loss_mask = epoch_loss_mask / len(data_loader_train)
avg_loss_objectness = epoch_loss_objectness / len(data_loader_train)
avg_loss_rpn_box_reg = epoch_loss_rpn_box_reg / len(data_loader_train)
logging.info(f"Epoch {epoch + 1} - Avg Loss: {avg_loss:.4f}")
logging.info(f" Classifier Loss: {avg_loss_classifier:.4f}")
logging.info(f" Box Reg Loss: {avg_loss_box_reg:.4f}")
logging.info(f" Mask Loss: {avg_loss_mask:.4f}")
logging.info(f" Objectness Loss: {avg_loss_objectness:.4f}")
logging.info(f" RPN Box Reg Loss: {avg_loss_rpn_box_reg:.4f}")
else:
logging.warning("No training batches were processed in this epoch.")
epoch_duration = time.time() - epoch_start_time
logging.info(f"Epoch duration: {epoch_duration:.2f}s")
# --- Validation ---
logging.info("Running validation...")
val_metrics = evaluate(model, data_loader_val, device)
logging.info(f"Validation Loss: {val_metrics['average_loss']:.4f}")
# --- Checkpoint Saving ---
if (epoch + 1) % config.get("checkpoint_freq", 1) == 0 or epoch == config.get(
"num_epochs", 10
) - 1:
checkpoint_file = os.path.join(
checkpoint_path, f"checkpoint_epoch_{epoch+1}.pth"
)
checkpoint = {
"epoch": epoch + 1,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": lr_scheduler.state_dict(),
"config": config,
"val_loss": val_metrics["average_loss"],
}
try:
torch.save(checkpoint, checkpoint_file)
logging.info(f"Checkpoint saved to {checkpoint_file}")
except Exception as e:
logging.error(f"Error saving checkpoint: {e}", exc_info=True)
# --- Final Metrics and Cleanup ---
total_training_time = time.time() - train_time_start
hours, remainder = divmod(total_training_time, 3600)
minutes, seconds = divmod(remainder, 60)
logging.info(f"Training completed in {int(hours)}h {int(minutes)}m {seconds:.2f}s")
logging.info(f"Final model saved to {checkpoint_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train a Mask R-CNN model")
parser.add_argument("--config", required=True, help="Path to configuration file")
parser.add_argument(
"--resume", action="store_true", help="Resume training from latest checkpoint"
)
args = parser.parse_args()
main(args)