Create test script, and refactor logic from train into common file for usage across both scripts
This commit is contained in:
109
test.py
109
test.py
@@ -0,0 +1,109 @@
|
|||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
|
|
||||||
|
# Project specific imports
|
||||||
|
from models.detection import get_maskrcnn_model
|
||||||
|
from utils.common import (
|
||||||
|
check_data_path,
|
||||||
|
load_checkpoint,
|
||||||
|
load_config,
|
||||||
|
setup_environment,
|
||||||
|
)
|
||||||
|
from utils.data_utils import PennFudanDataset, collate_fn, get_transform
|
||||||
|
from utils.eval_utils import evaluate
|
||||||
|
from utils.log_utils import setup_logging
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
# Load configuration
|
||||||
|
config = load_config(args.config)
|
||||||
|
|
||||||
|
# Setup output directory and get device
|
||||||
|
output_path, device = setup_environment(config)
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
setup_logging(output_path, f"{config['config_name']}_test")
|
||||||
|
logging.info("--- Testing Script Started ---")
|
||||||
|
logging.info(f"Loaded configuration from: {args.config}")
|
||||||
|
logging.info(f"Checkpoint path: {args.checkpoint}")
|
||||||
|
logging.info(f"Loaded configuration dictionary: {config}")
|
||||||
|
|
||||||
|
# Validate data path
|
||||||
|
data_root = config.get("data_root")
|
||||||
|
check_data_path(data_root)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create the full dataset instance for testing with eval transforms
|
||||||
|
dataset_test = PennFudanDataset(
|
||||||
|
root=data_root, transforms=get_transform(train=False)
|
||||||
|
)
|
||||||
|
logging.info(f"Test dataset size: {len(dataset_test)}")
|
||||||
|
|
||||||
|
# Create test DataLoader
|
||||||
|
data_loader_test = torch.utils.data.DataLoader(
|
||||||
|
dataset_test,
|
||||||
|
batch_size=config.get("batch_size", 2),
|
||||||
|
shuffle=False, # No need to shuffle test data
|
||||||
|
num_workers=config.get("num_workers", 4),
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
pin_memory=config.get("pin_memory", True),
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"Test dataloader configured. Est. batches: {len(data_loader_test)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error setting up dataset/dataloader: {e}", exc_info=True)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Create model
|
||||||
|
num_classes = config.get("num_classes")
|
||||||
|
if num_classes is None:
|
||||||
|
logging.error("'num_classes' not specified in configuration.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create the model with the same architecture as in training
|
||||||
|
model = get_maskrcnn_model(
|
||||||
|
num_classes=num_classes,
|
||||||
|
pretrained=False, # Don't need pretrained weights as we'll load checkpoint
|
||||||
|
pretrained_backbone=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load checkpoint
|
||||||
|
load_checkpoint(args.checkpoint, model, device)
|
||||||
|
model.to(device)
|
||||||
|
logging.info("Model loaded and moved to device successfully.")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error setting up model: {e}", exc_info=True)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Run Evaluation
|
||||||
|
try:
|
||||||
|
logging.info("Starting model evaluation...")
|
||||||
|
eval_metrics = evaluate(model, data_loader_test, device)
|
||||||
|
|
||||||
|
# Log detailed metrics
|
||||||
|
logging.info("--- Evaluation Results ---")
|
||||||
|
for metric_name, metric_value in eval_metrics.items():
|
||||||
|
logging.info(f" {metric_name}: {metric_value:.4f}")
|
||||||
|
|
||||||
|
logging.info("Evaluation completed successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error during evaluation: {e}", exc_info=True)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Test a trained Mask R-CNN model")
|
||||||
|
parser.add_argument("--config", required=True, help="Path to configuration file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint", required=True, help="Path to model checkpoint file (.pth)"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
|
|||||||
512
train.py
512
train.py
@@ -1,112 +1,44 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import importlib.util
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
import sys
|
import sys
|
||||||
import time # Import time for timing
|
import time
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
|
|
||||||
# Project specific imports
|
# Project specific imports
|
||||||
from models.detection import get_maskrcnn_model
|
from models.detection import get_maskrcnn_model
|
||||||
|
from utils.common import (
|
||||||
|
check_data_path,
|
||||||
|
load_checkpoint,
|
||||||
|
load_config,
|
||||||
|
setup_environment,
|
||||||
|
)
|
||||||
from utils.data_utils import PennFudanDataset, collate_fn, get_transform
|
from utils.data_utils import PennFudanDataset, collate_fn, get_transform
|
||||||
from utils.eval_utils import evaluate # Import evaluate function
|
from utils.eval_utils import evaluate
|
||||||
from utils.log_utils import setup_logging
|
from utils.log_utils import setup_logging
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
# --- Configuration Loading ---
|
# Load configuration
|
||||||
try:
|
config = load_config(args.config)
|
||||||
config_path = os.path.abspath(args.config)
|
|
||||||
if not os.path.exists(config_path):
|
|
||||||
print(f"Error: Config file not found at {config_path}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# Derive module path from file path relative to workspace root
|
# Setup output directory and get device
|
||||||
workspace_root = os.path.abspath(
|
output_path, device = setup_environment(config)
|
||||||
os.getcwd()
|
|
||||||
) # Assuming script is run from root
|
|
||||||
relative_path = os.path.relpath(config_path, workspace_root)
|
|
||||||
if relative_path.startswith(".."):
|
|
||||||
print(f"Error: Config file {args.config} is outside the project directory.")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
module_path_no_ext, _ = os.path.splitext(relative_path)
|
|
||||||
module_path_str = module_path_no_ext.replace(os.sep, ".")
|
|
||||||
|
|
||||||
print(f"Attempting to import config module: {module_path_str}")
|
|
||||||
config_module = importlib.import_module(module_path_str)
|
|
||||||
config = config_module.config
|
|
||||||
|
|
||||||
print(
|
|
||||||
f"Loaded configuration from: {config_path} (via module {module_path_str})"
|
|
||||||
)
|
|
||||||
|
|
||||||
except ImportError as e:
|
|
||||||
print(f"Error importing config module '{module_path_str}': {e}")
|
|
||||||
print(
|
|
||||||
"Ensure the config file path is correct and relative imports within it are valid."
|
|
||||||
)
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
sys.exit(1)
|
|
||||||
except AttributeError as e:
|
|
||||||
print(
|
|
||||||
f"Error: Could not find 'config' dictionary in module {module_path_str}. {e}"
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error loading configuration file {args.config}: {e}")
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# --- Output Directory Setup ---
|
|
||||||
output_dir = config.get("output_dir", "outputs")
|
|
||||||
config_name = config.get("config_name", "default_run")
|
|
||||||
output_path = os.path.join(output_dir, config_name)
|
|
||||||
checkpoint_path = os.path.join(output_path, "checkpoints")
|
checkpoint_path = os.path.join(output_path, "checkpoints")
|
||||||
os.makedirs(output_path, exist_ok=True)
|
|
||||||
os.makedirs(checkpoint_path, exist_ok=True)
|
os.makedirs(checkpoint_path, exist_ok=True)
|
||||||
print(f"Output will be saved to: {output_path}")
|
|
||||||
|
|
||||||
# --- Logging Setup (Prompt 9) ---
|
# Setup logging
|
||||||
setup_logging(output_path, config_name)
|
setup_logging(output_path, config.get("config_name", "default_run"))
|
||||||
logging.info("--- Training Script Started ---")
|
logging.info("--- Training Script Started ---")
|
||||||
logging.info(f"Loaded configuration from: {args.config}")
|
logging.info(f"Loaded configuration from: {args.config}")
|
||||||
logging.info(f"Loaded configuration dictionary: {config}")
|
logging.info(f"Loaded configuration dictionary: {config}")
|
||||||
logging.info(f"Output will be saved to: {output_path}")
|
logging.info(f"Output will be saved to: {output_path}")
|
||||||
|
|
||||||
# --- Reproducibility ---
|
# Validate data path
|
||||||
seed = config.get("seed", 42)
|
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.manual_seed_all(seed)
|
|
||||||
# Consider adding these for more determinism, but they might impact performance
|
|
||||||
# torch.backends.cudnn.deterministic = True
|
|
||||||
# torch.backends.cudnn.benchmark = False
|
|
||||||
logging.info(f"Set random seed to: {seed}")
|
|
||||||
|
|
||||||
# --- Device Setup ---
|
|
||||||
device_name = config.get("device", "cuda")
|
|
||||||
if device_name == "cuda" and not torch.cuda.is_available():
|
|
||||||
logging.warning("CUDA requested but not available, falling back to CPU.")
|
|
||||||
device_name = "cpu"
|
|
||||||
device = torch.device(device_name)
|
|
||||||
logging.info(f"Using device: {device}")
|
|
||||||
|
|
||||||
# --- Dataset and DataLoader ---
|
|
||||||
data_root = config.get("data_root")
|
data_root = config.get("data_root")
|
||||||
if not data_root or not os.path.isdir(data_root):
|
check_data_path(data_root)
|
||||||
logging.error(f"Data root directory not found or not specified: {data_root}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Create the full training dataset instance first
|
# Create the full training dataset instance first
|
||||||
@@ -183,7 +115,7 @@ def main(args):
|
|||||||
logging.error(f"Error setting up dataset/dataloader: {e}", exc_info=True)
|
logging.error(f"Error setting up dataset/dataloader: {e}", exc_info=True)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# --- Model Instantiation ---
|
# Create model
|
||||||
num_classes = config.get("num_classes")
|
num_classes = config.get("num_classes")
|
||||||
if num_classes is None:
|
if num_classes is None:
|
||||||
logging.error("'num_classes' not specified in configuration.")
|
logging.error("'num_classes' not specified in configuration.")
|
||||||
@@ -198,245 +130,215 @@ def main(args):
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
logging.info("Model loaded successfully.")
|
logging.info("Model loaded successfully.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error loading model: {e}", exc_info=True)
|
logging.error(f"Error creating model: {e}", exc_info=True)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# --- Optimizer ---
|
# Create optimizer and learning rate scheduler
|
||||||
# Filter parameters that require gradients
|
optimizer = torch.optim.SGD(
|
||||||
params = [p for p in model.parameters() if p.requires_grad]
|
model.parameters(),
|
||||||
try:
|
lr=config.get("lr", 0.005),
|
||||||
optimizer = torch.optim.SGD(
|
momentum=config.get("momentum", 0.9),
|
||||||
params,
|
weight_decay=config.get("weight_decay", 0.0005),
|
||||||
lr=config.get("lr", 0.005),
|
)
|
||||||
momentum=config.get("momentum", 0.9),
|
|
||||||
weight_decay=config.get("weight_decay", 0.0005),
|
|
||||||
)
|
|
||||||
logging.info(
|
|
||||||
f"Optimizer SGD configured with lr={config.get('lr', 0.005)}, momentum={config.get('momentum', 0.9)}, weight_decay={config.get('weight_decay', 0.0005)}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Error creating optimizer: {e}", exc_info=True)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# --- LR Scheduler (Prompt 10) ---
|
lr_scheduler = torch.optim.lr_scheduler.StepLR(
|
||||||
try:
|
optimizer,
|
||||||
lr_scheduler = torch.optim.lr_scheduler.StepLR(
|
step_size=config.get("lr_step_size", 3),
|
||||||
optimizer,
|
gamma=config.get("lr_gamma", 0.1),
|
||||||
step_size=config.get("lr_step_size", 3),
|
)
|
||||||
gamma=config.get("lr_gamma", 0.1),
|
|
||||||
)
|
|
||||||
logging.info(
|
|
||||||
f"LR scheduler StepLR configured with step_size={config.get('lr_step_size', 3)}, gamma={config.get('lr_gamma', 0.1)}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Error creating LR scheduler: {e}", exc_info=True)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# --- Resume Logic (Prompt 11) ---
|
# --- Resume from Checkpoint (if specified) ---
|
||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
latest_checkpoint_path = None
|
if args.resume:
|
||||||
if os.path.isdir(checkpoint_path):
|
|
||||||
checkpoints = sorted(
|
|
||||||
[f for f in os.listdir(checkpoint_path) if f.endswith(".pth")]
|
|
||||||
)
|
|
||||||
if checkpoints: # Check if list is not empty
|
|
||||||
latest_checkpoint_file = checkpoints[
|
|
||||||
-1
|
|
||||||
] # Get the last one (assuming naming convention like epoch_N.pth)
|
|
||||||
latest_checkpoint_path = os.path.join(
|
|
||||||
checkpoint_path, latest_checkpoint_file
|
|
||||||
)
|
|
||||||
logging.info(f"Found latest checkpoint: {latest_checkpoint_path}")
|
|
||||||
else:
|
|
||||||
logging.info("No checkpoints found in directory. Starting from scratch.")
|
|
||||||
else:
|
|
||||||
logging.info("Checkpoint directory not found. Starting from scratch.")
|
|
||||||
|
|
||||||
if latest_checkpoint_path:
|
|
||||||
try:
|
try:
|
||||||
logging.info(f"Loading checkpoint '{latest_checkpoint_path}'")
|
# Find latest checkpoint
|
||||||
# Ensure loading happens on the correct device
|
checkpoints = [f for f in os.listdir(checkpoint_path) if f.endswith(".pth")]
|
||||||
checkpoint = torch.load(latest_checkpoint_path, map_location=device)
|
if not checkpoints:
|
||||||
|
logging.warning(
|
||||||
# Load model state - handle potential 'module.' prefix if saved with DataParallel
|
f"No checkpoints found in {checkpoint_path}, starting from scratch."
|
||||||
model_state_dict = checkpoint["model_state_dict"]
|
|
||||||
# Simple check and correction for DataParallel prefix
|
|
||||||
if all(key.startswith("module.") for key in model_state_dict.keys()):
|
|
||||||
logging.info("Removing 'module.' prefix from checkpoint keys.")
|
|
||||||
model_state_dict = {
|
|
||||||
k.replace("module.", ""): v for k, v in model_state_dict.items()
|
|
||||||
}
|
|
||||||
model.load_state_dict(model_state_dict)
|
|
||||||
|
|
||||||
# Load optimizer state
|
|
||||||
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
|
||||||
|
|
||||||
# Load LR scheduler state
|
|
||||||
lr_scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
|
||||||
|
|
||||||
# Load starting epoch (epoch saved is the one *completed*, so start from next)
|
|
||||||
start_epoch = checkpoint["epoch"]
|
|
||||||
logging.info(f"Resuming training from epoch {start_epoch + 1}")
|
|
||||||
|
|
||||||
# Optionally load and verify config consistency
|
|
||||||
# loaded_config = checkpoint.get('config')
|
|
||||||
# if loaded_config:
|
|
||||||
# # Perform checks if necessary
|
|
||||||
# pass
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(
|
|
||||||
f"Error loading checkpoint: {e}. Starting training from scratch.",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
start_epoch = 0 # Reset start_epoch if loading fails
|
|
||||||
|
|
||||||
# --- Training Loop (Prompt 10, modified for Prompt 11) ---
|
|
||||||
logging.info("--- Starting Training Loop --- ")
|
|
||||||
start_time = time.time()
|
|
||||||
num_epochs = config.get("num_epochs", 10)
|
|
||||||
|
|
||||||
# Modify loop to start from start_epoch
|
|
||||||
for epoch in range(start_epoch, num_epochs):
|
|
||||||
model.train() # Set model to training mode for each epoch
|
|
||||||
epoch_start_time = time.time()
|
|
||||||
logging.info(f"--- Epoch {epoch + 1}/{num_epochs} --- ")
|
|
||||||
|
|
||||||
# Variables for tracking epoch progress (optional)
|
|
||||||
epoch_loss_sum = 0.0
|
|
||||||
num_batches = len(data_loader_train)
|
|
||||||
|
|
||||||
for i, (images, targets) in enumerate(data_loader_train):
|
|
||||||
batch_start_time = time.time() # Optional: time each batch
|
|
||||||
try:
|
|
||||||
# Move data to the device
|
|
||||||
images = list(image.to(device) for image in images)
|
|
||||||
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
|
||||||
|
|
||||||
# Perform forward pass
|
|
||||||
loss_dict = model(images, targets)
|
|
||||||
losses = sum(loss for loss in loss_dict.values())
|
|
||||||
loss_value = losses.item()
|
|
||||||
epoch_loss_sum += loss_value
|
|
||||||
|
|
||||||
# Perform backward pass
|
|
||||||
optimizer.zero_grad()
|
|
||||||
losses.backward()
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
# Log batch loss periodically
|
|
||||||
if (i + 1) % config.get("log_freq", 10) == 0:
|
|
||||||
batch_time = time.time() - batch_start_time
|
|
||||||
# Include individual losses if desired
|
|
||||||
loss_dict_items = {
|
|
||||||
k: f"{v.item():.4f}" for k, v in loss_dict.items()
|
|
||||||
}
|
|
||||||
logging.info(
|
|
||||||
f" Epoch {epoch + 1}, Iter {i + 1}/{num_batches}, Loss: {loss_value:.4f}, Batch Time: {batch_time:.2f}s"
|
|
||||||
)
|
|
||||||
logging.debug(
|
|
||||||
f" Loss Dict: {loss_dict_items}"
|
|
||||||
) # Log individual losses at DEBUG level
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(
|
|
||||||
f"Error during training epoch {epoch+1}, batch {i+1}: {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
)
|
||||||
# Decide if you want to stop training or continue to next batch/epoch
|
else:
|
||||||
logging.warning("Skipping rest of epoch due to error.")
|
# Extract epoch numbers from filenames and find the latest
|
||||||
break # Exit the inner loop for this epoch
|
max_epoch = -1
|
||||||
|
latest_checkpoint = None
|
||||||
|
for ckpt in checkpoints:
|
||||||
|
if ckpt.startswith("checkpoint_epoch_"):
|
||||||
|
try:
|
||||||
|
epoch_num = int(
|
||||||
|
ckpt.replace("checkpoint_epoch_", "").replace(
|
||||||
|
".pth", ""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if epoch_num > max_epoch:
|
||||||
|
max_epoch = epoch_num
|
||||||
|
latest_checkpoint = ckpt
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
|
||||||
# --- End of Epoch --- #
|
if latest_checkpoint:
|
||||||
# Step the learning rate scheduler
|
checkpoint_file = os.path.join(checkpoint_path, latest_checkpoint)
|
||||||
|
logging.info(f"Resuming from checkpoint: {checkpoint_file}")
|
||||||
|
|
||||||
|
# Load checkpoint
|
||||||
|
checkpoint, start_epoch = load_checkpoint(
|
||||||
|
checkpoint_file,
|
||||||
|
model,
|
||||||
|
device,
|
||||||
|
load_optimizer=True,
|
||||||
|
optimizer=optimizer,
|
||||||
|
load_scheduler=True,
|
||||||
|
scheduler=lr_scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(f"Resuming from epoch {start_epoch}")
|
||||||
|
else:
|
||||||
|
logging.warning(
|
||||||
|
f"No valid checkpoints found in {checkpoint_path}, starting from scratch."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error loading checkpoint: {e}", exc_info=True)
|
||||||
|
logging.warning("Starting training from scratch.")
|
||||||
|
start_epoch = 0
|
||||||
|
|
||||||
|
# --- Training Loop ---
|
||||||
|
train_time_start = time.time()
|
||||||
|
logging.info("--- Starting Training Loop ---")
|
||||||
|
|
||||||
|
for epoch in range(start_epoch, config.get("num_epochs", 10)):
|
||||||
|
# Set model to training mode
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
# Initialize epoch metrics
|
||||||
|
epoch_loss = 0.0
|
||||||
|
epoch_loss_classifier = 0.0
|
||||||
|
epoch_loss_box_reg = 0.0
|
||||||
|
epoch_loss_mask = 0.0
|
||||||
|
epoch_loss_objectness = 0.0
|
||||||
|
epoch_loss_rpn_box_reg = 0.0
|
||||||
|
|
||||||
|
logging.info(f"--- Epoch {epoch + 1}/{config.get('num_epochs', 10)} ---")
|
||||||
|
epoch_start_time = time.time()
|
||||||
|
|
||||||
|
# Train loop
|
||||||
|
for i, (images, targets) in enumerate(data_loader_train):
|
||||||
|
# Move data to device
|
||||||
|
images = list(image.to(device) for image in images)
|
||||||
|
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
loss_dict = model(images, targets)
|
||||||
|
|
||||||
|
# Sum loss components
|
||||||
|
losses = sum(loss for loss in loss_dict.values())
|
||||||
|
|
||||||
|
# Backward and optimize
|
||||||
|
optimizer.zero_grad()
|
||||||
|
losses.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# Log batch results
|
||||||
|
loss_value = losses.item()
|
||||||
|
epoch_loss += loss_value
|
||||||
|
|
||||||
|
# Accumulate individual loss components
|
||||||
|
if "loss_classifier" in loss_dict:
|
||||||
|
epoch_loss_classifier += loss_dict["loss_classifier"].item()
|
||||||
|
if "loss_box_reg" in loss_dict:
|
||||||
|
epoch_loss_box_reg += loss_dict["loss_box_reg"].item()
|
||||||
|
if "loss_mask" in loss_dict:
|
||||||
|
epoch_loss_mask += loss_dict["loss_mask"].item()
|
||||||
|
if "loss_objectness" in loss_dict:
|
||||||
|
epoch_loss_objectness += loss_dict["loss_objectness"].item()
|
||||||
|
if "loss_rpn_box_reg" in loss_dict:
|
||||||
|
epoch_loss_rpn_box_reg += loss_dict["loss_rpn_box_reg"].item()
|
||||||
|
|
||||||
|
# Periodic logging
|
||||||
|
if (i + 1) % config.get("log_freq", 10) == 0:
|
||||||
|
log_str = f"Epoch [{epoch + 1}/{config.get('num_epochs', 10)}], "
|
||||||
|
log_str += f"Iter [{i + 1}/{len(data_loader_train)}], "
|
||||||
|
log_str += f"Loss: {loss_value:.4f}"
|
||||||
|
|
||||||
|
# Add per-component losses for richer logging
|
||||||
|
comp_log = []
|
||||||
|
if "loss_classifier" in loss_dict:
|
||||||
|
comp_log.append(f"cls: {loss_dict['loss_classifier'].item():.4f}")
|
||||||
|
if "loss_box_reg" in loss_dict:
|
||||||
|
comp_log.append(f"box: {loss_dict['loss_box_reg'].item():.4f}")
|
||||||
|
if "loss_mask" in loss_dict:
|
||||||
|
comp_log.append(f"mask: {loss_dict['loss_mask'].item():.4f}")
|
||||||
|
if "loss_objectness" in loss_dict:
|
||||||
|
comp_log.append(f"obj: {loss_dict['loss_objectness'].item():.4f}")
|
||||||
|
if "loss_rpn_box_reg" in loss_dict:
|
||||||
|
comp_log.append(f"rpn: {loss_dict['loss_rpn_box_reg'].item():.4f}")
|
||||||
|
|
||||||
|
if comp_log:
|
||||||
|
log_str += f" [{', '.join(comp_log)}]"
|
||||||
|
|
||||||
|
logging.info(log_str)
|
||||||
|
|
||||||
|
# Step learning rate scheduler after each epoch
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
|
|
||||||
# Log epoch summary
|
# Calculate and log epoch metrics
|
||||||
epoch_end_time = time.time()
|
if len(data_loader_train) > 0:
|
||||||
epoch_duration = epoch_end_time - epoch_start_time
|
avg_loss = epoch_loss / len(data_loader_train)
|
||||||
avg_epoch_loss = epoch_loss_sum / num_batches if num_batches > 0 else 0
|
avg_loss_classifier = epoch_loss_classifier / len(data_loader_train)
|
||||||
current_lr = optimizer.param_groups[0]["lr"] # Get current learning rate
|
avg_loss_box_reg = epoch_loss_box_reg / len(data_loader_train)
|
||||||
logging.info(f"--- Epoch {epoch + 1} Summary --- ")
|
avg_loss_mask = epoch_loss_mask / len(data_loader_train)
|
||||||
logging.info(f" Average Loss: {avg_epoch_loss:.4f}")
|
avg_loss_objectness = epoch_loss_objectness / len(data_loader_train)
|
||||||
logging.info(f" Learning Rate: {current_lr:.6f}")
|
avg_loss_rpn_box_reg = epoch_loss_rpn_box_reg / len(data_loader_train)
|
||||||
logging.info(f" Epoch Duration: {epoch_duration:.2f}s")
|
|
||||||
|
|
||||||
# --- Checkpointing (Prompt 11) --- #
|
logging.info(f"Epoch {epoch + 1} - Avg Loss: {avg_loss:.4f}")
|
||||||
# Save checkpoint periodically or at the end
|
logging.info(f" Classifier Loss: {avg_loss_classifier:.4f}")
|
||||||
save_checkpoint = False
|
logging.info(f" Box Reg Loss: {avg_loss_box_reg:.4f}")
|
||||||
if (epoch + 1) % config.get("checkpoint_freq", 1) == 0:
|
logging.info(f" Mask Loss: {avg_loss_mask:.4f}")
|
||||||
save_checkpoint = True
|
logging.info(f" Objectness Loss: {avg_loss_objectness:.4f}")
|
||||||
logging.info(f"Checkpoint frequency met (epoch {epoch + 1})")
|
logging.info(f" RPN Box Reg Loss: {avg_loss_rpn_box_reg:.4f}")
|
||||||
elif (epoch + 1) == num_epochs:
|
else:
|
||||||
save_checkpoint = True
|
logging.warning("No training batches were processed in this epoch.")
|
||||||
logging.info(f"Final epoch ({epoch + 1}) reached, saving checkpoint.")
|
|
||||||
|
|
||||||
if save_checkpoint:
|
epoch_duration = time.time() - epoch_start_time
|
||||||
checkpoint_filename = f"checkpoint_epoch_{epoch + 1}.pth"
|
logging.info(f"Epoch duration: {epoch_duration:.2f}s")
|
||||||
save_path = os.path.join(checkpoint_path, checkpoint_filename)
|
|
||||||
|
# --- Validation ---
|
||||||
|
logging.info("Running validation...")
|
||||||
|
val_metrics = evaluate(model, data_loader_val, device)
|
||||||
|
logging.info(f"Validation Loss: {val_metrics['average_loss']:.4f}")
|
||||||
|
|
||||||
|
# --- Checkpoint Saving ---
|
||||||
|
if (epoch + 1) % config.get("checkpoint_freq", 1) == 0 or epoch == config.get(
|
||||||
|
"num_epochs", 10
|
||||||
|
) - 1:
|
||||||
|
checkpoint_file = os.path.join(
|
||||||
|
checkpoint_path, f"checkpoint_epoch_{epoch + 1}.pth"
|
||||||
|
)
|
||||||
|
checkpoint = {
|
||||||
|
"epoch": epoch + 1,
|
||||||
|
"model_state_dict": model.state_dict(),
|
||||||
|
"optimizer_state_dict": optimizer.state_dict(),
|
||||||
|
"scheduler_state_dict": lr_scheduler.state_dict(),
|
||||||
|
"config": config,
|
||||||
|
"val_loss": val_metrics["average_loss"],
|
||||||
|
}
|
||||||
try:
|
try:
|
||||||
checkpoint_data = {
|
torch.save(checkpoint, checkpoint_file)
|
||||||
"epoch": epoch + 1,
|
logging.info(f"Checkpoint saved to {checkpoint_file}")
|
||||||
"model_state_dict": model.state_dict(),
|
|
||||||
"optimizer_state_dict": optimizer.state_dict(),
|
|
||||||
"scheduler_state_dict": lr_scheduler.state_dict(),
|
|
||||||
"config": config, # Save config for reference
|
|
||||||
}
|
|
||||||
torch.save(checkpoint_data, save_path)
|
|
||||||
logging.info(f"Checkpoint saved to {save_path}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(
|
logging.error(f"Error saving checkpoint: {e}", exc_info=True)
|
||||||
f"Failed to save checkpoint for epoch {epoch + 1} to {save_path}: {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Evaluation (Prompt 12) --- #
|
# --- Final Metrics and Cleanup ---
|
||||||
if data_loader_val:
|
total_training_time = time.time() - train_time_start
|
||||||
logging.info(f"Starting evaluation for epoch {epoch + 1}...")
|
hours, remainder = divmod(total_training_time, 3600)
|
||||||
try:
|
minutes, seconds = divmod(remainder, 60)
|
||||||
val_metrics = evaluate(model, data_loader_val, device)
|
logging.info(f"Training completed in {int(hours)}h {int(minutes)}m {seconds:.2f}s")
|
||||||
logging.info(f"Epoch {epoch + 1} Validation Metrics: {val_metrics}")
|
logging.info(f"Final model saved to {checkpoint_path}")
|
||||||
|
|
||||||
# --- Best Model Checkpoint Logic (Optional Add-on) ---
|
|
||||||
# Add logic here to track the best metric (e.g., val_metrics['average_loss'])
|
|
||||||
# and save a separate 'best_model.pth' checkpoint if the current epoch is better.
|
|
||||||
# Example:
|
|
||||||
# if 'average_loss' in val_metrics:
|
|
||||||
# current_val_loss = val_metrics['average_loss']
|
|
||||||
# if best_val_loss is None or current_val_loss < best_val_loss:
|
|
||||||
# best_val_loss = current_val_loss
|
|
||||||
# best_model_path = os.path.join(output_path, 'best_model.pth')
|
|
||||||
# try:
|
|
||||||
# # Save only the model state_dict for the best model
|
|
||||||
# torch.save(model.state_dict(), best_model_path)
|
|
||||||
# logging.info(f"Saved NEW BEST model checkpoint to {best_model_path} (Val Loss: {best_val_loss:.4f})")
|
|
||||||
# except Exception as e:
|
|
||||||
# logging.error(f"Failed to save best model checkpoint: {e}", exc_info=True)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(
|
|
||||||
f"Error during evaluation for epoch {epoch + 1}: {e}", exc_info=True
|
|
||||||
)
|
|
||||||
# Decide if this error should stop the entire training process
|
|
||||||
|
|
||||||
# --- End of Training --- #
|
|
||||||
total_training_time = time.time() - start_time
|
|
||||||
logging.info("--- Training Finished --- ")
|
|
||||||
logging.info(
|
|
||||||
f"Total Training Time: {total_training_time:.2f}s ({total_training_time / 3600:.2f} hours)"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(description="Train a Mask R-CNN model")
|
||||||
description="Train Mask R-CNN on Penn-Fudan dataset."
|
parser.add_argument("--config", required=True, help="Path to configuration file")
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--config",
|
"--resume", action="store_true", help="Resume training from latest checkpoint"
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Path to the Python configuration file (e.g., configs/pennfudan_maskrcnn_config.py)",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
|||||||
185
utils/common.py
Normal file
185
utils/common.py
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
import importlib.util
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def load_config(config_path):
|
||||||
|
"""Load configuration from a Python file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path (str): Path to the configuration file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The loaded configuration dictionary.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
config_path = os.path.abspath(config_path)
|
||||||
|
if not os.path.exists(config_path):
|
||||||
|
print(f"Error: Config file not found at {config_path}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Derive module path from file path relative to workspace root
|
||||||
|
workspace_root = os.path.abspath(os.getcwd())
|
||||||
|
relative_path = os.path.relpath(config_path, workspace_root)
|
||||||
|
if relative_path.startswith(".."):
|
||||||
|
print(f"Error: Config file {config_path} is outside the project directory.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
module_path_no_ext, _ = os.path.splitext(relative_path)
|
||||||
|
module_path_str = module_path_no_ext.replace(os.sep, ".")
|
||||||
|
|
||||||
|
print(f"Attempting to import config module: {module_path_str}")
|
||||||
|
config_module = importlib.import_module(module_path_str)
|
||||||
|
config = config_module.config
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Loaded configuration from: {config_path} (via module {module_path_str})"
|
||||||
|
)
|
||||||
|
return config
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"Error importing config module '{module_path_str}': {e}")
|
||||||
|
print(
|
||||||
|
"Ensure the config file path is correct and relative imports within it are valid."
|
||||||
|
)
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
sys.exit(1)
|
||||||
|
except AttributeError as e:
|
||||||
|
print(
|
||||||
|
f"Error: Could not find 'config' dictionary in module {module_path_str}. {e}"
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading configuration file {config_path}: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_environment(config):
|
||||||
|
"""Set up the environment based on configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (dict): Configuration dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (output_path, device) - the output directory path and torch device.
|
||||||
|
"""
|
||||||
|
# Setup output directory
|
||||||
|
output_dir = config.get("output_dir", "outputs")
|
||||||
|
config_name = config.get("config_name", "default_run")
|
||||||
|
output_path = os.path.join(output_dir, config_name)
|
||||||
|
os.makedirs(output_path, exist_ok=True)
|
||||||
|
|
||||||
|
# Set random seeds
|
||||||
|
seed = config.get("seed", 42)
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
logging.info(f"Set random seed to: {seed}")
|
||||||
|
|
||||||
|
# Setup device
|
||||||
|
device_name = config.get("device", "cuda")
|
||||||
|
if device_name == "cuda" and not torch.cuda.is_available():
|
||||||
|
logging.warning("CUDA requested but not available, falling back to CPU.")
|
||||||
|
device_name = "cpu"
|
||||||
|
device = torch.device(device_name)
|
||||||
|
logging.info(f"Using device: {device}")
|
||||||
|
|
||||||
|
return output_path, device
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint(
|
||||||
|
checkpoint_path,
|
||||||
|
model,
|
||||||
|
device,
|
||||||
|
load_optimizer=False,
|
||||||
|
optimizer=None,
|
||||||
|
load_scheduler=False,
|
||||||
|
scheduler=None,
|
||||||
|
):
|
||||||
|
"""Load a checkpoint into the model and optionally optimizer and scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_path (str): Path to the checkpoint file.
|
||||||
|
model (torch.nn.Module): The model to load the weights into.
|
||||||
|
device (torch.device): The device to load the checkpoint on.
|
||||||
|
load_optimizer (bool): Whether to load optimizer state.
|
||||||
|
optimizer (torch.optim.Optimizer, optional): The optimizer to load state into.
|
||||||
|
load_scheduler (bool): Whether to load scheduler state.
|
||||||
|
scheduler (torch.optim.lr_scheduler._LRScheduler, optional): The scheduler to load state into.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The loaded checkpoint.
|
||||||
|
int: The starting epoch (checkpoint epoch + 1).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logging.info(f"Loading checkpoint from: {checkpoint_path}")
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||||
|
|
||||||
|
# Handle potential DataParallel prefix
|
||||||
|
state_dict = checkpoint.get("model_state_dict", checkpoint)
|
||||||
|
if isinstance(state_dict, dict):
|
||||||
|
# Handle case where model was trained with DataParallel
|
||||||
|
if all(k.startswith("module.") for k in state_dict.keys()):
|
||||||
|
logging.info(
|
||||||
|
"Detected DataParallel checkpoint, removing 'module.' prefix"
|
||||||
|
)
|
||||||
|
state_dict = {
|
||||||
|
k.replace("module.", ""): v for k, v in state_dict.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
logging.info("Model state loaded successfully")
|
||||||
|
|
||||||
|
# Load optimizer state if requested
|
||||||
|
if (
|
||||||
|
load_optimizer
|
||||||
|
and optimizer is not None
|
||||||
|
and "optimizer_state_dict" in checkpoint
|
||||||
|
):
|
||||||
|
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||||
|
logging.info("Optimizer state loaded successfully")
|
||||||
|
|
||||||
|
# Load scheduler state if requested
|
||||||
|
if (
|
||||||
|
load_scheduler
|
||||||
|
and scheduler is not None
|
||||||
|
and "scheduler_state_dict" in checkpoint
|
||||||
|
):
|
||||||
|
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
||||||
|
logging.info("Scheduler state loaded successfully")
|
||||||
|
|
||||||
|
# Get the epoch number
|
||||||
|
start_epoch = checkpoint.get("epoch", 0) + 1 if load_optimizer else 0
|
||||||
|
if "epoch" in checkpoint:
|
||||||
|
logging.info(f"Loaded checkpoint from epoch: {checkpoint['epoch']}")
|
||||||
|
|
||||||
|
return checkpoint, start_epoch
|
||||||
|
else:
|
||||||
|
logging.error("Checkpoint does not contain a valid state dictionary.")
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error loading checkpoint: {e}", exc_info=True)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def check_data_path(data_root):
|
||||||
|
"""Check if the data path exists and is valid.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_root (str): Path to the data directory.
|
||||||
|
"""
|
||||||
|
if not data_root or not os.path.isdir(data_root):
|
||||||
|
logging.error(f"Data root directory not found or not specified: {data_root}")
|
||||||
|
sys.exit(1)
|
||||||
Reference in New Issue
Block a user