Create test script, and refactor logic from train into common file for usage across both scripts
This commit is contained in:
109
test.py
109
test.py
@@ -0,0 +1,109 @@
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.utils.data
|
||||
|
||||
# Project specific imports
|
||||
from models.detection import get_maskrcnn_model
|
||||
from utils.common import (
|
||||
check_data_path,
|
||||
load_checkpoint,
|
||||
load_config,
|
||||
setup_environment,
|
||||
)
|
||||
from utils.data_utils import PennFudanDataset, collate_fn, get_transform
|
||||
from utils.eval_utils import evaluate
|
||||
from utils.log_utils import setup_logging
|
||||
|
||||
|
||||
def main(args):
|
||||
# Load configuration
|
||||
config = load_config(args.config)
|
||||
|
||||
# Setup output directory and get device
|
||||
output_path, device = setup_environment(config)
|
||||
|
||||
# Setup logging
|
||||
setup_logging(output_path, f"{config['config_name']}_test")
|
||||
logging.info("--- Testing Script Started ---")
|
||||
logging.info(f"Loaded configuration from: {args.config}")
|
||||
logging.info(f"Checkpoint path: {args.checkpoint}")
|
||||
logging.info(f"Loaded configuration dictionary: {config}")
|
||||
|
||||
# Validate data path
|
||||
data_root = config.get("data_root")
|
||||
check_data_path(data_root)
|
||||
|
||||
try:
|
||||
# Create the full dataset instance for testing with eval transforms
|
||||
dataset_test = PennFudanDataset(
|
||||
root=data_root, transforms=get_transform(train=False)
|
||||
)
|
||||
logging.info(f"Test dataset size: {len(dataset_test)}")
|
||||
|
||||
# Create test DataLoader
|
||||
data_loader_test = torch.utils.data.DataLoader(
|
||||
dataset_test,
|
||||
batch_size=config.get("batch_size", 2),
|
||||
shuffle=False, # No need to shuffle test data
|
||||
num_workers=config.get("num_workers", 4),
|
||||
collate_fn=collate_fn,
|
||||
pin_memory=config.get("pin_memory", True),
|
||||
)
|
||||
|
||||
logging.info(
|
||||
f"Test dataloader configured. Est. batches: {len(data_loader_test)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error setting up dataset/dataloader: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
# Create model
|
||||
num_classes = config.get("num_classes")
|
||||
if num_classes is None:
|
||||
logging.error("'num_classes' not specified in configuration.")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
# Create the model with the same architecture as in training
|
||||
model = get_maskrcnn_model(
|
||||
num_classes=num_classes,
|
||||
pretrained=False, # Don't need pretrained weights as we'll load checkpoint
|
||||
pretrained_backbone=False,
|
||||
)
|
||||
|
||||
# Load checkpoint
|
||||
load_checkpoint(args.checkpoint, model, device)
|
||||
model.to(device)
|
||||
logging.info("Model loaded and moved to device successfully.")
|
||||
except Exception as e:
|
||||
logging.error(f"Error setting up model: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
# Run Evaluation
|
||||
try:
|
||||
logging.info("Starting model evaluation...")
|
||||
eval_metrics = evaluate(model, data_loader_test, device)
|
||||
|
||||
# Log detailed metrics
|
||||
logging.info("--- Evaluation Results ---")
|
||||
for metric_name, metric_value in eval_metrics.items():
|
||||
logging.info(f" {metric_name}: {metric_value:.4f}")
|
||||
|
||||
logging.info("Evaluation completed successfully")
|
||||
except Exception as e:
|
||||
logging.error(f"Error during evaluation: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Test a trained Mask R-CNN model")
|
||||
parser.add_argument("--config", required=True, help="Path to configuration file")
|
||||
parser.add_argument(
|
||||
"--checkpoint", required=True, help="Path to model checkpoint file (.pth)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
Reference in New Issue
Block a user