Create eval loop and use full train dataset
This commit is contained in:
52
todo.md
52
todo.md
@@ -60,36 +60,36 @@ This list outlines the steps required to complete the Torchvision Finetuning pro
|
|||||||
- [x] Call `setup_logging`.
|
- [x] Call `setup_logging`.
|
||||||
- [x] Replace `print` with `logging.info`.
|
- [x] Replace `print` with `logging.info`.
|
||||||
- [x] Log config, device, and training progress/losses.
|
- [x] Log config, device, and training progress/losses.
|
||||||
- [ ] Implement full training loop in `train.py`.
|
- [x] Implement full training loop in `train.py`.
|
||||||
- [ ] Remove single-step exit.
|
- [x] Remove single-step exit.
|
||||||
- [ ] Add LR scheduler (`torch.optim.lr_scheduler.StepLR`).
|
- [x] Add LR scheduler (`torch.optim.lr_scheduler.StepLR`).
|
||||||
- [ ] Add epoch loop.
|
- [x] Add epoch loop.
|
||||||
- [ ] Add batch loop, integrating the single training step logic.
|
- [x] Add batch loop, integrating the single training step logic.
|
||||||
- [ ] Log loss periodically within the batch loop.
|
- [x] Log loss periodically within the batch loop.
|
||||||
- [ ] Step the LR scheduler at the end of each epoch.
|
- [x] Step the LR scheduler at the end of each epoch.
|
||||||
- [ ] Log total training time.
|
- [x] Log total training time.
|
||||||
- [ ] Implement checkpointing in `train.py`.
|
- [x] Implement checkpointing in `train.py`.
|
||||||
- [ ] Define checkpoint directory.
|
- [x] Define checkpoint directory.
|
||||||
- [ ] Implement logic to find and load the latest checkpoint (resume training).
|
- [x] Implement logic to find and load the latest checkpoint (resume training).
|
||||||
- [ ] Save checkpoints periodically (based on frequency or final epoch).
|
- [x] Save checkpoints periodically (based on frequency or final epoch).
|
||||||
- [ ] Include epoch, model state, optimizer state, scheduler state, config.
|
- [x] Include epoch, model state, optimizer state, scheduler state, config.
|
||||||
- [ ] Log checkpoint loading/saving.
|
- [x] Log checkpoint loading/saving.
|
||||||
|
|
||||||
## Phase 4: Evaluation & Testing
|
## Phase 4: Evaluation & Testing
|
||||||
|
|
||||||
- [ ] Add evaluation dependencies (`pycocotools` - optional initially).
|
- [ ] Add evaluation dependencies (`pycocotools` - optional initially).
|
||||||
- [ ] Create `utils/eval_utils.py` and implement `evaluate` function.
|
- [x] Create `utils/eval_utils.py` and implement `evaluate` function.
|
||||||
- [ ] Set `model.eval()`.
|
- [x] Set `model.eval()`.
|
||||||
- [ ] Use `torch.no_grad()`.
|
- [x] Use `torch.no_grad()`.
|
||||||
- [ ] Loop through validation/test dataloader.
|
- [x] Loop through validation/test dataloader.
|
||||||
- [ ] Perform forward pass.
|
- [x] Perform forward pass.
|
||||||
- [ ] Calculate/aggregate metrics (start with average loss, potentially add mAP later).
|
- [x] Calculate/aggregate metrics (start with average loss, potentially add mAP later).
|
||||||
- [ ] Log evaluation metrics and time.
|
- [x] Log evaluation metrics and time.
|
||||||
- [ ] Return metrics.
|
- [x] Return metrics.
|
||||||
- [ ] Integrate evaluation into `train.py`.
|
- [x] Integrate evaluation into `train.py`.
|
||||||
- [ ] Create validation `Dataset` and `DataLoader` (using `torch.utils.data.Subset`).
|
- [x] Create validation `Dataset` and `DataLoader` (using `torch.utils.data.Subset`).
|
||||||
- [ ] Call `evaluate` at the end of each epoch.
|
- [x] Call `evaluate` at the end of each epoch.
|
||||||
- [ ] Log validation metrics.
|
- [x] Log validation metrics.
|
||||||
- [ ] (Later) Implement logic to save the *best* model based on validation metric.
|
- [ ] (Later) Implement logic to save the *best* model based on validation metric.
|
||||||
- [ ] Implement `test.py` script.
|
- [ ] Implement `test.py` script.
|
||||||
- [ ] Reuse argument parsing, config loading, device setup, dataset/dataloader (test split), model creation from `train.py`.
|
- [ ] Reuse argument parsing, config loading, device setup, dataset/dataloader (test split), model creation from `train.py`.
|
||||||
|
|||||||
199
train.py
199
train.py
@@ -13,6 +13,7 @@ import torch.utils.data
|
|||||||
# Project specific imports
|
# Project specific imports
|
||||||
from models.detection import get_maskrcnn_model
|
from models.detection import get_maskrcnn_model
|
||||||
from utils.data_utils import PennFudanDataset, collate_fn, get_transform
|
from utils.data_utils import PennFudanDataset, collate_fn, get_transform
|
||||||
|
from utils.eval_utils import evaluate # Import evaluate function
|
||||||
from utils.log_utils import setup_logging
|
from utils.log_utils import setup_logging
|
||||||
|
|
||||||
|
|
||||||
@@ -108,31 +109,75 @@ def main(args):
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
dataset_train = PennFudanDataset(
|
# Create the full training dataset instance first
|
||||||
|
dataset_full = PennFudanDataset(
|
||||||
root=data_root, transforms=get_transform(train=True)
|
root=data_root, transforms=get_transform(train=True)
|
||||||
)
|
)
|
||||||
# Note: Validation split will be handled later (Prompt 12)
|
logging.info(f"Full dataset size: {len(dataset_full)}")
|
||||||
# dataset_val = PennFudanDataset(root=data_root, transforms=get_transform(train=False))
|
|
||||||
|
|
||||||
# TODO: Implement data splitting (e.g., using torch.utils.data.Subset)
|
# Create validation dataset instance with eval transforms
|
||||||
|
dataset_val_instance = PennFudanDataset(
|
||||||
|
root=data_root, transforms=get_transform(train=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Split the dataset indices
|
||||||
|
torch.manual_seed(
|
||||||
|
config.get("seed", 42)
|
||||||
|
) # Use the same seed for consistent splits
|
||||||
|
indices = torch.randperm(len(dataset_full)).tolist()
|
||||||
|
val_split_ratio = config.get(
|
||||||
|
"val_split_ratio", 0.1
|
||||||
|
) # Default to 10% validation
|
||||||
|
val_split_count = int(val_split_ratio * len(dataset_full))
|
||||||
|
if val_split_count == 0 and len(dataset_full) > 0:
|
||||||
|
logging.warning(
|
||||||
|
f"Validation split resulted in 0 samples (ratio={val_split_ratio}, total={len(dataset_full)}). Using 1 sample for validation."
|
||||||
|
)
|
||||||
|
val_split_count = 1
|
||||||
|
elif val_split_count >= len(dataset_full):
|
||||||
|
logging.error(
|
||||||
|
f"Validation split ratio ({val_split_ratio}) too high, results in no training samples."
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
train_indices = indices[:-val_split_count]
|
||||||
|
val_indices = indices[-val_split_count:]
|
||||||
|
|
||||||
|
# Create Subset datasets
|
||||||
|
dataset_train = torch.utils.data.Subset(dataset_full, train_indices)
|
||||||
|
dataset_val = torch.utils.data.Subset(dataset_val_instance, val_indices)
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"Using {len(train_indices)} samples for training and {len(val_indices)} for validation."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create DataLoaders
|
||||||
data_loader_train = torch.utils.data.DataLoader(
|
data_loader_train = torch.utils.data.DataLoader(
|
||||||
dataset_train,
|
dataset_train,
|
||||||
batch_size=config.get("batch_size", 2),
|
batch_size=config.get("batch_size", 2),
|
||||||
|
# Shuffle should be true for the training subset loader
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=config.get("num_workers", 4),
|
num_workers=config.get("num_workers", 4),
|
||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
pin_memory=config.get(
|
pin_memory=config.get("pin_memory", True),
|
||||||
"pin_memory", True
|
|
||||||
), # Often improves GPU transfer speed
|
|
||||||
)
|
)
|
||||||
logging.info(f"Training dataset size: {len(dataset_train)}")
|
data_loader_val = torch.utils.data.DataLoader(
|
||||||
logging.info(
|
dataset_val,
|
||||||
f"Training dataloader configured with batch size {config.get('batch_size', 2)}"
|
batch_size=config.get(
|
||||||
|
"batch_size", 2
|
||||||
|
), # Often use same or larger batch size for validation
|
||||||
|
shuffle=False, # No need to shuffle validation data
|
||||||
|
num_workers=config.get("num_workers", 4),
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
pin_memory=config.get("pin_memory", True),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Placeholder for validation loader
|
logging.info(
|
||||||
# data_loader_val = torch.utils.data.DataLoader(...)
|
f"Training dataloader configured. Est. batches: {len(data_loader_train)}"
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
f"Validation dataloader configured. Est. batches: {len(data_loader_val)}"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error setting up dataset/dataloader: {e}", exc_info=True)
|
logging.error(f"Error setting up dataset/dataloader: {e}", exc_info=True)
|
||||||
@@ -187,12 +232,72 @@ def main(args):
|
|||||||
logging.error(f"Error creating LR scheduler: {e}", exc_info=True)
|
logging.error(f"Error creating LR scheduler: {e}", exc_info=True)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# --- Training Loop (Prompt 10) ---
|
# --- Resume Logic (Prompt 11) ---
|
||||||
|
start_epoch = 0
|
||||||
|
latest_checkpoint_path = None
|
||||||
|
if os.path.isdir(checkpoint_path):
|
||||||
|
checkpoints = sorted(
|
||||||
|
[f for f in os.listdir(checkpoint_path) if f.endswith(".pth")]
|
||||||
|
)
|
||||||
|
if checkpoints: # Check if list is not empty
|
||||||
|
latest_checkpoint_file = checkpoints[
|
||||||
|
-1
|
||||||
|
] # Get the last one (assuming naming convention like epoch_N.pth)
|
||||||
|
latest_checkpoint_path = os.path.join(
|
||||||
|
checkpoint_path, latest_checkpoint_file
|
||||||
|
)
|
||||||
|
logging.info(f"Found latest checkpoint: {latest_checkpoint_path}")
|
||||||
|
else:
|
||||||
|
logging.info("No checkpoints found in directory. Starting from scratch.")
|
||||||
|
else:
|
||||||
|
logging.info("Checkpoint directory not found. Starting from scratch.")
|
||||||
|
|
||||||
|
if latest_checkpoint_path:
|
||||||
|
try:
|
||||||
|
logging.info(f"Loading checkpoint '{latest_checkpoint_path}'")
|
||||||
|
# Ensure loading happens on the correct device
|
||||||
|
checkpoint = torch.load(latest_checkpoint_path, map_location=device)
|
||||||
|
|
||||||
|
# Load model state - handle potential 'module.' prefix if saved with DataParallel
|
||||||
|
model_state_dict = checkpoint["model_state_dict"]
|
||||||
|
# Simple check and correction for DataParallel prefix
|
||||||
|
if all(key.startswith("module.") for key in model_state_dict.keys()):
|
||||||
|
logging.info("Removing 'module.' prefix from checkpoint keys.")
|
||||||
|
model_state_dict = {
|
||||||
|
k.replace("module.", ""): v for k, v in model_state_dict.items()
|
||||||
|
}
|
||||||
|
model.load_state_dict(model_state_dict)
|
||||||
|
|
||||||
|
# Load optimizer state
|
||||||
|
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||||
|
|
||||||
|
# Load LR scheduler state
|
||||||
|
lr_scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
||||||
|
|
||||||
|
# Load starting epoch (epoch saved is the one *completed*, so start from next)
|
||||||
|
start_epoch = checkpoint["epoch"]
|
||||||
|
logging.info(f"Resuming training from epoch {start_epoch + 1}")
|
||||||
|
|
||||||
|
# Optionally load and verify config consistency
|
||||||
|
# loaded_config = checkpoint.get('config')
|
||||||
|
# if loaded_config:
|
||||||
|
# # Perform checks if necessary
|
||||||
|
# pass
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(
|
||||||
|
f"Error loading checkpoint: {e}. Starting training from scratch.",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
start_epoch = 0 # Reset start_epoch if loading fails
|
||||||
|
|
||||||
|
# --- Training Loop (Prompt 10, modified for Prompt 11) ---
|
||||||
logging.info("--- Starting Training Loop --- ")
|
logging.info("--- Starting Training Loop --- ")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
num_epochs = config.get("num_epochs", 10)
|
num_epochs = config.get("num_epochs", 10)
|
||||||
|
|
||||||
for epoch in range(num_epochs):
|
# Modify loop to start from start_epoch
|
||||||
|
for epoch in range(start_epoch, num_epochs):
|
||||||
model.train() # Set model to training mode for each epoch
|
model.train() # Set model to training mode for each epoch
|
||||||
epoch_start_time = time.time()
|
epoch_start_time = time.time()
|
||||||
logging.info(f"--- Epoch {epoch + 1}/{num_epochs} --- ")
|
logging.info(f"--- Epoch {epoch + 1}/{num_epochs} --- ")
|
||||||
@@ -256,17 +361,63 @@ def main(args):
|
|||||||
logging.info(f" Learning Rate: {current_lr:.6f}")
|
logging.info(f" Learning Rate: {current_lr:.6f}")
|
||||||
logging.info(f" Epoch Duration: {epoch_duration:.2f}s")
|
logging.info(f" Epoch Duration: {epoch_duration:.2f}s")
|
||||||
|
|
||||||
# --- Checkpointing (Placeholder for Prompt 11) --- #
|
# --- Checkpointing (Prompt 11) --- #
|
||||||
# Add checkpoint saving logic here, e.g.:
|
# Save checkpoint periodically or at the end
|
||||||
# if (epoch + 1) % config.get('checkpoint_freq', 1) == 0 or (epoch + 1) == num_epochs:
|
save_checkpoint = False
|
||||||
# # ... save checkpoint ...
|
if (epoch + 1) % config.get("checkpoint_freq", 1) == 0:
|
||||||
# logging.info(f"Saved checkpoint for epoch {epoch + 1}")
|
save_checkpoint = True
|
||||||
|
logging.info(f"Checkpoint frequency met (epoch {epoch + 1})")
|
||||||
|
elif (epoch + 1) == num_epochs:
|
||||||
|
save_checkpoint = True
|
||||||
|
logging.info(f"Final epoch ({epoch + 1}) reached, saving checkpoint.")
|
||||||
|
|
||||||
# --- Evaluation (Placeholder for Prompt 12) --- #
|
if save_checkpoint:
|
||||||
# Add evaluation logic here, e.g.:
|
checkpoint_filename = f"checkpoint_epoch_{epoch + 1}.pth"
|
||||||
# if data_loader_val:
|
save_path = os.path.join(checkpoint_path, checkpoint_filename)
|
||||||
# evaluate(model, data_loader_val, device)
|
try:
|
||||||
# logging.info(f"Ran evaluation for epoch {epoch + 1}")
|
checkpoint_data = {
|
||||||
|
"epoch": epoch + 1,
|
||||||
|
"model_state_dict": model.state_dict(),
|
||||||
|
"optimizer_state_dict": optimizer.state_dict(),
|
||||||
|
"scheduler_state_dict": lr_scheduler.state_dict(),
|
||||||
|
"config": config, # Save config for reference
|
||||||
|
}
|
||||||
|
torch.save(checkpoint_data, save_path)
|
||||||
|
logging.info(f"Checkpoint saved to {save_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(
|
||||||
|
f"Failed to save checkpoint for epoch {epoch + 1} to {save_path}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Evaluation (Prompt 12) --- #
|
||||||
|
if data_loader_val:
|
||||||
|
logging.info(f"Starting evaluation for epoch {epoch + 1}...")
|
||||||
|
try:
|
||||||
|
val_metrics = evaluate(model, data_loader_val, device)
|
||||||
|
logging.info(f"Epoch {epoch + 1} Validation Metrics: {val_metrics}")
|
||||||
|
|
||||||
|
# --- Best Model Checkpoint Logic (Optional Add-on) ---
|
||||||
|
# Add logic here to track the best metric (e.g., val_metrics['average_loss'])
|
||||||
|
# and save a separate 'best_model.pth' checkpoint if the current epoch is better.
|
||||||
|
# Example:
|
||||||
|
# if 'average_loss' in val_metrics:
|
||||||
|
# current_val_loss = val_metrics['average_loss']
|
||||||
|
# if best_val_loss is None or current_val_loss < best_val_loss:
|
||||||
|
# best_val_loss = current_val_loss
|
||||||
|
# best_model_path = os.path.join(output_path, 'best_model.pth')
|
||||||
|
# try:
|
||||||
|
# # Save only the model state_dict for the best model
|
||||||
|
# torch.save(model.state_dict(), best_model_path)
|
||||||
|
# logging.info(f"Saved NEW BEST model checkpoint to {best_model_path} (Val Loss: {best_val_loss:.4f})")
|
||||||
|
# except Exception as e:
|
||||||
|
# logging.error(f"Failed to save best model checkpoint: {e}", exc_info=True)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(
|
||||||
|
f"Error during evaluation for epoch {epoch + 1}: {e}", exc_info=True
|
||||||
|
)
|
||||||
|
# Decide if this error should stop the entire training process
|
||||||
|
|
||||||
# --- End of Training --- #
|
# --- End of Training --- #
|
||||||
total_training_time = time.time() - start_time
|
total_training_time = time.time() - start_time
|
||||||
|
|||||||
65
utils/eval_utils.py
Normal file
65
utils/eval_utils.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(model, data_loader, device):
|
||||||
|
"""Performs evaluation on the dataset for one epoch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): The model to evaluate.
|
||||||
|
data_loader (torch.utils.data.DataLoader): DataLoader for the evaluation data.
|
||||||
|
device (torch.device): The device to run evaluation on.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary containing evaluation metrics (e.g., average loss).
|
||||||
|
"""
|
||||||
|
model.eval() # Set model to evaluation mode
|
||||||
|
total_loss = 0.0
|
||||||
|
num_batches = len(data_loader)
|
||||||
|
eval_start_time = time.time()
|
||||||
|
status_interval = max(1, num_batches // 10) # Log status roughly 10 times
|
||||||
|
|
||||||
|
logging.info("--- Starting Evaluation --- ")
|
||||||
|
|
||||||
|
with torch.no_grad(): # Disable gradient calculations
|
||||||
|
for i, (images, targets) in enumerate(data_loader):
|
||||||
|
images = list(image.to(device) for image in images)
|
||||||
|
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
||||||
|
|
||||||
|
# In eval mode with targets, Mask R-CNN should still return losses
|
||||||
|
# If it returned predictions, logic here would change to process predictions
|
||||||
|
loss_dict = model(images, targets)
|
||||||
|
losses = sum(loss for loss in loss_dict.values())
|
||||||
|
loss_value = losses.item()
|
||||||
|
total_loss += loss_value
|
||||||
|
|
||||||
|
if (i + 1) % status_interval == 0:
|
||||||
|
logging.info(f" Evaluated batch {i + 1}/{num_batches}")
|
||||||
|
|
||||||
|
avg_loss = total_loss / num_batches if num_batches > 0 else 0
|
||||||
|
eval_duration = time.time() - eval_start_time
|
||||||
|
|
||||||
|
logging.info("--- Evaluation Finished ---")
|
||||||
|
logging.info(f" Average Evaluation Loss: {avg_loss:.4f}")
|
||||||
|
logging.info(f" Evaluation Duration: {eval_duration:.2f}s")
|
||||||
|
|
||||||
|
# Return metrics (currently just average loss)
|
||||||
|
metrics = {"average_loss": avg_loss}
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage (can be removed or kept for testing):
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# This is a dummy test and requires a model, dataloader, device
|
||||||
|
print(
|
||||||
|
"This script contains the evaluate function and cannot be run directly for testing without setup."
|
||||||
|
)
|
||||||
|
# Example:
|
||||||
|
# device = torch.device('cpu')
|
||||||
|
# # Create dummy model and dataloader
|
||||||
|
# model = ...
|
||||||
|
# data_loader = ...
|
||||||
|
# metrics = evaluate(model, data_loader, device)
|
||||||
|
# print(f"Dummy evaluation metrics: {metrics}")
|
||||||
Reference in New Issue
Block a user