Pausing for now, train loop should now work and adding some tests
This commit is contained in:
152
scripts/test_model.py
Executable file
152
scripts/test_model.py
Executable file
@@ -0,0 +1,152 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Quick model testing script to verify model creation and inference.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
# Add project root to the path to enable imports
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from models.detection import get_maskrcnn_model
|
||||
from utils.data_utils import PennFudanDataset, get_transform
|
||||
|
||||
|
||||
def test_model_creation():
|
||||
"""Test that we can create the model."""
|
||||
print("Testing model creation...")
|
||||
model = get_maskrcnn_model(
|
||||
num_classes=2, pretrained=False, pretrained_backbone=False
|
||||
)
|
||||
print("✓ Model created successfully")
|
||||
return model
|
||||
|
||||
|
||||
def test_model_forward(model, device):
|
||||
"""Test model forward pass with random inputs."""
|
||||
print("\nTesting model forward pass...")
|
||||
|
||||
# Create a random batch
|
||||
image = torch.rand(3, 300, 400, device=device) # Random image
|
||||
|
||||
# Create a random target
|
||||
target = {
|
||||
"boxes": torch.tensor(
|
||||
[[100, 100, 200, 200]], dtype=torch.float32, device=device
|
||||
),
|
||||
"labels": torch.tensor([1], dtype=torch.int64, device=device),
|
||||
"masks": torch.randint(0, 2, (1, 300, 400), dtype=torch.uint8, device=device),
|
||||
"image_id": torch.tensor([0], device=device),
|
||||
"area": torch.tensor([10000.0], dtype=torch.float32, device=device),
|
||||
"iscrowd": torch.tensor([0], dtype=torch.uint8, device=device),
|
||||
}
|
||||
|
||||
# Test inference mode (no targets)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
start_time = time.time()
|
||||
output_inference = model([image])
|
||||
inference_time = time.time() - start_time
|
||||
|
||||
# Verify inference output
|
||||
print(f"✓ Inference mode output: {type(output_inference)}")
|
||||
print(f"✓ Inference time: {inference_time:.3f}s")
|
||||
print(f"✓ Detection boxes shape: {output_inference[0]['boxes'].shape}")
|
||||
print(f"✓ Detection scores shape: {output_inference[0]['scores'].shape}")
|
||||
|
||||
# Test training mode (with targets)
|
||||
model.train()
|
||||
start_time = time.time()
|
||||
output_train = model([image], [target])
|
||||
train_time = time.time() - start_time
|
||||
|
||||
# Verify training output
|
||||
print(f"✓ Training mode output: {type(output_train)}")
|
||||
print(f"✓ Training time: {train_time:.3f}s")
|
||||
|
||||
# Print loss values
|
||||
for loss_name, loss_value in output_train.items():
|
||||
print(f"✓ {loss_name}: {loss_value.item():.4f}")
|
||||
|
||||
return output_train
|
||||
|
||||
|
||||
def test_model_backward(model, loss_dict, device):
|
||||
"""Test model backward pass."""
|
||||
print("\nTesting model backward pass...")
|
||||
|
||||
# Calculate total loss
|
||||
total_loss = sum(loss for loss in loss_dict.values())
|
||||
|
||||
# Create optimizer
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
|
||||
|
||||
# Backward pass
|
||||
start_time = time.time()
|
||||
optimizer.zero_grad()
|
||||
total_loss.backward()
|
||||
optimizer.step()
|
||||
backward_time = time.time() - start_time
|
||||
|
||||
print("✓ Backward pass and optimization completed")
|
||||
print(f"✓ Backward time: {backward_time:.3f}s")
|
||||
|
||||
# Check that gradients were calculated
|
||||
has_gradients = any(
|
||||
param.grad is not None for param in model.parameters() if param.requires_grad
|
||||
)
|
||||
print(f"✓ Model has gradients: {has_gradients}")
|
||||
|
||||
|
||||
def test_dataset():
|
||||
"""Test that we can load the dataset."""
|
||||
print("\nTesting dataset loading...")
|
||||
|
||||
data_root = "data/PennFudanPed"
|
||||
if not os.path.exists(data_root):
|
||||
print("✗ Dataset not found at", data_root)
|
||||
return None
|
||||
|
||||
# Create dataset
|
||||
dataset = PennFudanDataset(root=data_root, transforms=get_transform(train=True))
|
||||
print(f"✓ Dataset loaded with {len(dataset)} samples")
|
||||
|
||||
# Test loading a sample
|
||||
start_time = time.time()
|
||||
img, target = dataset[0]
|
||||
load_time = time.time() - start_time
|
||||
|
||||
print(f"✓ Sample loaded in {load_time:.3f}s")
|
||||
print(f"✓ Image shape: {img.shape}")
|
||||
print(f"✓ Target boxes shape: {target['boxes'].shape}")
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all tests."""
|
||||
print("=== Quick Model Testing Script ===")
|
||||
|
||||
# Set device
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Run tests
|
||||
model = test_model_creation()
|
||||
model.to(device)
|
||||
|
||||
loss_dict = test_model_forward(model, device)
|
||||
|
||||
test_model_backward(model, loss_dict, device)
|
||||
|
||||
test_dataset()
|
||||
|
||||
print("\n=== All tests completed successfully ===")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
175
scripts/visualize_predictions.py
Executable file
175
scripts/visualize_predictions.py
Executable file
@@ -0,0 +1,175 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Visualization script for model predictions on the Penn-Fudan dataset.
|
||||
This helps visualize and debug model predictions.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# Add project root to path for imports
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from models.detection import get_maskrcnn_model
|
||||
from utils.common import load_checkpoint, load_config
|
||||
from utils.data_utils import PennFudanDataset, get_transform
|
||||
|
||||
|
||||
def visualize_prediction(image, prediction, threshold=0.5):
|
||||
"""
|
||||
Visualize model prediction on an image.
|
||||
|
||||
Args:
|
||||
image (torch.Tensor): The input image [C, H, W]
|
||||
prediction (dict): Model prediction dict with boxes, scores, labels, masks
|
||||
threshold (float): Score threshold for visualization
|
||||
|
||||
Returns:
|
||||
plt.Figure: The matplotlib figure with the visualization
|
||||
"""
|
||||
# Convert image from tensor to numpy
|
||||
img_np = image.permute(1, 2, 0).cpu().numpy()
|
||||
|
||||
# Denormalize if needed
|
||||
if img_np.max() <= 1.0:
|
||||
img_np = (img_np * 255).astype(np.uint8)
|
||||
|
||||
# Create figure and axes
|
||||
fig, ax = plt.subplots(1, 1, figsize=(12, 9))
|
||||
ax.imshow(img_np)
|
||||
ax.set_title("Model Predictions")
|
||||
|
||||
# Get predictions
|
||||
boxes = prediction["boxes"].cpu().numpy()
|
||||
scores = prediction["scores"].cpu().numpy()
|
||||
labels = prediction["labels"].cpu().numpy()
|
||||
masks = prediction["masks"].cpu().numpy()
|
||||
|
||||
# Filter by threshold
|
||||
mask = scores >= threshold
|
||||
boxes = boxes[mask]
|
||||
scores = scores[mask]
|
||||
labels = labels[mask]
|
||||
masks = masks[mask]
|
||||
|
||||
# Draw predictions
|
||||
for box, score, label, mask in zip(boxes, scores, labels, masks):
|
||||
# Draw box
|
||||
x1, y1, x2, y2 = box
|
||||
rect = plt.Rectangle(
|
||||
(x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="red", linewidth=2
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
|
||||
# Add label and score
|
||||
ax.text(
|
||||
x1, y1, f"Person: {score:.2f}", bbox=dict(facecolor="yellow", alpha=0.5)
|
||||
)
|
||||
|
||||
# Draw mask (with transparency)
|
||||
mask = mask[0] > 0.5 # Threshold mask
|
||||
mask_color = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
|
||||
mask_color[mask] = [255, 0, 0] # Red color
|
||||
ax.imshow(mask_color, alpha=0.3)
|
||||
|
||||
# Show count of detections
|
||||
ax.set_xlabel(f"Found {len(boxes)} pedestrians with confidence >= {threshold}")
|
||||
|
||||
plt.tight_layout()
|
||||
return fig
|
||||
|
||||
|
||||
def run_inference(model, dataset, device, idx=0):
|
||||
"""
|
||||
Run inference on a single image from the dataset.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model
|
||||
dataset (PennFudanDataset): The dataset
|
||||
device (torch.device): The device
|
||||
idx (int): Index of the image to test
|
||||
|
||||
Returns:
|
||||
tuple: (image, prediction)
|
||||
"""
|
||||
# Get image
|
||||
image, _ = dataset[idx]
|
||||
|
||||
# Prepare for model
|
||||
image = image.to(device)
|
||||
|
||||
# Run inference
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
prediction = model([image])[0]
|
||||
|
||||
return image, prediction
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
parser = argparse.ArgumentParser(description="Visualize model predictions")
|
||||
parser.add_argument("--config", required=True, help="Path to config file")
|
||||
parser.add_argument("--checkpoint", required=True, help="Path to checkpoint file")
|
||||
parser.add_argument("--index", type=int, default=0, help="Image index to visualize")
|
||||
parser.add_argument("--threshold", type=float, default=0.5, help="Score threshold")
|
||||
parser.add_argument("--output", help="Path to save visualization image")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load config
|
||||
config = load_config(args.config)
|
||||
|
||||
# Setup device
|
||||
device = torch.device(config.get("device", "cpu"))
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Create model
|
||||
model = get_maskrcnn_model(
|
||||
num_classes=config.get("num_classes", 2),
|
||||
pretrained=False,
|
||||
pretrained_backbone=False,
|
||||
)
|
||||
|
||||
# Load checkpoint
|
||||
checkpoint, _ = load_checkpoint(args.checkpoint, model, device)
|
||||
model.to(device)
|
||||
print(f"Loaded checkpoint from: {args.checkpoint}")
|
||||
|
||||
# Create dataset
|
||||
data_root = config.get("data_root", "data/PennFudanPed")
|
||||
if not os.path.exists(data_root):
|
||||
print(f"Error: Data not found at {data_root}")
|
||||
return
|
||||
|
||||
dataset = PennFudanDataset(root=data_root, transforms=get_transform(train=False))
|
||||
print(f"Dataset loaded with {len(dataset)} images")
|
||||
|
||||
# Validate index
|
||||
if args.index < 0 or args.index >= len(dataset):
|
||||
print(f"Error: Index {args.index} out of range (0-{len(dataset)-1})")
|
||||
return
|
||||
|
||||
# Run inference
|
||||
print(f"Running inference on image {args.index}...")
|
||||
image, prediction = run_inference(model, dataset, device, args.index)
|
||||
|
||||
# Visualize prediction
|
||||
print("Visualizing predictions...")
|
||||
fig = visualize_prediction(image, prediction, threshold=args.threshold)
|
||||
|
||||
# Save or show
|
||||
if args.output:
|
||||
fig.savefig(args.output)
|
||||
print(f"Visualization saved to: {args.output}")
|
||||
else:
|
||||
plt.show()
|
||||
print("Visualization displayed. Close window to continue.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user