callcut.training.CallDetectorModule🔗

class callcut.training.CallDetectorModule(model, loss, lr=0.001)[source]🔗

Lightning Module wrapping a call detector for training.

This module handles the training loop, validation, and optimizer configuration for any BaseDetector model.

Parameters:
modelBaseDetector

The detector model to train.

lossBaseLoss

Loss function to use. Available loss functions:

lrfloat

Learning rate for the optimizer.

Attributes

loss

The loss function.

model

The underlying detector model.

Methods

configure_optimizers()

Configure the optimizer.

forward(x)

Forward pass through the model.

training_step(batch, batch_idx)

Perform a single training step.

validation_step(batch, batch_idx)

Perform a single validation step.

Examples

>>> from callcut.nn import TinySegCNN
>>> from callcut.training import (
...     CallDetectorModule,
...     CallDataModule,
...     BCEWithLogitsLoss,
... )
>>> import lightning as L
>>>
>>> # Create model and module
>>> model = TinySegCNN(n_bands=8, window_frames=250)
>>> module = CallDetectorModule(model, loss=BCEWithLogitsLoss(), lr=1e-3)
>>>
>>> # Or with a different loss function
>>> from callcut.training import FocalLoss
>>> module = CallDetectorModule(model, loss=FocalLoss(gamma=2.0), lr=1e-3)
>>>
>>> # Create data module and train
>>> dm = CallDataModule(recordings=..., extractor=...)
>>> trainer = L.Trainer(max_epochs=10)
>>> trainer.fit(module, datamodule=dm)
configure_optimizers()[source]🔗

Configure the optimizer.

Returns:
optimizertorch.optim.Optimizer

The Adam optimizer.

forward(x)[source]🔗

Forward pass through the model.

Parameters:
xTensor

Input features of shape (batch, n_bands, time).

Returns:
logitsTensor

Output logits of shape (batch, time).

training_step(batch, batch_idx)[source]🔗

Perform a single training step.

Parameters:
batchtuple of Tensor

Tuple of (features, labels) tensors.

batch_idxint

Index of the current batch.

Returns:
lossTensor

The training loss for this batch.

validation_step(batch, batch_idx)[source]🔗

Perform a single validation step.

Parameters:
batchtuple of Tensor

Tuple of (features, labels) tensors.

batch_idxint

Index of the current batch.

property loss🔗

The loss function.

Type:

BaseLoss

property model🔗

The underlying detector model.

Type:

BaseDetector