Skip to content

Callbacks

torchloop.callbacks.base.Callback

Base callback with optional training lifecycle hooks.

Subclasses can override any hook they need.

Source code in src/torchloop/callbacks/base.py
class Callback:
    """Base callback with optional training lifecycle hooks.

    Subclasses can override any hook they need.
    """

    def on_train_begin(self, logs: dict) -> None:
        """Run before training starts."""

    def on_epoch_end(self, epoch: int, logs: dict) -> None:
        """Run at the end of each epoch."""

    def on_train_end(self, logs: dict) -> None:
        """Run after training finishes."""

on_epoch_end(epoch, logs)

Run at the end of each epoch.

Source code in src/torchloop/callbacks/base.py
def on_epoch_end(self, epoch: int, logs: dict) -> None:
    """Run at the end of each epoch."""

on_train_begin(logs)

Run before training starts.

Source code in src/torchloop/callbacks/base.py
def on_train_begin(self, logs: dict) -> None:
    """Run before training starts."""

on_train_end(logs)

Run after training finishes.

Source code in src/torchloop/callbacks/base.py
def on_train_end(self, logs: dict) -> None:
    """Run after training finishes."""

torchloop.callbacks.wandb_logger.WandBLogger

Bases: Callback

Log training metrics to Weights & Biases.

Parameters:

Name Type Description Default
project str

Weights & Biases project name.

required
name Optional[str]

Optional run name.

None
config Optional[dict[str, Any]]

Optional run configuration dictionary.

None
Source code in src/torchloop/callbacks/wandb_logger.py
class WandBLogger(Callback):
    """Log training metrics to Weights & Biases.

    Args:
        project: Weights & Biases project name.
        name: Optional run name.
        config: Optional run configuration dictionary.
    """

    def __init__(
        self,
        project: str,
        name: Optional[str] = None,
        config: Optional[dict[str, Any]] = None,
    ) -> None:
        self.project = project
        self.name = name
        self.config = config or {}

    def on_train_begin(self, logs: dict) -> None:
        """Initialize a W&B run."""
        try:
            wandb = import_module("wandb")
        except ImportError as exc:
            raise ImportError(
                "wandb is required for WandBLogger. "
                "Install with: pip install torchloop[logging]"
            ) from exc

        wandb.init(project=self.project, name=self.name, config=self.config)

    def on_epoch_end(self, epoch: int, logs: dict) -> None:
        """Log epoch metrics to W&B."""
        try:
            wandb = import_module("wandb")
        except ImportError as exc:
            raise ImportError(
                "wandb is required for WandBLogger. "
                "Install with: pip install torchloop[logging]"
            ) from exc

        wandb.log(logs, step=epoch)

    def on_train_end(self, logs: dict) -> None:
        """Finish the active W&B run."""
        try:
            wandb = import_module("wandb")
        except ImportError as exc:
            raise ImportError(
                "wandb is required for WandBLogger. "
                "Install with: pip install torchloop[logging]"
            ) from exc

        wandb.finish()

on_epoch_end(epoch, logs)

Log epoch metrics to W&B.

Source code in src/torchloop/callbacks/wandb_logger.py
def on_epoch_end(self, epoch: int, logs: dict) -> None:
    """Log epoch metrics to W&B."""
    try:
        wandb = import_module("wandb")
    except ImportError as exc:
        raise ImportError(
            "wandb is required for WandBLogger. "
            "Install with: pip install torchloop[logging]"
        ) from exc

    wandb.log(logs, step=epoch)

on_train_begin(logs)

Initialize a W&B run.

Source code in src/torchloop/callbacks/wandb_logger.py
def on_train_begin(self, logs: dict) -> None:
    """Initialize a W&B run."""
    try:
        wandb = import_module("wandb")
    except ImportError as exc:
        raise ImportError(
            "wandb is required for WandBLogger. "
            "Install with: pip install torchloop[logging]"
        ) from exc

    wandb.init(project=self.project, name=self.name, config=self.config)

on_train_end(logs)

Finish the active W&B run.

Source code in src/torchloop/callbacks/wandb_logger.py
def on_train_end(self, logs: dict) -> None:
    """Finish the active W&B run."""
    try:
        wandb = import_module("wandb")
    except ImportError as exc:
        raise ImportError(
            "wandb is required for WandBLogger. "
            "Install with: pip install torchloop[logging]"
        ) from exc

    wandb.finish()

torchloop.callbacks.mlflow_logger.MLflowLogger

Bases: Callback

Log training metrics to MLflow.

Parameters:

Name Type Description Default
experiment_name str

MLflow experiment name.

required
tracking_uri Optional[str]

Optional tracking server URI.

None
run_name Optional[str]

Optional MLflow run name.

None
Source code in src/torchloop/callbacks/mlflow_logger.py
class MLflowLogger(Callback):
    """Log training metrics to MLflow.

    Args:
        experiment_name: MLflow experiment name.
        tracking_uri: Optional tracking server URI.
        run_name: Optional MLflow run name.
    """

    def __init__(
        self,
        experiment_name: str,
        tracking_uri: Optional[str] = None,
        run_name: Optional[str] = None,
    ) -> None:
        self.experiment_name = experiment_name
        self.tracking_uri = tracking_uri
        self.run_name = run_name

    def on_train_begin(self, logs: dict) -> None:
        """Initialize MLflow experiment and run."""
        try:
            mlflow = import_module("mlflow")
        except ImportError as exc:
            raise ImportError(
                "mlflow is required for MLflowLogger. "
                "Install with: pip install torchloop[logging]"
            ) from exc

        if self.tracking_uri:
            mlflow.set_tracking_uri(self.tracking_uri)
        mlflow.set_experiment(self.experiment_name)
        mlflow.start_run(run_name=self.run_name)

    def on_epoch_end(self, epoch: int, logs: dict) -> None:
        """Log epoch metrics to MLflow."""
        try:
            mlflow = import_module("mlflow")
        except ImportError as exc:
            raise ImportError(
                "mlflow is required for MLflowLogger. "
                "Install with: pip install torchloop[logging]"
            ) from exc

        numeric_logs = {
            key: value
            for key, value in logs.items()
            if isinstance(value, (int, float))
        }
        mlflow.log_metrics(numeric_logs, step=epoch)

    def on_train_end(self, logs: dict) -> None:
        """End the active MLflow run."""
        try:
            mlflow = import_module("mlflow")
        except ImportError as exc:
            raise ImportError(
                "mlflow is required for MLflowLogger. "
                "Install with: pip install torchloop[logging]"
            ) from exc

        mlflow.end_run()

on_epoch_end(epoch, logs)

Log epoch metrics to MLflow.

Source code in src/torchloop/callbacks/mlflow_logger.py
def on_epoch_end(self, epoch: int, logs: dict) -> None:
    """Log epoch metrics to MLflow."""
    try:
        mlflow = import_module("mlflow")
    except ImportError as exc:
        raise ImportError(
            "mlflow is required for MLflowLogger. "
            "Install with: pip install torchloop[logging]"
        ) from exc

    numeric_logs = {
        key: value
        for key, value in logs.items()
        if isinstance(value, (int, float))
    }
    mlflow.log_metrics(numeric_logs, step=epoch)

on_train_begin(logs)

Initialize MLflow experiment and run.

Source code in src/torchloop/callbacks/mlflow_logger.py
def on_train_begin(self, logs: dict) -> None:
    """Initialize MLflow experiment and run."""
    try:
        mlflow = import_module("mlflow")
    except ImportError as exc:
        raise ImportError(
            "mlflow is required for MLflowLogger. "
            "Install with: pip install torchloop[logging]"
        ) from exc

    if self.tracking_uri:
        mlflow.set_tracking_uri(self.tracking_uri)
    mlflow.set_experiment(self.experiment_name)
    mlflow.start_run(run_name=self.run_name)

on_train_end(logs)

End the active MLflow run.

Source code in src/torchloop/callbacks/mlflow_logger.py
def on_train_end(self, logs: dict) -> None:
    """End the active MLflow run."""
    try:
        mlflow = import_module("mlflow")
    except ImportError as exc:
        raise ImportError(
            "mlflow is required for MLflowLogger. "
            "Install with: pip install torchloop[logging]"
        ) from exc

    mlflow.end_run()