callcut.training.CallDetectorModule🔗
- class callcut.training.CallDetectorModule(model, loss, lr=0.001)[source]🔗
Lightning Module wrapping a call detector for training.
This module handles the training loop, validation, and optimizer configuration for any
BaseDetectormodel.- Parameters:
- model
BaseDetector The detector model to train.
- loss
BaseLoss Loss function to use. Available loss functions:
BCEWithLogitsLoss: Standard binary cross-entropyFocalLoss: Down-weights easy examplesDiceLoss: Optimizes overlap directlyTverskyLoss: Adjustable FP/FN penalties
- lr
float Learning rate for the optimizer.
- model
Attributes
Methods
Configure the optimizer.
forward(x)Forward pass through the model.
training_step(batch, batch_idx)Perform a single training step.
validation_step(batch, batch_idx)Perform a single validation step.
Examples
>>> from callcut.nn import TinySegCNN >>> from callcut.training import ( ... CallDetectorModule, ... CallDataModule, ... BCEWithLogitsLoss, ... ) >>> import lightning as L >>> >>> # Create model and module >>> model = TinySegCNN(n_bands=8, window_frames=250) >>> module = CallDetectorModule(model, loss=BCEWithLogitsLoss(), lr=1e-3) >>> >>> # Or with a different loss function >>> from callcut.training import FocalLoss >>> module = CallDetectorModule(model, loss=FocalLoss(gamma=2.0), lr=1e-3) >>> >>> # Create data module and train >>> dm = CallDataModule(recordings=..., extractor=...) >>> trainer = L.Trainer(max_epochs=10) >>> trainer.fit(module, datamodule=dm)
- configure_optimizers()[source]🔗
Configure the optimizer.
- Returns:
- optimizer
torch.optim.Optimizer The Adam optimizer.
- optimizer
- property model🔗
The underlying detector model.
- Type: