292 lines
11 KiB
Python
292 lines
11 KiB
Python
import argparse
|
|
import importlib.util
|
|
import logging
|
|
import os
|
|
import random
|
|
import sys
|
|
import time # Import time for timing
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.utils.data
|
|
|
|
# Project specific imports
|
|
from models.detection import get_maskrcnn_model
|
|
from utils.data_utils import PennFudanDataset, collate_fn, get_transform
|
|
from utils.log_utils import setup_logging
|
|
|
|
|
|
def main(args):
|
|
# --- Configuration Loading ---
|
|
try:
|
|
config_path = os.path.abspath(args.config)
|
|
if not os.path.exists(config_path):
|
|
print(f"Error: Config file not found at {config_path}")
|
|
sys.exit(1)
|
|
|
|
# Derive module path from file path relative to workspace root
|
|
workspace_root = os.path.abspath(
|
|
os.getcwd()
|
|
) # Assuming script is run from root
|
|
relative_path = os.path.relpath(config_path, workspace_root)
|
|
if relative_path.startswith(".."):
|
|
print(f"Error: Config file {args.config} is outside the project directory.")
|
|
sys.exit(1)
|
|
|
|
module_path_no_ext, _ = os.path.splitext(relative_path)
|
|
module_path_str = module_path_no_ext.replace(os.sep, ".")
|
|
|
|
print(f"Attempting to import config module: {module_path_str}")
|
|
config_module = importlib.import_module(module_path_str)
|
|
config = config_module.config
|
|
|
|
print(
|
|
f"Loaded configuration from: {config_path} (via module {module_path_str})"
|
|
)
|
|
|
|
except ImportError as e:
|
|
print(f"Error importing config module '{module_path_str}': {e}")
|
|
print(
|
|
"Ensure the config file path is correct and relative imports within it are valid."
|
|
)
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
sys.exit(1)
|
|
except AttributeError as e:
|
|
print(
|
|
f"Error: Could not find 'config' dictionary in module {module_path_str}. {e}"
|
|
)
|
|
sys.exit(1)
|
|
except Exception as e:
|
|
print(f"Error loading configuration file {args.config}: {e}")
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
sys.exit(1)
|
|
|
|
# --- Output Directory Setup ---
|
|
output_dir = config.get("output_dir", "outputs")
|
|
config_name = config.get("config_name", "default_run")
|
|
output_path = os.path.join(output_dir, config_name)
|
|
checkpoint_path = os.path.join(output_path, "checkpoints")
|
|
os.makedirs(output_path, exist_ok=True)
|
|
os.makedirs(checkpoint_path, exist_ok=True)
|
|
print(f"Output will be saved to: {output_path}")
|
|
|
|
# --- Logging Setup (Prompt 9) ---
|
|
setup_logging(output_path, config_name)
|
|
logging.info("--- Training Script Started ---")
|
|
logging.info(f"Loaded configuration from: {args.config}")
|
|
logging.info(f"Loaded configuration dictionary: {config}")
|
|
logging.info(f"Output will be saved to: {output_path}")
|
|
|
|
# --- Reproducibility ---
|
|
seed = config.get("seed", 42)
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(seed)
|
|
# Consider adding these for more determinism, but they might impact performance
|
|
# torch.backends.cudnn.deterministic = True
|
|
# torch.backends.cudnn.benchmark = False
|
|
logging.info(f"Set random seed to: {seed}")
|
|
|
|
# --- Device Setup ---
|
|
device_name = config.get("device", "cuda")
|
|
if device_name == "cuda" and not torch.cuda.is_available():
|
|
logging.warning("CUDA requested but not available, falling back to CPU.")
|
|
device_name = "cpu"
|
|
device = torch.device(device_name)
|
|
logging.info(f"Using device: {device}")
|
|
|
|
# --- Dataset and DataLoader ---
|
|
data_root = config.get("data_root")
|
|
if not data_root or not os.path.isdir(data_root):
|
|
logging.error(f"Data root directory not found or not specified: {data_root}")
|
|
sys.exit(1)
|
|
|
|
try:
|
|
dataset_train = PennFudanDataset(
|
|
root=data_root, transforms=get_transform(train=True)
|
|
)
|
|
# Note: Validation split will be handled later (Prompt 12)
|
|
# dataset_val = PennFudanDataset(root=data_root, transforms=get_transform(train=False))
|
|
|
|
# TODO: Implement data splitting (e.g., using torch.utils.data.Subset)
|
|
|
|
data_loader_train = torch.utils.data.DataLoader(
|
|
dataset_train,
|
|
batch_size=config.get("batch_size", 2),
|
|
shuffle=True,
|
|
num_workers=config.get("num_workers", 4),
|
|
collate_fn=collate_fn,
|
|
pin_memory=config.get(
|
|
"pin_memory", True
|
|
), # Often improves GPU transfer speed
|
|
)
|
|
logging.info(f"Training dataset size: {len(dataset_train)}")
|
|
logging.info(
|
|
f"Training dataloader configured with batch size {config.get('batch_size', 2)}"
|
|
)
|
|
|
|
# Placeholder for validation loader
|
|
# data_loader_val = torch.utils.data.DataLoader(...)
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error setting up dataset/dataloader: {e}", exc_info=True)
|
|
sys.exit(1)
|
|
|
|
# --- Model Instantiation ---
|
|
num_classes = config.get("num_classes")
|
|
if num_classes is None:
|
|
logging.error("'num_classes' not specified in configuration.")
|
|
sys.exit(1)
|
|
|
|
try:
|
|
model = get_maskrcnn_model(
|
|
num_classes=num_classes,
|
|
pretrained=config.get("pretrained", True),
|
|
pretrained_backbone=config.get("pretrained_backbone", True),
|
|
)
|
|
model.to(device)
|
|
logging.info("Model loaded successfully.")
|
|
except Exception as e:
|
|
logging.error(f"Error loading model: {e}", exc_info=True)
|
|
sys.exit(1)
|
|
|
|
# --- Optimizer ---
|
|
# Filter parameters that require gradients
|
|
params = [p for p in model.parameters() if p.requires_grad]
|
|
try:
|
|
optimizer = torch.optim.SGD(
|
|
params,
|
|
lr=config.get("lr", 0.005),
|
|
momentum=config.get("momentum", 0.9),
|
|
weight_decay=config.get("weight_decay", 0.0005),
|
|
)
|
|
logging.info(
|
|
f"Optimizer SGD configured with lr={config.get('lr', 0.005)}, momentum={config.get('momentum', 0.9)}, weight_decay={config.get('weight_decay', 0.0005)}"
|
|
)
|
|
except Exception as e:
|
|
logging.error(f"Error creating optimizer: {e}", exc_info=True)
|
|
sys.exit(1)
|
|
|
|
# --- LR Scheduler (Prompt 10) ---
|
|
try:
|
|
lr_scheduler = torch.optim.lr_scheduler.StepLR(
|
|
optimizer,
|
|
step_size=config.get("lr_step_size", 3),
|
|
gamma=config.get("lr_gamma", 0.1),
|
|
)
|
|
logging.info(
|
|
f"LR scheduler StepLR configured with step_size={config.get('lr_step_size', 3)}, gamma={config.get('lr_gamma', 0.1)}"
|
|
)
|
|
except Exception as e:
|
|
logging.error(f"Error creating LR scheduler: {e}", exc_info=True)
|
|
sys.exit(1)
|
|
|
|
# --- Training Loop (Prompt 10) ---
|
|
logging.info("--- Starting Training Loop --- ")
|
|
start_time = time.time()
|
|
num_epochs = config.get("num_epochs", 10)
|
|
|
|
for epoch in range(num_epochs):
|
|
model.train() # Set model to training mode for each epoch
|
|
epoch_start_time = time.time()
|
|
logging.info(f"--- Epoch {epoch + 1}/{num_epochs} --- ")
|
|
|
|
# Variables for tracking epoch progress (optional)
|
|
epoch_loss_sum = 0.0
|
|
num_batches = len(data_loader_train)
|
|
|
|
for i, (images, targets) in enumerate(data_loader_train):
|
|
batch_start_time = time.time() # Optional: time each batch
|
|
try:
|
|
# Move data to the device
|
|
images = list(image.to(device) for image in images)
|
|
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
|
|
|
# Perform forward pass
|
|
loss_dict = model(images, targets)
|
|
losses = sum(loss for loss in loss_dict.values())
|
|
loss_value = losses.item()
|
|
epoch_loss_sum += loss_value
|
|
|
|
# Perform backward pass
|
|
optimizer.zero_grad()
|
|
losses.backward()
|
|
optimizer.step()
|
|
|
|
# Log batch loss periodically
|
|
if (i + 1) % config.get("log_freq", 10) == 0:
|
|
batch_time = time.time() - batch_start_time
|
|
# Include individual losses if desired
|
|
loss_dict_items = {
|
|
k: f"{v.item():.4f}" for k, v in loss_dict.items()
|
|
}
|
|
logging.info(
|
|
f" Epoch {epoch + 1}, Iter {i + 1}/{num_batches}, Loss: {loss_value:.4f}, Batch Time: {batch_time:.2f}s"
|
|
)
|
|
logging.debug(
|
|
f" Loss Dict: {loss_dict_items}"
|
|
) # Log individual losses at DEBUG level
|
|
|
|
except Exception as e:
|
|
logging.error(
|
|
f"Error during training epoch {epoch+1}, batch {i+1}: {e}",
|
|
exc_info=True,
|
|
)
|
|
# Decide if you want to stop training or continue to next batch/epoch
|
|
logging.warning("Skipping rest of epoch due to error.")
|
|
break # Exit the inner loop for this epoch
|
|
|
|
# --- End of Epoch --- #
|
|
# Step the learning rate scheduler
|
|
lr_scheduler.step()
|
|
|
|
# Log epoch summary
|
|
epoch_end_time = time.time()
|
|
epoch_duration = epoch_end_time - epoch_start_time
|
|
avg_epoch_loss = epoch_loss_sum / num_batches if num_batches > 0 else 0
|
|
current_lr = optimizer.param_groups[0]["lr"] # Get current learning rate
|
|
logging.info(f"--- Epoch {epoch + 1} Summary --- ")
|
|
logging.info(f" Average Loss: {avg_epoch_loss:.4f}")
|
|
logging.info(f" Learning Rate: {current_lr:.6f}")
|
|
logging.info(f" Epoch Duration: {epoch_duration:.2f}s")
|
|
|
|
# --- Checkpointing (Placeholder for Prompt 11) --- #
|
|
# Add checkpoint saving logic here, e.g.:
|
|
# if (epoch + 1) % config.get('checkpoint_freq', 1) == 0 or (epoch + 1) == num_epochs:
|
|
# # ... save checkpoint ...
|
|
# logging.info(f"Saved checkpoint for epoch {epoch + 1}")
|
|
|
|
# --- Evaluation (Placeholder for Prompt 12) --- #
|
|
# Add evaluation logic here, e.g.:
|
|
# if data_loader_val:
|
|
# evaluate(model, data_loader_val, device)
|
|
# logging.info(f"Ran evaluation for epoch {epoch + 1}")
|
|
|
|
# --- End of Training --- #
|
|
total_training_time = time.time() - start_time
|
|
logging.info("--- Training Finished --- ")
|
|
logging.info(
|
|
f"Total Training Time: {total_training_time:.2f}s ({total_training_time / 3600:.2f} hours)"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
description="Train Mask R-CNN on Penn-Fudan dataset."
|
|
)
|
|
parser.add_argument(
|
|
"--config",
|
|
type=str,
|
|
required=True,
|
|
help="Path to the Python configuration file (e.g., configs/pennfudan_maskrcnn_config.py)",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
main(args)
|