callcut.training.SaveBestModelCallbackπŸ”—

class callcut.training.SaveBestModelCallback(save_path, monitor='val_f1', mode='max')[source]πŸ”—

Callback that saves the best model weights during training.

Monitors a metric and saves the model’s state_dict when it improves. After training, load the best weights with torch.load() and model.load_state_dict(), then use save_pipeline() to save the full pipeline.

Parameters:
save_pathPath | str

Path to save the best model weights.

monitorstr

Metric to monitor (default: "val_f1").

modestr

One of "min" or "max". Whether to minimize or maximize the metric.

Methods

on_validation_epoch_end(trainer, pl_module)

Check if model improved and save weights if so.

Examples

>>> from callcut.training import SaveBestModelCallback
>>> import lightning as L
>>>
>>> trainer = L.Trainer(
...     max_epochs=10,
...     callbacks=[SaveBestModelCallback("best_weights.pt", monitor="val_f1")],
... )
>>> trainer.fit(module, datamodule=dm)
>>>
>>> # After training, load best weights and save full pipeline
>>> model.load_state_dict(torch.load("best_weights.pt", weights_only=True))
>>> save_pipeline(model, extractor, "pipeline.pt", decoder=decoder)
on_validation_epoch_end(trainer, pl_module)[source]πŸ”—

Check if model improved and save weights if so.

Parameters:
trainerTrainer

The Lightning trainer instance.

pl_moduleLightningModule

The Lightning module being trained.