345 lines
13 KiB
Python
345 lines
13 KiB
Python
import argparse
|
|
import logging
|
|
import os
|
|
import sys
|
|
import time
|
|
|
|
import torch
|
|
import torch.utils.data
|
|
|
|
# Project specific imports
|
|
from models.detection import get_maskrcnn_model
|
|
from utils.common import (
|
|
check_data_path,
|
|
load_checkpoint,
|
|
load_config,
|
|
setup_environment,
|
|
)
|
|
from utils.data_utils import PennFudanDataset, collate_fn, get_transform
|
|
from utils.eval_utils import evaluate
|
|
from utils.log_utils import setup_logging
|
|
|
|
|
|
def main(args):
|
|
# Load configuration
|
|
config = load_config(args.config)
|
|
|
|
# Setup output directory and get device
|
|
output_path, device = setup_environment(config)
|
|
checkpoint_path = os.path.join(output_path, "checkpoints")
|
|
os.makedirs(checkpoint_path, exist_ok=True)
|
|
|
|
# Setup logging
|
|
setup_logging(output_path, config.get("config_name", "default_run"))
|
|
logging.info("--- Training Script Started ---")
|
|
logging.info(f"Loaded configuration from: {args.config}")
|
|
logging.info(f"Loaded configuration dictionary: {config}")
|
|
logging.info(f"Output will be saved to: {output_path}")
|
|
|
|
# Validate data path
|
|
data_root = config.get("data_root")
|
|
check_data_path(data_root)
|
|
|
|
try:
|
|
# Create the full training dataset instance first
|
|
dataset_full = PennFudanDataset(
|
|
root=data_root, transforms=get_transform(train=True)
|
|
)
|
|
logging.info(f"Full dataset size: {len(dataset_full)}")
|
|
|
|
# Create validation dataset instance with eval transforms
|
|
dataset_val_instance = PennFudanDataset(
|
|
root=data_root, transforms=get_transform(train=False)
|
|
)
|
|
|
|
# Split the dataset indices
|
|
torch.manual_seed(
|
|
config.get("seed", 42)
|
|
) # Use the same seed for consistent splits
|
|
indices = torch.randperm(len(dataset_full)).tolist()
|
|
val_split_ratio = config.get(
|
|
"val_split_ratio", 0.1
|
|
) # Default to 10% validation
|
|
val_split_count = int(val_split_ratio * len(dataset_full))
|
|
if val_split_count == 0 and len(dataset_full) > 0:
|
|
logging.warning(
|
|
f"Validation split resulted in 0 samples (ratio={val_split_ratio}, total={len(dataset_full)}). Using 1 sample for validation."
|
|
)
|
|
val_split_count = 1
|
|
elif val_split_count >= len(dataset_full):
|
|
logging.error(
|
|
f"Validation split ratio ({val_split_ratio}) too high, results in no training samples."
|
|
)
|
|
sys.exit(1)
|
|
|
|
train_indices = indices[:-val_split_count]
|
|
val_indices = indices[-val_split_count:]
|
|
|
|
# Create Subset datasets
|
|
dataset_train = torch.utils.data.Subset(dataset_full, train_indices)
|
|
dataset_val = torch.utils.data.Subset(dataset_val_instance, val_indices)
|
|
|
|
logging.info(
|
|
f"Using {len(train_indices)} samples for training and {len(val_indices)} for validation."
|
|
)
|
|
|
|
# Create DataLoaders
|
|
data_loader_train = torch.utils.data.DataLoader(
|
|
dataset_train,
|
|
batch_size=config.get("batch_size", 2),
|
|
# Shuffle should be true for the training subset loader
|
|
shuffle=True,
|
|
num_workers=config.get("num_workers", 4),
|
|
collate_fn=collate_fn,
|
|
pin_memory=config.get("pin_memory", True),
|
|
)
|
|
data_loader_val = torch.utils.data.DataLoader(
|
|
dataset_val,
|
|
batch_size=config.get(
|
|
"batch_size", 2
|
|
), # Often use same or larger batch size for validation
|
|
shuffle=False, # No need to shuffle validation data
|
|
num_workers=config.get("num_workers", 4),
|
|
collate_fn=collate_fn,
|
|
pin_memory=config.get("pin_memory", True),
|
|
)
|
|
|
|
logging.info(
|
|
f"Training dataloader configured. Est. batches: {len(data_loader_train)}"
|
|
)
|
|
logging.info(
|
|
f"Validation dataloader configured. Est. batches: {len(data_loader_val)}"
|
|
)
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error setting up dataset/dataloader: {e}", exc_info=True)
|
|
sys.exit(1)
|
|
|
|
# Create model
|
|
num_classes = config.get("num_classes")
|
|
if num_classes is None:
|
|
logging.error("'num_classes' not specified in configuration.")
|
|
sys.exit(1)
|
|
|
|
try:
|
|
model = get_maskrcnn_model(
|
|
num_classes=num_classes,
|
|
pretrained=config.get("pretrained", True),
|
|
pretrained_backbone=config.get("pretrained_backbone", True),
|
|
)
|
|
model.to(device)
|
|
logging.info("Model loaded successfully.")
|
|
except Exception as e:
|
|
logging.error(f"Error creating model: {e}", exc_info=True)
|
|
sys.exit(1)
|
|
|
|
# Create optimizer and learning rate scheduler
|
|
optimizer = torch.optim.SGD(
|
|
model.parameters(),
|
|
lr=config.get("lr", 0.005),
|
|
momentum=config.get("momentum", 0.9),
|
|
weight_decay=config.get("weight_decay", 0.0005),
|
|
)
|
|
|
|
lr_scheduler = torch.optim.lr_scheduler.StepLR(
|
|
optimizer,
|
|
step_size=config.get("lr_step_size", 3),
|
|
gamma=config.get("lr_gamma", 0.1),
|
|
)
|
|
|
|
# --- Resume from Checkpoint (if specified) ---
|
|
start_epoch = 0
|
|
if args.resume:
|
|
try:
|
|
# Find latest checkpoint
|
|
checkpoints = [f for f in os.listdir(checkpoint_path) if f.endswith(".pth")]
|
|
if not checkpoints:
|
|
logging.warning(
|
|
f"No checkpoints found in {checkpoint_path}, starting from scratch."
|
|
)
|
|
else:
|
|
# Extract epoch numbers from filenames and find the latest
|
|
max_epoch = -1
|
|
latest_checkpoint = None
|
|
for ckpt in checkpoints:
|
|
if ckpt.startswith("checkpoint_epoch_"):
|
|
try:
|
|
epoch_num = int(
|
|
ckpt.replace("checkpoint_epoch_", "").replace(
|
|
".pth", ""
|
|
)
|
|
)
|
|
if epoch_num > max_epoch:
|
|
max_epoch = epoch_num
|
|
latest_checkpoint = ckpt
|
|
except ValueError:
|
|
continue
|
|
|
|
if latest_checkpoint:
|
|
checkpoint_file = os.path.join(checkpoint_path, latest_checkpoint)
|
|
logging.info(f"Resuming from checkpoint: {checkpoint_file}")
|
|
|
|
# Load checkpoint
|
|
checkpoint, start_epoch = load_checkpoint(
|
|
checkpoint_file,
|
|
model,
|
|
device,
|
|
load_optimizer=True,
|
|
optimizer=optimizer,
|
|
load_scheduler=True,
|
|
scheduler=lr_scheduler,
|
|
)
|
|
|
|
logging.info(f"Resuming from epoch {start_epoch}")
|
|
else:
|
|
logging.warning(
|
|
f"No valid checkpoints found in {checkpoint_path}, starting from scratch."
|
|
)
|
|
except Exception as e:
|
|
logging.error(f"Error loading checkpoint: {e}", exc_info=True)
|
|
logging.warning("Starting training from scratch.")
|
|
start_epoch = 0
|
|
|
|
# --- Training Loop ---
|
|
train_time_start = time.time()
|
|
logging.info("--- Starting Training Loop ---")
|
|
|
|
for epoch in range(start_epoch, config.get("num_epochs", 10)):
|
|
# Set model to training mode
|
|
model.train()
|
|
|
|
# Initialize epoch metrics
|
|
epoch_loss = 0.0
|
|
epoch_loss_classifier = 0.0
|
|
epoch_loss_box_reg = 0.0
|
|
epoch_loss_mask = 0.0
|
|
epoch_loss_objectness = 0.0
|
|
epoch_loss_rpn_box_reg = 0.0
|
|
|
|
logging.info(f"--- Epoch {epoch + 1}/{config.get('num_epochs', 10)} ---")
|
|
epoch_start_time = time.time()
|
|
|
|
# Train loop
|
|
for i, (images, targets) in enumerate(data_loader_train):
|
|
# Move data to device
|
|
images = list(image.to(device) for image in images)
|
|
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
|
|
|
# Forward pass
|
|
loss_dict = model(images, targets)
|
|
|
|
# Sum loss components
|
|
losses = sum(loss for loss in loss_dict.values())
|
|
|
|
# Backward and optimize
|
|
optimizer.zero_grad()
|
|
losses.backward()
|
|
optimizer.step()
|
|
|
|
# Log batch results
|
|
loss_value = losses.item()
|
|
epoch_loss += loss_value
|
|
|
|
# Accumulate individual loss components
|
|
if "loss_classifier" in loss_dict:
|
|
epoch_loss_classifier += loss_dict["loss_classifier"].item()
|
|
if "loss_box_reg" in loss_dict:
|
|
epoch_loss_box_reg += loss_dict["loss_box_reg"].item()
|
|
if "loss_mask" in loss_dict:
|
|
epoch_loss_mask += loss_dict["loss_mask"].item()
|
|
if "loss_objectness" in loss_dict:
|
|
epoch_loss_objectness += loss_dict["loss_objectness"].item()
|
|
if "loss_rpn_box_reg" in loss_dict:
|
|
epoch_loss_rpn_box_reg += loss_dict["loss_rpn_box_reg"].item()
|
|
|
|
# Periodic logging
|
|
if (i + 1) % config.get("log_freq", 10) == 0:
|
|
log_str = f"Epoch [{epoch + 1}/{config.get('num_epochs', 10)}], "
|
|
log_str += f"Iter [{i + 1}/{len(data_loader_train)}], "
|
|
log_str += f"Loss: {loss_value:.4f}"
|
|
|
|
# Add per-component losses for richer logging
|
|
comp_log = []
|
|
if "loss_classifier" in loss_dict:
|
|
comp_log.append(f"cls: {loss_dict['loss_classifier'].item():.4f}")
|
|
if "loss_box_reg" in loss_dict:
|
|
comp_log.append(f"box: {loss_dict['loss_box_reg'].item():.4f}")
|
|
if "loss_mask" in loss_dict:
|
|
comp_log.append(f"mask: {loss_dict['loss_mask'].item():.4f}")
|
|
if "loss_objectness" in loss_dict:
|
|
comp_log.append(f"obj: {loss_dict['loss_objectness'].item():.4f}")
|
|
if "loss_rpn_box_reg" in loss_dict:
|
|
comp_log.append(f"rpn: {loss_dict['loss_rpn_box_reg'].item():.4f}")
|
|
|
|
if comp_log:
|
|
log_str += f" [{', '.join(comp_log)}]"
|
|
|
|
logging.info(log_str)
|
|
|
|
# Step learning rate scheduler after each epoch
|
|
lr_scheduler.step()
|
|
|
|
# Calculate and log epoch metrics
|
|
if len(data_loader_train) > 0:
|
|
avg_loss = epoch_loss / len(data_loader_train)
|
|
avg_loss_classifier = epoch_loss_classifier / len(data_loader_train)
|
|
avg_loss_box_reg = epoch_loss_box_reg / len(data_loader_train)
|
|
avg_loss_mask = epoch_loss_mask / len(data_loader_train)
|
|
avg_loss_objectness = epoch_loss_objectness / len(data_loader_train)
|
|
avg_loss_rpn_box_reg = epoch_loss_rpn_box_reg / len(data_loader_train)
|
|
|
|
logging.info(f"Epoch {epoch + 1} - Avg Loss: {avg_loss:.4f}")
|
|
logging.info(f" Classifier Loss: {avg_loss_classifier:.4f}")
|
|
logging.info(f" Box Reg Loss: {avg_loss_box_reg:.4f}")
|
|
logging.info(f" Mask Loss: {avg_loss_mask:.4f}")
|
|
logging.info(f" Objectness Loss: {avg_loss_objectness:.4f}")
|
|
logging.info(f" RPN Box Reg Loss: {avg_loss_rpn_box_reg:.4f}")
|
|
else:
|
|
logging.warning("No training batches were processed in this epoch.")
|
|
|
|
epoch_duration = time.time() - epoch_start_time
|
|
logging.info(f"Epoch duration: {epoch_duration:.2f}s")
|
|
|
|
# --- Validation ---
|
|
logging.info("Running validation...")
|
|
val_metrics = evaluate(model, data_loader_val, device)
|
|
logging.info(f"Validation Loss: {val_metrics['average_loss']:.4f}")
|
|
|
|
# --- Checkpoint Saving ---
|
|
if (epoch + 1) % config.get("checkpoint_freq", 1) == 0 or epoch == config.get(
|
|
"num_epochs", 10
|
|
) - 1:
|
|
checkpoint_file = os.path.join(
|
|
checkpoint_path, f"checkpoint_epoch_{epoch+1}.pth"
|
|
)
|
|
checkpoint = {
|
|
"epoch": epoch + 1,
|
|
"model_state_dict": model.state_dict(),
|
|
"optimizer_state_dict": optimizer.state_dict(),
|
|
"scheduler_state_dict": lr_scheduler.state_dict(),
|
|
"config": config,
|
|
"val_loss": val_metrics["average_loss"],
|
|
}
|
|
try:
|
|
torch.save(checkpoint, checkpoint_file)
|
|
logging.info(f"Checkpoint saved to {checkpoint_file}")
|
|
except Exception as e:
|
|
logging.error(f"Error saving checkpoint: {e}", exc_info=True)
|
|
|
|
# --- Final Metrics and Cleanup ---
|
|
total_training_time = time.time() - train_time_start
|
|
hours, remainder = divmod(total_training_time, 3600)
|
|
minutes, seconds = divmod(remainder, 60)
|
|
logging.info(f"Training completed in {int(hours)}h {int(minutes)}m {seconds:.2f}s")
|
|
logging.info(f"Final model saved to {checkpoint_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Train a Mask R-CNN model")
|
|
parser.add_argument("--config", required=True, help="Path to configuration file")
|
|
parser.add_argument(
|
|
"--resume", action="store_true", help="Resume training from latest checkpoint"
|
|
)
|
|
args = parser.parse_args()
|
|
main(args)
|