callcut.training.CallDataModule🔗

class callcut.training.CallDataModule(recordings, extractor, *, train_frac=0.7, val_frac=0.1, test_frac=0.2, window_s=2.0, window_hop_s=0.5, batch_size=32, num_workers=4)[source]🔗

Lightning DataModule for call detection.

Handles data loading, train/val/test splitting (balanced by window count), and DataLoader creation for training and validation. The test split is exposed via test_recordings for use with evaluate_recordings().

Parameters:
recordingslist of Path | str

Paths to audio files. Each audio file should have a corresponding annotation CSV file with the same stem and _annotations.csv suffix.

extractorBaseExtractor

Feature extractor instance.

train_fracfloat

Target fraction of windows for training (default: 0.7).

val_fracfloat

Target fraction of windows for validation (default: 0.1).

test_fracfloat

Target fraction of windows for testing (default: 0.2).

window_sfloat

Window length in seconds for each sample.

window_hop_sfloat

Hop between consecutive windows in seconds.

batch_sizeint

Batch size for DataLoaders.

num_workersint

Number of workers for DataLoaders.

Attributes

extractor

Feature extractor used by this data module.

n_recordings

Total number of valid recordings.

test_recordings

Test recordings (available after setup or split).

train_dataset

Training dataset (available after setup).

train_recordings

Training recordings (available after setup or split).

val_dataset

Validation dataset (available after setup).

val_recordings

Validation recordings (available after setup or split).

Methods

setup(stage)

Set up datasets for the fit stage.

train_dataloader()

Return the training DataLoader.

val_dataloader()

Return the validation DataLoader.

Notes

The splitting is done at the recording level, balanced by window count rather than file count. This ensures each split contains approximately the target fraction of training samples, even when recordings have different durations.

For reproducible splits, call lightning.seed_everything(seed) before instantiating and calling setup().

Only the "fit" stage is supported for setup(). For evaluation on the held-out test split, use evaluate_recordings() with test_recordings.

Examples

>>> from pathlib import Path
>>> from callcut.extractors import SNRExtractor
>>> from callcut.training import CallDataModule
>>> from callcut.pipeline import evaluate_recordings
>>> import lightning as L
>>>
>>> extractor = SNRExtractor(sample_rate=32000, hop_ms=8.0, n_bands=8)
>>> recordings = list(Path("data/").glob("*.wav"))
>>> dm = CallDataModule(
...     recordings=recordings,
...     extractor=extractor,
...     batch_size=32,
... )
>>> trainer = L.Trainer(max_epochs=10)
>>> trainer.fit(model, datamodule=dm)
>>>
>>> # Evaluate on held-out test recordings
>>> report = evaluate_recordings(
...     model, extractor, dm.test_recordings, decoder, matcher
... )
setup(stage)[source]🔗

Set up datasets for the fit stage.

Parameters:
stagestr

Must be "fit". Other stages are not supported; use evaluate_recordings() for evaluation on held-out test recordings (available via test_recordings).

train_dataloader()[source]🔗

Return the training DataLoader.

Returns:
dataloaderDataLoader

Training DataLoader with shuffling enabled.

val_dataloader()[source]🔗

Return the validation DataLoader.

Returns:
dataloaderDataLoader | None

Validation DataLoader, or None if no validation recordings.

property extractor🔗

Feature extractor used by this data module.

Type:

BaseExtractor

property n_recordings🔗

Total number of valid recordings.

Type:

int

property test_recordings🔗

Test recordings (available after setup or split).

Type:

list of RecordingInfo

property train_dataset🔗

Training dataset (available after setup).

Type:

CallDataset | None

property train_recordings🔗

Training recordings (available after setup or split).

Type:

list of RecordingInfo

property val_dataset🔗

Validation dataset (available after setup).

Type:

CallDataset | None

property val_recordings🔗

Validation recordings (available after setup or split).

Type:

list of RecordingInfo