Compare commits
3 Commits
046e36678e
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
27082cbf33 | ||
|
|
e3b0f2a368 | ||
|
|
baba9b9b9f |
232
README.md
232
README.md
@@ -1,3 +1,233 @@
|
||||
This project was a test run using Cursor and "vibe coding" to create a full object detection project. I wrote almost no lines of code to get to this point, which kind of works. The technology is definitely impressive, but really feels more suited to things that can be developed in a more test-driven way. I'll update this later with other things I've learned along the way.
|
||||
|
||||
I stopped this project here because it got trapped in a doom loop not being able to fix a bug in the eval code and I wanted this to be an investigation into how well I could do with very low intervention.
|
||||
|
||||
|
||||
# Torchvision Vibecoding Project
|
||||
|
||||
A project demonstrating finetuning torchvision object detection models, built with the help of Vibecoding AI.
|
||||
A PyTorch-based object detection project using Mask R-CNN to detect pedestrians in the Penn-Fudan dataset. This project demonstrates model training, evaluation, and visualization with PyTorch and Torchvision.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Prerequisites](#prerequisites)
|
||||
- [Project Setup](#project-setup)
|
||||
- [Project Structure](#project-structure)
|
||||
- [Data Preparation](#data-preparation)
|
||||
- [Configuration](#configuration)
|
||||
- [Training](#training)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Visualization](#visualization)
|
||||
- [Testing](#testing)
|
||||
- [Debugging](#debugging)
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Python 3.10+
|
||||
- [uv](https://github.com/astral-sh/uv) for package management
|
||||
- CUDA-compatible GPU (optional but recommended)
|
||||
|
||||
## Project Setup
|
||||
|
||||
1. Clone the repository:
|
||||
```bash
|
||||
git clone https://github.com/yourusername/torchvision-vibecoding-project.git
|
||||
cd torchvision-vibecoding-project
|
||||
```
|
||||
|
||||
2. Set up the environment with uv:
|
||||
```bash
|
||||
uv init
|
||||
uv sync
|
||||
```
|
||||
|
||||
3. Install development dependencies:
|
||||
```bash
|
||||
uv add ruff pytest matplotlib
|
||||
```
|
||||
|
||||
4. Set up pre-commit hooks:
|
||||
```bash
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
├── configs/ # Configuration files
|
||||
│ ├── base_config.py # Base configuration with defaults
|
||||
│ ├── debug_config.py # Configuration for quick debugging
|
||||
│ └── pennfudan_maskrcnn_config.py # Configuration for Penn-Fudan dataset
|
||||
├── data/ # Dataset directory (not tracked by git)
|
||||
│ └── PennFudanPed/ # Penn-Fudan pedestrian dataset
|
||||
├── models/ # Model definitions
|
||||
│ └── detection.py # Mask R-CNN model definition
|
||||
├── outputs/ # Training outputs (not tracked by git)
|
||||
│ └── <config_name>/ # Named by configuration
|
||||
│ ├── checkpoints/ # Model checkpoints
|
||||
│ └── *.log # Log files
|
||||
├── scripts/ # Utility scripts
|
||||
│ ├── download_data.sh # Script to download dataset
|
||||
│ ├── test_model.py # Script for quick model testing
|
||||
│ └── visualize_predictions.py # Script for prediction visualization
|
||||
├── tests/ # Unit tests
|
||||
│ ├── conftest.py # Test fixtures
|
||||
│ ├── test_data_utils.py # Tests for data utilities
|
||||
│ ├── test_model.py # Tests for model functionality
|
||||
│ └── test_visualization.py # Tests for visualization
|
||||
├── utils/ # Utility modules
|
||||
│ ├── common.py # Common functionality
|
||||
│ ├── data_utils.py # Dataset handling
|
||||
│ ├── eval_utils.py # Evaluation functions
|
||||
│ └── log_utils.py # Logging utilities
|
||||
├── train.py # Training script
|
||||
├── test.py # Evaluation script
|
||||
├── pyproject.toml # Project dependencies and configuration
|
||||
├── .pre-commit-config.yaml # Pre-commit configuration
|
||||
└── README.md # This file
|
||||
```
|
||||
|
||||
## Data Preparation
|
||||
|
||||
Download the Penn-Fudan pedestrian dataset:
|
||||
|
||||
```bash
|
||||
./scripts/download_data.sh
|
||||
```
|
||||
|
||||
This will download and extract the dataset to the `data/PennFudanPed` directory.
|
||||
|
||||
## Configuration
|
||||
|
||||
The project uses Python dictionaries for configuration:
|
||||
|
||||
- `configs/base_config.py`: Default configuration values
|
||||
- `configs/pennfudan_maskrcnn_config.py`: Configuration for training on Penn-Fudan
|
||||
- `configs/debug_config.py`: Configuration for quick testing (CPU, minimal training)
|
||||
|
||||
Key configuration parameters:
|
||||
|
||||
- `data_root`: Path to dataset
|
||||
- `output_dir`: Directory for outputs
|
||||
- `device`: Computing device ('cuda' or 'cpu')
|
||||
- `batch_size`: Batch size for training
|
||||
- `num_epochs`: Number of training epochs
|
||||
- `lr`, `momentum`, `weight_decay`: Optimizer parameters
|
||||
|
||||
## Training
|
||||
|
||||
Run the training script with a configuration file:
|
||||
|
||||
```bash
|
||||
python train.py --config configs/pennfudan_maskrcnn_config.py
|
||||
```
|
||||
|
||||
For quick debugging on CPU:
|
||||
|
||||
```bash
|
||||
python train.py --config configs/debug_config.py
|
||||
```
|
||||
|
||||
To resume training from the latest checkpoint:
|
||||
|
||||
```bash
|
||||
python train.py --config configs/pennfudan_maskrcnn_config.py --resume
|
||||
```
|
||||
|
||||
Training outputs (logs, checkpoints) are saved to `outputs/<config_name>/`.
|
||||
|
||||
## Evaluation
|
||||
|
||||
Evaluate a trained model:
|
||||
|
||||
```bash
|
||||
python test.py --config configs/pennfudan_maskrcnn_config.py --checkpoint outputs/pennfudan_maskrcnn_v1/checkpoints/checkpoint_epoch_10.pth
|
||||
```
|
||||
|
||||
This runs the model on the test dataset and reports metrics.
|
||||
|
||||
## Visualization
|
||||
|
||||
Visualize model predictions on images:
|
||||
|
||||
```bash
|
||||
python scripts/visualize_predictions.py --config configs/pennfudan_maskrcnn_config.py --checkpoint outputs/pennfudan_maskrcnn_v1/checkpoints/checkpoint_epoch_10.pth --index 0 --output prediction.png
|
||||
```
|
||||
|
||||
Parameters:
|
||||
- `--config`: Configuration file path
|
||||
- `--checkpoint`: Model checkpoint path
|
||||
- `--index`: Image index in dataset (default: 0)
|
||||
- `--threshold`: Detection confidence threshold (default: 0.5)
|
||||
- `--output`: Output image path (optional, displays interactively if not specified)
|
||||
|
||||
## Testing
|
||||
|
||||
Run all tests:
|
||||
|
||||
```bash
|
||||
python -m pytest
|
||||
```
|
||||
|
||||
Run specific test file:
|
||||
|
||||
```bash
|
||||
python -m pytest tests/test_data_utils.py
|
||||
```
|
||||
|
||||
Run tests with verbosity:
|
||||
|
||||
```bash
|
||||
python -m pytest -v
|
||||
```
|
||||
|
||||
## Debugging
|
||||
|
||||
For quick model testing without full training:
|
||||
|
||||
```bash
|
||||
python scripts/test_model.py
|
||||
```
|
||||
|
||||
This verifies:
|
||||
- Model creation
|
||||
- Forward pass
|
||||
- Backward pass
|
||||
- Dataset loading
|
||||
|
||||
For training with minimal resources:
|
||||
|
||||
```bash
|
||||
python train.py --config configs/debug_config.py
|
||||
```
|
||||
|
||||
This uses:
|
||||
- CPU computation
|
||||
- Minimal epochs (1)
|
||||
- Small batch size (1)
|
||||
- No multiprocessing
|
||||
|
||||
## Code Quality
|
||||
|
||||
Format code:
|
||||
|
||||
```bash
|
||||
ruff format .
|
||||
```
|
||||
|
||||
Run linter:
|
||||
|
||||
```bash
|
||||
ruff check .
|
||||
```
|
||||
|
||||
Fix auto-fixable issues:
|
||||
|
||||
```bash
|
||||
ruff check --fix .
|
||||
```
|
||||
|
||||
Run pre-commit checks:
|
||||
|
||||
```bash
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
@@ -1,33 +1,34 @@
|
||||
"""
|
||||
Configuration for training Mask R-CNN on the Penn-Fudan dataset.
|
||||
Configuration for MaskRCNN training on the PennFudan Dataset.
|
||||
"""
|
||||
|
||||
from configs.base_config import base_config
|
||||
|
||||
# Create a copy of the base configuration
|
||||
config = base_config.copy()
|
||||
|
||||
# Update specific values for this experiment
|
||||
config.update(
|
||||
{
|
||||
# Core configuration
|
||||
"config_name": "pennfudan_maskrcnn_v1",
|
||||
"data_root": "data/PennFudanPed",
|
||||
"num_classes": 2, # background + pedestrian
|
||||
# Training parameters - modified for memory constraints
|
||||
"batch_size": 1, # Reduced from 2 to 1 to save memory
|
||||
"num_epochs": 10,
|
||||
# Optimizer settings
|
||||
"lr": 0.002, # Slightly reduced learning rate for smaller batch size
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 0.0005,
|
||||
# Memory optimization settings
|
||||
"pin_memory": False, # Set to False to reduce memory pressure
|
||||
"num_workers": 2, # Reduced from 4 to 2
|
||||
# Device settings
|
||||
"device": "cuda",
|
||||
}
|
||||
)
|
||||
config = {
|
||||
# Data settings
|
||||
"data_root": "data/PennFudanPed",
|
||||
"output_dir": "outputs",
|
||||
# Hardware settings
|
||||
"device": "cuda", # "cuda" or "cpu"
|
||||
# Model settings
|
||||
"num_classes": 2, # Background + person
|
||||
# Training settings
|
||||
"batch_size": 1, # Reduced from 2 to 1 to save memory
|
||||
"num_epochs": 10,
|
||||
"seed": 42,
|
||||
# Optimizer settings
|
||||
"lr": 0.002,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 0.0005,
|
||||
"lr_step_size": 3,
|
||||
"lr_gamma": 0.1,
|
||||
# Logging and checkpoints
|
||||
"log_freq": 10, # Log every N steps
|
||||
"checkpoint_freq": 1, # Save checkpoint every N epochs
|
||||
# Run identification
|
||||
"config_name": "pennfudan_maskrcnn_v1",
|
||||
# DataLoader settings
|
||||
"pin_memory": False, # Set to False to reduce memory usage
|
||||
"num_workers": 2, # Reduced from 4 to 2 to reduce memory pressure
|
||||
}
|
||||
|
||||
# Ensure derived paths or settings are consistent if needed
|
||||
# (Not strictly necessary with this simple structure)
|
||||
|
||||
25
test.py
25
test.py
@@ -31,6 +31,8 @@ def main(args):
|
||||
logging.info(f"Loaded configuration from: {args.config}")
|
||||
logging.info(f"Checkpoint path: {args.checkpoint}")
|
||||
logging.info(f"Loaded configuration dictionary: {config}")
|
||||
if args.max_samples:
|
||||
logging.info(f"Limiting evaluation to {args.max_samples} samples")
|
||||
|
||||
# Validate data path
|
||||
data_root = config.get("data_root")
|
||||
@@ -86,12 +88,15 @@ def main(args):
|
||||
# Run Evaluation
|
||||
try:
|
||||
logging.info("Starting model evaluation...")
|
||||
eval_metrics = evaluate(model, data_loader_test, device)
|
||||
eval_metrics = evaluate(model, data_loader_test, device, args.max_samples)
|
||||
|
||||
# Log detailed metrics
|
||||
logging.info("--- Evaluation Results ---")
|
||||
for metric_name, metric_value in eval_metrics.items():
|
||||
logging.info(f" {metric_name}: {metric_value:.4f}")
|
||||
if isinstance(metric_value, (int, float)):
|
||||
logging.info(f" {metric_name}: {metric_value:.4f}")
|
||||
else:
|
||||
logging.info(f" {metric_name}: {metric_value}")
|
||||
|
||||
logging.info("Evaluation completed successfully")
|
||||
except Exception as e:
|
||||
@@ -100,10 +105,20 @@ def main(args):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Test a trained Mask R-CNN model")
|
||||
parser.add_argument("--config", required=True, help="Path to configuration file")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Test script for torchvision Mask R-CNN"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint", required=True, help="Path to model checkpoint file (.pth)"
|
||||
"--config", required=True, type=str, help="Path to configuration file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint", required=True, type=str, help="Path to model checkpoint"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_samples",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum number of samples to evaluate",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
@@ -1,40 +1,84 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchvision.ops import box_iou
|
||||
|
||||
|
||||
def evaluate(model, data_loader, device):
|
||||
def evaluate(model, data_loader, device, max_samples=None):
|
||||
"""Performs evaluation on the dataset for one epoch.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model to evaluate.
|
||||
data_loader (torch.utils.data.DataLoader): DataLoader for the evaluation data.
|
||||
device (torch.device): The device to run evaluation on.
|
||||
max_samples (int, optional): Maximum number of batches to evaluate. If None, evaluate all.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing evaluation metrics (e.g., average loss).
|
||||
dict: A dictionary containing evaluation metrics (e.g., average loss, mAP).
|
||||
"""
|
||||
model.eval() # Set model to evaluation mode
|
||||
total_loss = 0.0
|
||||
num_batches = len(data_loader)
|
||||
|
||||
# Limit evaluation samples if specified
|
||||
if max_samples is not None:
|
||||
num_batches = min(num_batches, max_samples)
|
||||
logging.info(f"Limiting evaluation to {num_batches} batches")
|
||||
|
||||
eval_start_time = time.time()
|
||||
status_interval = max(1, num_batches // 10) # Log status roughly 10 times
|
||||
|
||||
# Initialize metrics collection
|
||||
inference_times = []
|
||||
|
||||
# IoU thresholds for mAP calculation
|
||||
iou_thresholds = [0.5, 0.75, 0.5] # 0.5, 0.75, 0.5:0.95 (COCO standard)
|
||||
confidence_thresholds = [0.5, 0.75, 0.9] # Different confidence thresholds
|
||||
|
||||
# Initialize counters for metrics
|
||||
metric_accumulators = initialize_metric_accumulators(
|
||||
iou_thresholds, confidence_thresholds
|
||||
)
|
||||
|
||||
logging.info("--- Starting Evaluation --- ")
|
||||
|
||||
with torch.no_grad(): # Disable gradient calculations
|
||||
for i, (images, targets) in enumerate(data_loader):
|
||||
# Stop if we've reached the max samples
|
||||
if max_samples is not None and i >= max_samples:
|
||||
break
|
||||
|
||||
# Free cached memory
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
images = list(image.to(device) for image in images)
|
||||
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
||||
|
||||
# To handle the different behavior of Mask R-CNN in eval mode,
|
||||
# we explicitly reset the model to training mode to compute losses,
|
||||
# then switch back to eval mode for the rest of the evaluation
|
||||
# Measure inference time
|
||||
start_time = time.time()
|
||||
# Get predictions in eval mode
|
||||
predictions = model(images)
|
||||
inference_time = time.time() - start_time
|
||||
inference_times.append(inference_time)
|
||||
|
||||
# Process metrics on-the-fly for this batch only
|
||||
process_batch_metrics(
|
||||
predictions,
|
||||
targets,
|
||||
metric_accumulators,
|
||||
iou_thresholds,
|
||||
confidence_thresholds,
|
||||
)
|
||||
|
||||
# Compute losses (switch to train mode temporarily)
|
||||
model.train()
|
||||
loss_dict = model(images, targets)
|
||||
model.eval()
|
||||
|
||||
# Calculate total loss
|
||||
losses = sum(loss for loss in loss_dict.values())
|
||||
loss_value = losses.item()
|
||||
total_loss += loss_value
|
||||
@@ -42,18 +86,727 @@ def evaluate(model, data_loader, device):
|
||||
if (i + 1) % status_interval == 0:
|
||||
logging.info(f" Evaluated batch {i + 1}/{num_batches}")
|
||||
|
||||
# Explicitly clean up to help with memory
|
||||
del images, targets, predictions, loss_dict
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Calculate basic metrics
|
||||
avg_loss = total_loss / num_batches if num_batches > 0 else 0
|
||||
avg_inference_time = np.mean(inference_times) if inference_times else 0
|
||||
|
||||
# Calculate final metrics from accumulators
|
||||
metrics = {
|
||||
"average_loss": avg_loss,
|
||||
"average_inference_time": avg_inference_time,
|
||||
}
|
||||
|
||||
# Compute final metrics from accumulators
|
||||
metrics.update(finalize_metrics(metric_accumulators))
|
||||
|
||||
eval_duration = time.time() - eval_start_time
|
||||
|
||||
# Log results
|
||||
logging.info("--- Evaluation Finished ---")
|
||||
logging.info(f" Average Evaluation Loss: {avg_loss:.4f}")
|
||||
logging.info(f" Average Inference Time: {avg_inference_time:.4f}s per batch")
|
||||
|
||||
# Log detailed metrics
|
||||
for metric_name, metric_value in metrics.items():
|
||||
if metric_name != "average_loss": # Already logged
|
||||
if isinstance(metric_value, (int, float)):
|
||||
logging.info(f" {metric_name}: {metric_value:.4f}")
|
||||
else:
|
||||
logging.info(f" {metric_name}: {metric_value}")
|
||||
|
||||
logging.info(f" Evaluation Duration: {eval_duration:.2f}s")
|
||||
|
||||
# Return metrics (currently just average loss)
|
||||
metrics = {"average_loss": avg_loss}
|
||||
return metrics
|
||||
|
||||
|
||||
def initialize_metric_accumulators(iou_thresholds, confidence_thresholds):
|
||||
"""Initialize accumulators for incremental metric calculation"""
|
||||
accumulators = {
|
||||
"total_gt": 0,
|
||||
"map_accumulators": {},
|
||||
"conf_accumulators": {},
|
||||
"size_accumulators": {
|
||||
"small_gt": 0,
|
||||
"medium_gt": 0,
|
||||
"large_gt": 0,
|
||||
"small_tp": 0,
|
||||
"medium_tp": 0,
|
||||
"large_tp": 0,
|
||||
"small_det": 0,
|
||||
"medium_det": 0,
|
||||
"large_det": 0,
|
||||
},
|
||||
}
|
||||
|
||||
# Initialize map accumulators for each IoU threshold
|
||||
for iou in iou_thresholds:
|
||||
accumulators["map_accumulators"][iou] = {
|
||||
"true_positives": 0,
|
||||
"false_positives": 0,
|
||||
"total_detections": 0,
|
||||
}
|
||||
|
||||
# Initialize confidence accumulators
|
||||
for conf in confidence_thresholds:
|
||||
accumulators["conf_accumulators"][conf] = {
|
||||
"true_positives": 0,
|
||||
"detections": 0,
|
||||
}
|
||||
|
||||
return accumulators
|
||||
|
||||
|
||||
def process_batch_metrics(
|
||||
predictions, targets, accumulators, iou_thresholds, confidence_thresholds
|
||||
):
|
||||
"""Process metrics for a single batch incrementally"""
|
||||
small_threshold = 32 * 32 # Small objects: area < 32²
|
||||
medium_threshold = 96 * 96 # Medium objects: 32² <= area < 96²
|
||||
|
||||
# Count total ground truth boxes in this batch
|
||||
batch_gt = sum(len(target["boxes"]) for target in targets)
|
||||
accumulators["total_gt"] += batch_gt
|
||||
|
||||
# Process all predictions in the batch
|
||||
for pred, target in zip(predictions, targets):
|
||||
pred_boxes = pred["boxes"]
|
||||
pred_scores = pred["scores"]
|
||||
pred_labels = pred["labels"]
|
||||
gt_boxes = target["boxes"]
|
||||
gt_labels = target["labels"]
|
||||
|
||||
# Skip if no predictions or no ground truth
|
||||
if len(pred_boxes) == 0 or len(gt_boxes) == 0:
|
||||
continue
|
||||
|
||||
# Calculate IoU between predictions and ground truth
|
||||
iou_matrix = box_iou(pred_boxes, gt_boxes)
|
||||
|
||||
# Process size-based metrics
|
||||
gt_areas = target.get("area", None)
|
||||
if gt_areas is None:
|
||||
# Calculate if not provided
|
||||
gt_areas = (gt_boxes[:, 2] - gt_boxes[:, 0]) * (
|
||||
gt_boxes[:, 3] - gt_boxes[:, 1]
|
||||
)
|
||||
|
||||
# Count ground truth by size
|
||||
small_mask_gt = gt_areas < small_threshold
|
||||
medium_mask_gt = (gt_areas >= small_threshold) & (gt_areas < medium_threshold)
|
||||
large_mask_gt = gt_areas >= medium_threshold
|
||||
|
||||
accumulators["size_accumulators"]["small_gt"] += torch.sum(small_mask_gt).item()
|
||||
accumulators["size_accumulators"]["medium_gt"] += torch.sum(
|
||||
medium_mask_gt
|
||||
).item()
|
||||
accumulators["size_accumulators"]["large_gt"] += torch.sum(large_mask_gt).item()
|
||||
|
||||
# Calculate areas for predictions
|
||||
pred_areas = (pred_boxes[:, 2] - pred_boxes[:, 0]) * (
|
||||
pred_boxes[:, 3] - pred_boxes[:, 1]
|
||||
)
|
||||
|
||||
# Count predictions by size (with confidence >= 0.5)
|
||||
conf_mask = pred_scores >= 0.5
|
||||
if torch.sum(conf_mask) == 0:
|
||||
continue # Skip if no predictions meet confidence threshold
|
||||
|
||||
small_mask = (pred_areas < small_threshold) & conf_mask
|
||||
medium_mask = (
|
||||
(pred_areas >= small_threshold)
|
||||
& (pred_areas < medium_threshold)
|
||||
& conf_mask
|
||||
)
|
||||
large_mask = (pred_areas >= medium_threshold) & conf_mask
|
||||
|
||||
accumulators["size_accumulators"]["small_det"] += torch.sum(small_mask).item()
|
||||
accumulators["size_accumulators"]["medium_det"] += torch.sum(medium_mask).item()
|
||||
accumulators["size_accumulators"]["large_det"] += torch.sum(large_mask).item()
|
||||
|
||||
# Process metrics for each IoU threshold
|
||||
for iou_threshold in iou_thresholds:
|
||||
process_iou_metrics(
|
||||
pred_boxes,
|
||||
pred_scores,
|
||||
pred_labels,
|
||||
gt_boxes,
|
||||
gt_labels,
|
||||
iou_matrix,
|
||||
accumulators["map_accumulators"][iou_threshold],
|
||||
iou_threshold,
|
||||
)
|
||||
|
||||
# Process metrics for each confidence threshold
|
||||
for conf_threshold in confidence_thresholds:
|
||||
process_confidence_metrics(
|
||||
pred_boxes,
|
||||
pred_scores,
|
||||
pred_labels,
|
||||
gt_boxes,
|
||||
gt_labels,
|
||||
iou_matrix,
|
||||
accumulators["conf_accumulators"][conf_threshold],
|
||||
conf_threshold,
|
||||
)
|
||||
|
||||
# Process size-based true positives with fixed IoU threshold of 0.5
|
||||
# Use a new gt_matched array to avoid interference with other metric calculations
|
||||
gt_matched = torch.zeros(len(gt_boxes), dtype=torch.bool)
|
||||
filtered_mask = pred_scores >= 0.5
|
||||
|
||||
if torch.sum(filtered_mask) > 0:
|
||||
filtered_boxes = pred_boxes[filtered_mask]
|
||||
filtered_scores = pred_scores[filtered_mask]
|
||||
filtered_labels = pred_labels[filtered_mask]
|
||||
# Recalculate IoU for filtered boxes
|
||||
filtered_iou_matrix = box_iou(filtered_boxes, gt_boxes)
|
||||
|
||||
# Sort predictions by confidence
|
||||
sorted_indices = torch.argsort(filtered_scores, descending=True)
|
||||
|
||||
for idx in sorted_indices:
|
||||
best_iou, best_gt_idx = torch.max(filtered_iou_matrix[idx], dim=0)
|
||||
|
||||
if best_iou >= 0.5 and not gt_matched[best_gt_idx]:
|
||||
if filtered_labels[idx] == gt_labels[best_gt_idx]:
|
||||
gt_matched[best_gt_idx] = True
|
||||
|
||||
# Categorize true positive by ground truth size (not prediction size)
|
||||
area = gt_areas[best_gt_idx].item()
|
||||
if area < small_threshold:
|
||||
accumulators["size_accumulators"]["small_tp"] += 1
|
||||
elif area < medium_threshold:
|
||||
accumulators["size_accumulators"]["medium_tp"] += 1
|
||||
else:
|
||||
accumulators["size_accumulators"]["large_tp"] += 1
|
||||
|
||||
|
||||
def process_iou_metrics(
|
||||
pred_boxes,
|
||||
pred_scores,
|
||||
pred_labels,
|
||||
gt_boxes,
|
||||
gt_labels,
|
||||
iou_matrix,
|
||||
accumulator,
|
||||
iou_threshold,
|
||||
):
|
||||
"""Process metrics for a specific IoU threshold"""
|
||||
# Apply a minimum confidence threshold of 0.05 for metrics
|
||||
min_conf_threshold = 0.05
|
||||
conf_mask = pred_scores >= min_conf_threshold
|
||||
|
||||
if torch.sum(conf_mask) == 0:
|
||||
return # Skip if no predictions after confidence filtering
|
||||
|
||||
# Filter predictions by confidence
|
||||
filtered_boxes = pred_boxes[conf_mask]
|
||||
filtered_scores = pred_scores[conf_mask]
|
||||
filtered_labels = pred_labels[conf_mask]
|
||||
|
||||
# Initialize array to track which gt boxes have been matched
|
||||
gt_matched = torch.zeros(len(gt_boxes), dtype=torch.bool)
|
||||
|
||||
# We may need a filtered IoU matrix if we're filtering predictions
|
||||
if len(filtered_boxes) < len(pred_boxes):
|
||||
# Recalculate IoU for filtered predictions
|
||||
filtered_iou_matrix = box_iou(filtered_boxes, gt_boxes)
|
||||
else:
|
||||
filtered_iou_matrix = iou_matrix
|
||||
|
||||
# Sort predictions by confidence score (high to low)
|
||||
sorted_indices = torch.argsort(filtered_scores, descending=True)
|
||||
|
||||
# True positives count for this batch
|
||||
batch_tp = 0
|
||||
|
||||
for idx in sorted_indices:
|
||||
# Find best matching ground truth box
|
||||
iou_values = filtered_iou_matrix[idx]
|
||||
|
||||
# Skip if no ground truth boxes
|
||||
if len(iou_values) == 0:
|
||||
# This is a false positive since there's no ground truth to match
|
||||
accumulator["false_positives"] += 1
|
||||
continue
|
||||
|
||||
best_iou, best_gt_idx = torch.max(iou_values, dim=0)
|
||||
|
||||
# Check if the prediction matches a ground truth box
|
||||
if (
|
||||
best_iou >= iou_threshold
|
||||
and not gt_matched[best_gt_idx]
|
||||
and filtered_labels[idx] == gt_labels[best_gt_idx]
|
||||
):
|
||||
batch_tp += 1
|
||||
gt_matched[best_gt_idx] = True
|
||||
else:
|
||||
accumulator["false_positives"] += 1
|
||||
|
||||
# Update true positives - Important: Don't artificially cap true positives here
|
||||
# Let finalize_metrics handle the capping to avoid recall underestimation during intermediate calculations
|
||||
accumulator["true_positives"] += batch_tp
|
||||
|
||||
# Count total detection (after confidence filtering)
|
||||
accumulator["total_detections"] += len(filtered_boxes)
|
||||
|
||||
|
||||
def process_confidence_metrics(
|
||||
pred_boxes,
|
||||
pred_scores,
|
||||
pred_labels,
|
||||
gt_boxes,
|
||||
gt_labels,
|
||||
iou_matrix,
|
||||
accumulator,
|
||||
conf_threshold,
|
||||
):
|
||||
"""Process metrics for a specific confidence threshold"""
|
||||
# Filter by confidence
|
||||
mask = pred_scores >= conf_threshold
|
||||
|
||||
if torch.sum(mask) == 0:
|
||||
return # Skip if no predictions after filtering
|
||||
|
||||
filtered_boxes = pred_boxes[mask]
|
||||
filtered_scores = pred_scores[mask]
|
||||
filtered_labels = pred_labels[mask]
|
||||
|
||||
accumulator["detections"] += len(filtered_boxes)
|
||||
|
||||
if len(filtered_boxes) == 0 or len(gt_boxes) == 0:
|
||||
return
|
||||
|
||||
# Calculate matches with fixed IoU threshold of 0.5
|
||||
gt_matched = torch.zeros(len(gt_boxes), dtype=torch.bool)
|
||||
|
||||
# We need to recalculate IoU for the filtered boxes
|
||||
filtered_iou_matrix = box_iou(filtered_boxes, gt_boxes)
|
||||
|
||||
# Sort by confidence for consistent ordering
|
||||
sorted_indices = torch.argsort(filtered_scores, descending=True)
|
||||
|
||||
for pred_idx in sorted_indices:
|
||||
best_iou, best_gt_idx = torch.max(filtered_iou_matrix[pred_idx], dim=0)
|
||||
if best_iou >= 0.5 and not gt_matched[best_gt_idx]:
|
||||
if filtered_labels[pred_idx] == gt_labels[best_gt_idx]:
|
||||
accumulator["true_positives"] += 1
|
||||
gt_matched[best_gt_idx] = True
|
||||
|
||||
|
||||
def finalize_metrics(accumulators):
|
||||
"""Calculate final metrics from accumulators"""
|
||||
metrics = {}
|
||||
total_gt = accumulators["total_gt"]
|
||||
|
||||
# Calculate mAP metrics
|
||||
for iou_threshold, map_acc in accumulators["map_accumulators"].items():
|
||||
true_positives = map_acc["true_positives"]
|
||||
false_positives = map_acc["false_positives"]
|
||||
|
||||
# Calculate metrics - Only cap true positives at the very end for final metrics
|
||||
# to prevent recall underestimation during intermediate calculations
|
||||
precision = true_positives / max(true_positives + false_positives, 1)
|
||||
recall = true_positives / max(total_gt, 1)
|
||||
|
||||
# Cap metrics for final reporting to ensure they're in valid range
|
||||
precision = min(1.0, precision)
|
||||
recall = min(1.0, recall)
|
||||
|
||||
f1_score = 2 * precision * recall / max(precision + recall, 1e-6)
|
||||
|
||||
# Simple average precision calculation (precision * recall)
|
||||
# This is a simplification; full AP calculation requires a PR curve
|
||||
ap = precision * recall
|
||||
|
||||
metrics.update(
|
||||
{
|
||||
f"mAP@{iou_threshold}": ap,
|
||||
f"precision@{iou_threshold}": precision,
|
||||
f"recall@{iou_threshold}": recall,
|
||||
f"f1_score@{iou_threshold}": f1_score,
|
||||
f"tp@{iou_threshold}": true_positives,
|
||||
f"fp@{iou_threshold}": false_positives,
|
||||
"gt_total": total_gt,
|
||||
}
|
||||
)
|
||||
|
||||
# Calculate confidence threshold metrics
|
||||
for conf_threshold, conf_acc in accumulators["conf_accumulators"].items():
|
||||
true_positives = conf_acc["true_positives"]
|
||||
detections = conf_acc["detections"]
|
||||
|
||||
# Calculate metrics without artificial capping to prevent recall underestimation
|
||||
precision = true_positives / max(detections, 1)
|
||||
recall = true_positives / max(total_gt, 1)
|
||||
|
||||
# Cap metrics for final reporting only
|
||||
precision = min(1.0, precision)
|
||||
recall = min(1.0, recall)
|
||||
|
||||
f1_score = 2 * precision * recall / max(precision + recall, 1e-6)
|
||||
|
||||
metrics.update(
|
||||
{
|
||||
f"precision@conf{conf_threshold}": precision,
|
||||
f"recall@conf{conf_threshold}": recall,
|
||||
f"f1_score@conf{conf_threshold}": f1_score,
|
||||
f"detections@conf{conf_threshold}": detections,
|
||||
f"tp@conf{conf_threshold}": true_positives,
|
||||
}
|
||||
)
|
||||
|
||||
# Calculate size metrics
|
||||
size_acc = accumulators["size_accumulators"]
|
||||
small_gt = size_acc["small_gt"]
|
||||
medium_gt = size_acc["medium_gt"]
|
||||
large_gt = size_acc["large_gt"]
|
||||
small_tp = size_acc["small_tp"]
|
||||
medium_tp = size_acc["medium_tp"]
|
||||
large_tp = size_acc["large_tp"]
|
||||
small_det = size_acc["small_det"]
|
||||
medium_det = size_acc["medium_det"]
|
||||
large_det = size_acc["large_det"]
|
||||
|
||||
# Calculate precision and recall without artificial capping
|
||||
small_precision = small_tp / max(small_det, 1) if small_det > 0 else 0
|
||||
small_recall = small_tp / max(small_gt, 1) if small_gt > 0 else 0
|
||||
|
||||
medium_precision = medium_tp / max(medium_det, 1) if medium_det > 0 else 0
|
||||
medium_recall = medium_tp / max(medium_gt, 1) if medium_gt > 0 else 0
|
||||
|
||||
large_precision = large_tp / max(large_det, 1) if large_det > 0 else 0
|
||||
large_recall = large_tp / max(large_gt, 1) if large_gt > 0 else 0
|
||||
|
||||
# Cap metrics for final reporting
|
||||
small_precision = min(1.0, small_precision)
|
||||
small_recall = min(1.0, small_recall)
|
||||
medium_precision = min(1.0, medium_precision)
|
||||
medium_recall = min(1.0, medium_recall)
|
||||
large_precision = min(1.0, large_precision)
|
||||
large_recall = min(1.0, large_recall)
|
||||
|
||||
metrics.update(
|
||||
{
|
||||
"small_precision": small_precision,
|
||||
"small_recall": small_recall,
|
||||
"small_count": small_gt,
|
||||
"small_tp": small_tp,
|
||||
"small_det": small_det,
|
||||
"medium_precision": medium_precision,
|
||||
"medium_recall": medium_recall,
|
||||
"medium_count": medium_gt,
|
||||
"medium_tp": medium_tp,
|
||||
"medium_det": medium_det,
|
||||
"large_precision": large_precision,
|
||||
"large_recall": large_recall,
|
||||
"large_count": large_gt,
|
||||
"large_tp": large_tp,
|
||||
"large_det": large_det,
|
||||
}
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def calculate_map(predictions, targets, iou_threshold=0.5):
|
||||
"""
|
||||
Calculate mean Average Precision (mAP) at a specific IoU threshold.
|
||||
|
||||
Args:
|
||||
predictions (list): List of prediction dictionaries
|
||||
targets (list): List of target dictionaries
|
||||
iou_threshold (float): IoU threshold for considering a detection as correct
|
||||
|
||||
Returns:
|
||||
dict: Dictionary with mAP, precision, recall and F1 score
|
||||
"""
|
||||
# Initialize counters
|
||||
total_gt = 0
|
||||
total_detections = 0
|
||||
true_positives = 0
|
||||
false_positives = 0
|
||||
|
||||
# Count total ground truth boxes
|
||||
for target in targets:
|
||||
total_gt += len(target["boxes"])
|
||||
|
||||
# Process all predictions
|
||||
for pred, target in zip(predictions, targets):
|
||||
pred_boxes = pred["boxes"]
|
||||
pred_scores = pred["scores"]
|
||||
pred_labels = pred["labels"]
|
||||
gt_boxes = target["boxes"]
|
||||
gt_labels = target["labels"]
|
||||
|
||||
# Skip if no predictions or no ground truth
|
||||
if len(pred_boxes) == 0 or len(gt_boxes) == 0:
|
||||
continue
|
||||
|
||||
# Calculate IoU between predictions and ground truth
|
||||
iou_matrix = box_iou(pred_boxes, gt_boxes)
|
||||
|
||||
# Initialize array to track which gt boxes have been matched
|
||||
gt_matched = torch.zeros(len(gt_boxes), dtype=torch.bool)
|
||||
|
||||
# Sort predictions by confidence score (high to low)
|
||||
sorted_indices = torch.argsort(pred_scores, descending=True)
|
||||
|
||||
# Count true positives and false positives
|
||||
for idx in sorted_indices:
|
||||
# Find best matching ground truth box
|
||||
iou_values = iou_matrix[idx]
|
||||
best_iou, best_gt_idx = torch.max(iou_values, dim=0)
|
||||
|
||||
# Check if the prediction matches a ground truth box
|
||||
if (
|
||||
best_iou >= iou_threshold
|
||||
and not gt_matched[best_gt_idx]
|
||||
and pred_labels[idx] == gt_labels[best_gt_idx]
|
||||
):
|
||||
true_positives += 1
|
||||
gt_matched[best_gt_idx] = True
|
||||
else:
|
||||
false_positives += 1
|
||||
|
||||
total_detections += len(pred_boxes)
|
||||
|
||||
# Calculate metrics
|
||||
precision = true_positives / max(true_positives + false_positives, 1)
|
||||
recall = true_positives / max(total_gt, 1)
|
||||
|
||||
# Cap metrics for final reporting
|
||||
precision = min(1.0, precision)
|
||||
recall = min(1.0, recall)
|
||||
|
||||
f1_score = 2 * precision * recall / max(precision + recall, 1e-6)
|
||||
|
||||
return {
|
||||
"mAP": precision * recall, # Simplified mAP calculation
|
||||
"precision": precision,
|
||||
"recall": recall,
|
||||
"f1_score": f1_score,
|
||||
"true_positives": true_positives,
|
||||
"false_positives": false_positives,
|
||||
"total_gt": total_gt,
|
||||
"total_detections": total_detections,
|
||||
}
|
||||
|
||||
|
||||
def calculate_metrics_at_confidence(predictions, targets, confidence_threshold=0.5):
|
||||
"""
|
||||
Calculate detection metrics at a specific confidence threshold.
|
||||
|
||||
Args:
|
||||
predictions (list): List of prediction dictionaries
|
||||
targets (list): List of target dictionaries
|
||||
confidence_threshold (float): Confidence threshold to filter predictions
|
||||
|
||||
Returns:
|
||||
dict: Dictionary with precision, recall, F1 score and detection count
|
||||
"""
|
||||
# Initialize counters
|
||||
total_gt = 0
|
||||
detections = 0
|
||||
true_positives = 0
|
||||
|
||||
# Count total ground truth boxes
|
||||
for target in targets:
|
||||
total_gt += len(target["boxes"])
|
||||
|
||||
# Process all predictions with confidence filter
|
||||
for pred, target in zip(predictions, targets):
|
||||
# Filter predictions by confidence threshold
|
||||
mask = pred["scores"] >= confidence_threshold
|
||||
filtered_boxes = pred["boxes"][mask]
|
||||
filtered_labels = pred["labels"][mask] if len(mask) > 0 else []
|
||||
|
||||
detections += len(filtered_boxes)
|
||||
|
||||
# Skip if no predictions after filtering
|
||||
if len(filtered_boxes) == 0:
|
||||
continue
|
||||
|
||||
# Calculate IoU with ground truth
|
||||
gt_boxes = target["boxes"]
|
||||
gt_labels = target["labels"]
|
||||
|
||||
# Skip if no ground truth
|
||||
if len(gt_boxes) == 0:
|
||||
continue
|
||||
|
||||
iou_matrix = box_iou(filtered_boxes, gt_boxes)
|
||||
|
||||
# Initialize array to track which gt boxes have been matched
|
||||
gt_matched = torch.zeros(len(gt_boxes), dtype=torch.bool)
|
||||
|
||||
# Find matches based on IoU threshold of 0.5
|
||||
for pred_idx in range(len(filtered_boxes)):
|
||||
best_iou, best_gt_idx = torch.max(iou_matrix[pred_idx], dim=0)
|
||||
if best_iou >= 0.5 and not gt_matched[best_gt_idx]:
|
||||
if (
|
||||
len(filtered_labels) > 0
|
||||
and filtered_labels[pred_idx] == gt_labels[best_gt_idx]
|
||||
):
|
||||
true_positives += 1
|
||||
gt_matched[best_gt_idx] = True
|
||||
|
||||
# Calculate metrics
|
||||
precision = true_positives / max(detections, 1)
|
||||
recall = true_positives / max(total_gt, 1)
|
||||
|
||||
# Cap metrics for final reporting
|
||||
precision = min(1.0, precision)
|
||||
recall = min(1.0, recall)
|
||||
|
||||
f1_score = 2 * precision * recall / max(precision + recall, 1e-6)
|
||||
|
||||
return {
|
||||
"precision": precision,
|
||||
"recall": recall,
|
||||
"f1_score": f1_score,
|
||||
"detections": detections,
|
||||
"true_positives": true_positives,
|
||||
}
|
||||
|
||||
|
||||
def calculate_size_based_metrics(predictions, targets):
|
||||
"""
|
||||
Calculate detection performance by object size.
|
||||
|
||||
Args:
|
||||
predictions (list): List of prediction dictionaries
|
||||
targets (list): List of target dictionaries
|
||||
|
||||
Returns:
|
||||
dict: Dictionary with size-based metrics
|
||||
"""
|
||||
# Define size categories (in pixels²)
|
||||
small_threshold = 32 * 32 # Small objects: area < 32²
|
||||
medium_threshold = 96 * 96 # Medium objects: 32² <= area < 96²
|
||||
# Large objects: area >= 96²
|
||||
|
||||
# Initialize counters for each size category
|
||||
size_metrics = {
|
||||
"small_recall": 0,
|
||||
"small_precision": 0,
|
||||
"small_count": 0,
|
||||
"medium_recall": 0,
|
||||
"medium_precision": 0,
|
||||
"medium_count": 0,
|
||||
"large_recall": 0,
|
||||
"large_precision": 0,
|
||||
"large_count": 0,
|
||||
}
|
||||
|
||||
# Count by size
|
||||
small_gt, medium_gt, large_gt = 0, 0, 0
|
||||
small_tp, medium_tp, large_tp = 0, 0, 0
|
||||
small_det, medium_det, large_det = 0, 0, 0
|
||||
|
||||
# Process all predictions
|
||||
for pred, target in zip(predictions, targets):
|
||||
pred_boxes = pred["boxes"]
|
||||
pred_scores = pred["scores"]
|
||||
gt_boxes = target["boxes"]
|
||||
|
||||
# Skip if no predictions or no ground truth
|
||||
if len(pred_boxes) == 0 or len(gt_boxes) == 0:
|
||||
continue
|
||||
|
||||
# Calculate areas for ground truth
|
||||
gt_areas = target.get("area", None)
|
||||
if gt_areas is None:
|
||||
# Calculate if not provided
|
||||
gt_areas = (gt_boxes[:, 2] - gt_boxes[:, 0]) * (
|
||||
gt_boxes[:, 3] - gt_boxes[:, 1]
|
||||
)
|
||||
|
||||
# Count ground truth by size
|
||||
small_gt += torch.sum((gt_areas < small_threshold)).item()
|
||||
medium_gt += torch.sum(
|
||||
(gt_areas >= small_threshold) & (gt_areas < medium_threshold)
|
||||
).item()
|
||||
large_gt += torch.sum((gt_areas >= medium_threshold)).item()
|
||||
|
||||
# Calculate areas for predictions
|
||||
pred_areas = (pred_boxes[:, 2] - pred_boxes[:, 0]) * (
|
||||
pred_boxes[:, 3] - pred_boxes[:, 1]
|
||||
)
|
||||
|
||||
# Count predictions by size (with confidence >= 0.5)
|
||||
conf_mask = pred_scores >= 0.5
|
||||
small_mask = (pred_areas < small_threshold) & conf_mask
|
||||
medium_mask = (
|
||||
(pred_areas >= small_threshold)
|
||||
& (pred_areas < medium_threshold)
|
||||
& conf_mask
|
||||
)
|
||||
large_mask = (pred_areas >= medium_threshold) & conf_mask
|
||||
|
||||
small_det += torch.sum(small_mask).item()
|
||||
medium_det += torch.sum(medium_mask).item()
|
||||
large_det += torch.sum(large_mask).item()
|
||||
|
||||
# Calculate IoU between predictions and ground truth
|
||||
iou_matrix = box_iou(pred_boxes, gt_boxes)
|
||||
|
||||
# Initialize array to track which gt boxes have been matched
|
||||
gt_matched = torch.zeros(len(gt_boxes), dtype=torch.bool)
|
||||
|
||||
# Sort predictions by confidence score (high to low)
|
||||
sorted_indices = torch.argsort(pred_scores, descending=True)
|
||||
|
||||
# Count true positives by size
|
||||
for idx in sorted_indices:
|
||||
if pred_scores[idx] < 0.5: # Skip low confidence detections
|
||||
continue
|
||||
|
||||
# Find best matching ground truth box
|
||||
best_iou, best_gt_idx = torch.max(iou_matrix[idx], dim=0)
|
||||
|
||||
# Check if the prediction matches a ground truth box with IoU >= 0.5
|
||||
if best_iou >= 0.5 and not gt_matched[best_gt_idx]:
|
||||
gt_matched[best_gt_idx] = True
|
||||
|
||||
# Categorize true positive by size
|
||||
area = gt_areas[best_gt_idx].item()
|
||||
if area < small_threshold:
|
||||
small_tp += 1
|
||||
elif area < medium_threshold:
|
||||
medium_tp += 1
|
||||
else:
|
||||
large_tp += 1
|
||||
|
||||
# Calculate metrics for each size category
|
||||
size_metrics["small_precision"] = small_tp / max(small_det, 1)
|
||||
size_metrics["small_recall"] = small_tp / max(small_gt, 1)
|
||||
size_metrics["small_count"] = small_gt
|
||||
|
||||
size_metrics["medium_precision"] = medium_tp / max(medium_det, 1)
|
||||
size_metrics["medium_recall"] = medium_tp / max(medium_gt, 1)
|
||||
size_metrics["medium_count"] = medium_gt
|
||||
|
||||
size_metrics["large_precision"] = large_tp / max(large_det, 1)
|
||||
size_metrics["large_recall"] = large_tp / max(large_gt, 1)
|
||||
size_metrics["large_count"] = large_gt
|
||||
|
||||
# Cap metrics for final reporting
|
||||
size_metrics["small_precision"] = min(1.0, size_metrics["small_precision"])
|
||||
size_metrics["small_recall"] = min(1.0, size_metrics["small_recall"])
|
||||
size_metrics["medium_precision"] = min(1.0, size_metrics["medium_precision"])
|
||||
size_metrics["medium_recall"] = min(1.0, size_metrics["medium_recall"])
|
||||
size_metrics["large_precision"] = min(1.0, size_metrics["large_precision"])
|
||||
size_metrics["large_recall"] = min(1.0, size_metrics["large_recall"])
|
||||
|
||||
return size_metrics
|
||||
|
||||
|
||||
# Example usage (can be removed or kept for testing):
|
||||
if __name__ == "__main__":
|
||||
# This is a dummy test and requires a model, dataloader, device
|
||||
|
||||
Reference in New Issue
Block a user