import os import sys import matplotlib.pyplot as plt import torch # Import visualization functions sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from scripts.visualize_predictions import visualize_prediction def test_visualize_prediction(): """Test that the visualization function works.""" # Create a dummy image tensor image = torch.rand(3, 400, 600) # Create a dummy prediction dictionary prediction = { "boxes": torch.tensor( [[100, 100, 200, 200], [300, 300, 400, 400]], dtype=torch.float32 ), "scores": torch.tensor([0.9, 0.7], dtype=torch.float32), "labels": torch.tensor([1, 1], dtype=torch.int64), "masks": torch.zeros(2, 1, 400, 600, dtype=torch.float32), } # Set some pixels in the mask to 1 prediction["masks"][0, 0, 100:200, 100:200] = 1.0 prediction["masks"][1, 0, 300:400, 300:400] = 1.0 # Call the visualization function fig = visualize_prediction(image, prediction, threshold=0.5) # Check that a figure was returned assert isinstance(fig, plt.Figure) # Check figure properties assert len(fig.axes) == 1 # Close the figure to avoid memory leaks plt.close(fig) def test_visualize_prediction_threshold(): """Test that the threshold parameter filters predictions correctly.""" # Create a dummy image tensor image = torch.rand(3, 400, 600) # Create a dummy prediction dictionary with varying scores prediction = { "boxes": torch.tensor( [[100, 100, 200, 200], [300, 300, 400, 400], [500, 100, 550, 150]], dtype=torch.float32, ), "scores": torch.tensor([0.9, 0.7, 0.3], dtype=torch.float32), "labels": torch.tensor([1, 1, 1], dtype=torch.int64), "masks": torch.zeros(3, 1, 400, 600, dtype=torch.float32), } # Call the visualization function with different thresholds fig_low = visualize_prediction(image, prediction, threshold=0.2) fig_med = visualize_prediction(image, prediction, threshold=0.5) fig_high = visualize_prediction(image, prediction, threshold=0.8) # Low threshold should show all 3 boxes assert "Found 3" in fig_low.axes[0].get_xlabel() # Medium threshold should show 2 boxes assert "Found 2" in fig_med.axes[0].get_xlabel() # High threshold should show 1 box assert "Found 1" in fig_high.axes[0].get_xlabel() # Close figures plt.close(fig_low) plt.close(fig_med) plt.close(fig_high)