#!/usr/bin/env python """ Visualization script for model predictions on the Penn-Fudan dataset. This helps visualize and debug model predictions. """ import argparse import os import sys import matplotlib.pyplot as plt import numpy as np import torch # Add project root to path for imports sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from models.detection import get_maskrcnn_model from utils.common import load_checkpoint, load_config from utils.data_utils import PennFudanDataset, get_transform def visualize_prediction(image, prediction, threshold=0.5): """ Visualize model prediction on an image. Args: image (torch.Tensor): The input image [C, H, W] prediction (dict): Model prediction dict with boxes, scores, labels, masks threshold (float): Score threshold for visualization Returns: plt.Figure: The matplotlib figure with the visualization """ # Convert image from tensor to numpy img_np = image.permute(1, 2, 0).cpu().numpy() # Denormalize if needed if img_np.max() <= 1.0: img_np = (img_np * 255).astype(np.uint8) # Create figure and axes fig, ax = plt.subplots(1, 1, figsize=(12, 9)) ax.imshow(img_np) ax.set_title("Model Predictions") # Get predictions boxes = prediction["boxes"].cpu().numpy() scores = prediction["scores"].cpu().numpy() labels = prediction["labels"].cpu().numpy() masks = prediction["masks"].cpu().numpy() # Filter by threshold mask = scores >= threshold boxes = boxes[mask] scores = scores[mask] labels = labels[mask] masks = masks[mask] # Draw predictions for box, score, label, mask in zip(boxes, scores, labels, masks): # Draw box x1, y1, x2, y2 = box rect = plt.Rectangle( (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="red", linewidth=2 ) ax.add_patch(rect) # Add label and score ax.text( x1, y1, f"Person: {score:.2f}", bbox=dict(facecolor="yellow", alpha=0.5) ) # Draw mask (with transparency) mask = mask[0] > 0.5 # Threshold mask mask_color = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8) mask_color[mask] = [255, 0, 0] # Red color ax.imshow(mask_color, alpha=0.3) # Show count of detections ax.set_xlabel(f"Found {len(boxes)} pedestrians with confidence >= {threshold}") plt.tight_layout() return fig def run_inference(model, dataset, device, idx=0): """ Run inference on a single image from the dataset. Args: model (torch.nn.Module): The model dataset (PennFudanDataset): The dataset device (torch.device): The device idx (int): Index of the image to test Returns: tuple: (image, prediction) """ # Get image image, _ = dataset[idx] # Prepare for model image = image.to(device) # Run inference model.eval() with torch.no_grad(): prediction = model([image])[0] return image, prediction def main(): """Main entry point.""" parser = argparse.ArgumentParser(description="Visualize model predictions") parser.add_argument("--config", required=True, help="Path to config file") parser.add_argument("--checkpoint", required=True, help="Path to checkpoint file") parser.add_argument("--index", type=int, default=0, help="Image index to visualize") parser.add_argument("--threshold", type=float, default=0.5, help="Score threshold") parser.add_argument("--output", help="Path to save visualization image") args = parser.parse_args() # Load config config = load_config(args.config) # Setup device device = torch.device(config.get("device", "cpu")) print(f"Using device: {device}") # Create model model = get_maskrcnn_model( num_classes=config.get("num_classes", 2), pretrained=False, pretrained_backbone=False, ) # Load checkpoint checkpoint, _ = load_checkpoint(args.checkpoint, model, device) model.to(device) print(f"Loaded checkpoint from: {args.checkpoint}") # Create dataset data_root = config.get("data_root", "data/PennFudanPed") if not os.path.exists(data_root): print(f"Error: Data not found at {data_root}") return dataset = PennFudanDataset(root=data_root, transforms=get_transform(train=False)) print(f"Dataset loaded with {len(dataset)} images") # Validate index if args.index < 0 or args.index >= len(dataset): print(f"Error: Index {args.index} out of range (0-{len(dataset)-1})") return # Run inference print(f"Running inference on image {args.index}...") image, prediction = run_inference(model, dataset, device, args.index) # Visualize prediction print("Visualizing predictions...") fig = visualize_prediction(image, prediction, threshold=args.threshold) # Save or show if args.output: fig.savefig(args.output) print(f"Visualization saved to: {args.output}") else: plt.show() print("Visualization displayed. Close window to continue.") if __name__ == "__main__": main()