78 lines
2.5 KiB
Python
78 lines
2.5 KiB
Python
import os
|
|
import sys
|
|
|
|
import matplotlib.pyplot as plt
|
|
import torch
|
|
|
|
# Import visualization functions
|
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
|
from scripts.visualize_predictions import visualize_prediction # noqa: E402
|
|
|
|
|
|
def test_visualize_prediction():
|
|
"""Test that the visualization function works."""
|
|
# Create a dummy image tensor
|
|
image = torch.rand(3, 400, 600)
|
|
|
|
# Create a dummy prediction dictionary
|
|
prediction = {
|
|
"boxes": torch.tensor(
|
|
[[100, 100, 200, 200], [300, 300, 400, 400]], dtype=torch.float32
|
|
),
|
|
"scores": torch.tensor([0.9, 0.7], dtype=torch.float32),
|
|
"labels": torch.tensor([1, 1], dtype=torch.int64),
|
|
"masks": torch.zeros(2, 1, 400, 600, dtype=torch.float32),
|
|
}
|
|
|
|
# Set some pixels in the mask to 1
|
|
prediction["masks"][0, 0, 100:200, 100:200] = 1.0
|
|
prediction["masks"][1, 0, 300:400, 300:400] = 1.0
|
|
|
|
# Call the visualization function
|
|
fig = visualize_prediction(image, prediction, threshold=0.5)
|
|
|
|
# Check that a figure was returned
|
|
assert isinstance(fig, plt.Figure)
|
|
|
|
# Check figure properties
|
|
assert len(fig.axes) == 1
|
|
|
|
# Close the figure to avoid memory leaks
|
|
plt.close(fig)
|
|
|
|
|
|
def test_visualize_prediction_threshold():
|
|
"""Test that the threshold parameter filters predictions correctly."""
|
|
# Create a dummy image tensor
|
|
image = torch.rand(3, 400, 600)
|
|
|
|
# Create a dummy prediction dictionary with varying scores
|
|
prediction = {
|
|
"boxes": torch.tensor(
|
|
[[100, 100, 200, 200], [300, 300, 400, 400], [500, 100, 550, 150]],
|
|
dtype=torch.float32,
|
|
),
|
|
"scores": torch.tensor([0.9, 0.7, 0.3], dtype=torch.float32),
|
|
"labels": torch.tensor([1, 1, 1], dtype=torch.int64),
|
|
"masks": torch.zeros(3, 1, 400, 600, dtype=torch.float32),
|
|
}
|
|
|
|
# Call the visualization function with different thresholds
|
|
fig_low = visualize_prediction(image, prediction, threshold=0.2)
|
|
fig_med = visualize_prediction(image, prediction, threshold=0.5)
|
|
fig_high = visualize_prediction(image, prediction, threshold=0.8)
|
|
|
|
# Low threshold should show all 3 boxes
|
|
assert "Found 3" in fig_low.axes[0].get_xlabel()
|
|
|
|
# Medium threshold should show 2 boxes
|
|
assert "Found 2" in fig_med.axes[0].get_xlabel()
|
|
|
|
# High threshold should show 1 box
|
|
assert "Found 1" in fig_high.axes[0].get_xlabel()
|
|
|
|
# Close figures
|
|
plt.close(fig_low)
|
|
plt.close(fig_med)
|
|
plt.close(fig_high)
|