Skip to content

classification

Model training and evaluation for glaucoma classification.

Overview

This module handles:

  • Bootstrap evaluation with 1000 iterations
  • STRATOS-compliant metric computation
  • Multiple classifier support (CatBoost default)
  • Subject-wise analysis

Main Entry Point

flow_classification

flow_classification

flow_classification(cfg: DictConfig) -> None

Main classification flow for glaucoma screening from PLR features.

Orchestrates the classification pipeline including feature-based and time-series classification approaches. Initializes MLflow experiment and delegates to subflows.

PARAMETER DESCRIPTION
cfg

Hydra configuration with PREFECT flow names and settings.

TYPE: DictConfig

Notes

Time-series classification is currently disabled as it showed limited promise after refactoring.

Source code in src/classification/flow_classification.py
def flow_classification(cfg: DictConfig) -> None:
    """
    Main classification flow for glaucoma screening from PLR features.

    Orchestrates the classification pipeline including feature-based
    and time-series classification approaches. Initializes MLflow
    experiment and delegates to subflows.

    Parameters
    ----------
    cfg : DictConfig
        Hydra configuration with PREFECT flow names and settings.

    Notes
    -----
    Time-series classification is currently disabled as it showed
    limited promise after refactoring.
    """
    experiment_name = experiment_name_wrapper(
        experiment_name=cfg["PREFECT"]["FLOW_NAMES"]["CLASSIFICATION"], cfg=cfg
    )
    logger.info("FLOW | Name: {}".format(experiment_name))
    logger.info("=====================")
    prev_experiment_name = experiment_name_wrapper(
        experiment_name=cfg["PREFECT"]["FLOW_NAMES"]["FEATURIZATION"], cfg=cfg
    )

    # Init the MLflow experiment
    init_mlflow_experiment(experiment_name=experiment_name)

    # Classify from hand-crafted features/embeddings
    flow_feature_classification(cfg, prev_experiment_name)

    # Classify from time series
    ts_cls = False
    if ts_cls:
        raise NotImplementedError(
            "Need to be finished, new bug with the refactoring, but did not seem promising"
        )

Bootstrap Evaluation

bootstrap_evaluation

prepare_for_bootstrap

prepare_for_bootstrap(
    dict_arrays: dict, method_cfg: DictConfig
)

Prepare data arrays for bootstrap evaluation.

Sets up index arrays for stratified bootstrap resampling. The train split will be resampled into new train/val splits while test remains untouched.

PARAMETER DESCRIPTION
dict_arrays

Dictionary containing: - x_train, y_train: Training features and labels - x_test, y_test: Test features and labels - subject_codes_train, subject_codes_test: Subject identifiers

TYPE: dict

method_cfg

Bootstrap configuration with 'join_test_and_train' option.

TYPE: DictConfig

RETURNS DESCRIPTION
dict

Updated dict_arrays with 'X_idxs' array for bootstrap sampling.

References
  • https://machinelearningmastery.com/calculate-bootstrap-confidence-intervals-machine-learning-results-python/
Source code in src/classification/bootstrap_evaluation.py
def prepare_for_bootstrap(dict_arrays: dict, method_cfg: DictConfig):
    """
    Prepare data arrays for bootstrap evaluation.

    Sets up index arrays for stratified bootstrap resampling. The train split
    will be resampled into new train/val splits while test remains untouched.

    Parameters
    ----------
    dict_arrays : dict
        Dictionary containing:
        - x_train, y_train: Training features and labels
        - x_test, y_test: Test features and labels
        - subject_codes_train, subject_codes_test: Subject identifiers
    method_cfg : DictConfig
        Bootstrap configuration with 'join_test_and_train' option.

    Returns
    -------
    dict
        Updated dict_arrays with 'X_idxs' array for bootstrap sampling.

    References
    ----------
    - https://machinelearningmastery.com/calculate-bootstrap-confidence-intervals-machine-learning-results-python/
    """
    # Bootstrap resample is splitting the train split into -> new train and val
    # Test will be untouched
    dict_arrays["X_idxs"] = np.linspace(
        0, dict_arrays["x_train"].shape[0] - 1, dict_arrays["x_train"].shape[0]
    ).astype(int)

    if method_cfg["join_test_and_train"]:
        raise NotImplementedError
        # X = np.concatenate((X_train, X_test), axis=0)
        # y = np.concatenate((y_train, y_test), axis=0)
        # X_test, y_test, codes_test = None, None, None
    else:
        assert dict_arrays["x_test"].shape[0] == dict_arrays["y_test"].shape[0], (
            "X_test and y_test must have the same number of rows"
        )
        assert (
            dict_arrays["x_test"].shape[0] == dict_arrays["subject_codes_test"].shape[0]
        ), "X_test and subject_codes_test must have the same number of rows"

    assert dict_arrays["x_train"].shape[0] == dict_arrays["y_train"].shape[0], (
        "X and y must have the same number of rows"
    )
    assert dict_arrays["x_train"].shape[0] == dict_arrays["X_idxs"].shape[0], (
        "X and X_idxs must have the same number of rows"
    )
    assert (
        dict_arrays["x_train"].shape[0] == dict_arrays["subject_codes_train"].shape[0]
    ), "X and subject_codes_train must have the same number of rows"

    return dict_arrays

select_bootstrap_samples

select_bootstrap_samples(
    dict_arrays, n_samples, method_cfg
) -> dict

Select bootstrap samples for a single iteration.

Performs stratified bootstrap resampling to create new train/val splits from the original training data.

PARAMETER DESCRIPTION
dict_arrays

Data arrays including X_idxs for sampling.

TYPE: dict

n_samples

Number of samples to draw for training.

TYPE: int

method_cfg

Bootstrap configuration.

TYPE: DictConfig

RETURNS DESCRIPTION
dict

Dictionary with resampled train/val data arrays.

Source code in src/classification/bootstrap_evaluation.py
def select_bootstrap_samples(dict_arrays, n_samples, method_cfg) -> dict:
    """
    Select bootstrap samples for a single iteration.

    Performs stratified bootstrap resampling to create new train/val splits
    from the original training data.

    Parameters
    ----------
    dict_arrays : dict
        Data arrays including X_idxs for sampling.
    n_samples : int
        Number of samples to draw for training.
    method_cfg : DictConfig
        Bootstrap configuration.

    Returns
    -------
    dict
        Dictionary with resampled train/val data arrays.
    """

    def reample_split_indices(X_idxs: np.ndarray, n_samples: int, y: np.ndarray):
        train_idxs = resample(X_idxs, n_samples=n_samples, stratify=y)
        val_idxs = np.array(
            [x for x in X_idxs if x.tolist() not in train_idxs.tolist()]
        )
        return train_idxs, val_idxs

    dict_arrays_iter = dict_arrays.copy()

    # Get indices of the new split samples
    dict_arrays_iter["train_idxs"], dict_arrays_iter["val_idxs"] = (
        reample_split_indices(
            X_idxs=dict_arrays_iter["X_idxs"],
            n_samples=n_samples,
            y=dict_arrays_iter["y_train"],
        )
    )

    # TODO! You could obviously try to loop these and parametrize the split(s)
    #  and make this more compact
    # Tmp the original train split
    x_train = dict_arrays_iter["x_train"]
    y_train = dict_arrays_iter["y_train"]
    subject_codes_train = dict_arrays_iter["subject_codes_train"]
    x_train_w = dict_arrays_iter["x_train_w"]

    # Pick the corresponding samples
    dict_arrays_iter["x_train"] = x_train[dict_arrays_iter["train_idxs"]]
    dict_arrays_iter["y_train"] = y_train[dict_arrays_iter["train_idxs"]]
    dict_arrays_iter["x_val"] = x_train[dict_arrays_iter["val_idxs"]]
    dict_arrays_iter["y_val"] = y_train[dict_arrays_iter["val_idxs"]]

    dict_arrays_iter["subject_codes_train"] = subject_codes_train[
        dict_arrays_iter["train_idxs"]
    ]
    dict_arrays_iter["subject_codes_val"] = subject_codes_train[
        dict_arrays_iter["val_idxs"]
    ]
    dict_arrays_iter["x_train_w"] = x_train_w[dict_arrays_iter["train_idxs"]]
    dict_arrays_iter["x_val_w"] = x_train_w[dict_arrays_iter["val_idxs"]]

    assert dict_arrays_iter["x_train"].shape[0] == dict_arrays_iter["y_train"].shape[0]
    assert dict_arrays_iter["x_val"].shape[0] == dict_arrays_iter["y_val"].shape[0]
    assert (
        dict_arrays_iter["x_train_w"].shape[0] == dict_arrays_iter["y_train"].shape[0]
    )
    assert dict_arrays_iter["x_val_w"].shape[0] == dict_arrays_iter["y_val"].shape[0]
    assert (
        dict_arrays_iter["x_train"].shape[0]
        == dict_arrays_iter["subject_codes_train"].shape[0]
    )
    assert (
        dict_arrays_iter["x_val"].shape[0]
        == dict_arrays_iter["subject_codes_val"].shape[0]
    )

    return dict(sorted(dict_arrays_iter.items()))

splits_as_dicts

splits_as_dicts(dict_arrays_iter: dict)

Convert flat array dictionary to nested split-based structure.

PARAMETER DESCRIPTION
dict_arrays_iter

Flat dictionary with keys like 'x_train', 'y_train', etc.

TYPE: dict

RETURNS DESCRIPTION
dict

Nested dictionary with structure: {split: {'X': ..., 'y': ..., 'w': ..., 'codes': ...}}

Source code in src/classification/bootstrap_evaluation.py
def splits_as_dicts(dict_arrays_iter: dict):
    """
    Convert flat array dictionary to nested split-based structure.

    Parameters
    ----------
    dict_arrays_iter : dict
        Flat dictionary with keys like 'x_train', 'y_train', etc.

    Returns
    -------
    dict
        Nested dictionary with structure:
        {split: {'X': ..., 'y': ..., 'w': ..., 'codes': ...}}
    """
    splits = ["train", "val", "test"]
    dict_splits = {}
    for split in splits:
        dict_splits[split] = {
            "X": dict_arrays_iter[f"x_{split}"],
            "y": dict_arrays_iter[f"y_{split}"],
            "w": dict_arrays_iter[f"x_{split}_w"],
            "codes": dict_arrays_iter[f"subject_codes_{split}"],
        }
        assert dict_splits[split]["X"].shape[0] == dict_splits[split]["y"].shape[0]
        assert dict_splits[split]["X"].shape[0] == dict_splits[split]["w"].shape[0]
        assert dict_splits[split]["X"].shape[0] == dict_splits[split]["codes"].shape[0]

    return dict_splits

check_bootstrap_iteration_quality

check_bootstrap_iteration_quality(
    metrics_iter, dict_arrays_iter, dict_arrays
)

Validate that bootstrap iteration used all expected samples.

Checks that train/val splits used all subject codes from the original training data and that test samples match expected count.

PARAMETER DESCRIPTION
metrics_iter

Metrics collected from all bootstrap iterations.

TYPE: dict

dict_arrays_iter

Data arrays from the current iteration.

TYPE: dict

dict_arrays

Original data arrays before bootstrap resampling.

TYPE: dict

RAISES DESCRIPTION
AssertionError

If train/val codes don't match original or test samples count is wrong.

Source code in src/classification/bootstrap_evaluation.py
def check_bootstrap_iteration_quality(metrics_iter, dict_arrays_iter, dict_arrays):
    """
    Validate that bootstrap iteration used all expected samples.

    Checks that train/val splits used all subject codes from the original
    training data and that test samples match expected count.

    Parameters
    ----------
    metrics_iter : dict
        Metrics collected from all bootstrap iterations.
    dict_arrays_iter : dict
        Data arrays from the current iteration.
    dict_arrays : dict
        Original data arrays before bootstrap resampling.

    Raises
    ------
    AssertionError
        If train/val codes don't match original or test samples count is wrong.
    """
    train_codes_used = list(
        metrics_iter["train"]["preds_dict"]["arrays"]["y_pred_proba"].keys()
    )
    val_codes_used = list(
        metrics_iter["val"]["preds_dict"]["arrays"]["y_pred_proba"].keys()
    )
    # in the bootstrap scenario, after all the iterations, both splits should have had all the codes used
    # from the original train split
    assert len(train_codes_used) == len(val_codes_used), (
        "Train and val codes must have the same length"
    )
    assert len(train_codes_used) == dict_arrays["subject_codes_train"].shape[0], (
        "All codes must have been used"
    )

    # this has a different structure, as test samples are always the same across the bootstrapping
    # so we can just aggregate predictions to a np.ndarray
    no_test_samples_used = metrics_iter["test"]["preds"]["arrays"]["predictions"][
        "y_pred_proba"
    ].shape[0]
    assert no_test_samples_used == dict_arrays["x_test"].shape[0], (
        "All test samples must have been used"
    )

get_ensemble_stats

get_ensemble_stats(
    metrics_iter,
    dict_arrays,
    method_cfg,
    call_from: str = None,
    sort_list: bool = True,
    verbose: bool = True,
)

Compute aggregate statistics from bootstrap iterations.

Aggregates per-iteration metrics into final statistics including: - Mean and CI for AUROC, Brier, etc. - Per-subject prediction statistics - Global uncertainty metrics

PARAMETER DESCRIPTION
metrics_iter

Per-iteration metrics from bootstrap.

TYPE: dict

dict_arrays

Original data arrays.

TYPE: dict

method_cfg

Bootstrap configuration.

TYPE: DictConfig

call_from

Caller identifier for logging.

TYPE: str DEFAULT: None

sort_list

Sort subject statistics.

TYPE: bool DEFAULT: True

verbose

Enable verbose logging.

TYPE: bool DEFAULT: True

RETURNS DESCRIPTION
tuple

(metrics_stats, subjectwise_stats, subject_global_stats)

Source code in src/classification/bootstrap_evaluation.py
def get_ensemble_stats(
    metrics_iter,
    dict_arrays,
    method_cfg,
    call_from: str = None,
    sort_list: bool = True,
    verbose: bool = True,
):
    """
    Compute aggregate statistics from bootstrap iterations.

    Aggregates per-iteration metrics into final statistics including:
    - Mean and CI for AUROC, Brier, etc.
    - Per-subject prediction statistics
    - Global uncertainty metrics

    Parameters
    ----------
    metrics_iter : dict
        Per-iteration metrics from bootstrap.
    dict_arrays : dict
        Original data arrays.
    method_cfg : DictConfig
        Bootstrap configuration.
    call_from : str, optional
        Caller identifier for logging.
    sort_list : bool, default True
        Sort subject statistics.
    verbose : bool, default True
        Enable verbose logging.

    Returns
    -------
    tuple
        (metrics_stats, subjectwise_stats, subject_global_stats)
    """
    # Compute the final stats of the metrics (scalar AUROC, array ROC curves, etc.)
    try:
        metrics_stats = bootstrap_compute_stats(
            metrics_iter, method_cfg, call_from, verbose=verbose
        )
    except Exception as e:
        logger.error(f"Error in computing metrics stats: {e}")
        metrics_stats = None

    # Compute the stats of the predictions (class probabilities per subject)
    try:
        subjectwise_stats = bootstrap_compute_subject_stats(
            metrics_iter,
            dict_arrays,
            method_cfg,
            sort_list=sort_list,
            call_from=call_from,
            verbose=verbose,
        )

    except Exception as e:
        logger.error(f"Error in computing subjectwise stats: {e}")
        subjectwise_stats = None

    try:
        # Subjectwise Uncertainty metrics ("unks" from Catboost tutorial)
        subjectwise_stats = compute_uq_for_subjectwise_stats(
            metrics_iter, subjectwise_stats, verbose=verbose
        )
    except Exception as e:
        logger.error(f"Error in computing subjectwise uncertainty: {e}")

    # Compute "mean response" of the subjects, e.g. scalar mean UQ metric to describe the whole model uncertainty
    try:
        subject_global_stats = bootstrap_compute_global_subject_stats(
            subjectwise_stats, method_cfg, verbose=verbose
        )
    except Exception as e:
        logger.error(f"Error in computing global subject stats: {e}")
        subject_global_stats = None

    return metrics_stats, subjectwise_stats, subject_global_stats

append_models_to_list_for_mlflow

append_models_to_list_for_mlflow(
    models: list, model, model_name: str, i: int
)

Add a trained model to the list for MLflow logging.

Handles special cases like moving TabM models from GPU to CPU to avoid memory issues during serialization.

PARAMETER DESCRIPTION
models

List of trained models from previous iterations.

TYPE: list

model

The trained model from current iteration.

TYPE: object

model_name

Name of the classifier (e.g., 'TabM', 'CatBoost').

TYPE: str

i

Current bootstrap iteration index.

TYPE: int

RETURNS DESCRIPTION
list

Updated list of models with the new model appended.

Source code in src/classification/bootstrap_evaluation.py
def append_models_to_list_for_mlflow(models: list, model, model_name: str, i: int):
    """
    Add a trained model to the list for MLflow logging.

    Handles special cases like moving TabM models from GPU to CPU to avoid
    memory issues during serialization.

    Parameters
    ----------
    models : list
        List of trained models from previous iterations.
    model : object
        The trained model from current iteration.
    model_name : str
        Name of the classifier (e.g., 'TabM', 'CatBoost').
    i : int
        Current bootstrap iteration index.

    Returns
    -------
    list
        Updated list of models with the new model appended.
    """
    if model_name == "TabM":
        # classifier is now on CUDA and will cause possible memory issues if you don't detach it and use CPU
        model = model.to("cpu")
    models.append(deepcopy(model))
    return models

bootstrap_evaluator

bootstrap_evaluator(
    model_name: str,
    run_name: str,
    dict_arrays: dict,
    best_params,
    cls_model_cfg: DictConfig,
    method_cfg: DictConfig,
    hparam_cfg: DictConfig,
    cfg: DictConfig,
    debug_aggregation: bool = False,
)

Run bootstrap evaluation for classifier performance estimation.

Performs n_iterations of bootstrap resampling to estimate: - STRATOS-compliant metrics (AUROC, calibration, clinical utility) - Confidence intervals via percentile bootstrap - Per-subject prediction uncertainty

PARAMETER DESCRIPTION
model_name

Classifier name (e.g., 'CatBoost', 'XGBoost').

TYPE: str

run_name

MLflow run name.

TYPE: str

dict_arrays

Data arrays with train/test splits.

TYPE: dict

best_params

Best hyperparameters from optimization.

TYPE: dict

cls_model_cfg

Classifier model configuration.

TYPE: DictConfig

method_cfg

Bootstrap configuration with 'n_iterations', 'data_ratio'.

TYPE: DictConfig

hparam_cfg

Hyperparameter configuration.

TYPE: DictConfig

cfg

Full Hydra configuration.

TYPE: DictConfig

debug_aggregation

Enable debug logging for metric aggregation.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
tuple

(models, results_dict) where: - models: List of trained models (one per iteration) - results_dict: Contains metrics_iter, metrics_stats, subjectwise_stats, subject_global_stats

Notes

Uses stratified resampling to maintain class balance across iterations. Test set remains fixed; only train is resampled into train/val.

Source code in src/classification/bootstrap_evaluation.py
def bootstrap_evaluator(
    model_name: str,
    run_name: str,
    dict_arrays: dict,
    best_params,
    cls_model_cfg: DictConfig,
    method_cfg: DictConfig,
    hparam_cfg: DictConfig,
    cfg: DictConfig,
    debug_aggregation: bool = False,
):
    """
    Run bootstrap evaluation for classifier performance estimation.

    Performs n_iterations of bootstrap resampling to estimate:
    - STRATOS-compliant metrics (AUROC, calibration, clinical utility)
    - Confidence intervals via percentile bootstrap
    - Per-subject prediction uncertainty

    Parameters
    ----------
    model_name : str
        Classifier name (e.g., 'CatBoost', 'XGBoost').
    run_name : str
        MLflow run name.
    dict_arrays : dict
        Data arrays with train/test splits.
    best_params : dict
        Best hyperparameters from optimization.
    cls_model_cfg : DictConfig
        Classifier model configuration.
    method_cfg : DictConfig
        Bootstrap configuration with 'n_iterations', 'data_ratio'.
    hparam_cfg : DictConfig
        Hyperparameter configuration.
    cfg : DictConfig
        Full Hydra configuration.
    debug_aggregation : bool, default False
        Enable debug logging for metric aggregation.

    Returns
    -------
    tuple
        (models, results_dict) where:
        - models: List of trained models (one per iteration)
        - results_dict: Contains metrics_iter, metrics_stats,
          subjectwise_stats, subject_global_stats

    Notes
    -----
    Uses stratified resampling to maintain class balance across iterations.
    Test set remains fixed; only train is resampled into train/val.
    """
    warnings.simplefilter("ignore")
    start_time = time.time()
    dict_arrays = prepare_for_bootstrap(dict_arrays, method_cfg)
    n_samples = int(dict_arrays["X_idxs"].shape[0] * method_cfg["data_ratio"])
    metrics_iter = {}
    models = []

    for i in tqdm(
        range(method_cfg["n_iterations"]),
        total=method_cfg["n_iterations"],
        desc="Bootstrap iterations",
    ):
        # What samples to use per iteration (sample weights do not require re-computation)
        dict_arrays_iter = select_bootstrap_samples(dict_arrays, n_samples, method_cfg)

        # Update weights and other dataset dependent params as well
        weights_dict = return_weights_as_dict(dict_arrays_iter, cls_model_cfg)

        # Retrain the model with the bootstrapped samples
        model, results_per_iter = bootstrap_model_selector(
            model_name=model_name,
            cls_model_cfg=cls_model_cfg,
            hparam_cfg=hparam_cfg,
            cfg=cfg,
            best_params=best_params,
            dict_arrays=dict_arrays_iter,
            weights_dict=weights_dict,
        )

        # Calibrate classifier (if desired)
        model = bootstrap_calibrate_classifier(
            i, model, cls_model_cfg, dict_arrays_iter, weights_dict
        )

        # Append to a list to be saved to MLflow
        models = append_models_to_list_for_mlflow(models, model, model_name, i)

        # Easier to evaluate with nested dictionaries instead of the flat one above
        dict_splits = splits_as_dicts(dict_arrays_iter)

        # Compute your scalar metrics (AUROC, etc.), ROC Curve stats, and patient predictions with uncertainty
        metrics_iter = bootstrap_metrics(
            i,
            model,
            dict_splits,
            metrics_iter,
            results_per_iter,
            method_cfg,
            cfg,
            debug_aggregation=debug_aggregation,
            model_name=model_name,
        )

    # Check bootstrap iters
    del model
    check_bootstrap_iteration_quality(metrics_iter, dict_arrays_iter, dict_arrays)

    # Get ensemble stats
    metrics_stats, subjectwise_stats, subject_global_stats = get_ensemble_stats(
        metrics_iter, dict_arrays, method_cfg
    )

    end_time = time.time() - start_time
    logger.info("Bootstrap evaluation in {:.2f} seconds".format(end_time))
    try:
        mlflow.log_param("bootstrap_time", end_time)
    except Exception as e:
        # Might happen during grid search, you are trying to write over the same value
        logger.debug(f"MLFlow logging failed: {e}")
    warnings.resetwarnings()

    return models, {
        "metrics_iter": metrics_iter,
        "metrics_stats": metrics_stats,
        "subjectwise_stats": subjectwise_stats,
        "subject_global_stats": subject_global_stats,
    }

Metrics and Statistics

stats_metric_utils

interpolation_wrapper

interpolation_wrapper(
    x: ndarray,
    y: ndarray,
    x_new: ndarray,
    n_samples: int,
    metric: str,
    _kind: str = "linear",
) -> tuple[ndarray, ndarray]

Interpolate metric curves to a fixed number of points.

PARAMETER DESCRIPTION
x

Original x values.

TYPE: array - like

y

Original y values.

TYPE: array - like

x_new

New x values to interpolate to.

TYPE: array - like

n_samples

Number of samples for interpolation.

TYPE: int

metric

Metric name ('AUROC', 'AUPR', 'calibration_curve').

TYPE: str

_kind

Interpolation method (currently unused).

TYPE: str DEFAULT: 'linear'

RETURNS DESCRIPTION
tuple

(x_new, y_new) interpolated arrays.

Source code in src/classification/stats_metric_utils.py
def interpolation_wrapper(
    x: np.ndarray,
    y: np.ndarray,
    x_new: np.ndarray,
    n_samples: int,
    metric: str,
    _kind: str = "linear",
) -> tuple[np.ndarray, np.ndarray]:
    """
    Interpolate metric curves to a fixed number of points.

    Parameters
    ----------
    x : array-like
        Original x values.
    y : array-like
        Original y values.
    x_new : array-like
        New x values to interpolate to.
    n_samples : int
        Number of samples for interpolation.
    metric : str
        Metric name ('AUROC', 'AUPR', 'calibration_curve').
    _kind : str, default 'linear'
        Interpolation method (currently unused).

    Returns
    -------
    tuple
        (x_new, y_new) interpolated arrays.
    """

    def clip_illegal_calibration_values(y_new: np.ndarray) -> np.ndarray:
        y_new[y_new < 0] = 0
        y_new[y_new > 1] = 1
        return y_new

    warnings.simplefilter("ignore")
    assert len(x) == len(y), "x and y must have the same length"
    if metric == "calibration_curve":
        # so few points and does not behave well
        x_new = np.linspace(0, 1, 10)
        f = interp1d(x, y, kind="linear", fill_value="extrapolate")
        y_new = f(x_new)
        y_new = clip_illegal_calibration_values(y_new)

    else:
        if len(x) > 1:
            f = interp1d(x, y, kind="linear")
            y_new = f(x_new)
        else:
            # you got only one value, not a vector
            y_new = np.repeat(y, n_samples)
    warnings.resetwarnings()

    return x_new, y_new

bootstrap_get_array_axis_names

bootstrap_get_array_axis_names(
    metric: str,
) -> tuple[str, str]

Get x and y axis names for a given metric curve.

PARAMETER DESCRIPTION
metric

Metric name ('AUROC', 'AUPR', 'calibration_curve').

TYPE: str

RETURNS DESCRIPTION
tuple

(x_name, y_name) axis names for the metric.

RAISES DESCRIPTION
ValueError

If unknown metric specified.

Source code in src/classification/stats_metric_utils.py
def bootstrap_get_array_axis_names(metric: str) -> tuple[str, str]:
    """
    Get x and y axis names for a given metric curve.

    Parameters
    ----------
    metric : str
        Metric name ('AUROC', 'AUPR', 'calibration_curve').

    Returns
    -------
    tuple
        (x_name, y_name) axis names for the metric.

    Raises
    ------
    ValueError
        If unknown metric specified.
    """
    if metric == "AUROC":
        x_name = "fpr"
        y_name = "tpr"
    elif metric == "AUPR":
        x_name = "recall"
        y_name = "precision"
    elif metric == "calibration_curve":
        # https://scikit-learn.org/stable/auto_examples/calibration/plot_calibration_curve.html
        x_name = "prob_pred"
        y_name = "prob_true"
    else:
        logger.error(
            f"Unknown metric: {metric} (only AUROC, AUPR and calibration curve supported)"
        )
        raise ValueError

    return x_name, y_name

bootstrap_interpolate_metric_arrays

bootstrap_interpolate_metric_arrays(
    arrays: dict[str, Any], n_samples: int = 200
) -> dict[str, Any]

Interpolate all metric arrays to a fixed number of samples.

Enables aggregation of ROC/PR curves across bootstrap iterations by standardizing the x-axis.

PARAMETER DESCRIPTION
arrays

Dictionary of metric arrays with varying lengths.

TYPE: dict

n_samples

Number of points for interpolation.

TYPE: int DEFAULT: 200

RETURNS DESCRIPTION
dict

Dictionary of interpolated metric arrays.

Source code in src/classification/stats_metric_utils.py
def bootstrap_interpolate_metric_arrays(
    arrays: dict[str, Any], n_samples: int = 200
) -> dict[str, Any]:
    """
    Interpolate all metric arrays to a fixed number of samples.

    Enables aggregation of ROC/PR curves across bootstrap iterations
    by standardizing the x-axis.

    Parameters
    ----------
    arrays : dict
        Dictionary of metric arrays with varying lengths.
    n_samples : int, default 200
        Number of points for interpolation.

    Returns
    -------
    dict
        Dictionary of interpolated metric arrays.
    """
    arrays_out = {}
    for metric in arrays.keys():
        arrays_out[metric] = {}
        x_name, y_name = bootstrap_get_array_axis_names(metric)

        # Interpolate (actual metric)
        arrays_out[metric][x_name], arrays_out[metric][y_name] = interpolation_wrapper(
            x=arrays[metric][x_name],
            y=arrays[metric][y_name],
            x_new=np.linspace(
                arrays[metric][x_name][0], arrays[metric][y_name][-1], n_samples
            ),
            n_samples=n_samples,
            metric=metric,
        )

        # Interpolate (thresholds as well)
        if "thresholds" in arrays[metric]:
            # exist for AUROC and AUPR
            _, arrays_out[metric]["thresholds"] = interpolation_wrapper(
                x=np.linspace(0, 1, len(arrays[metric]["thresholds"])),
                y=arrays[metric]["thresholds"],
                x_new=np.linspace(0, 1, n_samples),
                n_samples=n_samples,
                metric="thresholds",
            )

    return arrays_out

bootstrap_aggregate_arrays

bootstrap_aggregate_arrays(
    arrays: dict[str, Any],
    metrics_per_split: dict[str, Any],
    main_key: str = "metrics",
) -> dict[str, Any]

Aggregate array metrics across bootstrap iterations.

Stacks interpolated curves (ROC, PR, calibration) horizontally for later statistical analysis.

PARAMETER DESCRIPTION
arrays

Interpolated metric arrays from current iteration.

TYPE: dict

metrics_per_split

Accumulated metrics from previous iterations.

TYPE: dict

main_key

Key for storing metrics in output dict.

TYPE: str DEFAULT: "metrics"

RETURNS DESCRIPTION
dict

Updated metrics_per_split with new arrays appended.

Source code in src/classification/stats_metric_utils.py
def bootstrap_aggregate_arrays(
    arrays: dict[str, Any], metrics_per_split: dict[str, Any], main_key: str = "metrics"
) -> dict[str, Any]:
    """
    Aggregate array metrics across bootstrap iterations.

    Stacks interpolated curves (ROC, PR, calibration) horizontally
    for later statistical analysis.

    Parameters
    ----------
    arrays : dict
        Interpolated metric arrays from current iteration.
    metrics_per_split : dict
        Accumulated metrics from previous iterations.
    main_key : str, default "metrics"
        Key for storing metrics in output dict.

    Returns
    -------
    dict
        Updated metrics_per_split with new arrays appended.
    """
    # For first bootstrap iteration, initialize the metrics_per_split dict
    if main_key not in metrics_per_split.keys():
        metrics_per_split[main_key] = {}
    if "arrays" not in metrics_per_split[main_key].keys():
        metrics_per_split[main_key]["arrays"] = {}

    # Go through all the scalar metrics and aggregate them, flexible, so you can add new metrics without changing this
    for metric in arrays.keys():
        if metric not in metrics_per_split[main_key]["arrays"].keys():
            # For first iteration
            metrics_per_split[main_key]["arrays"][metric] = {}
        values = arrays[metric]
        for variable in values.keys():
            array_var: np.ndarray = values[variable][:, np.newaxis]  # e,g, (200, 1)
            if variable not in metrics_per_split[main_key]["arrays"][metric].keys():
                # For first iteration, (curve_length, no_iterations)
                # print(1, ' ', metric, variable, array_var.shape)
                metrics_per_split[main_key]["arrays"][metric][variable] = array_var
            else:
                # When there is something already here
                # print(2, " ", metric, variable, array_var.shape)
                try:
                    metrics_per_split[main_key]["arrays"][metric][variable] = np.hstack(
                        (
                            metrics_per_split[main_key]["arrays"][metric][variable],
                            array_var,
                        )
                    )
                except Exception as e:
                    logger.error(f"Could not stack arrays: {e}")
                    logger.error(
                        f"previous shape: "
                        f"{metrics_per_split[main_key]['arrays'][metric][variable].shape[0]}, "
                        f"and current shape: "
                        f"{array_var.shape}"
                    )
                    raise e

    return metrics_per_split

bootstrap_aggregate_scalars

bootstrap_aggregate_scalars(
    metrics_dict: dict[str, Any],
    metrics_per_split: dict[str, Any],
) -> dict[str, Any]

Aggregate scalar metrics across bootstrap iterations.

Appends scalar values (AUROC, Brier, etc.) to lists for later statistical analysis.

PARAMETER DESCRIPTION
metrics_dict

Metrics from current iteration.

TYPE: dict

metrics_per_split

Accumulated metrics from previous iterations.

TYPE: dict

RETURNS DESCRIPTION
dict

Updated metrics_per_split with new scalars appended.

Source code in src/classification/stats_metric_utils.py
def bootstrap_aggregate_scalars(
    metrics_dict: dict[str, Any], metrics_per_split: dict[str, Any]
) -> dict[str, Any]:
    """
    Aggregate scalar metrics across bootstrap iterations.

    Appends scalar values (AUROC, Brier, etc.) to lists for later
    statistical analysis.

    Parameters
    ----------
    metrics_dict : dict
        Metrics from current iteration.
    metrics_per_split : dict
        Accumulated metrics from previous iterations.

    Returns
    -------
    dict
        Updated metrics_per_split with new scalars appended.
    """
    # For first bootstrap iteration, initialize the metrics_per_split dict
    if "metrics" not in metrics_per_split.keys():
        metrics_per_split["metrics"] = {}
    if "scalars" not in metrics_per_split["metrics"].keys():
        metrics_per_split["metrics"]["scalars"] = {}

    # Go through all the scalar metrics and aggregate them, flexible, so you can add new metrics without changing this
    for metric in metrics_dict["metrics"]["scalars"].keys():
        if metric not in metrics_per_split["metrics"]["scalars"].keys():
            # For first iteration
            metrics_per_split["metrics"]["scalars"][metric] = [
                metrics_dict["metrics"]["scalars"][metric]
            ]
        else:
            # When there is something already here
            metrics_per_split["metrics"]["scalars"][metric].append(
                metrics_dict["metrics"]["scalars"][metric]
            )

    return metrics_per_split

bootstrap_aggregate_by_subject_per_split

bootstrap_aggregate_by_subject_per_split(
    arrays: dict[str, Any],
    metrics_per_split: dict[str, Any],
    codes_per_split: ndarray,
    main_key: str,
    subkey: str = "predictions",
    is_init_with_correct_codes: bool = False,
) -> dict[str, Any]

Aggregate predictions by subject code across bootstrap iterations.

Used for train/val splits where subjects vary between iterations due to resampling. Stores predictions keyed by subject code.

PARAMETER DESCRIPTION
arrays

Predictions from current iteration.

TYPE: dict

metrics_per_split

Accumulated predictions from previous iterations.

TYPE: dict

codes_per_split

Subject codes for current iteration.

TYPE: ndarray

main_key

Key for storing predictions.

TYPE: str

subkey

Subkey within arrays.

TYPE: str DEFAULT: "predictions"

is_init_with_correct_codes

If True, expects codes to already exist.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
dict

Updated metrics_per_split with predictions aggregated by subject.

Source code in src/classification/stats_metric_utils.py
def bootstrap_aggregate_by_subject_per_split(
    arrays: dict[str, Any],
    metrics_per_split: dict[str, Any],
    codes_per_split: np.ndarray,
    main_key: str,
    subkey: str = "predictions",
    is_init_with_correct_codes: bool = False,
) -> dict[str, Any]:
    """
    Aggregate predictions by subject code across bootstrap iterations.

    Used for train/val splits where subjects vary between iterations
    due to resampling. Stores predictions keyed by subject code.

    Parameters
    ----------
    arrays : dict
        Predictions from current iteration.
    metrics_per_split : dict
        Accumulated predictions from previous iterations.
    codes_per_split : np.ndarray
        Subject codes for current iteration.
    main_key : str
        Key for storing predictions.
    subkey : str, default "predictions"
        Subkey within arrays.
    is_init_with_correct_codes : bool, default False
        If True, expects codes to already exist.

    Returns
    -------
    dict
        Updated metrics_per_split with predictions aggregated by subject.
    """
    # For first bootstrap iteration, initialize the metrics_per_split dict
    if main_key not in metrics_per_split.keys():
        metrics_per_split[main_key] = {}
    if "arrays" not in metrics_per_split[main_key].keys():
        metrics_per_split[main_key]["arrays"] = {}
    code_warnings = []

    # Aggregate in a dict based on the subject code. Remember that in the bootstrap resample, you
    # get the same code multiple times in each iteration, so you get less unique codes per iteration
    # than you have samples (if you are counting the number of codes added per iteration and are confused)
    for metric in arrays[subkey].keys():
        if metric not in metrics_per_split[main_key]["arrays"].keys():
            metrics_per_split[main_key]["arrays"][metric] = {}
        values: np.ndarray = arrays[subkey][metric]
        assert len(values) == len(codes_per_split), (
            "Values and codes must have the same length"
        )

        for i, code in enumerate(codes_per_split):
            value = values[i]
            if code not in metrics_per_split[main_key]["arrays"][metric].keys():
                # When this code has not yet been added
                if is_init_with_correct_codes:
                    # if you are trying to add a code that is already there, you are doing something wrong
                    logger.debug(
                        "You are trying to add a code that is not in the split?, code = {}".format(
                            code
                        )
                    )
                    code_warnings.append(code)
                else:
                    # if you have an empty the dict, then you can append
                    metrics_per_split[main_key]["arrays"][metric][code] = [value]
            else:
                # Just keep on appending then over the iterations
                metrics_per_split[main_key]["arrays"][metric][code].append(value)

    if len(code_warnings) > 0:
        logger.warning(
            "You tried to add {} codes that are not in the split".format(
                len(code_warnings)
            )
        )
        logger.warning(f"Codes that are not in the split: {code_warnings}")

    return metrics_per_split

bootstrap_aggregate_subjects

bootstrap_aggregate_subjects(
    metrics_per_split: dict[str, Any],
    codes_per_split: ndarray,
    split: str,
    preds: dict[str, ndarray],
) -> dict[str, Any]

Aggregate subject predictions based on split type.

For test split, predictions are stacked as arrays (same subjects each iter). For train/val, predictions are stored in dicts keyed by subject code.

PARAMETER DESCRIPTION
metrics_per_split

Accumulated metrics and predictions.

TYPE: dict

codes_per_split

Subject codes for current split.

TYPE: ndarray

split

Split name ('train', 'val', 'test').

TYPE: str

preds

Predictions from current iteration.

TYPE: dict

RETURNS DESCRIPTION
dict

Updated metrics_per_split with aggregated predictions.

RAISES DESCRIPTION
ValueError

If unknown split specified.

Source code in src/classification/stats_metric_utils.py
def bootstrap_aggregate_subjects(
    metrics_per_split: dict[str, Any],
    codes_per_split: np.ndarray,
    split: str,
    preds: dict[str, np.ndarray],
) -> dict[str, Any]:
    """
    Aggregate subject predictions based on split type.

    For test split, predictions are stacked as arrays (same subjects each iter).
    For train/val, predictions are stored in dicts keyed by subject code.

    Parameters
    ----------
    metrics_per_split : dict
        Accumulated metrics and predictions.
    codes_per_split : np.ndarray
        Subject codes for current split.
    split : str
        Split name ('train', 'val', 'test').
    preds : dict
        Predictions from current iteration.

    Returns
    -------
    dict
        Updated metrics_per_split with aggregated predictions.

    Raises
    ------
    ValueError
        If unknown split specified.
    """
    # Aggregate the predictions as well so you could get average probabilty per patient,
    # and some uncertainty quantification (aleatoric and epistemic uncertainty)?

    # Add a "dummy key" so that the array code above works for this without modifications
    arrays_preds = {"predictions": preds}
    if split == "test":
        # The subjects are always the same for each iteration
        metrics_per_split = bootstrap_aggregate_arrays(
            arrays=arrays_preds, metrics_per_split=metrics_per_split, main_key="preds"
        )
        from src.ensemble.ensemble_classification import check_metrics_iter_preds

        check_metrics_iter_preds(
            dict_arrays=metrics_per_split["preds"]["arrays"]["predictions"]
        )
    elif split == "train" or split == "val":
        # Subjects are not the same now for each iterations, so we need to aggregate them
        # by the subject code
        metrics_per_split = bootstrap_aggregate_by_subject_per_split(
            arrays=arrays_preds,
            metrics_per_split=metrics_per_split,
            codes_per_split=codes_per_split,
            main_key="preds_dict",
        )
        from src.ensemble.ensemble_classification import check_metrics_iter_preds_dict

        check_metrics_iter_preds_dict(
            dict_arrays=metrics_per_split["preds_dict"]["arrays"]
        )
    else:
        logger.error(f"Unknown split: {split}")
        raise ValueError

    return metrics_per_split

bootstrap_metrics_per_split

bootstrap_metrics_per_split(
    X: ndarray,
    y_true: ndarray,
    preds: dict[str, ndarray],
    model: Any,
    model_name: str,
    metrics_per_split: dict[str, Any],
    codes_per_split: ndarray,
    method_cfg: DictConfig,
    cfg: DictConfig,
    split: str,
    skip_mlflow: bool = False,
    recompute_for_ensemble: bool = False,
) -> dict[str, Any]

Compute and aggregate metrics for a single split in bootstrap iteration.

Calculates classifier metrics, calibration metrics, interpolates curves, and aggregates all results across iterations.

PARAMETER DESCRIPTION
X

Feature matrix for the split.

TYPE: ndarray

y_true

True labels.

TYPE: ndarray

preds

Model predictions with 'y_pred', 'y_pred_proba'.

TYPE: dict

model

Trained classifier model.

TYPE: object

model_name

Name of the classifier.

TYPE: str

metrics_per_split

Accumulated metrics from previous iterations.

TYPE: dict

codes_per_split

Subject codes for this split.

TYPE: ndarray

method_cfg

Bootstrap method configuration.

TYPE: DictConfig

cfg

Full Hydra configuration.

TYPE: DictConfig

split

Split name ('train', 'val', 'test').

TYPE: str

skip_mlflow

Skip MLflow logging.

TYPE: bool DEFAULT: False

recompute_for_ensemble

If True, skip subject aggregation (already done).

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
dict

Updated metrics_per_split with new iteration's results.

Source code in src/classification/stats_metric_utils.py
def bootstrap_metrics_per_split(
    X: np.ndarray,
    y_true: np.ndarray,
    preds: dict[str, np.ndarray],
    model: Any,
    model_name: str,
    metrics_per_split: dict[str, Any],
    codes_per_split: np.ndarray,
    method_cfg: DictConfig,
    cfg: DictConfig,
    split: str,
    skip_mlflow: bool = False,
    recompute_for_ensemble: bool = False,
) -> dict[str, Any]:
    """
    Compute and aggregate metrics for a single split in bootstrap iteration.

    Calculates classifier metrics, calibration metrics, interpolates curves,
    and aggregates all results across iterations.

    Parameters
    ----------
    X : np.ndarray
        Feature matrix for the split.
    y_true : np.ndarray
        True labels.
    preds : dict
        Model predictions with 'y_pred', 'y_pred_proba'.
    model : object
        Trained classifier model.
    model_name : str
        Name of the classifier.
    metrics_per_split : dict
        Accumulated metrics from previous iterations.
    codes_per_split : np.ndarray
        Subject codes for this split.
    method_cfg : DictConfig
        Bootstrap method configuration.
    cfg : DictConfig
        Full Hydra configuration.
    split : str
        Split name ('train', 'val', 'test').
    skip_mlflow : bool, default False
        Skip MLflow logging.
    recompute_for_ensemble : bool, default False
        If True, skip subject aggregation (already done).

    Returns
    -------
    dict
        Updated metrics_per_split with new iteration's results.
    """
    assert len(y_true) == len(preds["y_pred"]), (
        "y_true and y_pred must have the same length"
    )
    assert len(y_true) == len(codes_per_split), (
        "y_true and codes_per_split must have the same length"
    )

    # Get the basic metrics that you want
    metrics_dict = get_classifier_metrics(
        y_true, preds=preds, cfg=cfg, skip_mlflow=skip_mlflow, model_name=model_name
    )

    # Note with such few samples for calibration&uncertainty, results might be crap, do after the bootstrap?
    # Doing now both
    # Get calibration metrics
    metrics_dict = get_calibration_metrics(model, metrics_dict, y_true, preds=preds)

    # Interpolate the ROC and PR curves to a shared fixed length so you can aggregate them and do stats
    interpolated_arrays = bootstrap_interpolate_metric_arrays(
        arrays=metrics_dict["metrics"]["arrays"], n_samples=method_cfg["curve_x_length"]
    )

    # aggregate the scalar metrics
    metrics_per_split = bootstrap_aggregate_scalars(
        metrics_dict=metrics_dict, metrics_per_split=metrics_per_split
    )

    # aggregate the array curves
    metrics_per_split = bootstrap_aggregate_arrays(
        arrays=interpolated_arrays, metrics_per_split=metrics_per_split
    )

    # Aggregate the subject predictions (i.e. preds)
    if not recompute_for_ensemble:
        # during a "live" bootstrapping (as in when not ensembling from results loaded from MLflow)
        # ENSEMBLING: this is basically created already by the previous ensembling functions
        metrics_per_split = bootstrap_aggregate_subjects(
            metrics_per_split=metrics_per_split,
            codes_per_split=codes_per_split,
            split=split,
            preds=preds,
        )

    return metrics_per_split

bootstrap_predict

bootstrap_predict(
    model: Any,
    X: ndarray,
    i: int,
    split: str,
    debug_aggregation: bool = True,
) -> dict[str, ndarray]

Get predictions from model for bootstrap iteration.

PARAMETER DESCRIPTION
model

Trained classifier with predict_proba() method.

TYPE: object

X

Feature matrix.

TYPE: ndarray

i

Bootstrap iteration index.

TYPE: int

split

Split name for logging.

TYPE: str

debug_aggregation

If True, log debug info for test split.

TYPE: bool DEFAULT: True

RETURNS DESCRIPTION
dict

Predictions with 'y_pred_proba' and 'y_pred' keys.

RAISES DESCRIPTION
Exception

If model prediction fails.

Source code in src/classification/stats_metric_utils.py
def bootstrap_predict(
    model: Any, X: np.ndarray, i: int, split: str, debug_aggregation: bool = True
) -> dict[str, np.ndarray]:
    """
    Get predictions from model for bootstrap iteration.

    Parameters
    ----------
    model : object
        Trained classifier with predict_proba() method.
    X : np.ndarray
        Feature matrix.
    i : int
        Bootstrap iteration index.
    split : str
        Split name for logging.
    debug_aggregation : bool, default True
        If True, log debug info for test split.

    Returns
    -------
    dict
        Predictions with 'y_pred_proba' and 'y_pred' keys.

    Raises
    ------
    Exception
        If model prediction fails.
    """
    try:
        predict_probs = model.predict_proba(X)  # (n_samples, n_classes), e.g. (72,2)
        preds = {
            "y_pred_proba": predict_probs[
                :, 1
            ],  # (n_samples,), e.g. (72,) for the class 1 (e.g. glaucoma)
            "y_pred": model.predict(X),  # (n_samples,), e.g. (72,)
        }
    except Exception as e:
        logger.error(f"Could not get prediction from the model: {e}")
        raise e

    if debug_aggregation:
        if split == "test":
            logger.info(
                f"DEBUG iter #{i + 1}: 1st sample probs the test split: {preds['y_pred_proba'][0]}"
            )

    return preds

tabm_demodata_fix

tabm_demodata_fix(
    preds: dict[str, ndarray],
) -> dict[str, ndarray]

Fix TabM prediction array length mismatch on demo data.

PARAMETER DESCRIPTION
preds

Predictions dictionary with 'y_pred_proba', 'y_pred', 'label'.

TYPE: dict

RETURNS DESCRIPTION
dict

Fixed predictions dictionary.

RAISES DESCRIPTION
RuntimeError

If prediction length doesn't match expected ratio.

Source code in src/classification/stats_metric_utils.py
def tabm_demodata_fix(preds: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
    """
    Fix TabM prediction array length mismatch on demo data.

    Parameters
    ----------
    preds : dict
        Predictions dictionary with 'y_pred_proba', 'y_pred', 'label'.

    Returns
    -------
    dict
        Fixed predictions dictionary.

    Raises
    ------
    RuntimeError
        If prediction length doesn't match expected ratio.
    """
    no_pred_length = len(preds["y_pred_proba"])
    no_pred_labels = len(preds["label"])

    if no_pred_length != no_pred_labels:
        if no_pred_length == 2 * no_pred_labels:
            preds["y_pred"] = preds["y_pred"][0:no_pred_labels]
            preds["y_pred_proba"] = preds["y_pred_proba"][0:no_pred_labels]
        else:
            logger.error("Number of predictions does not match number of labels")
            raise RuntimeError("Number of predictions does not match number of labels")

    return preds

bootstrap_metrics

bootstrap_metrics(
    i: int,
    model: Any,
    dict_splits: dict[str, dict[str, ndarray]],
    metrics: dict[str, dict],
    results_per_iter: dict[str, dict] | None,
    method_cfg: DictConfig,
    cfg: DictConfig,
    debug_aggregation: bool = False,
    model_name: str | None = None,
) -> dict[str, dict]

i: int which bootstrap iteration, or a submodel of the ensemble dict_splits: dict test: dict X: np.ndarray y: np.ndarray w: np.ndarray codes: np.ndarray metrics: dict, e.g. {} on i=0 results_per_iter, e.g. None on i=0

Source code in src/classification/stats_metric_utils.py
def bootstrap_metrics(
    i: int,
    model: Any,
    dict_splits: dict[str, dict[str, np.ndarray]],
    metrics: dict[str, dict],
    results_per_iter: dict[str, dict] | None,
    method_cfg: DictConfig,
    cfg: DictConfig,
    debug_aggregation: bool = False,
    model_name: str | None = None,
) -> dict[str, dict]:
    """
    i: int
        which bootstrap iteration, or a submodel of the ensemble
    dict_splits: dict
        test: dict
            X: np.ndarray
            y: np.ndarray
            w: np.ndarray
            codes: np.ndarray
    metrics: dict, e.g. {} on i=0
    results_per_iter, e.g. None on i=0
    """
    for split in dict_splits.keys():
        if split not in metrics:
            metrics[split] = {}
        X = dict_splits[split]["X"]
        y_true = dict_splits[split]["y"]
        if results_per_iter is None:
            # for sklearn API like models, we can do predict() here
            preds = bootstrap_predict(
                model, X, i, split, debug_aggregation=debug_aggregation
            )
        else:
            # more custom models like TabM, have already the preds done
            if model_name is not None:
                if model_name == "TabM":
                    preds = get_tabm_preds_from_results_for_bootstrap(
                        split_results=results_per_iter[split]
                    )
                elif model_name == "CATBOOST":
                    preds = get_catboost_preds_from_results_for_bootstrap(
                        split_results=results_per_iter[split], split=split
                    )
                elif model_name == "TabPFN":
                    preds = results_per_iter[split]
                else:
                    logger.error(
                        "Some novel model name for bootstrap? model_name = {}".format(
                            model_name
                        )
                    )
                    raise ValueError(
                        "Some novel model name for bootstrap? model_name = {}".format(
                            model_name
                        )
                    )

        # add the label here (if you want to plot some probs distributions as a function of label, apply threshold
        # tuning, or whatever)
        preds["label"] = y_true

        # hacky fix if you are running TabM, on demo data
        if model_name == "TabM":
            preds = tabm_demodata_fix(preds)

        assert preds["label"].shape[0] == preds["y_pred_proba"].shape[0], (
            f"label ({preds['label'].shape[0]}) and "
            f"probs ({preds['y_pred_proba'].shape[0]}) should have the same length"
        )

        # TODO! Add the "granular label", when you add it to the input data?
        #  early, moderate, severe, etc. for glaucoma, or whatever you have

        warnings.simplefilter("ignore")
        metrics[split] = bootstrap_metrics_per_split(
            X,
            y_true,
            preds,
            model,
            model_name,
            metrics_per_split=metrics[split],
            codes_per_split=dict_splits[split]["codes"],
            method_cfg=method_cfg,
            cfg=cfg,
            split=split,
        )
        warnings.resetwarnings()

    return metrics

get_p_from_alpha

get_p_from_alpha(alpha: float = DEFAULT_CI_LEVEL) -> float

Convert confidence level alpha to percentile value.

PARAMETER DESCRIPTION
alpha

Confidence level (e.g., 0.95 for 95% CI).

TYPE: float DEFAULT: DEFAULT_CI_LEVEL

RETURNS DESCRIPTION
float

Percentile value for lower bound (e.g., 2.5 for alpha=0.95).

Source code in src/classification/stats_metric_utils.py
def get_p_from_alpha(alpha: float = DEFAULT_CI_LEVEL) -> float:
    """
    Convert confidence level alpha to percentile value.

    Parameters
    ----------
    alpha : float, default DEFAULT_CI_LEVEL
        Confidence level (e.g., 0.95 for 95% CI).

    Returns
    -------
    float
        Percentile value for lower bound (e.g., 2.5 for alpha=0.95).
    """
    return np.round(
        ((1.0 - alpha) / 2.0) * 100, 1
    )  # e.g. 2.5 with alpha=0.95 (2.5 - 97.5%)

bootstrap_scalar_stats_per_metric

bootstrap_scalar_stats_per_metric(
    values: ndarray, method_cfg: DictConfig
) -> dict[str, int | float | ndarray]

Compute summary statistics for a scalar metric across bootstrap iterations.

PARAMETER DESCRIPTION
values

Array of metric values from all iterations.

TYPE: ndarray

method_cfg

Bootstrap configuration with 'alpha_CI'.

TYPE: DictConfig

RETURNS DESCRIPTION
dict

Statistics with 'n', 'mean', 'std', 'ci' keys.

Source code in src/classification/stats_metric_utils.py
def bootstrap_scalar_stats_per_metric(
    values: np.ndarray, method_cfg: DictConfig
) -> dict[str, int | float | np.ndarray]:
    """
    Compute summary statistics for a scalar metric across bootstrap iterations.

    Parameters
    ----------
    values : np.ndarray
        Array of metric values from all iterations.
    method_cfg : DictConfig
        Bootstrap configuration with 'alpha_CI'.

    Returns
    -------
    dict
        Statistics with 'n', 'mean', 'std', 'ci' keys.
    """
    dict_out = {}
    dict_out["n"] = len(values)
    warnings.simplefilter("ignore")
    dict_out["mean"] = np.nanmean(values)
    dict_out["std"] = np.nanstd(values)

    # Confidence intervals
    p = get_p_from_alpha(alpha=method_cfg["alpha_CI"])
    if dict_out["std"] != 0:
        # all the values are the same, no point in trying to estimate CI, and get a warning clogging up your logs
        dict_out["ci"] = np.nanpercentile(values, [p, 100 - p])
    else:
        dict_out["ci"] = np.array((np.nan, np.nan))
    warnings.resetwarnings()

    return dict_out

convert_inf_to_nan

convert_inf_to_nan(values: ndarray) -> ndarray

Replace infinite values with NaN in array.

PARAMETER DESCRIPTION
values

2D array possibly containing inf values.

TYPE: ndarray

RETURNS DESCRIPTION
ndarray

Array with inf replaced by NaN.

RAISES DESCRIPTION
NotImplementedError

If array is not 2D.

Source code in src/classification/stats_metric_utils.py
def convert_inf_to_nan(values: np.ndarray) -> np.ndarray:
    """
    Replace infinite values with NaN in array.

    Parameters
    ----------
    values : np.ndarray
        2D array possibly containing inf values.

    Returns
    -------
    np.ndarray
        Array with inf replaced by NaN.

    Raises
    ------
    NotImplementedError
        If array is not 2D.
    """
    # vector-based for 2D arrays? instead of the loop
    if np.any(np.isinf(values)):
        if len(values.shape) == 2:
            for row in range(values.shape[0]):
                for col in range(values.shape[1]):
                    is_inf = np.isinf(values[row, col])
                    if is_inf:
                        values[row, col] = np.nan
        else:
            raise NotImplementedError(f"Not implemented, {values.shape}dim array")
    return values

get_array_stats_per_metric

get_array_stats_per_metric(
    values: ndarray,
    method_cfg: DictConfig,
    inf_to_nan: bool = True,
) -> dict[str, ndarray]

Compute summary statistics for array metrics across bootstrap iterations.

PARAMETER DESCRIPTION
values

2D array of shape (curve_length, n_iterations).

TYPE: ndarray

method_cfg

Bootstrap configuration with 'alpha_CI'.

TYPE: DictConfig

inf_to_nan

Convert infinite values to NaN before computing stats.

TYPE: bool DEFAULT: True

RETURNS DESCRIPTION
dict

Statistics with 'mean', 'std', 'ci' arrays.

Source code in src/classification/stats_metric_utils.py
def get_array_stats_per_metric(
    values: np.ndarray, method_cfg: DictConfig, inf_to_nan: bool = True
) -> dict[str, np.ndarray]:
    """
    Compute summary statistics for array metrics across bootstrap iterations.

    Parameters
    ----------
    values : np.ndarray
        2D array of shape (curve_length, n_iterations).
    method_cfg : DictConfig
        Bootstrap configuration with 'alpha_CI'.
    inf_to_nan : bool, default True
        Convert infinite values to NaN before computing stats.

    Returns
    -------
    dict
        Statistics with 'mean', 'std', 'ci' arrays.
    """
    if inf_to_nan:
        values = convert_inf_to_nan(values)
    warnings.simplefilter("ignore")
    dict_out = {}
    dict_out["mean"] = np.mean(values, axis=1)
    dict_out["std"] = np.mean(values, axis=1)

    # Confidence intervals
    p = round(((1.0 - method_cfg["alpha_CI"]) / 2.0) * 100, 1)
    if ~np.all(values.flatten() == values.flatten()[0]):
        dict_out["ci"] = np.nanpercentile(values, [p, 100 - p], axis=1)
    else:
        # if all the values are the same, no point in trying to estimate CI
        dict_out["ci"] = np.array((np.nan, np.nan))
    warnings.resetwarnings()
    return dict_out

bootstrap_scalar_stats

bootstrap_scalar_stats(
    metrics_per_split: dict[str, ndarray],
    method_cfg: DictConfig,
    split: str,
) -> dict[str, dict[str, int | float | ndarray]]

Compute statistics for all scalar metrics in a split.

PARAMETER DESCRIPTION
metrics_per_split

Accumulated scalar metrics per metric name.

TYPE: dict

method_cfg

Bootstrap configuration.

TYPE: DictConfig

split

Split name (unused, for signature compatibility).

TYPE: str

RETURNS DESCRIPTION
dict

Statistics per metric with mean, std, CI.

Source code in src/classification/stats_metric_utils.py
def bootstrap_scalar_stats(
    metrics_per_split: dict[str, np.ndarray], method_cfg: DictConfig, split: str
) -> dict[str, dict[str, int | float | np.ndarray]]:
    """
    Compute statistics for all scalar metrics in a split.

    Parameters
    ----------
    metrics_per_split : dict
        Accumulated scalar metrics per metric name.
    method_cfg : DictConfig
        Bootstrap configuration.
    split : str
        Split name (unused, for signature compatibility).

    Returns
    -------
    dict
        Statistics per metric with mean, std, CI.
    """
    metrics_out = {}
    for metric in metrics_per_split.keys():
        metrics_out[metric] = bootstrap_scalar_stats_per_metric(
            values=metrics_per_split[metric],
            method_cfg=method_cfg,
        )
    return metrics_out

bootstrap_array_stats

bootstrap_array_stats(
    metrics_per_split: dict[str, dict[str, ndarray]],
    method_cfg: DictConfig,
) -> dict[str, dict[str, dict[str, ndarray]]]

Compute statistics for all array metrics (curves) in a split.

PARAMETER DESCRIPTION
metrics_per_split

Accumulated array metrics per metric name.

TYPE: dict

method_cfg

Bootstrap configuration.

TYPE: DictConfig

RETURNS DESCRIPTION
dict

Statistics per metric and variable with mean, std, CI arrays.

Source code in src/classification/stats_metric_utils.py
def bootstrap_array_stats(
    metrics_per_split: dict[str, dict[str, np.ndarray]], method_cfg: DictConfig
) -> dict[str, dict[str, dict[str, np.ndarray]]]:
    """
    Compute statistics for all array metrics (curves) in a split.

    Parameters
    ----------
    metrics_per_split : dict
        Accumulated array metrics per metric name.
    method_cfg : DictConfig
        Bootstrap configuration.

    Returns
    -------
    dict
        Statistics per metric and variable with mean, std, CI arrays.
    """
    metrics_out = {}
    for metric in metrics_per_split.keys():
        metrics_out[metric] = {}
        for variable in metrics_per_split[metric].keys():
            metrics_out[metric][variable] = get_array_stats_per_metric(
                values=metrics_per_split[metric][variable], method_cfg=method_cfg
            )

    return metrics_out

check_bootstrap_probability_predictions

check_bootstrap_probability_predictions(
    metrics: dict[str, dict],
) -> None

Validate that bootstrap predictions vary across iterations.

Warns if all predictions are identical, which indicates a bug in model retraining or bootstrap resampling.

PARAMETER DESCRIPTION
metrics

Accumulated metrics with predictions per split.

TYPE: dict

Source code in src/classification/stats_metric_utils.py
def check_bootstrap_probability_predictions(metrics: dict[str, dict]) -> None:
    """
    Validate that bootstrap predictions vary across iterations.

    Warns if all predictions are identical, which indicates a bug
    in model retraining or bootstrap resampling.

    Parameters
    ----------
    metrics : dict
        Accumulated metrics with predictions per split.
    """

    def check_probs_array(probs_array, split):
        if isinstance(probs_array, list):
            # These are list if coming from train/val
            if len(probs_array) > 1:
                # (n_samples, n_bootstrap_iters)
                probs_array = np.array(probs_array)[np.newaxis, :]
            else:
                # if you run only for couple of iterations, you might have just
                # one sample per subject, and stdev will obviously be zero
                return
        # stdev = np.nanstd(probs_array, axis=1)
        probs_are_the_same = np.all(probs_array.flatten() == probs_array.flatten()[0])
        if probs_are_the_same:
            logger.warning(
                f"Class probabilities across all the {probs_array.shape[1]} "
                f"bootstrap iteration seem to be the same"
            )
            logger.warning(
                "Either your model is very deterministic on different bootstrap iters or you have a bug"
            )
            logger.error(
                "Either the model does not get updated on each iteration, "
                "you do not use the bootstrap samples?"
            )
            if np.all(probs_array == 0.5):
                logger.warning(
                    "All probabilities across the bootstrap iteration seem to be 0.5 with the model"
                    "failing to learning anything?"
                )
            logger.error(
                "Not raising an exception here as I guess with garbage outlier detection + imputation, "
                'you might get "unlearnable input data"?'
            )

    for split, split_dict in metrics.items():
        if "preds_dict" in split_dict:  # train/val
            probs_dict = split_dict["preds_dict"]["arrays"]["y_pred_proba"]
            for code, probs_array in probs_dict.items():
                check_probs_array(probs_array, split)
        elif "preds" in split_dict:
            probs_array = split_dict["preds"]["arrays"]["predictions"]["y_pred_proba"]
            check_probs_array(probs_array, split)
        else:
            logger.error("How come an error here?")
            raise ValueError

bootstrap_compute_stats

bootstrap_compute_stats(
    metrics: dict[str, dict],
    method_cfg: DictConfig,
    call_from: str,
    verbose: bool = True,
) -> dict[str, dict[str, dict]]

Compute final statistics from all bootstrap iterations.

Aggregates scalar and array metrics into mean, std, and CI values.

PARAMETER DESCRIPTION
metrics

Accumulated metrics from all bootstrap iterations.

TYPE: dict

method_cfg

Bootstrap configuration.

TYPE: DictConfig

call_from

Caller identifier for conditional checks.

TYPE: str

verbose

Enable logging.

TYPE: bool DEFAULT: True

RETURNS DESCRIPTION
dict

Statistics per split with scalars and arrays.

Source code in src/classification/stats_metric_utils.py
def bootstrap_compute_stats(
    metrics: dict[str, dict],
    method_cfg: DictConfig,
    call_from: str,
    verbose: bool = True,
) -> dict[str, dict[str, dict]]:
    """
    Compute final statistics from all bootstrap iterations.

    Aggregates scalar and array metrics into mean, std, and CI values.

    Parameters
    ----------
    metrics : dict
        Accumulated metrics from all bootstrap iterations.
    method_cfg : DictConfig
        Bootstrap configuration.
    call_from : str
        Caller identifier for conditional checks.
    verbose : bool, default True
        Enable logging.

    Returns
    -------
    dict
        Statistics per split with scalars and arrays.
    """
    if verbose:
        logger.info("Compute Bootstrap statistics (AUROC, ROC Curves, etc.)")
    warnings.simplefilter("ignore")
    if call_from != "ts_ensemble":
        check_bootstrap_probability_predictions(metrics)
    else:
        logger.info("Skip bootstrap check")
    metrics_stats = {}
    for split in metrics.keys():
        metrics_stats[split] = {}
        metrics_stats[split]["metrics"] = {}
        if "metrics" not in metrics[split]:
            logger.info(
                f"Skip the split {split} as no metrics found (e.g. with ensembling val is skipped)"
            )
            metrics_stats[split]["metrics"]["scalars"] = None
            metrics_stats[split]["metrics"]["arrays"] = None
        else:
            metrics_stats[split]["metrics"]["scalars"] = bootstrap_scalar_stats(
                metrics_per_split=metrics[split]["metrics"]["scalars"],
                method_cfg=method_cfg,
                split=split,
            )
            metrics_stats[split]["metrics"]["arrays"] = bootstrap_array_stats(
                metrics_per_split=metrics[split]["metrics"]["arrays"],
                method_cfg=method_cfg,
            )

    warnings.resetwarnings()

    return metrics_stats

bootstrap_subject_stats_numpy_array

bootstrap_subject_stats_numpy_array(
    preds_per_key: ndarray, labels: ndarray, key: str
) -> dict[str, ndarray]

Compute per-subject statistics from 2D prediction array.

Used for test split where subjects are consistent across iterations.

PARAMETER DESCRIPTION
preds_per_key

Predictions of shape (n_subjects, n_iterations).

TYPE: ndarray

labels

True labels (unused, for signature consistency).

TYPE: ndarray

key

Prediction key (unused, for signature consistency).

TYPE: str

RETURNS DESCRIPTION
dict

Statistics with 'mean' and 'std' arrays (n_subjects,).

Source code in src/classification/stats_metric_utils.py
def bootstrap_subject_stats_numpy_array(
    preds_per_key: np.ndarray, labels: np.ndarray, key: str
) -> dict[str, np.ndarray]:
    """
    Compute per-subject statistics from 2D prediction array.

    Used for test split where subjects are consistent across iterations.

    Parameters
    ----------
    preds_per_key : np.ndarray
        Predictions of shape (n_subjects, n_iterations).
    labels : np.ndarray
        True labels (unused, for signature consistency).
    key : str
        Prediction key (unused, for signature consistency).

    Returns
    -------
    dict
        Statistics with 'mean' and 'std' arrays (n_subjects,).
    """
    dict_out = {}
    warnings.simplefilter("ignore")
    dict_out["mean"] = np.mean(preds_per_key, axis=1)
    dict_out["std"] = np.std(preds_per_key, axis=1)
    warnings.resetwarnings()

    return dict_out

aggregate_dict_subjects

aggregate_dict_subjects(
    dict_out: dict[str, ndarray],
    stats_per_code: dict[str, ndarray],
) -> dict[str, ndarray]

Aggregate subject statistics by horizontal stacking.

PARAMETER DESCRIPTION
dict_out

Accumulated statistics.

TYPE: dict

stats_per_code

Statistics for a single subject code.

TYPE: dict

RETURNS DESCRIPTION
dict

Updated statistics with new subject appended.

Source code in src/classification/stats_metric_utils.py
def aggregate_dict_subjects(
    dict_out: dict[str, np.ndarray], stats_per_code: dict[str, np.ndarray]
) -> dict[str, np.ndarray]:
    """
    Aggregate subject statistics by horizontal stacking.

    Parameters
    ----------
    dict_out : dict
        Accumulated statistics.
    stats_per_code : dict
        Statistics for a single subject code.

    Returns
    -------
    dict
        Updated statistics with new subject appended.
    """
    for key, scalar_in_array in stats_per_code.items():
        if key not in dict_out.keys():
            dict_out[key] = scalar_in_array
        else:
            dict_out[key] = np.hstack((dict_out[key], scalar_in_array))

    return dict_out

bootstrap_subject_stats_dict

bootstrap_subject_stats_dict(
    preds_per_key: dict[str, list[float]],
    labels: ndarray,
    _codes_train: ndarray,
    key: str,
    split: str = "train",
    verbose: bool = True,
    check_preds: bool = False,
) -> tuple[dict[str, ndarray], dict[str, Any]]

Compute per-subject statistics from dictionary of predictions.

Used for train/val splits where predictions are stored by subject code. Also computes uncertainty quantification for probability predictions.

PARAMETER DESCRIPTION
preds_per_key

Predictions keyed by subject code.

TYPE: dict

labels

True labels for subjects.

TYPE: ndarray

_codes_train

Subject codes for ordering (currently unused).

TYPE: ndarray

key

Prediction key (e.g., 'y_pred_proba').

TYPE: str

split

Split name for logging.

TYPE: str DEFAULT: 'train'

verbose

Enable logging.

TYPE: bool DEFAULT: True

check_preds

Validate predictions vary per subject.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
tuple

(stats_dict, uq_dict) with per-subject statistics and uncertainty.

Source code in src/classification/stats_metric_utils.py
def bootstrap_subject_stats_dict(
    preds_per_key: dict[str, list[float]],
    labels: np.ndarray,
    _codes_train: np.ndarray,
    key: str,
    split: str = "train",
    verbose: bool = True,
    check_preds: bool = False,
) -> tuple[dict[str, np.ndarray], dict[str, Any]]:
    """
    Compute per-subject statistics from dictionary of predictions.

    Used for train/val splits where predictions are stored by subject code.
    Also computes uncertainty quantification for probability predictions.

    Parameters
    ----------
    preds_per_key : dict
        Predictions keyed by subject code.
    labels : np.ndarray
        True labels for subjects.
    _codes_train : np.ndarray
        Subject codes for ordering (currently unused).
    key : str
        Prediction key (e.g., 'y_pred_proba').
    split : str
        Split name for logging.
    verbose : bool, default True
        Enable logging.
    check_preds : bool, default False
        Validate predictions vary per subject.

    Returns
    -------
    tuple
        (stats_dict, uq_dict) with per-subject statistics and uncertainty.
    """
    warnings.simplefilter("ignore")
    dict_out = {}
    for code in preds_per_key.keys():
        array_per_code = np.array(preds_per_key[code])[np.newaxis, :]
        if check_preds:
            check_indiv_code_for_different_preds(code, list(array_per_code), key)
        stats_per_code = bootstrap_subject_stats_numpy_array(
            array_per_code, labels=labels, key=key
        )
        dict_out = aggregate_dict_subjects(dict_out, stats_per_code)

    # Uncertainty here for the train/val splits
    if key == "y_pred_proba":
        if verbose:
            logger.info(
                "Compute uncertainty quantification, split = {}, key = {}".format(
                    split, key
                )
            )
        assert len(labels) == len(preds_per_key), (
            "label and prediction lengths do not match"
        )
        uq_dict = uncertainty_wrapper_from_subject_codes(
            p_mean=dict_out["mean"], p_std=dict_out["std"], y_true=labels, split=split
        )
        warnings.resetwarnings()
    else:
        uq_dict = {}

    return dict_out, uq_dict

sort_dict_keys_based_on_list

sort_dict_keys_based_on_list(
    dict_to_sort: dict[str, Any],
    list_to_sort_by: list[str],
    sort_list: bool = True,
) -> dict[str, Any]

Sort dictionary keys to match a reference list order.

PARAMETER DESCRIPTION
dict_to_sort

Dictionary to reorder.

TYPE: dict

list_to_sort_by

Reference list defining key order.

TYPE: list

sort_list

If True, reorder to match list. If False, just sort alphabetically.

TYPE: bool DEFAULT: True

RETURNS DESCRIPTION
dict

Reordered dictionary.

RAISES DESCRIPTION
Exception

If keys don't match reference list.

Source code in src/classification/stats_metric_utils.py
def sort_dict_keys_based_on_list(
    dict_to_sort: dict[str, Any], list_to_sort_by: list[str], sort_list: bool = True
) -> dict[str, Any]:
    """
    Sort dictionary keys to match a reference list order.

    Parameters
    ----------
    dict_to_sort : dict
        Dictionary to reorder.
    list_to_sort_by : list
        Reference list defining key order.
    sort_list : bool, default True
        If True, reorder to match list. If False, just sort alphabetically.

    Returns
    -------
    dict
        Reordered dictionary.

    Raises
    ------
    Exception
        If keys don't match reference list.
    """
    # sort the keys based on original train codes as you now get arrays for the stats
    dict_to_sort = dict(sorted(dict_to_sort.items()))

    if sort_list:
        # Only need to sort when you are bootstrapping, does not matter for CatBoost Ensemble
        try:
            return {k: dict_to_sort[k] for k in list_to_sort_by}
        except Exception as e:
            logger.error(f"Could not sort the dict: {e}")
            raise e
    else:
        return dict_to_sort

bootstrap_check_that_samples_different

bootstrap_check_that_samples_different(
    preds_per_key: dict, key: str, check_preds: bool = False
)

Verify predictions vary across bootstrap iterations per subject.

PARAMETER DESCRIPTION
preds_per_key

Predictions keyed by subject code.

TYPE: dict

key

Prediction type key.

TYPE: str

check_preds

If True, perform detailed validation.

TYPE: bool DEFAULT: False

Source code in src/classification/stats_metric_utils.py
def bootstrap_check_that_samples_different(
    preds_per_key: dict, key: str, check_preds: bool = False
):
    """
    Verify predictions vary across bootstrap iterations per subject.

    Parameters
    ----------
    preds_per_key : dict
        Predictions keyed by subject code.
    key : str
        Prediction type key.
    check_preds : bool, default False
        If True, perform detailed validation.
    """
    for subject_code in list(preds_per_key.keys()):
        preds_per_code = preds_per_key[subject_code]
        if check_preds:
            check_indiv_code_for_different_preds(subject_code, preds_per_code, key)

check_indiv_code_for_different_preds

check_indiv_code_for_different_preds(
    subject_code: str, preds_per_code: list, key: str
)

Check if predictions for a subject vary across iterations.

Note: May raise false alarms for garbage input data where model consistently outputs same predictions.

PARAMETER DESCRIPTION
subject_code

Subject identifier.

TYPE: str

preds_per_code

List of predictions for this subject across iterations.

TYPE: list

key

Prediction type (e.g., 'y_pred_proba').

TYPE: str

RAISES DESCRIPTION
ValueError

If all predictions are identical for probability predictions.

Source code in src/classification/stats_metric_utils.py
def check_indiv_code_for_different_preds(
    subject_code: str, preds_per_code: list, key: str
):
    """
    Check if predictions for a subject vary across iterations.

    Note: May raise false alarms for garbage input data where model
    consistently outputs same predictions.

    Parameters
    ----------
    subject_code : str
        Subject identifier.
    preds_per_code : list
        List of predictions for this subject across iterations.
    key : str
        Prediction type (e.g., 'y_pred_proba').

    Raises
    ------
    ValueError
        If all predictions are identical for probability predictions.
    """
    if key == "y_pred_proba":
        # could happen that class labels are same for tiny bootstraps?
        # or badly functioning model maybe just outputs all the classes the same?
        if len(np.unique(preds_per_code)) == 1:
            logger.warning(f"Subject {subject_code} has the same predictions")
            logger.warning(preds_per_code)
            raise ValueError

bootstrap_compute_subject_stats

bootstrap_compute_subject_stats(
    metrics_iter,
    dict_arrays,
    method_cfg,
    sort_list: bool = True,
    call_from: str = None,
    verbose: bool = True,
)

Compute per-subject statistics from bootstrap iterations.

Aggregates predictions across bootstrap iterations to compute mean predictions and uncertainty per subject.

PARAMETER DESCRIPTION
metrics_iter

Accumulated metrics from all iterations.

TYPE: dict

dict_arrays

Original data arrays with labels and codes.

TYPE: dict

method_cfg

Bootstrap configuration.

TYPE: DictConfig

sort_list

Sort results to match original code order.

TYPE: bool DEFAULT: True

call_from

Caller identifier for special handling.

TYPE: str DEFAULT: None

verbose

Enable logging.

TYPE: bool DEFAULT: True

RETURNS DESCRIPTION
dict

Per-subject statistics per split.

Source code in src/classification/stats_metric_utils.py
def bootstrap_compute_subject_stats(
    metrics_iter,
    dict_arrays,
    method_cfg,
    sort_list: bool = True,
    call_from: str = None,
    verbose: bool = True,
):
    """
    Compute per-subject statistics from bootstrap iterations.

    Aggregates predictions across bootstrap iterations to compute
    mean predictions and uncertainty per subject.

    Parameters
    ----------
    metrics_iter : dict
        Accumulated metrics from all iterations.
    dict_arrays : dict
        Original data arrays with labels and codes.
    method_cfg : DictConfig
        Bootstrap configuration.
    sort_list : bool, default True
        Sort results to match original code order.
    call_from : str, optional
        Caller identifier for special handling.
    verbose : bool, default True
        Enable logging.

    Returns
    -------
    dict
        Per-subject statistics per split.
    """
    if verbose:
        logger.info(
            "Compute subject-wise Bootstrap statistics (class probabiities, uncertainty quantification, etc."
        )
    subject_stats = {}
    for split in metrics_iter.keys():
        labels, _ = get_labels_and_codes(split, dict_arrays, call_from)
        subject_stats[split] = {}
        subject_stats[split]["preds"] = {}

        dict_per_split = metrics_iter[split]
        if "preds_dict" in dict_per_split.keys():
            if call_from == "CATBOOST":
                codes = dict_arrays[f"subject_codes_{split}"]
            elif call_from == "classification_ensemble":
                # examine later why the codes do not seem to match, if you want to use the classification ensembles
                # which do not seem to provide much, just use "the internal ensembling" of Tree-based methods
                codes = None  # dict_arrays[f"subject_codes_{split}"]
            else:
                codes = dict_arrays["subject_codes_train"]
            # train/val split with the subject codes
            preds_per_split = dict_per_split["preds_dict"]["arrays"]
            for key in preds_per_split.keys():
                if codes is None:
                    # don't use the metadata codes (there is some glitch?) for the ensemble
                    # use directly the prediction codes for sorting
                    codes = sorted(list(set(list(preds_per_split[key].keys()))))
                preds_per_key = sort_dict_keys_based_on_list(
                    preds_per_split[key], list(codes), sort_list=sort_list
                )
                assert len(labels) == len(preds_per_key), (
                    f"label ({len(labels)}) and pred ({len(preds_per_key)}) "
                    f"lengths do not match"
                )
                bootstrap_check_that_samples_different(preds_per_key, key)
                if key == "y_pred_proba":
                    subject_stats[split]["preds"][key], subject_stats[split]["uq"] = (
                        bootstrap_subject_stats_dict(
                            preds_per_key=preds_per_key,
                            labels=labels,
                            codes_train=codes,
                            key=key,
                            split=split,
                            verbose=verbose,
                        )
                    )
                else:
                    subject_stats[split]["preds"][key], _ = (
                        bootstrap_subject_stats_dict(
                            preds_per_key=preds_per_key,
                            labels=labels,
                            codes_train=codes,
                            key=key,
                            split=split,
                            verbose=verbose,
                        )
                    )
            subject_stats[split]["subject_code"] = codes
            subject_stats[split]["labels"] = labels
            assert len(subject_stats[split]["subject_code"]) == len(preds_per_key), (
                "Codes and predictions must have the same length"
            )
            assert len(subject_stats[split]["subject_code"]) == len(labels), (
                "Codes and labels must have the same length"
            )

        elif "preds" in dict_per_split.keys():
            # test split
            codes = dict_arrays["subject_codes_test"]
            preds_per_split = dict_per_split["preds"]["arrays"]["predictions"]
            for key in preds_per_split.keys():
                preds_per_key = preds_per_split[key]
                subject_stats[split]["preds"][key] = (
                    bootstrap_subject_stats_numpy_array(
                        preds_per_key=preds_per_key, labels=labels, key=key
                    )
                )
                # Uncertainty Quantification
                if key == "y_pred_proba":
                    subject_stats[split]["uq"] = uncertainty_wrapper(
                        preds=preds_per_key, y_true=labels, key=key, split=split
                    )

            subject_stats[split]["subject_code"] = codes
            subject_stats[split]["labels"] = labels
            assert len(subject_stats[split]["subject_code"]) == len(preds_per_key), (
                "Codes and predictions must have the same length"
            )
            assert len(subject_stats[split]["subject_code"]) == len(labels), (
                "Codes and labels must have the same length"
            )

        else:
            logger.error(f"Where are the predictions now? {dict_per_split.keys()}")
            raise ValueError

    return subject_stats

global_subject_stats

global_subject_stats(
    values: ndarray,
    labels: ndarray,
    key: str,
    variable: str,
    method_cfg: DictConfig,
)

Compute global statistics stratified by class label.

PARAMETER DESCRIPTION
values

Per-subject values to aggregate.

TYPE: ndarray

labels

Class labels for stratification.

TYPE: ndarray

key

Prediction key (unused, for logging).

TYPE: str

variable

Variable name (unused, for logging).

TYPE: str

method_cfg

Bootstrap configuration.

TYPE: DictConfig

RETURNS DESCRIPTION
dict

Statistics per class label with mean, std, CI.

Source code in src/classification/stats_metric_utils.py
def global_subject_stats(
    values: np.ndarray,
    labels: np.ndarray,
    key: str,
    variable: str,
    method_cfg: DictConfig,
):
    """
    Compute global statistics stratified by class label.

    Parameters
    ----------
    values : np.ndarray
        Per-subject values to aggregate.
    labels : np.ndarray
        Class labels for stratification.
    key : str
        Prediction key (unused, for logging).
    variable : str
        Variable name (unused, for logging).
    method_cfg : DictConfig
        Bootstrap configuration.

    Returns
    -------
    dict
        Statistics per class label with mean, std, CI.
    """
    dict_out = {}
    # not much point in averaging all the subject probabilities together without accounting for the label
    unique_labels = np.unique(labels)
    for label in unique_labels:
        values_per_label = values[labels == label]
        dict_out[label] = bootstrap_scalar_stats_per_metric(
            values_per_label, method_cfg=method_cfg
        )

    return dict_out

get_labels_and_codes

get_labels_and_codes(split, dict_arrays, call_from)

Get labels and codes for a split, handling bootstrap vs ensemble cases.

PARAMETER DESCRIPTION
split

Split name ('train', 'val', 'test').

TYPE: str

dict_arrays

Data arrays with labels and codes.

TYPE: dict

call_from

Caller identifier for special handling.

TYPE: str or None

RETURNS DESCRIPTION
tuple

(labels, codes) arrays for the split.

RAISES DESCRIPTION
ValueError

If unknown split specified.

Source code in src/classification/stats_metric_utils.py
def get_labels_and_codes(split, dict_arrays, call_from):
    """
    Get labels and codes for a split, handling bootstrap vs ensemble cases.

    Parameters
    ----------
    split : str
        Split name ('train', 'val', 'test').
    dict_arrays : dict
        Data arrays with labels and codes.
    call_from : str or None
        Caller identifier for special handling.

    Returns
    -------
    tuple
        (labels, codes) arrays for the split.

    Raises
    ------
    ValueError
        If unknown split specified.
    """
    # These splits are now "bootstrapping splits" so the labels, codes are from the original Train
    if split == "train" or split == "val":
        if call_from is None or call_from == "classification_ensemble":
            labels, codes = dict_arrays["y_train"], dict_arrays["subject_codes_train"]
        elif call_from == "CATBOOST":
            # This is actually a call for the ensemble evaluation, the Catboost bootstrap call is handled
            # normally as the None condition above
            labels, codes = (
                dict_arrays[f"y_{split}"],
                dict_arrays[f"subject_codes_{split}"],
            )
    elif split == "test":
        labels, codes = dict_arrays["y_test"], dict_arrays["subject_codes_test"]
    else:
        logger.error(f"Unknown split: {split}")
        raise ValueError
    assert len(labels) == len(codes), "Labels and codes must have the same length"

    return labels, codes

bootstrap_compute_global_subject_stats

bootstrap_compute_global_subject_stats(
    subjectwise_stats, method_cfg, verbose: bool = True
)

Compute global subject-level statistics across all subjects.

Aggregates per-subject statistics (e.g., mean probability, uncertainty) into population-level summaries stratified by class.

PARAMETER DESCRIPTION
subjectwise_stats

Per-subject statistics from bootstrap_compute_subject_stats.

TYPE: dict

method_cfg

Bootstrap configuration.

TYPE: DictConfig

verbose

Enable logging.

TYPE: bool DEFAULT: True

RETURNS DESCRIPTION
dict

Global statistics per split, key, variable, and class.

Source code in src/classification/stats_metric_utils.py
def bootstrap_compute_global_subject_stats(
    subjectwise_stats, method_cfg, verbose: bool = True
):
    """
    Compute global subject-level statistics across all subjects.

    Aggregates per-subject statistics (e.g., mean probability, uncertainty)
    into population-level summaries stratified by class.

    Parameters
    ----------
    subjectwise_stats : dict
        Per-subject statistics from bootstrap_compute_subject_stats.
    method_cfg : DictConfig
        Bootstrap configuration.
    verbose : bool, default True
        Enable logging.

    Returns
    -------
    dict
        Global statistics per split, key, variable, and class.
    """
    # Compute the "mean response" of the subjects, e.g. scalar mean UQ metric to describe the whole model uncertainty

    subject_global_stats = {}
    for split in subjectwise_stats.keys():
        subject_global_stats[split] = {}
        labels = subjectwise_stats[split]["labels"]
        for key in subjectwise_stats[split]["preds"].keys():
            preds_per_key = subjectwise_stats[split]["preds"][key]
            subject_global_stats[split][key] = {}
            for variable in preds_per_key.keys():
                # Note! Now we just compute stats for whatever variables you have here, not like UQ epistemic
                # uncertainty computed from y_pred (class) is necesssarily what you need, but pick the correct combos
                # when visualizing and logging to MLflow. Easier than cherrypicking here what makes sense and not
                values: np.ndarray = preds_per_key[variable]
                subject_global_stats[split][key][variable] = global_subject_stats(
                    values, labels, key, variable, method_cfg
                )

    return subject_global_stats

compute_uq_unks_from_dict_of_subjects

compute_uq_unks_from_dict_of_subjects(probs_dict: dict)

Compute uncertainty metrics from subject-keyed probability dictionary.

Used for train/val splits where different subjects appear in different bootstrap iterations. Computes ensemble-style uncertainty metrics.

PARAMETER DESCRIPTION
probs_dict

Probabilities keyed by subject code, each a list of predictions.

TYPE: dict

RETURNS DESCRIPTION
dict

Uncertainty metrics (confidence, entropy, mutual_information) per subject.

Notes

Uses ensemble_uncertainties from CatBoost tutorial code for metrics like total uncertainty, data uncertainty, and knowledge uncertainty.

Source code in src/classification/stats_metric_utils.py
def compute_uq_unks_from_dict_of_subjects(probs_dict: dict):
    """
    Compute uncertainty metrics from subject-keyed probability dictionary.

    Used for train/val splits where different subjects appear in different
    bootstrap iterations. Computes ensemble-style uncertainty metrics.

    Parameters
    ----------
    probs_dict : dict
        Probabilities keyed by subject code, each a list of predictions.

    Returns
    -------
    dict
        Uncertainty metrics (confidence, entropy, mutual_information) per subject.

    Notes
    -----
    Uses ensemble_uncertainties from CatBoost tutorial code for metrics like
    total uncertainty, data uncertainty, and knowledge uncertainty.
    """
    uq = None
    iters = []
    for code, list_of_probs in probs_dict.items():
        probs_code = np.array(list_of_probs)[
            :, np.newaxis, np.newaxis
        ]  # (esize/no_iter, 1, 1)
        probs_code = np.concatenate(
            [probs_code, 1 - probs_code], axis=2
        )  # (n_samples, 1, n_classes=2)
        iters.append(probs_code.shape[0])
        uq_code: dict = ensemble_uncertainties(probs=probs_code)
        # TODO! compute AURC here as well? instead of before?
        if uq is None:
            uq = deepcopy(uq_code)
        else:
            for key in uq.keys():
                uq[key] = np.vstack((uq[key], uq_code[key]))

    assert uq["confidence"].shape[0] == len(probs_dict), (
        "You did not compute probs ({}) for all the subjects ({})".format(
            uq["confidence"].shape[0], len(probs_dict)
        )
    )

    return uq

compute_uq_for_subjectwise_stats

compute_uq_for_subjectwise_stats(
    metrics_iter, subjectwise_stats, verbose: bool = True
)

Compute and merge uncertainty quantification into subject-wise stats.

Computes ensemble-based uncertainty metrics (confidence, entropy, mutual information) and adds them to subjectwise_stats.

PARAMETER DESCRIPTION
metrics_iter

Accumulated metrics with predictions per split.

TYPE: dict

subjectwise_stats

Per-subject statistics to augment.

TYPE: dict

verbose

Enable logging.

TYPE: bool DEFAULT: True

RETURNS DESCRIPTION
dict

Updated subjectwise_stats with uncertainty metrics added.

Source code in src/classification/stats_metric_utils.py
def compute_uq_for_subjectwise_stats(
    metrics_iter, subjectwise_stats, verbose: bool = True
):
    """
    Compute and merge uncertainty quantification into subject-wise stats.

    Computes ensemble-based uncertainty metrics (confidence, entropy,
    mutual information) and adds them to subjectwise_stats.

    Parameters
    ----------
    metrics_iter : dict
        Accumulated metrics with predictions per split.
    subjectwise_stats : dict
        Per-subject statistics to augment.
    verbose : bool, default True
        Enable logging.

    Returns
    -------
    dict
        Updated subjectwise_stats with uncertainty metrics added.
    """
    from src.classification.catboost.catboost_main import (
        combine_unks_with_subjectwise_stats,
    )

    uq = {}
    for split in metrics_iter.keys():
        if "preds_dict" in metrics_iter[split].keys():
            probs_dict = metrics_iter[split]["preds_dict"]["arrays"]["y_pred_proba"]
            assert isinstance(probs_dict, dict), "Probs must be a dictionary"
            uq[split] = compute_uq_unks_from_dict_of_subjects(probs_dict)
        elif "preds" in metrics_iter[split].keys():
            probs = metrics_iter[split]["preds"]["arrays"]["predictions"][
                "y_pred_proba"
            ]
            assert isinstance(probs, np.ndarray), "Probs must be a numpy array"
            probs = probs.T[:, :, np.newaxis]  # (n_iter, n_samples, 1)
            probs = np.concatenate(
                [probs, 1 - probs], axis=2
            )  # (n_iter, n_samples, n_classes=2)
            uq[split] = ensemble_uncertainties(probs)
        else:
            logger.error(f"Where are the predictions now? {metrics_iter[split].keys()}")
            raise ValueError

        subjectwise_stats = combine_unks_with_subjectwise_stats(
            subjectwise_stats, unks=uq[split], split=split
        )

    return subjectwise_stats

Classifier Utilities

classifier_utils

cls_train_on_this_combo

cls_train_on_this_combo(
    run_name: str, cls_model_name: str
) -> bool

Determine if classifier should be trained on this preprocessing combination.

Applies filtering logic to skip certain combinations (e.g., experimental models on non-ground-truth data, TabPFN on high-dimensional embeddings).

PARAMETER DESCRIPTION
run_name

MLflow run name encoding preprocessing pipeline.

TYPE: str

cls_model_name

Classifier model name.

TYPE: str

RETURNS DESCRIPTION
bool

True if this combination should be trained, False to skip.

Source code in src/classification/classifier_utils.py
def cls_train_on_this_combo(run_name: str, cls_model_name: str) -> bool:
    """
    Determine if classifier should be trained on this preprocessing combination.

    Applies filtering logic to skip certain combinations (e.g., experimental
    models on non-ground-truth data, TabPFN on high-dimensional embeddings).

    Parameters
    ----------
    run_name : str
        MLflow run name encoding preprocessing pipeline.
    cls_model_name : str
        Classifier model name.

    Returns
    -------
    bool
        True if this combination should be trained, False to skip.
    """
    is_on_gt = cls_is_on_ground_truth(run_name=run_name)
    is_experimental = is_experimental_cls_model(cls_model_name=cls_model_name)
    is_on_embeddings = is_trained_on_embeddings(run_name=run_name)

    if is_on_gt:
        # train all the experimental and embeddings models always with the ground truth
        if is_on_embeddings and cls_model_name == "TabPFN":
            # maximum number of features is 100, will crash with the 1,024 dim embedding
            return False
        else:
            return True
    elif is_experimental:
        return False
    elif is_on_embeddings:
        # with 1024 features, this will take quite long
        return False
    else:
        return True

cls_is_on_ground_truth

cls_is_on_ground_truth(run_name: str) -> bool

Check if run uses ground truth preprocessing.

PARAMETER DESCRIPTION
run_name

MLflow run name.

TYPE: str

RETURNS DESCRIPTION
bool

True if using ground truth outlier detection and imputation.

Source code in src/classification/classifier_utils.py
def cls_is_on_ground_truth(run_name: str) -> bool:
    """
    Check if run uses ground truth preprocessing.

    Parameters
    ----------
    run_name : str
        MLflow run name.

    Returns
    -------
    bool
        True if using ground truth outlier detection and imputation.
    """
    if "pupil-gt__pupil-gt" in run_name:
        return True
    else:
        return False

get_cls_baseline_models

get_cls_baseline_models() -> list[str]

Get list of baseline classifier model names.

RETURNS DESCRIPTION
list of str

Names of standard baseline classifiers used in the study.

Source code in src/classification/classifier_utils.py
def get_cls_baseline_models() -> list[str]:
    """
    Get list of baseline classifier model names.

    Returns
    -------
    list of str
        Names of standard baseline classifiers used in the study.
    """
    return ["LogisticRegression", "XGBOOST", "CATBOOST", "TabPFN", "TabM"]

is_experimental_cls_model

is_experimental_cls_model(cls_model_name: str) -> bool

Check if classifier is experimental (not a baseline model).

PARAMETER DESCRIPTION
cls_model_name

Classifier model name.

TYPE: str

RETURNS DESCRIPTION
bool

True if model is not in baseline models list.

Source code in src/classification/classifier_utils.py
def is_experimental_cls_model(cls_model_name: str) -> bool:
    """
    Check if classifier is experimental (not a baseline model).

    Parameters
    ----------
    cls_model_name : str
        Classifier model name.

    Returns
    -------
    bool
        True if model is not in baseline models list.
    """
    for baseline_model in get_cls_baseline_models():
        if baseline_model in cls_model_name:
            return False
    return True

is_trained_on_embeddings

is_trained_on_embeddings(run_name: str) -> bool

Check if run uses embedding features instead of handcrafted.

PARAMETER DESCRIPTION
run_name

MLflow run name.

TYPE: str

RETURNS DESCRIPTION
bool

True if using foundation model embeddings as features.

Source code in src/classification/classifier_utils.py
def is_trained_on_embeddings(run_name: str) -> bool:
    """
    Check if run uses embedding features instead of handcrafted.

    Parameters
    ----------
    run_name : str
        MLflow run name.

    Returns
    -------
    bool
        True if using foundation model embeddings as features.
    """
    if "embedding" in run_name:
        return True
    else:
        return False

get_dict_array_splits

get_dict_array_splits(
    dict_arrays: dict[str, Any],
) -> list[str]

Extract split names from dict_arrays keys.

PARAMETER DESCRIPTION
dict_arrays

Dictionary with keys like 'y_train', 'y_test', 'y_val'.

TYPE: dict

RETURNS DESCRIPTION
list

Unique split names (e.g., ['train', 'test', 'val']).

Source code in src/classification/classifier_utils.py
def get_dict_array_splits(dict_arrays: dict[str, Any]) -> list[str]:
    """
    Extract split names from dict_arrays keys.

    Parameters
    ----------
    dict_arrays : dict
        Dictionary with keys like 'y_train', 'y_test', 'y_val'.

    Returns
    -------
    list
        Unique split names (e.g., ['train', 'test', 'val']).
    """
    keys = list(dict_arrays.keys())
    splits = []
    for key in keys:
        if key.startswith("y_"):
            key_fields = key.split("_")
            splits.append(key_fields[1])
    return list(set(splits))

get_classifier_run_name

get_classifier_run_name(imputer_name: str) -> str

Generate classifier run name from imputer name.

PARAMETER DESCRIPTION
imputer_name

Name of the imputation method.

TYPE: str

RETURNS DESCRIPTION
str

Run name for the classifier.

Source code in src/classification/classifier_utils.py
def get_classifier_run_name(imputer_name: str) -> str:
    """
    Generate classifier run name from imputer name.

    Parameters
    ----------
    imputer_name : str
        Name of the imputation method.

    Returns
    -------
    str
        Run name for the classifier.
    """
    return f"{imputer_name}"

get_cls_run_name

get_cls_run_name(
    imputer_mlflow_run: dict[str, Any] | None,
    cls_model_name: str,
    cls_model_cfg: DictConfig,
    source: str,
) -> str

Generate full classifier run name encoding the preprocessing pipeline.

PARAMETER DESCRIPTION
imputer_mlflow_run

MLflow run info for the imputation model.

TYPE: dict or None

cls_model_name

Classifier model name.

TYPE: str

cls_model_cfg

Classifier model configuration.

TYPE: DictConfig

source

Data source identifier (e.g., 'GT', 'Raw').

TYPE: str

RETURNS DESCRIPTION
str

Full run name like 'CatBoost__pupil-gt__pupil-gt'.

Source code in src/classification/classifier_utils.py
def get_cls_run_name(
    imputer_mlflow_run: dict[str, Any] | None,
    cls_model_name: str,
    cls_model_cfg: DictConfig,
    source: str,
) -> str:
    """
    Generate full classifier run name encoding the preprocessing pipeline.

    Parameters
    ----------
    imputer_mlflow_run : dict or None
        MLflow run info for the imputation model.
    cls_model_name : str
        Classifier model name.
    cls_model_cfg : DictConfig
        Classifier model configuration.
    source : str
        Data source identifier (e.g., 'GT', 'Raw').

    Returns
    -------
    str
        Full run name like 'CatBoost__pupil-gt__pupil-gt'.
    """
    if imputer_mlflow_run is not None:
        base_name = imputer_mlflow_run["tags.mlflow.runName"]
    else:
        # No MLflow run available for the non-imputed data sources (e.g. GT, Raw)
        base_name = source
    # TODO! Update the run name to include the classifier model configuration
    run_name = f"{cls_model_name}__{base_name}"
    return run_name

preprocess_features

preprocess_features(
    train_df: DataFrame,
    val_df: DataFrame,
    _cls_preprocess_cfg: DictConfig,
) -> tuple[DataFrame, DataFrame]

Trees do not really need standardization, but can benefit from something? See below

See e.g. Hubert Ruczyński and Anna Kozak (2024) Do Tree-based Models Need Data Preprocessing? https://openreview.net/forum?id=08Y5sFtRhN Furthermore, we introduce the preprocessibility measure, based on tunability from (Probst et al., 2018). It describes how much performance can we gain or lose for a dataset 𝐷 by using various preprocessing strategies.

Source code in src/classification/classifier_utils.py
def preprocess_features(
    train_df: pl.DataFrame, val_df: pl.DataFrame, _cls_preprocess_cfg: DictConfig
) -> tuple[pl.DataFrame, pl.DataFrame]:
    """
    Trees do not really need standardization, but can benefit from something? See below

    See e.g. Hubert Ruczyński and Anna Kozak (2024) Do Tree-based Models Need Data Preprocessing?
    https://openreview.net/forum?id=08Y5sFtRhN
        Furthermore, we introduce the preprocessibility measure, based on tunability from (Probst et al.,
        2018). It describes how much performance can we gain or lose for a dataset 𝐷
        by using various preprocessing strategies.
    """
    logger.info("Placeholder for Preprocessing features")
    return train_df, val_df

logger_remaining_samples

logger_remaining_samples(
    features: dict[str, Any],
    samples_in: dict[str, int],
    source: str,
) -> None

Log the number of remaining samples after filtering.

PARAMETER DESCRIPTION
features

Features dictionary with data per source.

TYPE: dict

samples_in

Original sample counts per split before filtering.

TYPE: dict

source

Data source name.

TYPE: str

Source code in src/classification/classifier_utils.py
def logger_remaining_samples(
    features: dict[str, Any], samples_in: dict[str, int], source: str
) -> None:
    """
    Log the number of remaining samples after filtering.

    Parameters
    ----------
    features : dict
        Features dictionary with data per source.
    samples_in : dict
        Original sample counts per split before filtering.
    source : str
        Data source name.
    """
    data: dict[str, pl.DataFrame] = features[source]["data"]
    for split in data.keys():
        df = data[split]
        logger.info(
            f"{split} | remaining samples = {df.shape[0]} / {samples_in[split]}"
        )

drop_unlabeled_subjects

drop_unlabeled_subjects(
    df: DataFrame | DataFrame,
    cfg: DictConfig,
    label_col_name: str = "metadata_class_label",
) -> DataFrame

Remove rows without classification labels from dataframe.

PARAMETER DESCRIPTION
df

Input dataframe with subject data.

TYPE: DataFrame or DataFrame

cfg

Hydra configuration.

TYPE: DictConfig

label_col_name

Column name containing class labels.

TYPE: str DEFAULT: "metadata_class_label"

RETURNS DESCRIPTION
DataFrame

Filtered dataframe with only labeled subjects.

Source code in src/classification/classifier_utils.py
def drop_unlabeled_subjects(
    df: pl.DataFrame | pd.DataFrame,
    cfg: DictConfig,
    label_col_name: str = "metadata_class_label",
) -> pl.DataFrame:
    """
    Remove rows without classification labels from dataframe.

    Parameters
    ----------
    df : pl.DataFrame or pd.DataFrame
        Input dataframe with subject data.
    cfg : DictConfig
        Hydra configuration.
    label_col_name : str, default "metadata_class_label"
        Column name containing class labels.

    Returns
    -------
    pl.DataFrame
        Filtered dataframe with only labeled subjects.
    """
    if isinstance(df, pd.DataFrame):
        df = pl.from_pandas(df)

    # Drop the rows from Polars dataframe with label_col_name being null or empty
    df = df.filter(pl.any_horizontal(pl.col(label_col_name).is_not_null()))
    # if you happen to have "None" string
    df = df.filter(pl.any_horizontal(pl.col(label_col_name) != "None"))
    return df

drop_useless_cols

drop_useless_cols(
    features: dict[str, Any], cfg: DictConfig
) -> dict[str, Any]

Remove metadata columns not needed for classification.

PARAMETER DESCRIPTION
features

Features dictionary with data per source and split.

TYPE: dict

cfg

Hydra configuration.

TYPE: DictConfig

RETURNS DESCRIPTION
dict

Features dictionary with useless columns removed.

Source code in src/classification/classifier_utils.py
def drop_useless_cols(features: dict[str, Any], cfg: DictConfig) -> dict[str, Any]:
    """
    Remove metadata columns not needed for classification.

    Parameters
    ----------
    features : dict
        Features dictionary with data per source and split.
    cfg : DictConfig
        Hydra configuration.

    Returns
    -------
    dict
        Features dictionary with useless columns removed.
    """
    useless_cols = ["metadata_outlier_mask", "metadata_subject_code", "metadata_split"]
    logger.info("Dropping the 'useless columns': {}".format(useless_cols))
    for source in features.keys():
        for split in features[source]["data"].keys():
            for col in useless_cols:
                if col in features[source]["data"][split].columns:
                    features[source]["data"][split].drop_in_place(col)

    return features

check_classification_labels

check_classification_labels(
    features: dict[str, Any],
    source: str,
    split: str,
    features_in: int,
) -> None

Validate that classification labels are present and binary.

PARAMETER DESCRIPTION
features

Features dictionary.

TYPE: dict

source

Data source name.

TYPE: str

split

Data split name ('train', 'test').

TYPE: str

features_in

Expected number of features.

TYPE: int

RAISES DESCRIPTION
AssertionError

If feature count changed or labels are not binary.

Source code in src/classification/classifier_utils.py
def check_classification_labels(
    features: dict[str, Any], source: str, split: str, features_in: int
) -> None:
    """
    Validate that classification labels are present and binary.

    Parameters
    ----------
    features : dict
        Features dictionary.
    source : str
        Data source name.
    split : str
        Data split name ('train', 'test').
    features_in : int
        Expected number of features.

    Raises
    ------
    AssertionError
        If feature count changed or labels are not binary.
    """
    assert features_in == features[source]["data"][split].shape[1], (
        "Number of features changed"
    )

    labels: pl.Series = features[source]["data"][split]["metadata_class_label"]
    unique_labels = set(labels)
    no_unique_labels = len(unique_labels)
    assert no_unique_labels == 2, (
        "We have != 2 unique labels (n={}), not good for a binary classification, something "
        "went wrong\n"
        "labels = {}\n"
        "n_samples = {}".format(no_unique_labels, unique_labels, len(labels))
    )

keep_only_labeled_subjects

keep_only_labeled_subjects(
    features: dict[str, Any],
    cfg: DictConfig,
    data_key: str = "data",
) -> dict[str, Any]

Filter features to keep only subjects with classification labels.

PARAMETER DESCRIPTION
features

Features dictionary with data per source and split.

TYPE: dict

cfg

Hydra configuration.

TYPE: DictConfig

data_key

Key in features dict containing the dataframes.

TYPE: str DEFAULT: "data"

RETURNS DESCRIPTION
dict

Filtered features dictionary.

Source code in src/classification/classifier_utils.py
def keep_only_labeled_subjects(
    features: dict[str, Any], cfg: DictConfig, data_key: str = "data"
) -> dict[str, Any]:
    """
    Filter features to keep only subjects with classification labels.

    Parameters
    ----------
    features : dict
        Features dictionary with data per source and split.
    cfg : DictConfig
        Hydra configuration.
    data_key : str, default "data"
        Key in features dict containing the dataframes.

    Returns
    -------
    dict
        Filtered features dictionary.
    """
    logger.info("Dropping the subjects without a label")
    samples_in = {}
    for source in features.keys():
        for split in features[source][data_key].keys():
            df = features[source][data_key][split]
            samples_in[split], features_in = df.shape
            features[source][data_key][split] = drop_unlabeled_subjects(df, cfg)
            check_classification_labels(features, source, split, features_in)

    logger_remaining_samples(features, samples_in, source)
    return features

get_numpy_boolean_index_for_class_labels

get_numpy_boolean_index_for_class_labels(
    label_array: ndarray,
) -> ndarray

Create boolean index for samples with valid classification labels.

PARAMETER DESCRIPTION
label_array

2D array of labels (n_subjects, n_timepoints).

TYPE: ndarray

RETURNS DESCRIPTION
ndarray

Boolean array where True indicates valid label.

RAISES DESCRIPTION
AssertionError

If input is not 2D numpy array or doesn't have exactly 2 classes.

Source code in src/classification/classifier_utils.py
def get_numpy_boolean_index_for_class_labels(label_array: np.ndarray) -> np.ndarray:
    """
    Create boolean index for samples with valid classification labels.

    Parameters
    ----------
    label_array : np.ndarray
        2D array of labels (n_subjects, n_timepoints).

    Returns
    -------
    np.ndarray
        Boolean array where True indicates valid label.

    Raises
    ------
    AssertionError
        If input is not 2D numpy array or doesn't have exactly 2 classes.
    """
    assert isinstance(label_array, np.ndarray), "label_array must be a Numpy array"
    assert len(label_array.shape) == 2, (
        "Must be a 2D array, (no_subjects, no_timepoints)"
    )
    label_array = label_array[:, 0]  # (no_subjects)
    is_None = []
    for item in label_array:
        is_None.append(item is None)
    is_None = np.array(is_None, dtype=bool)
    is_str_None = label_array == "None"
    is_labeled = ~is_None & ~is_str_None
    labels = label_array[is_labeled]
    assert len(np.unique(labels)) == 2, (
        "Label array must have 2 unique values (as we have a binary classification)"
    )

    return is_labeled

index_with_boolean_all_numpys_in_datadict

index_with_boolean_all_numpys_in_datadict(
    data_dict: dict[str, dict[str, ndarray]],
    labeled_boolean: ndarray,
) -> dict[str, dict[str, ndarray]]

Apply boolean indexing to all numpy arrays in nested dictionary.

PARAMETER DESCRIPTION
data_dict

Nested dictionary with numpy arrays as values.

TYPE: dict

labeled_boolean

Boolean index array for filtering.

TYPE: ndarray

RETURNS DESCRIPTION
dict

Copy of data_dict with all arrays filtered by boolean index.

Source code in src/classification/classifier_utils.py
def index_with_boolean_all_numpys_in_datadict(
    data_dict: dict[str, dict[str, np.ndarray]], labeled_boolean: np.ndarray
) -> dict[str, dict[str, np.ndarray]]:
    """
    Apply boolean indexing to all numpy arrays in nested dictionary.

    Parameters
    ----------
    data_dict : dict
        Nested dictionary with numpy arrays as values.
    labeled_boolean : np.ndarray
        Boolean index array for filtering.

    Returns
    -------
    dict
        Copy of data_dict with all arrays filtered by boolean index.
    """
    dict_out = deepcopy(data_dict)
    for category in data_dict.keys():
        for variable in data_dict[category].keys():
            assert isinstance(data_dict[category][variable], np.ndarray), (
                "data_dict[category][variable] must be a numpy array"
            )
            dict_out[category][variable] = data_dict[category][variable][
                labeled_boolean, :
            ]
    return dict_out

keep_only_labeled_subjects_from_source

keep_only_labeled_subjects_from_source(
    data_dicts: dict[str, dict[str, dict[str, ndarray]]],
    cfg: DictConfig,
) -> dict[str, dict[str, dict[str, ndarray]]]

Filter data dictionaries to keep only labeled subjects.

PARAMETER DESCRIPTION
data_dicts

Dictionary with splits as keys, each containing category/variable arrays.

TYPE: dict

cfg

Hydra configuration.

TYPE: DictConfig

RETURNS DESCRIPTION
dict

Filtered data dictionaries with only labeled subjects.

Source code in src/classification/classifier_utils.py
def keep_only_labeled_subjects_from_source(
    data_dicts: dict[str, dict[str, dict[str, np.ndarray]]], cfg: DictConfig
) -> dict[str, dict[str, dict[str, np.ndarray]]]:
    """
    Filter data dictionaries to keep only labeled subjects.

    Parameters
    ----------
    data_dicts : dict
        Dictionary with splits as keys, each containing category/variable arrays.
    cfg : DictConfig
        Hydra configuration.

    Returns
    -------
    dict
        Filtered data dictionaries with only labeled subjects.
    """
    for split in data_dicts.keys():
        data_dict = data_dicts[split]
        labeled_boolean = get_numpy_boolean_index_for_class_labels(
            label_array=data_dict["labels"]["class_label"]
        )
        data_dict_cls = index_with_boolean_all_numpys_in_datadict(
            data_dict, labeled_boolean
        )
        data_dicts[split] = data_dict_cls

    return data_dicts

pick_subset_of_features_for_classification

pick_subset_of_features_for_classification(
    features: dict, cfg: DictConfig
) -> dict

Select specific feature subset for classification from full features.

PARAMETER DESCRIPTION
features

Full features dictionary with all sources and splits.

TYPE: dict

cfg

Hydra configuration with DATA_SUBSET settings.

TYPE: DictConfig

RETURNS DESCRIPTION
dict

Features dictionary with only selected feature subset.

RAISES DESCRIPTION
ValueError

If unknown source type encountered.

Source code in src/classification/classifier_utils.py
def pick_subset_of_features_for_classification(features: dict, cfg: DictConfig) -> dict:
    """
    Select specific feature subset for classification from full features.

    Parameters
    ----------
    features : dict
        Full features dictionary with all sources and splits.
    cfg : DictConfig
        Hydra configuration with DATA_SUBSET settings.

    Returns
    -------
    dict
        Features dictionary with only selected feature subset.

    Raises
    ------
    ValueError
        If unknown source type encountered.
    """
    features_out = {}
    for source, features_per_source in features.items():
        features_out[source] = {}
        features_out[source]["mlflow_run"] = features_per_source["mlflow_run"]
        features_out[source]["metadata"] = {
            "dummy": "placeholder"
        }  # features_per_source["metadata"]
        features_out[source]["data"] = {}
        for split, features_per_split in features_per_source["data"].items():
            if "BASELINE" in source:
                if "GT" in source:
                    split_key = "gt"
                elif "Raw" in source:
                    split_key = "raw"
                else:
                    logger.error(f"Unknown source: {source}")
                    raise ValueError(f"Unknown source: {source}")
            else:
                split_key = cfg["CLASSIFICATION_SETTINGS"]["DATA_SUBSET"]["split_key"]
            df: pl.DataFrame = features_per_source["data"][split][split_key]
            features_out[source]["data"][split] = df

    return features_out

check_data_for_NaNs

check_data_for_NaNs(
    source: str, features_per_source: dict[str, Any]
) -> bool

Check feature columns for NaN values.

PARAMETER DESCRIPTION
source

Data source name.

TYPE: str

features_per_source

Features for a single source with 'data' key.

TYPE: dict

RETURNS DESCRIPTION
bool

True if no NaNs found in feature columns, False otherwise.

Source code in src/classification/classifier_utils.py
def check_data_for_NaNs(source: str, features_per_source: dict[str, Any]) -> bool:
    """
    Check feature columns for NaN values.

    Parameters
    ----------
    source : str
        Data source name.
    features_per_source : dict
        Features for a single source with 'data' key.

    Returns
    -------
    bool
        True if no NaNs found in feature columns, False otherwise.
    """
    any_col_null_sums = False
    for split, df in features_per_source["data"].items():
        col_null_sums = df.select(pl.all().is_null().sum())
        for col in col_null_sums.columns:
            if col_null_sums[col][0] > 0:
                if "_value" in col:
                    logger.error(f"Found NaNs in {split} split, col_name: {col}")
                    any_col_null_sums = True

    if any_col_null_sums:
        logger.error(f"Found NaNs in feature columns, source = {source}")
        logger.error("This might easily happen for source 'BASELINE_OutlierRemovedRaw'")
        logger.error("As it has missing values")
        return False
    else:
        return True

classifier_hpo_eval

classifier_hpo_eval(
    y_true: ndarray,
    pred_proba: ndarray,
    eval_metric: str,
    model: str,
    hpo_method: str,
) -> float

Evaluate classifier predictions for hyperparameter optimization.

PARAMETER DESCRIPTION
y_true

True class labels.

TYPE: array - like

pred_proba

Predicted class probabilities.

TYPE: array - like

eval_metric

Metric to compute ('logloss', 'auc', 'f1').

TYPE: str

model

Model name for logging.

TYPE: str

hpo_method

HPO method ('hyperopt' negates loss for minimization).

TYPE: str

RETURNS DESCRIPTION
float

Computed metric value (negated for hyperopt).

RAISES DESCRIPTION
ValueError

If unknown eval_metric specified.

Source code in src/classification/classifier_utils.py
def classifier_hpo_eval(
    y_true: np.ndarray,
    pred_proba: np.ndarray,
    eval_metric: str,
    model: str,
    hpo_method: str,
) -> float:
    """
    Evaluate classifier predictions for hyperparameter optimization.

    Parameters
    ----------
    y_true : array-like
        True class labels.
    pred_proba : array-like
        Predicted class probabilities.
    eval_metric : str
        Metric to compute ('logloss', 'auc', 'f1').
    model : str
        Model name for logging.
    hpo_method : str
        HPO method ('hyperopt' negates loss for minimization).

    Returns
    -------
    float
        Computed metric value (negated for hyperopt).

    Raises
    ------
    ValueError
        If unknown eval_metric specified.
    """
    if eval_metric == "logloss":
        loss = log_loss(y_true, pred_proba)
    elif eval_metric == "auc":
        loss = roc_auc_score(y_true, pred_proba)
    elif eval_metric == "f1":
        pred = (pred_proba > 0.5).astype(int)
        loss = f1_score(y_true, pred, average="binary", zero_division=np.nan)
    else:
        logger.error("Unknown loss function {}".format(eval_metric))
        raise ValueError("Unknown loss function {}".format(eval_metric))

    if hpo_method == "hyperopt":
        loss = -1 * loss

    return loss

classifier_evaluation

get_preds

get_preds(model, dict_arrays)

Get predictions from a trained model for train and test splits.

PARAMETER DESCRIPTION
model

Trained classifier with predict() and predict_proba() methods.

TYPE: object

dict_arrays

Dictionary containing 'x_train', 'x_test', 'y_train', 'y_test'.

TYPE: dict

RETURNS DESCRIPTION
dict

Dictionary with 'train' and 'test' keys, each containing: - 'y_pred_proba': Class 1 probabilities (n_samples,) - 'y_pred': Predicted class labels (n_samples,) - 'label': True labels

Source code in src/classification/classifier_evaluation.py
def get_preds(model, dict_arrays):
    """
    Get predictions from a trained model for train and test splits.

    Parameters
    ----------
    model : object
        Trained classifier with predict() and predict_proba() methods.
    dict_arrays : dict
        Dictionary containing 'x_train', 'x_test', 'y_train', 'y_test'.

    Returns
    -------
    dict
        Dictionary with 'train' and 'test' keys, each containing:
        - 'y_pred_proba': Class 1 probabilities (n_samples,)
        - 'y_pred': Predicted class labels (n_samples,)
        - 'label': True labels
    """
    preds = {}
    for split in ["train", "test"]:
        X = dict_arrays[f"x_{split}"]
        predict_probs = model.predict_proba(X)  # (n_samples, n_classes), e.g. (72,2)
        preds[split] = {
            "y_pred_proba": predict_probs[
                :, 1
            ],  # (n_samples,) e.g. (72,) for the class 1 (e.g. glaucoma)
            "y_pred": model.predict(X),  # (n_samples,), e.g. (72,)
            "label": dict_arrays[f"y_{split}"],
        }

    return preds

arrange_to_match_bootstrap_results

arrange_to_match_bootstrap_results(preds, metrics_dict)

Rearrange predictions and metrics to match bootstrap result structure.

PARAMETER DESCRIPTION
preds

Predictions per split from get_preds().

TYPE: dict

metrics_dict

Metrics dictionary per split.

TYPE: dict

RETURNS DESCRIPTION
dict

Results dictionary with 'metrics' key containing nested structure matching bootstrap evaluation output format.

Source code in src/classification/classifier_evaluation.py
def arrange_to_match_bootstrap_results(preds, metrics_dict):
    """
    Rearrange predictions and metrics to match bootstrap result structure.

    Parameters
    ----------
    preds : dict
        Predictions per split from get_preds().
    metrics_dict : dict
        Metrics dictionary per split.

    Returns
    -------
    dict
        Results dictionary with 'metrics' key containing nested structure
        matching bootstrap evaluation output format.
    """
    for split, preds_dict in preds.items():
        metrics_dict[split]["preds"] = {}
        metrics_dict[split]["preds"]["arrays"] = {}
        metrics_dict[split]["preds"]["arrays"]["predictions"] = preds_dict

    baseline_results = {"metrics": metrics_dict}
    return baseline_results

eval_sklearn_baseline_results

eval_sklearn_baseline_results(model, dict_arrays, cfg)

Evaluate sklearn-style model and compute STRATOS-compliant metrics.

Computes classifier metrics (AUROC, etc.) and calibration metrics for a baseline model without bootstrap uncertainty estimation.

PARAMETER DESCRIPTION
model

Trained sklearn-compatible classifier.

TYPE: object

dict_arrays

Data arrays with train/test splits.

TYPE: dict

cfg

Hydra configuration.

TYPE: DictConfig

RETURNS DESCRIPTION
dict

Baseline results with metrics structured to match bootstrap output.

Source code in src/classification/classifier_evaluation.py
def eval_sklearn_baseline_results(model, dict_arrays, cfg):
    """
    Evaluate sklearn-style model and compute STRATOS-compliant metrics.

    Computes classifier metrics (AUROC, etc.) and calibration metrics
    for a baseline model without bootstrap uncertainty estimation.

    Parameters
    ----------
    model : object
        Trained sklearn-compatible classifier.
    dict_arrays : dict
        Data arrays with train/test splits.
    cfg : DictConfig
        Hydra configuration.

    Returns
    -------
    dict
        Baseline results with metrics structured to match bootstrap output.
    """
    metrics_dict = {}
    preds = get_preds(model, dict_arrays)
    for split in preds.keys():
        metrics_dict[split] = get_classifier_metrics(
            y_true=dict_arrays[f"y_{split}"], preds=preds[split], cfg=cfg
        )
        metrics_dict[split] = get_calibration_metrics(
            model,
            metrics_dict[split],
            y_true=dict_arrays[f"y_{split}"],
            preds=preds[split],
        )

    baseline_results = arrange_to_match_bootstrap_results(preds, metrics_dict)
    return baseline_results

get_the_baseline_model

get_the_baseline_model(
    model_name: str,
    cls_model_cfg: DictConfig,
    hparam_cfg: DictConfig,
    cfg: DictConfig,
    best_params: dict,
    dict_arrays: dict,
    weights_dict: dict,
)

Train and evaluate a single baseline model without bootstrap.

Used as reference before bootstrap evaluation to get deterministic baseline performance metrics.

PARAMETER DESCRIPTION
model_name

Classifier name (e.g., 'CatBoost', 'XGBoost', 'TabM').

TYPE: str

cls_model_cfg

Classifier model configuration.

TYPE: DictConfig

hparam_cfg

Hyperparameter configuration.

TYPE: DictConfig

cfg

Full Hydra configuration.

TYPE: DictConfig

best_params

Best hyperparameters from optimization.

TYPE: dict

dict_arrays

Data arrays with train/test splits.

TYPE: dict

weights_dict

Sample weights per split.

TYPE: dict

RETURNS DESCRIPTION
tuple

(model, baseline_results) where model is the trained classifier and baseline_results contains metrics in bootstrap-compatible format.

Source code in src/classification/classifier_evaluation.py
def get_the_baseline_model(
    model_name: str,
    cls_model_cfg: DictConfig,
    hparam_cfg: DictConfig,
    cfg: DictConfig,
    best_params: dict,
    dict_arrays: dict,
    weights_dict: dict,
):
    """
    Train and evaluate a single baseline model without bootstrap.

    Used as reference before bootstrap evaluation to get deterministic
    baseline performance metrics.

    Parameters
    ----------
    model_name : str
        Classifier name (e.g., 'CatBoost', 'XGBoost', 'TabM').
    cls_model_cfg : DictConfig
        Classifier model configuration.
    hparam_cfg : DictConfig
        Hyperparameter configuration.
    cfg : DictConfig
        Full Hydra configuration.
    best_params : dict
        Best hyperparameters from optimization.
    dict_arrays : dict
        Data arrays with train/test splits.
    weights_dict : dict
        Sample weights per split.

    Returns
    -------
    tuple
        (model, baseline_results) where model is the trained classifier
        and baseline_results contains metrics in bootstrap-compatible format.
    """
    model, baseline_results = bootstrap_model_selector(
        model_name=model_name,
        cls_model_cfg=cls_model_cfg,
        hparam_cfg=hparam_cfg,
        cfg=cfg,
        best_params=best_params,
        dict_arrays=dict_arrays,
        weights_dict=weights_dict,
    )

    if baseline_results is None:
        # sklearn models do not return results, so you need to do the predict and compute the metrics
        baseline_results = eval_sklearn_baseline_results(model, dict_arrays, cfg)
        if model_name == "TabM":
            # due to how TabM was implemented, we got this twice if we did not have a validation split
            baseline_results.pop("val")

    return model, baseline_results

evaluate_sklearn_classifier

evaluate_sklearn_classifier(
    model_name: str,
    dict_arrays: dict,
    best_params,
    cls_model_cfg: DictConfig,
    eval_cfg: DictConfig,
    cfg: DictConfig,
    run_name: str,
)

Evaluate an sklearn-compatible classifier with bootstrap CI estimation.

Trains a baseline model and then performs bootstrap evaluation to estimate confidence intervals for STRATOS-compliant metrics.

PARAMETER DESCRIPTION
model_name

Classifier name (e.g., 'LogisticRegression', 'XGBoost').

TYPE: str

dict_arrays

Data arrays with train/test splits.

TYPE: dict

best_params

Best hyperparameters from optimization.

TYPE: dict

cls_model_cfg

Classifier model configuration.

TYPE: DictConfig

eval_cfg

Evaluation configuration with method and parameters.

TYPE: DictConfig

cfg

Full Hydra configuration.

TYPE: DictConfig

run_name

MLflow run name for logging.

TYPE: str

RETURNS DESCRIPTION
tuple

(models, metrics) where models is list of bootstrap models and metrics contains aggregated statistics.

RAISES DESCRIPTION
ValueError

If unknown evaluation method specified.

Source code in src/classification/classifier_evaluation.py
def evaluate_sklearn_classifier(
    model_name: str,
    dict_arrays: dict,
    best_params,
    cls_model_cfg: DictConfig,
    eval_cfg: DictConfig,
    cfg: DictConfig,
    run_name: str,
):
    """
    Evaluate an sklearn-compatible classifier with bootstrap CI estimation.

    Trains a baseline model and then performs bootstrap evaluation to
    estimate confidence intervals for STRATOS-compliant metrics.

    Parameters
    ----------
    model_name : str
        Classifier name (e.g., 'LogisticRegression', 'XGBoost').
    dict_arrays : dict
        Data arrays with train/test splits.
    best_params : dict
        Best hyperparameters from optimization.
    cls_model_cfg : DictConfig
        Classifier model configuration.
    eval_cfg : DictConfig
        Evaluation configuration with method and parameters.
    cfg : DictConfig
        Full Hydra configuration.
    run_name : str
        MLflow run name for logging.

    Returns
    -------
    tuple
        (models, metrics) where models is list of bootstrap models
        and metrics contains aggregated statistics.

    Raises
    ------
    ValueError
        If unknown evaluation method specified.
    """
    eval_method = eval_cfg["method"]
    method_cfg = eval_cfg[eval_method]
    hparam_cfg = cfg["CLS_HYPERPARAMS"][model_name]

    # Get the baseline model
    weights_dict = return_weights_as_dict(dict_arrays, cls_model_cfg)
    model, baseline_results = get_the_baseline_model(
        model_name,
        cls_model_cfg,
        hparam_cfg,
        cfg,
        best_params,
        dict_arrays,
        weights_dict,
    )
    logger.info(
        f"Baseline Test AUROC = {baseline_results['metrics']['test']['metrics']['scalars']['AUROC']:.2f}"
    )

    # Bootstrap
    if eval_method == "BOOTSTRAP":
        models, metrics = bootstrap_evaluator(
            model_name,
            run_name,
            dict_arrays,
            best_params,
            cls_model_cfg,
            method_cfg=method_cfg,
            hparam_cfg=hparam_cfg,
            cfg=cfg,
        )  # ~30sec
    else:
        logger.error(f"Unknown evaluation method: {eval_cfg['method']}")
        raise ValueError(f"Unknown evaluation method: {eval_cfg['method']}")
        # - Conformal Prediction for classifier (e.g. https://github.com/donlnz/nonconformist)
        # - https://arxiv.org/abs/2404.19472v1

    # Log best params
    for best_param, best_value in best_params.items():
        if best_value is None:
            mlflow.log_param("hparam_" + best_param, "None")
        else:
            mlflow.log_param("hparam_" + best_param, best_value)

    # Log the extra metrics to MLflow
    classifier_log_cls_evaluation_to_mlflow(
        model,
        baseline_results,
        models,
        metrics,
        dict_arrays,
        cls_model_cfg,
        run_name=run_name,
        model_name=model_name,
    )

    return models, metrics

Other Classifiers

sklearn_simple_classifiers

display_grid_search_results

display_grid_search_results(
    grid_result: GridSearchCV, scoring: str
)

Display and log grid search results.

Prints all parameter combinations with scores and logs best params to MLflow.

PARAMETER DESCRIPTION
grid_result

Completed grid search object.

TYPE: GridSearchCV

scoring

Name of the scoring metric.

TYPE: str

Source code in src/classification/sklearn_simple_classifiers.py
def display_grid_search_results(grid_result: GridSearchCV, scoring: str):
    """
    Display and log grid search results.

    Prints all parameter combinations with scores and logs best params
    to MLflow.

    Parameters
    ----------
    grid_result : GridSearchCV
        Completed grid search object.
    scoring : str
        Name of the scoring metric.
    """
    means = grid_result.cv_results_["mean_test_score"]
    stds = grid_result.cv_results_["std_test_score"]
    params = grid_result.cv_results_["params"]
    logger.debug("Grid search results:")
    for mean, stdev, param in zip(means, stds, params):
        logger.debug("%f (%f) with: %r" % (mean, stdev, param))

    logger.info(
        "Best %s: %f using %s"
        % (scoring, grid_result.best_score_, grid_result.best_params_)
    )

    logger.info("Log best params to MLflow")
    for key, value in grid_result.best_params_.items():
        key = f"hyperparam_{key}"
        mlflow.log_param(key, value)
        logger.debug(f"key {key}: {value}")

standardize_features

standardize_features(X)

Standardize features to zero mean and unit variance.

PARAMETER DESCRIPTION
X

Feature matrix (n_samples, n_features).

TYPE: ndarray

RETURNS DESCRIPTION
tuple

(X_scaled, scaler) where X_scaled is standardized features and scaler is the fitted StandardScaler.

Source code in src/classification/sklearn_simple_classifiers.py
def standardize_features(X):
    """
    Standardize features to zero mean and unit variance.

    Parameters
    ----------
    X : np.ndarray
        Feature matrix (n_samples, n_features).

    Returns
    -------
    tuple
        (X_scaled, scaler) where X_scaled is standardized features and
        scaler is the fitted StandardScaler.
    """
    scaler = preprocessing.StandardScaler().fit(X)
    X = scaler.transform(X)
    logger.info("Standardizing features:")
    logger.info("mean = {}".format(scaler.mean_))
    logger.info("std = {}".format(scaler.scale_))
    return X, scaler

prepare_for_logistic_hpo

prepare_for_logistic_hpo(X, y, hparam_cfg)

Prepare data and grid for logistic regression hyperparameter optimization.

Standardizes features and constructs parameter grid from config.

PARAMETER DESCRIPTION
X

Feature matrix.

TYPE: ndarray

y

Labels.

TYPE: ndarray

hparam_cfg

Hyperparameter configuration with SEARCH_SPACE.

TYPE: DictConfig

RETURNS DESCRIPTION
tuple

(X_scaled, y, grid) where grid is dict of hyperparameter lists.

RAISES DESCRIPTION
AssertionError

If X and y have different sample counts or X contains NaNs.

Source code in src/classification/sklearn_simple_classifiers.py
def prepare_for_logistic_hpo(X, y, hparam_cfg):
    """
    Prepare data and grid for logistic regression hyperparameter optimization.

    Standardizes features and constructs parameter grid from config.

    Parameters
    ----------
    X : np.ndarray
        Feature matrix.
    y : np.ndarray
        Labels.
    hparam_cfg : DictConfig
        Hyperparameter configuration with SEARCH_SPACE.

    Returns
    -------
    tuple
        (X_scaled, y, grid) where grid is dict of hyperparameter lists.

    Raises
    ------
    AssertionError
        If X and y have different sample counts or X contains NaNs.
    """
    assert X.shape[0] == y.shape[0], "X and y must have the same number of rows"
    assert np.sum(np.isnan(X)) == 0, "X must not contain NaNs"

    # Standardize features
    X, scaler = standardize_features(X)

    # Not so many params to play with compared to XGBoost
    hparam_method = hparam_cfg["HYPERPARAMS"]["method"]
    if hparam_method is not None:
        hparams = hparam_cfg["SEARCH_SPACE"][hparam_method]
        grid = {}
        for key, value in dict(hparams).items():
            grid[key] = list(value)  # from ListConfig to list
    else:
        grid = None

    return X, y, grid
logistic_regression_hpo_grid_search(
    X,
    y,
    weights_dict: dict,
    hparam_cfg: DictConfig,
    cls_model_cfg: DictConfig,
)

Perform grid search hyperparameter optimization for logistic regression.

Uses RepeatedStratifiedKFold cross-validation to find optimal hyperparameters.

PARAMETER DESCRIPTION
X

Feature matrix.

TYPE: ndarray

y

Labels.

TYPE: ndarray

weights_dict

Sample weights (currently unused).

TYPE: dict

hparam_cfg

Hyperparameter configuration with SEARCH_SPACE and cv_params.

TYPE: DictConfig

cls_model_cfg

Classifier model configuration with default hyperparams.

TYPE: DictConfig

RETURNS DESCRIPTION
tuple

(grid_result, best_params) where grid_result is GridSearchCV object and best_params is dict of optimal hyperparameters.

Source code in src/classification/sklearn_simple_classifiers.py
def logistic_regression_hpo_grid_search(
    X, y, weights_dict: dict, hparam_cfg: DictConfig, cls_model_cfg: DictConfig
):
    """
    Perform grid search hyperparameter optimization for logistic regression.

    Uses RepeatedStratifiedKFold cross-validation to find optimal hyperparameters.

    Parameters
    ----------
    X : np.ndarray
        Feature matrix.
    y : np.ndarray
        Labels.
    weights_dict : dict
        Sample weights (currently unused).
    hparam_cfg : DictConfig
        Hyperparameter configuration with SEARCH_SPACE and cv_params.
    cls_model_cfg : DictConfig
        Classifier model configuration with default hyperparams.

    Returns
    -------
    tuple
        (grid_result, best_params) where grid_result is GridSearchCV object
        and best_params is dict of optimal hyperparameters.
    """
    # https://machinelearningmastery.com/hyperparameters-for-classification-machine-learning-algorithms/
    X, y, grid = prepare_for_logistic_hpo(X, y, hparam_cfg)

    if grid is not None:
        logger.info("Grid search for Logistic Regression, params: {}".format(grid))
        model = LogisticRegression(max_iter=500)
        logger.info(
            "Cross-validation params: {}".format(hparam_cfg["HYPERPARAMS"]["cv_params"])
        )
        cv = RepeatedStratifiedKFold(**hparam_cfg["HYPERPARAMS"]["cv_params"])
        logger.info(
            "GridSearchCV params: {}".format(hparam_cfg["HYPERPARAMS"]["fit_params"])
        )
        start_time = time.time()
        grid_search = GridSearchCV(
            estimator=model,
            param_grid=grid,
            cv=cv,
            verbose=2,
            **hparam_cfg["HYPERPARAMS"]["fit_params"],
        )
        grid_result = grid_search.fit(X, y)
        display_grid_search_results(
            grid_result, scoring=hparam_cfg["HYPERPARAMS"]["fit_params"]["scoring"]
        )
        logger.info("Grid search time: {:.2f} seconds".format(time.time() - start_time))
        best_params = grid_result.best_params_
    else:
        logger.info(
            "Logistic Regression without grid search, using default hyperparameters"
        )
        grid_result = None
        best_params = cls_model_cfg["MODEL"]["HYPERPARAMS_DEFAULT"]

    return grid_result, best_params

logistic_regression

logistic_regression(
    model_name,
    dict_arrays,
    weights_dict,
    cls_model_cfg,
    hparam_cfg,
    cfg,
    run_name: str,
    features_per_source: dict,
    join_test_and_train: bool = True,
)

Train and evaluate logistic regression classifier with MLflow tracking.

Performs hyperparameter optimization via grid search, trains final model, and evaluates with bootstrap confidence intervals.

PARAMETER DESCRIPTION
model_name

Model name for logging.

TYPE: str

dict_arrays

Data arrays with train/test splits.

TYPE: dict

weights_dict

Sample weights.

TYPE: dict

cls_model_cfg

Classifier model configuration.

TYPE: DictConfig

hparam_cfg

Hyperparameter configuration.

TYPE: DictConfig

cfg

Full Hydra configuration.

TYPE: DictConfig

run_name

MLflow run name.

TYPE: str

features_per_source

Feature source metadata for logging.

TYPE: dict

join_test_and_train

Join train and test for HPO (currently not used for final model).

TYPE: bool DEFAULT: True

Source code in src/classification/sklearn_simple_classifiers.py
def logistic_regression(
    model_name,
    dict_arrays,
    weights_dict,
    cls_model_cfg,
    hparam_cfg,
    cfg,
    run_name: str,
    features_per_source: dict,
    join_test_and_train: bool = True,
):
    """
    Train and evaluate logistic regression classifier with MLflow tracking.

    Performs hyperparameter optimization via grid search, trains final model,
    and evaluates with bootstrap confidence intervals.

    Parameters
    ----------
    model_name : str
        Model name for logging.
    dict_arrays : dict
        Data arrays with train/test splits.
    weights_dict : dict
        Sample weights.
    cls_model_cfg : DictConfig
        Classifier model configuration.
    hparam_cfg : DictConfig
        Hyperparameter configuration.
    cfg : DictConfig
        Full Hydra configuration.
    run_name : str
        MLflow run name.
    features_per_source : dict
        Feature source metadata for logging.
    join_test_and_train : bool, default True
        Join train and test for HPO (currently not used for final model).
    """
    with mlflow.start_run(run_name=run_name):
        mlflow.log_params(hparam_cfg)

        # Log the source params
        log_classifier_sources_as_params(
            features_per_source, dict_arrays, run_name, cfg
        )
        mlflow.log_param("model_name", "LogisticRegression")

        # https://scikit-learn.org/1.5/modules/generated/sklearn.linear_model.LogisticRegression.html
        if join_test_and_train:
            X, y, X_weights = join_test_and_train_arrays(dict_arrays)
        else:
            X = dict_arrays["X_train"]
            y = dict_arrays["y_train"]
            # X_weights = weights_dict["X_train_w"]

        # Find best params with a grid search
        grid_result, best_params = logistic_regression_hpo_grid_search(
            X, y, weights_dict, hparam_cfg=hparam_cfg, cls_model_cfg=cls_model_cfg
        )

        # Evaluate the model performance
        models, metrics = evaluate_sklearn_classifier(
            model_name,
            dict_arrays,
            best_params=best_params,
            cls_model_cfg=cls_model_cfg,
            eval_cfg=cfg["CLS_EVALUATION"],
            cfg=cfg,
            run_name=run_name,
        )

        logger.debug("Logistic Regression model evaluation:", metrics, models)
        mlflow.end_run()

sklearn_simple_cls_main

sklearn_simple_cls_main(
    train_df: DataFrame,
    test_df: DataFrame,
    model_name: str,
    cfg: DictConfig,
    cls_model_cfg: DictConfig,
    hparam_cfg: DictConfig,
    run_name: str,
    features_per_source: dict,
)

Main entry point for sklearn-based simple classifiers.

Converts dataframes to arrays and dispatches to appropriate classifier training function.

PARAMETER DESCRIPTION
train_df

Training data as Polars DataFrame.

TYPE: DataFrame

test_df

Test data as Polars DataFrame.

TYPE: DataFrame

model_name

Classifier name (e.g., 'LogisticRegression').

TYPE: str

cfg

Full Hydra configuration.

TYPE: DictConfig

cls_model_cfg

Classifier model configuration.

TYPE: DictConfig

hparam_cfg

Hyperparameter configuration.

TYPE: DictConfig

run_name

MLflow run name.

TYPE: str

features_per_source

Feature source metadata.

TYPE: dict

RAISES DESCRIPTION
ValueError

If unknown model_name specified.

Source code in src/classification/sklearn_simple_classifiers.py
def sklearn_simple_cls_main(
    train_df: pl.DataFrame,
    test_df: pl.DataFrame,
    model_name: str,
    cfg: DictConfig,
    cls_model_cfg: DictConfig,
    hparam_cfg: DictConfig,
    run_name: str,
    features_per_source: dict,
):
    """
    Main entry point for sklearn-based simple classifiers.

    Converts dataframes to arrays and dispatches to appropriate classifier
    training function.

    Parameters
    ----------
    train_df : pl.DataFrame
        Training data as Polars DataFrame.
    test_df : pl.DataFrame
        Test data as Polars DataFrame.
    model_name : str
        Classifier name (e.g., 'LogisticRegression').
    cfg : DictConfig
        Full Hydra configuration.
    cls_model_cfg : DictConfig
        Classifier model configuration.
    hparam_cfg : DictConfig
        Hyperparameter configuration.
    run_name : str
        MLflow run name.
    features_per_source : dict
        Feature source metadata.

    Raises
    ------
    ValueError
        If unknown model_name specified.
    """
    # Convert Polars DataFrames to arrays
    _, _, dict_arrays = data_transform_wrapper(train_df, test_df, cls_model_cfg, None)

    # Get weights for the sklearn classifiers if you want to use these
    weights_dict = weights_dict_wrapper(dict_arrays, cls_model_cfg)

    if model_name == "LogisticRegression":
        logistic_regression(
            model_name,
            dict_arrays,
            weights_dict,
            cls_model_cfg,
            hparam_cfg,
            cfg,
            run_name,
            features_per_source,
        )

    else:
        logger.error(f"Unknown classifier model: {model_name}")
        raise ValueError(f"Unknown classifier model: {model_name}")

tabpfn_main

eval_tabpfn_model

eval_tabpfn_model(
    model: Any, dict_arrays: Dict[str, ndarray]
) -> Tuple[Dict[str, Dict[str, ndarray]], float]

Evaluate TabPFN model on all available splits.

Computes predictions and optionally AUROC for baseline evaluation.

PARAMETER DESCRIPTION
model

Fitted TabPFN model.

TYPE: TabPFNClassifier

dict_arrays

Data arrays with train/test/optionally val splits.

TYPE: dict

RETURNS DESCRIPTION
tuple

(results, auroc) where results contains predictions per split and auroc is test AUROC (or NaN if validation split present).

Source code in src/classification/tabpfn_main.py
def eval_tabpfn_model(
    model: Any, dict_arrays: Dict[str, np.ndarray]
) -> Tuple[Dict[str, Dict[str, np.ndarray]], float]:
    """
    Evaluate TabPFN model on all available splits.

    Computes predictions and optionally AUROC for baseline evaluation.

    Parameters
    ----------
    model : TabPFNClassifier
        Fitted TabPFN model.
    dict_arrays : dict
        Data arrays with train/test/optionally val splits.

    Returns
    -------
    tuple
        (results, auroc) where results contains predictions per split and
        auroc is test AUROC (or NaN if validation split present).
    """
    if "x_val" in dict_arrays:
        splits = ["train", "val", "test"]
    else:
        splits = ["train", "test"]

    results = {}
    auroc = {}
    for split in splits:
        results[split] = {}
        results[split]["y_pred"] = model.predict(dict_arrays[f"x_{split}"])
        results[split]["y_pred_proba"] = model.predict_proba(dict_arrays[f"x_{split}"])[
            :, 1
        ]  # class 1 probs

        if "x_val" not in dict_arrays:
            # only for baseline model, saving some (milli)seconds when doing bootstrap
            fpr, tpr, thresholds = roc_curve(
                dict_arrays[f"y_{split}"], results[split]["y_pred"]
            )
            auroc[split] = auc(fpr, tpr)

    if "x_val" not in dict_arrays:
        logger.info(
            "TabPFN Baseline | Test AUROC: {:.3f}, Train AUROC: {:.3f}, GAP: {:.3f}".format(
                auroc["test"], auroc["train"], auroc["train"] - auroc["test"]
            )
        )
        return results, auroc["test"]
    else:
        return results, np.nan

train_and_eval_tabpfn

train_and_eval_tabpfn(
    dict_arrays: Dict[str, ndarray], hparams: Dict[str, Any]
) -> Tuple[None, Dict[str, Dict[str, ndarray]], float]

Train and evaluate TabPFN classifier.

TabPFN is a prior-fitted network that requires no training on the target dataset - it uses in-context learning.

PARAMETER DESCRIPTION
dict_arrays

Data arrays with train/test splits.

TYPE: dict

hparams

Hyperparameters for TabPFN (currently unused for v2).

TYPE: dict

RETURNS DESCRIPTION
tuple

(model, results, metric) where model is None (to save RAM), results contains predictions, and metric is test AUROC.

References

TabPFN: https://github.com/automl/TabPFN

Source code in src/classification/tabpfn_main.py
def train_and_eval_tabpfn(
    dict_arrays: Dict[str, np.ndarray], hparams: Dict[str, Any]
) -> Tuple[None, Dict[str, Dict[str, np.ndarray]], float]:
    """
    Train and evaluate TabPFN classifier.

    TabPFN is a prior-fitted network that requires no training on the
    target dataset - it uses in-context learning.

    Parameters
    ----------
    dict_arrays : dict
        Data arrays with train/test splits.
    hparams : dict
        Hyperparameters for TabPFN (currently unused for v2).

    Returns
    -------
    tuple
        (model, results, metric) where model is None (to save RAM),
        results contains predictions, and metric is test AUROC.

    References
    ----------
    TabPFN: https://github.com/automl/TabPFN
    """
    # see https://github.com/automl/TabPFN?tab=readme-ov-file#getting-started
    # # When N_ensemble_configurations > #features * #classes, no further averaging is applied.
    # 17 > 8 x 2
    warnings.simplefilter("ignore")
    if torch.cuda.is_available():
        # around 6x times faster with 2070 Super than laptop CPU, so definitely use GPU if possible
        device = "cuda"
    else:
        device = "cpu"
    mlflow.log_param("device", device)

    classifier = TabPFNClassifier(
        device=device
    )  # , **hparams) # TODO! if you want to pass hyperparams for v2
    classifier.fit(dict_arrays["x_train"], dict_arrays["y_train"])
    results, metric = eval_tabpfn_model(classifier, dict_arrays)
    warnings.resetwarnings()
    # to save RAM, do not return the model, if you want these, write them on disk
    model = None

    return model, results, metric

tabpfn_wrapper

tabpfn_wrapper(
    dict_arrays: Dict[str, ndarray],
    cls_model_cfg: DictConfig,
    hparam_cfg: DictConfig,
    cfg: DictConfig,
    run_HPO: bool = False,
) -> Tuple[
    None, Dict[str, Dict[str, ndarray]], Dict[str, Any]
]

Wrapper for TabPFN training with optional hyperparameter optimization.

PARAMETER DESCRIPTION
dict_arrays

Data arrays with train/test splits.

TYPE: dict

cls_model_cfg

TabPFN model configuration.

TYPE: DictConfig

hparam_cfg

Hyperparameter configuration.

TYPE: DictConfig

cfg

Full Hydra configuration.

TYPE: DictConfig

run_HPO

Run hyperparameter optimization (not implemented for v2).

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
tuple

(model, results, best_hparams) from training.

RAISES DESCRIPTION
NotImplementedError

If run_HPO is True (was for TabPFN v1).

Source code in src/classification/tabpfn_main.py
def tabpfn_wrapper(
    dict_arrays: Dict[str, np.ndarray],
    cls_model_cfg: DictConfig,
    hparam_cfg: DictConfig,
    cfg: DictConfig,
    run_HPO: bool = False,
) -> Tuple[None, Dict[str, Dict[str, np.ndarray]], Dict[str, Any]]:
    """
    Wrapper for TabPFN training with optional hyperparameter optimization.

    Parameters
    ----------
    dict_arrays : dict
        Data arrays with train/test splits.
    cls_model_cfg : DictConfig
        TabPFN model configuration.
    hparam_cfg : DictConfig
        Hyperparameter configuration.
    cfg : DictConfig
        Full Hydra configuration.
    run_HPO : bool, default False
        Run hyperparameter optimization (not implemented for v2).

    Returns
    -------
    tuple
        (model, results, best_hparams) from training.

    Raises
    ------
    NotImplementedError
        If run_HPO is True (was for TabPFN v1).
    """
    if run_HPO:
        raise NotImplementedError("HPO was for TabFPN v1")
        # quick'n'dirty one param HPO
        # n_vector = np.linspace(1, 64, 64).astype(int)
        # hparams = []
        # for n in n_vector:
        #     hparams.append({"N_ensemble_configurations": n})
        # best_hparams = None
    else:
        hparams = [dict(cls_model_cfg["MODEL"]["HYPERPARAMS"])]
        best_hparams = hparams[0]

    best_metric = 0
    for hparam_dict in hparams:
        # n=7 is the first value to reach test AUROC of 0.831 for the pupil-gt__pupil-gt
        # TabPFN Baseline | Test AUROC: 0.831, Train AUROC: 0.925, GAP: 0.094
        # Train AUROC fluctuates between 0.925 and 0.912 and converges on large ns, when > 30
        # you could later replicate this in the bootstrap_evaluator to have a better idea? TODO!
        # logger.info(f"Training TabPFN with hyperparameters: {hparam_dict}")
        try:
            model, results, metric = train_and_eval_tabpfn(dict_arrays, hparam_dict)
        except Exception as e:
            logger.error("Failed to train TabPFN due to error: {}".format(e))
            raise e
        if metric > best_metric:
            best_metric = metric
            best_hparams = hparam_dict

    # Log the hyperparams to MLflow
    # for key, value in best_hparams.items():
    #     mlflow.log_param('hparam'+key, value)

    return model, results, best_hparams

tabpfn_main

tabpfn_main(
    train_df: DataFrame,
    test_df: DataFrame,
    run_name: str,
    cfg: DictConfig,
    cls_model_cfg: DictConfig,
    hparam_cfg: DictConfig,
    features_per_source: Dict[str, List[str]],
) -> None

Main entry point for TabPFN classifier training with MLflow tracking.

TabPFN uses in-context learning and doesn't require traditional training. This function handles data preparation, bootstrap evaluation, and MLflow logging.

PARAMETER DESCRIPTION
train_df

Training data as Polars DataFrame.

TYPE: DataFrame

test_df

Test data as Polars DataFrame.

TYPE: DataFrame

run_name

MLflow run name.

TYPE: str

cfg

Full Hydra configuration.

TYPE: DictConfig

cls_model_cfg

TabPFN model configuration.

TYPE: DictConfig

hparam_cfg

Hyperparameter configuration.

TYPE: DictConfig

features_per_source

Feature source metadata for logging.

TYPE: dict

Source code in src/classification/tabpfn_main.py
def tabpfn_main(
    train_df: pl.DataFrame,
    test_df: pl.DataFrame,
    run_name: str,
    cfg: DictConfig,
    cls_model_cfg: DictConfig,
    hparam_cfg: DictConfig,
    features_per_source: Dict[str, List[str]],
) -> None:
    """
    Main entry point for TabPFN classifier training with MLflow tracking.

    TabPFN uses in-context learning and doesn't require traditional training.
    This function handles data preparation, bootstrap evaluation, and
    MLflow logging.

    Parameters
    ----------
    train_df : pl.DataFrame
        Training data as Polars DataFrame.
    test_df : pl.DataFrame
        Test data as Polars DataFrame.
    run_name : str
        MLflow run name.
    cfg : DictConfig
        Full Hydra configuration.
    cls_model_cfg : DictConfig
        TabPFN model configuration.
    hparam_cfg : DictConfig
        Hyperparameter configuration.
    features_per_source : dict
        Feature source metadata for logging.
    """
    with mlflow.start_run(run_name=run_name):
        mlflow.log_param("model_name", "TabPFN")
        for k, v in cls_model_cfg.items():
            if k != "MODEL":
                mlflow.log_param(k, v)

        _, _, dict_arrays = data_transform_wrapper(
            train_df, test_df, cls_model_cfg, None
        )
        log_classifier_sources_as_params(
            features_per_source, dict_arrays, run_name, cfg
        )

        # Define hyperparameter search space
        model_cfgs = [cls_model_cfg]

        # Define the baseline model
        # Get the baseline model
        from src.classification.bootstrap_evaluation import bootstrap_evaluator
        from src.classification.classifier_evaluation import get_the_baseline_model

        weights_dict = return_weights_as_dict(dict_arrays, cls_model_cfg)
        model, baseline_results = get_the_baseline_model(
            "TabPFN", cls_model_cfg, hparam_cfg, cfg, None, dict_arrays, weights_dict
        )

        metrics = {}
        for i, cls_model_cfg in enumerate(model_cfgs):
            # logger.info(f"{i+1}/{len(model_cfgs)}: Hyperparameter grid search")
            # e.g. Bootstrap iterations:  11%|█         | 110/1000 [01:19<07:53,
            models, metrics[i] = bootstrap_evaluator(
                model_name="TabPFN",
                run_name=run_name,
                dict_arrays=dict_arrays,
                best_params=None,
                cls_model_cfg=cls_model_cfg,
                method_cfg=cfg["CLS_EVALUATION"]["BOOTSTRAP"],
                hparam_cfg=hparam_cfg,
                cfg=cfg,
            )

        if len(metrics) == 1:
            metrics = metrics[0]
        else:
            logger.debug("Get the best set of hyperparameters")
            metrics, best_choice = pick_the_best_hyperparam_metrics(
                metrics, hparam_cfg, model_cfgs, cfg
            )

        # Log the extra metrics to MLflow
        classifier_log_cls_evaluation_to_mlflow(
            model,
            baseline_results,
            models,
            metrics,
            dict_arrays,
            cls_model_cfg,
            run_name=run_name,
            model_name="TabPFN",
        )