Files
torchvision-vibecoding-project/tests/test_visualization.py

78 lines
2.5 KiB
Python

import os
import sys
import matplotlib.pyplot as plt
import torch
# Import visualization functions
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from scripts.visualize_predictions import visualize_prediction # noqa: E402
def test_visualize_prediction():
"""Test that the visualization function works."""
# Create a dummy image tensor
image = torch.rand(3, 400, 600)
# Create a dummy prediction dictionary
prediction = {
"boxes": torch.tensor(
[[100, 100, 200, 200], [300, 300, 400, 400]], dtype=torch.float32
),
"scores": torch.tensor([0.9, 0.7], dtype=torch.float32),
"labels": torch.tensor([1, 1], dtype=torch.int64),
"masks": torch.zeros(2, 1, 400, 600, dtype=torch.float32),
}
# Set some pixels in the mask to 1
prediction["masks"][0, 0, 100:200, 100:200] = 1.0
prediction["masks"][1, 0, 300:400, 300:400] = 1.0
# Call the visualization function
fig = visualize_prediction(image, prediction, threshold=0.5)
# Check that a figure was returned
assert isinstance(fig, plt.Figure)
# Check figure properties
assert len(fig.axes) == 1
# Close the figure to avoid memory leaks
plt.close(fig)
def test_visualize_prediction_threshold():
"""Test that the threshold parameter filters predictions correctly."""
# Create a dummy image tensor
image = torch.rand(3, 400, 600)
# Create a dummy prediction dictionary with varying scores
prediction = {
"boxes": torch.tensor(
[[100, 100, 200, 200], [300, 300, 400, 400], [500, 100, 550, 150]],
dtype=torch.float32,
),
"scores": torch.tensor([0.9, 0.7, 0.3], dtype=torch.float32),
"labels": torch.tensor([1, 1, 1], dtype=torch.int64),
"masks": torch.zeros(3, 1, 400, 600, dtype=torch.float32),
}
# Call the visualization function with different thresholds
fig_low = visualize_prediction(image, prediction, threshold=0.2)
fig_med = visualize_prediction(image, prediction, threshold=0.5)
fig_high = visualize_prediction(image, prediction, threshold=0.8)
# Low threshold should show all 3 boxes
assert "Found 3" in fig_low.axes[0].get_xlabel()
# Medium threshold should show 2 boxes
assert "Found 2" in fig_med.axes[0].get_xlabel()
# High threshold should show 1 box
assert "Found 1" in fig_high.axes[0].get_xlabel()
# Close figures
plt.close(fig_low)
plt.close(fig_med)
plt.close(fig_high)