Files
torchvision-vibecoding-project/test.py

125 lines
3.9 KiB
Python

import argparse
import logging
import sys
import torch
import torch.utils.data
# Project specific imports
from models.detection import get_maskrcnn_model
from utils.common import (
check_data_path,
load_checkpoint,
load_config,
setup_environment,
)
from utils.data_utils import PennFudanDataset, collate_fn, get_transform
from utils.eval_utils import evaluate
from utils.log_utils import setup_logging
def main(args):
# Load configuration
config = load_config(args.config)
# Setup output directory and get device
output_path, device = setup_environment(config)
# Setup logging
setup_logging(output_path, f"{config['config_name']}_test")
logging.info("--- Testing Script Started ---")
logging.info(f"Loaded configuration from: {args.config}")
logging.info(f"Checkpoint path: {args.checkpoint}")
logging.info(f"Loaded configuration dictionary: {config}")
if args.max_samples:
logging.info(f"Limiting evaluation to {args.max_samples} samples")
# Validate data path
data_root = config.get("data_root")
check_data_path(data_root)
try:
# Create the full dataset instance for testing with eval transforms
dataset_test = PennFudanDataset(
root=data_root, transforms=get_transform(train=False)
)
logging.info(f"Test dataset size: {len(dataset_test)}")
# Create test DataLoader
data_loader_test = torch.utils.data.DataLoader(
dataset_test,
batch_size=config.get("batch_size", 2),
shuffle=False, # No need to shuffle test data
num_workers=config.get("num_workers", 4),
collate_fn=collate_fn,
pin_memory=config.get("pin_memory", True),
)
logging.info(
f"Test dataloader configured. Est. batches: {len(data_loader_test)}"
)
except Exception as e:
logging.error(f"Error setting up dataset/dataloader: {e}", exc_info=True)
sys.exit(1)
# Create model
num_classes = config.get("num_classes")
if num_classes is None:
logging.error("'num_classes' not specified in configuration.")
sys.exit(1)
try:
# Create the model with the same architecture as in training
model = get_maskrcnn_model(
num_classes=num_classes,
pretrained=False, # Don't need pretrained weights as we'll load checkpoint
pretrained_backbone=False,
)
# Load checkpoint
load_checkpoint(args.checkpoint, model, device)
model.to(device)
logging.info("Model loaded and moved to device successfully.")
except Exception as e:
logging.error(f"Error setting up model: {e}", exc_info=True)
sys.exit(1)
# Run Evaluation
try:
logging.info("Starting model evaluation...")
eval_metrics = evaluate(model, data_loader_test, device, args.max_samples)
# Log detailed metrics
logging.info("--- Evaluation Results ---")
for metric_name, metric_value in eval_metrics.items():
if isinstance(metric_value, (int, float)):
logging.info(f" {metric_name}: {metric_value:.4f}")
else:
logging.info(f" {metric_name}: {metric_value}")
logging.info("Evaluation completed successfully")
except Exception as e:
logging.error(f"Error during evaluation: {e}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Test script for torchvision Mask R-CNN"
)
parser.add_argument(
"--config", required=True, type=str, help="Path to configuration file"
)
parser.add_argument(
"--checkpoint", required=True, type=str, help="Path to model checkpoint"
)
parser.add_argument(
"--max_samples",
type=int,
default=None,
help="Maximum number of samples to evaluate",
)
args = parser.parse_args()
main(args)