Skip to content

Evaluator

torchloop.evaluator.Evaluator

Classification model evaluator.

Parameters:

Name Type Description Default
model

Trained nn.Module.

required
device

'cuda', 'cpu', or 'mps'. Auto-detects if None.

required
Source code in src/torchloop/evaluator.py
class Evaluator:
    """
    Classification model evaluator.

    Args:
        model       : Trained nn.Module.
        device      : 'cuda', 'cpu', or 'mps'. Auto-detects if None.
    """

    def __init__(self, model: nn.Module, device: Optional[str] = None):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.model = model.to(self.device)

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------

    def report(
        self,
        loader: DataLoader,
        class_names: Optional[list[str]] = None,
    ) -> dict:
        """
        Print full sklearn classification report.

        Returns:
            dict with keys: accuracy, macro_f1, weighted_f1, per_class_f1
        """
        preds, targets = self._infer(loader)
        report = classification_report(
            targets, preds, target_names=class_names, zero_division=0
        )
        print(report)
        per_class = f1_score(targets, preds, average=None, zero_division=0).tolist()
        return {
            "accuracy": float((np.array(preds) == np.array(targets)).mean()),
            "macro_f1": float(
                f1_score(targets, preds, average="macro", zero_division=0)
            ),
            "weighted_f1": float(
                f1_score(targets, preds, average="weighted", zero_division=0)
            ),
            "per_class_f1": {
                (class_names[i] if class_names else str(i)): round(v, 4)
                for i, v in enumerate(per_class)
            },
        }

    def confusion_matrix(
        self,
        loader: DataLoader,
        class_names: Optional[list[str]] = None,
        normalize: Optional[str] = "true",   # 'true' | 'pred' | 'all' | None
        figsize: tuple = (8, 6),
    ) -> plt.Figure:
        """
        Plot and return confusion matrix figure.
        """
        preds, targets = self._infer(loader)
        cm = confusion_matrix(targets, preds, normalize=normalize)
        fig, ax = plt.subplots(figsize=figsize)
        disp = ConfusionMatrixDisplay(cm, display_labels=class_names)
        disp.plot(ax=ax, colorbar=True, cmap="Blues")
        ax.set_title("Confusion Matrix")
        plt.tight_layout()
        return fig

    def f1_per_class(
        self,
        loader: DataLoader,
        class_names: Optional[list[str]] = None,
    ) -> dict[str, float]:
        """
        Returns per-class F1 as a dict. Clean for logging to W&B or MLflow.
        """
        preds, targets = self._infer(loader)
        scores = f1_score(targets, preds, average=None, zero_division=0)
        return {
            (class_names[i] if class_names else str(i)): round(float(s), 4)
            for i, s in enumerate(scores)
        }

    # ------------------------------------------------------------------
    # Internal
    # ------------------------------------------------------------------

    def _infer(self, loader: DataLoader) -> tuple[list, list]:
        self.model.eval()
        all_preds, all_targets = [], []
        with torch.no_grad():
            for inputs, targets in loader:
                inputs = inputs.to(self.device)
                outputs = self.model(inputs)
                preds = outputs.argmax(dim=1).cpu().tolist()
                all_preds.extend(preds)
                all_targets.extend(targets.tolist())
        return all_preds, all_targets

confusion_matrix(loader, class_names=None, normalize='true', figsize=(8, 6))

Plot and return confusion matrix figure.

Source code in src/torchloop/evaluator.py
def confusion_matrix(
    self,
    loader: DataLoader,
    class_names: Optional[list[str]] = None,
    normalize: Optional[str] = "true",   # 'true' | 'pred' | 'all' | None
    figsize: tuple = (8, 6),
) -> plt.Figure:
    """
    Plot and return confusion matrix figure.
    """
    preds, targets = self._infer(loader)
    cm = confusion_matrix(targets, preds, normalize=normalize)
    fig, ax = plt.subplots(figsize=figsize)
    disp = ConfusionMatrixDisplay(cm, display_labels=class_names)
    disp.plot(ax=ax, colorbar=True, cmap="Blues")
    ax.set_title("Confusion Matrix")
    plt.tight_layout()
    return fig

f1_per_class(loader, class_names=None)

Returns per-class F1 as a dict. Clean for logging to W&B or MLflow.

Source code in src/torchloop/evaluator.py
def f1_per_class(
    self,
    loader: DataLoader,
    class_names: Optional[list[str]] = None,
) -> dict[str, float]:
    """
    Returns per-class F1 as a dict. Clean for logging to W&B or MLflow.
    """
    preds, targets = self._infer(loader)
    scores = f1_score(targets, preds, average=None, zero_division=0)
    return {
        (class_names[i] if class_names else str(i)): round(float(s), 4)
        for i, s in enumerate(scores)
    }

report(loader, class_names=None)

Print full sklearn classification report.

Returns:

Type Description
dict

dict with keys: accuracy, macro_f1, weighted_f1, per_class_f1

Source code in src/torchloop/evaluator.py
def report(
    self,
    loader: DataLoader,
    class_names: Optional[list[str]] = None,
) -> dict:
    """
    Print full sklearn classification report.

    Returns:
        dict with keys: accuracy, macro_f1, weighted_f1, per_class_f1
    """
    preds, targets = self._infer(loader)
    report = classification_report(
        targets, preds, target_names=class_names, zero_division=0
    )
    print(report)
    per_class = f1_score(targets, preds, average=None, zero_division=0).tolist()
    return {
        "accuracy": float((np.array(preds) == np.array(targets)).mean()),
        "macro_f1": float(
            f1_score(targets, preds, average="macro", zero_division=0)
        ),
        "weighted_f1": float(
            f1_score(targets, preds, average="weighted", zero_division=0)
        ),
        "per_class_f1": {
            (class_names[i] if class_names else str(i)): round(v, 4)
            for i, v in enumerate(per_class)
        },
    }