import argparse import logging import sys import torch import torch.utils.data # Project specific imports from models.detection import get_maskrcnn_model from utils.common import ( check_data_path, load_checkpoint, load_config, setup_environment, ) from utils.data_utils import PennFudanDataset, collate_fn, get_transform from utils.eval_utils import evaluate from utils.log_utils import setup_logging def main(args): # Load configuration config = load_config(args.config) # Setup output directory and get device output_path, device = setup_environment(config) # Setup logging setup_logging(output_path, f"{config['config_name']}_test") logging.info("--- Testing Script Started ---") logging.info(f"Loaded configuration from: {args.config}") logging.info(f"Checkpoint path: {args.checkpoint}") logging.info(f"Loaded configuration dictionary: {config}") if args.max_samples: logging.info(f"Limiting evaluation to {args.max_samples} samples") # Validate data path data_root = config.get("data_root") check_data_path(data_root) try: # Create the full dataset instance for testing with eval transforms dataset_test = PennFudanDataset( root=data_root, transforms=get_transform(train=False) ) logging.info(f"Test dataset size: {len(dataset_test)}") # Create test DataLoader data_loader_test = torch.utils.data.DataLoader( dataset_test, batch_size=config.get("batch_size", 2), shuffle=False, # No need to shuffle test data num_workers=config.get("num_workers", 4), collate_fn=collate_fn, pin_memory=config.get("pin_memory", True), ) logging.info( f"Test dataloader configured. Est. batches: {len(data_loader_test)}" ) except Exception as e: logging.error(f"Error setting up dataset/dataloader: {e}", exc_info=True) sys.exit(1) # Create model num_classes = config.get("num_classes") if num_classes is None: logging.error("'num_classes' not specified in configuration.") sys.exit(1) try: # Create the model with the same architecture as in training model = get_maskrcnn_model( num_classes=num_classes, pretrained=False, # Don't need pretrained weights as we'll load checkpoint pretrained_backbone=False, ) # Load checkpoint load_checkpoint(args.checkpoint, model, device) model.to(device) logging.info("Model loaded and moved to device successfully.") except Exception as e: logging.error(f"Error setting up model: {e}", exc_info=True) sys.exit(1) # Run Evaluation try: logging.info("Starting model evaluation...") eval_metrics = evaluate(model, data_loader_test, device, args.max_samples) # Log detailed metrics logging.info("--- Evaluation Results ---") for metric_name, metric_value in eval_metrics.items(): if isinstance(metric_value, (int, float)): logging.info(f" {metric_name}: {metric_value:.4f}") else: logging.info(f" {metric_name}: {metric_value}") logging.info("Evaluation completed successfully") except Exception as e: logging.error(f"Error during evaluation: {e}", exc_info=True) sys.exit(1) if __name__ == "__main__": parser = argparse.ArgumentParser( description="Test script for torchvision Mask R-CNN" ) parser.add_argument( "--config", required=True, type=str, help="Path to configuration file" ) parser.add_argument( "--checkpoint", required=True, type=str, help="Path to model checkpoint" ) parser.add_argument( "--max_samples", type=int, default=None, help="Maximum number of samples to evaluate", ) args = parser.parse_args() main(args)