125 lines
3.9 KiB
Python
125 lines
3.9 KiB
Python
import argparse
|
|
import logging
|
|
import sys
|
|
|
|
import torch
|
|
import torch.utils.data
|
|
|
|
# Project specific imports
|
|
from models.detection import get_maskrcnn_model
|
|
from utils.common import (
|
|
check_data_path,
|
|
load_checkpoint,
|
|
load_config,
|
|
setup_environment,
|
|
)
|
|
from utils.data_utils import PennFudanDataset, collate_fn, get_transform
|
|
from utils.eval_utils import evaluate
|
|
from utils.log_utils import setup_logging
|
|
|
|
|
|
def main(args):
|
|
# Load configuration
|
|
config = load_config(args.config)
|
|
|
|
# Setup output directory and get device
|
|
output_path, device = setup_environment(config)
|
|
|
|
# Setup logging
|
|
setup_logging(output_path, f"{config['config_name']}_test")
|
|
logging.info("--- Testing Script Started ---")
|
|
logging.info(f"Loaded configuration from: {args.config}")
|
|
logging.info(f"Checkpoint path: {args.checkpoint}")
|
|
logging.info(f"Loaded configuration dictionary: {config}")
|
|
if args.max_samples:
|
|
logging.info(f"Limiting evaluation to {args.max_samples} samples")
|
|
|
|
# Validate data path
|
|
data_root = config.get("data_root")
|
|
check_data_path(data_root)
|
|
|
|
try:
|
|
# Create the full dataset instance for testing with eval transforms
|
|
dataset_test = PennFudanDataset(
|
|
root=data_root, transforms=get_transform(train=False)
|
|
)
|
|
logging.info(f"Test dataset size: {len(dataset_test)}")
|
|
|
|
# Create test DataLoader
|
|
data_loader_test = torch.utils.data.DataLoader(
|
|
dataset_test,
|
|
batch_size=config.get("batch_size", 2),
|
|
shuffle=False, # No need to shuffle test data
|
|
num_workers=config.get("num_workers", 4),
|
|
collate_fn=collate_fn,
|
|
pin_memory=config.get("pin_memory", True),
|
|
)
|
|
|
|
logging.info(
|
|
f"Test dataloader configured. Est. batches: {len(data_loader_test)}"
|
|
)
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error setting up dataset/dataloader: {e}", exc_info=True)
|
|
sys.exit(1)
|
|
|
|
# Create model
|
|
num_classes = config.get("num_classes")
|
|
if num_classes is None:
|
|
logging.error("'num_classes' not specified in configuration.")
|
|
sys.exit(1)
|
|
|
|
try:
|
|
# Create the model with the same architecture as in training
|
|
model = get_maskrcnn_model(
|
|
num_classes=num_classes,
|
|
pretrained=False, # Don't need pretrained weights as we'll load checkpoint
|
|
pretrained_backbone=False,
|
|
)
|
|
|
|
# Load checkpoint
|
|
load_checkpoint(args.checkpoint, model, device)
|
|
model.to(device)
|
|
logging.info("Model loaded and moved to device successfully.")
|
|
except Exception as e:
|
|
logging.error(f"Error setting up model: {e}", exc_info=True)
|
|
sys.exit(1)
|
|
|
|
# Run Evaluation
|
|
try:
|
|
logging.info("Starting model evaluation...")
|
|
eval_metrics = evaluate(model, data_loader_test, device, args.max_samples)
|
|
|
|
# Log detailed metrics
|
|
logging.info("--- Evaluation Results ---")
|
|
for metric_name, metric_value in eval_metrics.items():
|
|
if isinstance(metric_value, (int, float)):
|
|
logging.info(f" {metric_name}: {metric_value:.4f}")
|
|
else:
|
|
logging.info(f" {metric_name}: {metric_value}")
|
|
|
|
logging.info("Evaluation completed successfully")
|
|
except Exception as e:
|
|
logging.error(f"Error during evaluation: {e}", exc_info=True)
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
description="Test script for torchvision Mask R-CNN"
|
|
)
|
|
parser.add_argument(
|
|
"--config", required=True, type=str, help="Path to configuration file"
|
|
)
|
|
parser.add_argument(
|
|
"--checkpoint", required=True, type=str, help="Path to model checkpoint"
|
|
)
|
|
parser.add_argument(
|
|
"--max_samples",
|
|
type=int,
|
|
default=None,
|
|
help="Maximum number of samples to evaluate",
|
|
)
|
|
args = parser.parse_args()
|
|
main(args)
|