Files
torchvision-vibecoding-project/scripts/visualize_predictions.py

176 lines
5.1 KiB
Python
Executable File

#!/usr/bin/env python
"""
Visualization script for model predictions on the Penn-Fudan dataset.
This helps visualize and debug model predictions.
"""
import argparse
import os
import sys
import matplotlib.pyplot as plt
import numpy as np
import torch
# Add project root to path for imports
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from models.detection import get_maskrcnn_model
from utils.common import load_checkpoint, load_config
from utils.data_utils import PennFudanDataset, get_transform
def visualize_prediction(image, prediction, threshold=0.5):
"""
Visualize model prediction on an image.
Args:
image (torch.Tensor): The input image [C, H, W]
prediction (dict): Model prediction dict with boxes, scores, labels, masks
threshold (float): Score threshold for visualization
Returns:
plt.Figure: The matplotlib figure with the visualization
"""
# Convert image from tensor to numpy
img_np = image.permute(1, 2, 0).cpu().numpy()
# Denormalize if needed
if img_np.max() <= 1.0:
img_np = (img_np * 255).astype(np.uint8)
# Create figure and axes
fig, ax = plt.subplots(1, 1, figsize=(12, 9))
ax.imshow(img_np)
ax.set_title("Model Predictions")
# Get predictions
boxes = prediction["boxes"].cpu().numpy()
scores = prediction["scores"].cpu().numpy()
labels = prediction["labels"].cpu().numpy()
masks = prediction["masks"].cpu().numpy()
# Filter by threshold
mask = scores >= threshold
boxes = boxes[mask]
scores = scores[mask]
labels = labels[mask]
masks = masks[mask]
# Draw predictions
for box, score, label, mask in zip(boxes, scores, labels, masks):
# Draw box
x1, y1, x2, y2 = box
rect = plt.Rectangle(
(x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="red", linewidth=2
)
ax.add_patch(rect)
# Add label and score
ax.text(
x1, y1, f"Person: {score:.2f}", bbox=dict(facecolor="yellow", alpha=0.5)
)
# Draw mask (with transparency)
mask = mask[0] > 0.5 # Threshold mask
mask_color = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
mask_color[mask] = [255, 0, 0] # Red color
ax.imshow(mask_color, alpha=0.3)
# Show count of detections
ax.set_xlabel(f"Found {len(boxes)} pedestrians with confidence >= {threshold}")
plt.tight_layout()
return fig
def run_inference(model, dataset, device, idx=0):
"""
Run inference on a single image from the dataset.
Args:
model (torch.nn.Module): The model
dataset (PennFudanDataset): The dataset
device (torch.device): The device
idx (int): Index of the image to test
Returns:
tuple: (image, prediction)
"""
# Get image
image, _ = dataset[idx]
# Prepare for model
image = image.to(device)
# Run inference
model.eval()
with torch.no_grad():
prediction = model([image])[0]
return image, prediction
def main():
"""Main entry point."""
parser = argparse.ArgumentParser(description="Visualize model predictions")
parser.add_argument("--config", required=True, help="Path to config file")
parser.add_argument("--checkpoint", required=True, help="Path to checkpoint file")
parser.add_argument("--index", type=int, default=0, help="Image index to visualize")
parser.add_argument("--threshold", type=float, default=0.5, help="Score threshold")
parser.add_argument("--output", help="Path to save visualization image")
args = parser.parse_args()
# Load config
config = load_config(args.config)
# Setup device
device = torch.device(config.get("device", "cpu"))
print(f"Using device: {device}")
# Create model
model = get_maskrcnn_model(
num_classes=config.get("num_classes", 2),
pretrained=False,
pretrained_backbone=False,
)
# Load checkpoint
checkpoint, _ = load_checkpoint(args.checkpoint, model, device)
model.to(device)
print(f"Loaded checkpoint from: {args.checkpoint}")
# Create dataset
data_root = config.get("data_root", "data/PennFudanPed")
if not os.path.exists(data_root):
print(f"Error: Data not found at {data_root}")
return
dataset = PennFudanDataset(root=data_root, transforms=get_transform(train=False))
print(f"Dataset loaded with {len(dataset)} images")
# Validate index
if args.index < 0 or args.index >= len(dataset):
print(f"Error: Index {args.index} out of range (0-{len(dataset)-1})")
return
# Run inference
print(f"Running inference on image {args.index}...")
image, prediction = run_inference(model, dataset, device, args.index)
# Visualize prediction
print("Visualizing predictions...")
fig = visualize_prediction(image, prediction, threshold=args.threshold)
# Save or show
if args.output:
fig.savefig(args.output)
print(f"Visualization saved to: {args.output}")
else:
plt.show()
print("Visualization displayed. Close window to continue.")
if __name__ == "__main__":
main()