Formatting
This commit is contained in:
@@ -0,0 +1,55 @@
|
|||||||
|
import torchvision
|
||||||
|
from torchvision.models import ResNet50_Weights
|
||||||
|
|
||||||
|
# Import weights enums for clarity
|
||||||
|
from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights
|
||||||
|
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
||||||
|
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
|
||||||
|
|
||||||
|
|
||||||
|
def get_maskrcnn_model(num_classes, pretrained=True, pretrained_backbone=True):
|
||||||
|
"""Loads a Mask R-CNN model with a ResNet-50-FPN backbone.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_classes (int): Number of output classes (including background).
|
||||||
|
pretrained (bool): If True, loads weights pre-trained on COCO.
|
||||||
|
pretrained_backbone (bool): If True (and pretrained=False), loads backbone
|
||||||
|
weights pre-trained on ImageNet.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torchvision.models.detection.MaskRCNN: The modified Mask R-CNN model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Determine weights based on arguments
|
||||||
|
if pretrained:
|
||||||
|
weights = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT
|
||||||
|
weights_backbone = None # Backbone weights are included in MaskRCNN weights
|
||||||
|
elif pretrained_backbone:
|
||||||
|
weights = None
|
||||||
|
weights_backbone = ResNet50_Weights.DEFAULT
|
||||||
|
else:
|
||||||
|
weights = None
|
||||||
|
weights_backbone = None
|
||||||
|
|
||||||
|
# Load the model structure with specified weights
|
||||||
|
# Use maskrcnn_resnet50_fpn_v2 for compatibility with V2 weights
|
||||||
|
model = torchvision.models.detection.maskrcnn_resnet50_fpn_v2(
|
||||||
|
weights=weights, weights_backbone=weights_backbone
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. Replace the box predictor
|
||||||
|
# Get number of input features for the classifier
|
||||||
|
in_features_box = model.roi_heads.box_predictor.cls_score.in_features
|
||||||
|
# Replace the pre-trained head with a new one
|
||||||
|
model.roi_heads.box_predictor = FastRCNNPredictor(in_features_box, num_classes)
|
||||||
|
|
||||||
|
# 2. Replace the mask predictor
|
||||||
|
# Get number of input features for the mask classifier
|
||||||
|
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
|
||||||
|
hidden_layer = 256 # Default value
|
||||||
|
# Replace the mask predictor with a new one
|
||||||
|
model.roi_heads.mask_predictor = MaskRCNNPredictor(
|
||||||
|
in_features_mask, hidden_layer, num_classes
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|||||||
50
todo.md
50
todo.md
@@ -26,34 +26,34 @@ This list outlines the steps required to complete the Torchvision Finetuning pro
|
|||||||
- [x] `__len__`: Return dataset size.
|
- [x] `__len__`: Return dataset size.
|
||||||
- [x] Implement `get_transform(train)` function in `utils/data_utils.py` (using `torchvision.transforms.v2`).
|
- [x] Implement `get_transform(train)` function in `utils/data_utils.py` (using `torchvision.transforms.v2`).
|
||||||
- [x] Implement `collate_fn(batch)` function in `utils/data_utils.py`.
|
- [x] Implement `collate_fn(batch)` function in `utils/data_utils.py`.
|
||||||
- [ ] Implement `get_maskrcnn_model(num_classes, ...)` function in `models/detection.py`.
|
- [x] Implement `get_maskrcnn_model(num_classes, ...)` function in `models/detection.py`.
|
||||||
- [ ] Load pre-trained Mask R-CNN (`maskrcnn_resnet50_fpn_v2`).
|
- [x] Load pre-trained Mask R-CNN (`maskrcnn_resnet50_fpn_v2`).
|
||||||
- [ ] Replace box predictor head (`FastRCNNPredictor`).
|
- [x] Replace box predictor head (`FastRCNNPredictor`).
|
||||||
- [ ] Replace mask predictor head (`MaskRCNNPredictor`).
|
- [x] Replace mask predictor head (`MaskRCNNPredictor`).
|
||||||
|
|
||||||
## Phase 3: Training Script & Core Logic
|
## Phase 3: Training Script & Core Logic
|
||||||
|
|
||||||
- [ ] Set up basic `train.py` structure.
|
- [x] Set up basic `train.py` structure.
|
||||||
- [ ] Add imports.
|
- [x] Add imports.
|
||||||
- [ ] Implement `argparse` for `--config` argument.
|
- [x] Implement `argparse` for `--config` argument.
|
||||||
- [ ] Implement dynamic config loading (`importlib`).
|
- [x] Implement dynamic config loading (`importlib`).
|
||||||
- [ ] Set random seeds.
|
- [x] Set random seeds.
|
||||||
- [ ] Determine compute device (`cuda` or `cpu`).
|
- [x] Determine compute device (`cuda` or `cpu`).
|
||||||
- [ ] Create output directory structure (`outputs/<config_name>/checkpoints`).
|
- [x] Create output directory structure (`outputs/<config_name>/checkpoints`).
|
||||||
- [ ] Instantiate `PennFudanDataset` (train).
|
- [x] Instantiate `PennFudanDataset` (train).
|
||||||
- [ ] Instantiate `DataLoader` (train) using `collate_fn`.
|
- [x] Instantiate `DataLoader` (train) using `collate_fn`.
|
||||||
- [ ] Instantiate model using `get_maskrcnn_model`.
|
- [x] Instantiate model using `get_maskrcnn_model`.
|
||||||
- [ ] Move model to device.
|
- [x] Move model to device.
|
||||||
- [ ] Add `if __name__ == "__main__":` guard.
|
- [x] Add `if __name__ == "__main__":` guard.
|
||||||
- [ ] Implement minimal training step in `train.py`.
|
- [x] Implement minimal training step in `train.py`.
|
||||||
- [ ] Instantiate optimizer (`torch.optim.SGD`).
|
- [x] Instantiate optimizer (`torch.optim.SGD`).
|
||||||
- [ ] Set `model.train()`.
|
- [x] Set `model.train()`.
|
||||||
- [ ] Fetch one batch.
|
- [x] Fetch one batch.
|
||||||
- [ ] Move data to device.
|
- [x] Move data to device.
|
||||||
- [ ] Perform forward pass (`loss_dict = model(...)`).
|
- [x] Perform forward pass (`loss_dict = model(...)`).
|
||||||
- [ ] Calculate total loss (`sum(...)`).
|
- [x] Calculate total loss (`sum(...)`).
|
||||||
- [ ] Perform backward pass (`optimizer.zero_grad()`, `loss.backward()`, `optimizer.step()`).
|
- [x] Perform backward pass (`optimizer.zero_grad()`, `loss.backward()`, `optimizer.step()`)
|
||||||
- [ ] Print/log loss for the single step (and temporarily exit).
|
- [x] Print/log loss for the single step (and temporarily exit).
|
||||||
- [ ] Implement logging setup in `utils/log_utils.py` (`setup_logging` function).
|
- [ ] Implement logging setup in `utils/log_utils.py` (`setup_logging` function).
|
||||||
- [ ] Configure `logging.basicConfig` for file and console output.
|
- [ ] Configure `logging.basicConfig` for file and console output.
|
||||||
- [ ] Integrate logging into `train.py`.
|
- [ ] Integrate logging into `train.py`.
|
||||||
|
|||||||
241
train.py
241
train.py
@@ -0,0 +1,241 @@
|
|||||||
|
import argparse
|
||||||
|
import importlib.util
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
|
|
||||||
|
# Project specific imports
|
||||||
|
from models.detection import get_maskrcnn_model
|
||||||
|
from utils.data_utils import PennFudanDataset, collate_fn, get_transform
|
||||||
|
from utils.log_utils import setup_logging
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
# --- Configuration Loading ---
|
||||||
|
try:
|
||||||
|
config_path = os.path.abspath(args.config)
|
||||||
|
if not os.path.exists(config_path):
|
||||||
|
print(f"Error: Config file not found at {config_path}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Derive module path from file path relative to workspace root
|
||||||
|
workspace_root = os.path.abspath(
|
||||||
|
os.getcwd()
|
||||||
|
) # Assuming script is run from root
|
||||||
|
relative_path = os.path.relpath(config_path, workspace_root)
|
||||||
|
if relative_path.startswith(".."):
|
||||||
|
print(f"Error: Config file {args.config} is outside the project directory.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
module_path_no_ext, _ = os.path.splitext(relative_path)
|
||||||
|
module_path_str = module_path_no_ext.replace(os.sep, ".")
|
||||||
|
|
||||||
|
print(f"Attempting to import config module: {module_path_str}")
|
||||||
|
config_module = importlib.import_module(module_path_str)
|
||||||
|
config = config_module.config
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Loaded configuration from: {config_path} (via module {module_path_str})"
|
||||||
|
)
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"Error importing config module '{module_path_str}': {e}")
|
||||||
|
print(
|
||||||
|
"Ensure the config file path is correct and relative imports within it are valid."
|
||||||
|
)
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
sys.exit(1)
|
||||||
|
except AttributeError as e:
|
||||||
|
print(
|
||||||
|
f"Error: Could not find 'config' dictionary in module {module_path_str}. {e}"
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading configuration file {args.config}: {e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# --- Output Directory Setup ---
|
||||||
|
output_dir = config.get("output_dir", "outputs")
|
||||||
|
config_name = config.get("config_name", "default_run")
|
||||||
|
output_path = os.path.join(output_dir, config_name)
|
||||||
|
checkpoint_path = os.path.join(output_path, "checkpoints")
|
||||||
|
os.makedirs(output_path, exist_ok=True)
|
||||||
|
os.makedirs(checkpoint_path, exist_ok=True)
|
||||||
|
print(f"Output will be saved to: {output_path}")
|
||||||
|
|
||||||
|
# --- Logging Setup (Prompt 9) ---
|
||||||
|
setup_logging(output_path, config_name)
|
||||||
|
logging.info("--- Training Script Started ---")
|
||||||
|
logging.info(f"Loaded configuration from: {args.config}")
|
||||||
|
logging.info(f"Loaded configuration dictionary: {config}")
|
||||||
|
logging.info(f"Output will be saved to: {output_path}")
|
||||||
|
|
||||||
|
# --- Reproducibility ---
|
||||||
|
seed = config.get("seed", 42)
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
# Consider adding these for more determinism, but they might impact performance
|
||||||
|
# torch.backends.cudnn.deterministic = True
|
||||||
|
# torch.backends.cudnn.benchmark = False
|
||||||
|
logging.info(f"Set random seed to: {seed}")
|
||||||
|
|
||||||
|
# --- Device Setup ---
|
||||||
|
device_name = config.get("device", "cuda")
|
||||||
|
if device_name == "cuda" and not torch.cuda.is_available():
|
||||||
|
logging.warning("CUDA requested but not available, falling back to CPU.")
|
||||||
|
device_name = "cpu"
|
||||||
|
device = torch.device(device_name)
|
||||||
|
logging.info(f"Using device: {device}")
|
||||||
|
|
||||||
|
# --- Dataset and DataLoader ---
|
||||||
|
data_root = config.get("data_root")
|
||||||
|
if not data_root or not os.path.isdir(data_root):
|
||||||
|
logging.error(f"Data root directory not found or not specified: {data_root}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
dataset_train = PennFudanDataset(
|
||||||
|
root=data_root, transforms=get_transform(train=True)
|
||||||
|
)
|
||||||
|
# Note: Validation split will be handled later (Prompt 12)
|
||||||
|
# dataset_val = PennFudanDataset(root=data_root, transforms=get_transform(train=False))
|
||||||
|
|
||||||
|
# TODO: Implement data splitting (e.g., using torch.utils.data.Subset)
|
||||||
|
|
||||||
|
data_loader_train = torch.utils.data.DataLoader(
|
||||||
|
dataset_train,
|
||||||
|
batch_size=config.get("batch_size", 2),
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=config.get("num_workers", 4),
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
pin_memory=config.get(
|
||||||
|
"pin_memory", True
|
||||||
|
), # Often improves GPU transfer speed
|
||||||
|
)
|
||||||
|
logging.info(f"Training dataset size: {len(dataset_train)}")
|
||||||
|
logging.info(
|
||||||
|
f"Training dataloader configured with batch size {config.get('batch_size', 2)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Placeholder for validation loader
|
||||||
|
# data_loader_val = torch.utils.data.DataLoader(...)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error setting up dataset/dataloader: {e}", exc_info=True)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# --- Model Instantiation ---
|
||||||
|
num_classes = config.get("num_classes")
|
||||||
|
if num_classes is None:
|
||||||
|
logging.error("'num_classes' not specified in configuration.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
model = get_maskrcnn_model(
|
||||||
|
num_classes=num_classes,
|
||||||
|
pretrained=config.get("pretrained", True),
|
||||||
|
pretrained_backbone=config.get("pretrained_backbone", True),
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
logging.info("Model loaded successfully.")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error loading model: {e}", exc_info=True)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# --- Optimizer ---
|
||||||
|
# Filter parameters that require gradients
|
||||||
|
params = [p for p in model.parameters() if p.requires_grad]
|
||||||
|
try:
|
||||||
|
optimizer = torch.optim.SGD(
|
||||||
|
params,
|
||||||
|
lr=config.get("lr", 0.005),
|
||||||
|
momentum=config.get("momentum", 0.9),
|
||||||
|
weight_decay=config.get("weight_decay", 0.0005),
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
f"Optimizer SGD configured with lr={config.get('lr', 0.005)}, momentum={config.get('momentum', 0.9)}, weight_decay={config.get('weight_decay', 0.0005)}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error creating optimizer: {e}", exc_info=True)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# --- LR Scheduler (Placeholder for Prompt 10) ---
|
||||||
|
# lr_scheduler = torch.optim.lr_scheduler.StepLR(
|
||||||
|
# optimizer,
|
||||||
|
# step_size=config.get('lr_step_size', 3),
|
||||||
|
# gamma=config.get('lr_gamma', 0.1)
|
||||||
|
# )
|
||||||
|
|
||||||
|
# --- Minimal Training Step (Prompt 8 / Updated for Prompt 9) ---
|
||||||
|
logging.info("--- Starting Minimal Training Step --- ")
|
||||||
|
model.train() # Set model to training mode
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Fetch one batch
|
||||||
|
images, targets = next(iter(data_loader_train))
|
||||||
|
|
||||||
|
# Move data to the device
|
||||||
|
images = list(image.to(device) for image in images)
|
||||||
|
# Targets is a list of dicts. Move each tensor in the dict to the device.
|
||||||
|
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
||||||
|
|
||||||
|
# Perform forward pass (model returns loss dict in train mode)
|
||||||
|
loss_dict = model(images, targets)
|
||||||
|
|
||||||
|
# Calculate total loss
|
||||||
|
losses = sum(loss for loss in loss_dict.values())
|
||||||
|
loss_value = losses.item() # Get scalar value
|
||||||
|
|
||||||
|
# Perform backward pass
|
||||||
|
optimizer.zero_grad() # Clear previous gradients
|
||||||
|
losses.backward() # Compute gradients
|
||||||
|
optimizer.step() # Update weights
|
||||||
|
|
||||||
|
# Convert loss_dict tensors to scalar values for logging
|
||||||
|
loss_dict_log = {k: v.item() for k, v in loss_dict.items()}
|
||||||
|
logging.info(f"Single step loss dict: {loss_dict_log}")
|
||||||
|
logging.info(f"Single step total loss: {loss_value:.4f}")
|
||||||
|
logging.info("--- Minimal Training Step Completed Successfully --- ")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error during minimal training step: {e}", exc_info=True)
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
# traceback.print_exc() # Already logged with exc_info=True
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Temporarily exit after the single step (as per Prompt 8)
|
||||||
|
logging.info("Exiting after single training step.")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# --- Full Training Loop (Placeholder for Prompt 10) ---
|
||||||
|
# print("Basic setup complete. Full training loop implementation pending.")
|
||||||
|
# ... loop implementation ...
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Train Mask R-CNN on Penn-Fudan dataset."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the Python configuration file (e.g., configs/pennfudan_maskrcnn_config.py)",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
|
|||||||
44
utils/log_utils.py
Normal file
44
utils/log_utils.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging(log_dir, config_name):
|
||||||
|
"""Configures logging to output to both file and console.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_dir (str): The directory where the log file should be saved.
|
||||||
|
config_name (str): The name of the configuration run, used for the log filename.
|
||||||
|
"""
|
||||||
|
# Ensure log directory exists
|
||||||
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
|
||||||
|
log_filename = f"{config_name}_train.log"
|
||||||
|
log_filepath = os.path.join(log_dir, log_filename)
|
||||||
|
|
||||||
|
# Configure the root logger
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, # Log INFO level and above (INFO, WARNING, ERROR, CRITICAL)
|
||||||
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler(log_filepath), # Log to a file
|
||||||
|
logging.StreamHandler(sys.stdout), # Log to the console (stdout)
|
||||||
|
],
|
||||||
|
# Force=True ensures that if basicConfig was called before (e.g., by a library),
|
||||||
|
# this configuration will overwrite it. Use with caution if libraries might
|
||||||
|
# configure logging themselves in complex ways.
|
||||||
|
force=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(f"Logging configured. Log file: {log_filepath}")
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage (can be removed or kept for testing):
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("Testing logging setup...")
|
||||||
|
setup_logging("temp_logs", "test_config")
|
||||||
|
logging.info("This is an info message.")
|
||||||
|
logging.warning("This is a warning message.")
|
||||||
|
logging.error("This is an error message.")
|
||||||
|
print("Check 'temp_logs/test_config_train.log' and console output.")
|
||||||
Reference in New Issue
Block a user