Pausing for now, train loop should now work and adding some tests
This commit is contained in:
77
tests/test_visualization.py
Normal file
77
tests/test_visualization.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
|
||||
# Import visualization functions
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
from scripts.visualize_predictions import visualize_prediction
|
||||
|
||||
|
||||
def test_visualize_prediction():
|
||||
"""Test that the visualization function works."""
|
||||
# Create a dummy image tensor
|
||||
image = torch.rand(3, 400, 600)
|
||||
|
||||
# Create a dummy prediction dictionary
|
||||
prediction = {
|
||||
"boxes": torch.tensor(
|
||||
[[100, 100, 200, 200], [300, 300, 400, 400]], dtype=torch.float32
|
||||
),
|
||||
"scores": torch.tensor([0.9, 0.7], dtype=torch.float32),
|
||||
"labels": torch.tensor([1, 1], dtype=torch.int64),
|
||||
"masks": torch.zeros(2, 1, 400, 600, dtype=torch.float32),
|
||||
}
|
||||
|
||||
# Set some pixels in the mask to 1
|
||||
prediction["masks"][0, 0, 100:200, 100:200] = 1.0
|
||||
prediction["masks"][1, 0, 300:400, 300:400] = 1.0
|
||||
|
||||
# Call the visualization function
|
||||
fig = visualize_prediction(image, prediction, threshold=0.5)
|
||||
|
||||
# Check that a figure was returned
|
||||
assert isinstance(fig, plt.Figure)
|
||||
|
||||
# Check figure properties
|
||||
assert len(fig.axes) == 1
|
||||
|
||||
# Close the figure to avoid memory leaks
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def test_visualize_prediction_threshold():
|
||||
"""Test that the threshold parameter filters predictions correctly."""
|
||||
# Create a dummy image tensor
|
||||
image = torch.rand(3, 400, 600)
|
||||
|
||||
# Create a dummy prediction dictionary with varying scores
|
||||
prediction = {
|
||||
"boxes": torch.tensor(
|
||||
[[100, 100, 200, 200], [300, 300, 400, 400], [500, 100, 550, 150]],
|
||||
dtype=torch.float32,
|
||||
),
|
||||
"scores": torch.tensor([0.9, 0.7, 0.3], dtype=torch.float32),
|
||||
"labels": torch.tensor([1, 1, 1], dtype=torch.int64),
|
||||
"masks": torch.zeros(3, 1, 400, 600, dtype=torch.float32),
|
||||
}
|
||||
|
||||
# Call the visualization function with different thresholds
|
||||
fig_low = visualize_prediction(image, prediction, threshold=0.2)
|
||||
fig_med = visualize_prediction(image, prediction, threshold=0.5)
|
||||
fig_high = visualize_prediction(image, prediction, threshold=0.8)
|
||||
|
||||
# Low threshold should show all 3 boxes
|
||||
assert "Found 3" in fig_low.axes[0].get_xlabel()
|
||||
|
||||
# Medium threshold should show 2 boxes
|
||||
assert "Found 2" in fig_med.axes[0].get_xlabel()
|
||||
|
||||
# High threshold should show 1 box
|
||||
assert "Found 1" in fig_high.axes[0].get_xlabel()
|
||||
|
||||
# Close figures
|
||||
plt.close(fig_low)
|
||||
plt.close(fig_med)
|
||||
plt.close(fig_high)
|
||||
Reference in New Issue
Block a user