Skip to content

Visualization Module

Publication-quality figure generation for Foundation PLR.

Overview

The viz module provides Python-based visualization functions for generating figures. For R-based figures, see src/r/.

Computation vs Visualization

The viz module only visualizes data from DuckDB. All metric computation happens during extraction (see scripts/extract_all_configs_to_duckdb.py).

Key Modules

Module Description
plot_config Style setup, colors, save functions
calibration_plot STRATOS calibration curves
dca_plot Decision curve analysis
cd_diagram Critical difference diagrams
factorial_matrix Factorial design heatmaps
featurization_comparison Handcrafted vs embeddings

API Reference

plot_config

Plot Configuration for Foundation PLR Figures

Style: Neue Haas Grotesk / Helvetica Neue inspired typography Clean, professional academic visualization aesthetic

Usage: from src.viz.plot_config import setup_style, save_figure, COLORS setup_style() # Call before creating figures

AIDEV-NOTE: This module provides styling and export functionality.

All viz modules should call setup_style() before creating figures.

setup_style module-attribute

setup_style = apply_style

COLORS module-attribute

COLORS = {
    "primary": "#2E5090",
    "secondary": "#D64045",
    "tertiary": "#45B29D",
    "quaternary": "#F5A623",
    "quinary": "#7B68EE",
    "positive": "#45B29D",
    "negative": "#D64045",
    "neutral": "#4A4A4A",
    "reference": "#D64045",
    "background": "#FAFAFA",
    "grid": "#E0E0E0",
    "moment": "#2E5090",
    "units": "#45B29D",
    "traditional": "#7B68EE",
    "ensemble": "#F5A623",
    "ground_truth": "#666666",
    "foundation_model": "#0072B2",
    "deep_learning": "#009E73",
    "handcrafted": "#2E5090",
    "embeddings": "#D64045",
    "catboost": "#2E5090",
    "tabpfn": "#45B29D",
    "xgboost": "#D64045",
    "logreg": "#7B68EE",
    "good": "#45B29D",
    "bad": "#D64045",
    "accent": "#F5A623",
    "highlight": "#F5A623",
    "glaucoma": "#E74C3C",
    "control": "#3498DB",
    "blue_stimulus": "#1f77b4",
    "red_stimulus": "#d62728",
    "background_light": "#f0f0f0",
    "blue_zone": "#cce5ff",
    "red_zone": "#ffcccc",
    "cd_rank1": "#2ecc71",
    "cd_rank2": "#3498db",
    "cd_rank3": "#e74c3c",
    "cd_rank4": "#9b59b6",
    "cd_rank5": "#f39c12",
    "text_primary": "#333333",
    "text_secondary": "#666666",
    "background_neutral": "#F5F5F5",
    "grid_lines": "#CCCCCC",
    "decomp_component_1": "#E69F00",
    "decomp_component_2": "#56B4E9",
    "decomp_component_3": "#009E73",
    "decomp_mean_waveform": "#888888",
    "cycle_brown": "#8B4513",
    "cycle_seagreen": "#20B2AA",
}

save_figure

save_figure(
    fig: Figure,
    name: str,
    data: Optional[Dict[str, Any]] = None,
    formats: List[str] = None,
    output_dir: Optional[Path] = None,
    synthetic: Optional[bool] = None,
) -> Path

Save figure to multiple formats and optionally save accompanying JSON data.

Part of the 4-gate isolation architecture. Synthetic figures are automatically routed to figures/synthetic/ when synthetic=True or when is_synthetic_mode().

PARAMETER DESCRIPTION
fig

The figure to save

TYPE: Figure

name

Base name for the output file (without extension)

TYPE: str

data

Data dictionary to save as JSON for reproducibility. If synthetic, adds _synthetic_warning=True to the data.

TYPE: dict DEFAULT: None

formats

Output formats. Default loads from config or uses ['png', 'svg']. SVG preferred over PDF for vector graphics (infinite scalability).

TYPE: list DEFAULT: None

output_dir

Output directory (default: figures/generated/ or figures/synthetic/). Auto-detected from data mode if not specified.

TYPE: Path DEFAULT: None

synthetic

If True, route to synthetic directory. If None, auto-detect from is_synthetic_mode() environment variable.

TYPE: bool DEFAULT: None

RETURNS DESCRIPTION
Path

Path to the primary output file (PNG)

Source code in src/viz/plot_config.py
def save_figure(
    fig: plt.Figure,
    name: str,
    data: Optional[Dict[str, Any]] = None,
    formats: List[str] = None,
    output_dir: Optional[Path] = None,
    synthetic: Optional[bool] = None,
) -> Path:
    """
    Save figure to multiple formats and optionally save accompanying JSON data.

    Part of the 4-gate isolation architecture. Synthetic figures are automatically
    routed to figures/synthetic/ when synthetic=True or when is_synthetic_mode().

    Parameters
    ----------
    fig : matplotlib.Figure
        The figure to save
    name : str
        Base name for the output file (without extension)
    data : dict, optional
        Data dictionary to save as JSON for reproducibility.
        If synthetic, adds _synthetic_warning=True to the data.
    formats : list, optional
        Output formats. Default loads from config or uses ['png', 'svg'].
        SVG preferred over PDF for vector graphics (infinite scalability).
    output_dir : Path, optional
        Output directory (default: figures/generated/ or figures/synthetic/).
        Auto-detected from data mode if not specified.
    synthetic : bool, optional
        If True, route to synthetic directory. If None, auto-detect from
        is_synthetic_mode() environment variable.

    Returns
    -------
    Path
        Path to the primary output file (PNG)
    """
    # Import data mode utilities for synthetic detection
    from src.utils.data_mode import (
        get_figures_dir_for_mode,
        is_synthetic_mode,
    )

    if formats is None:
        # Load from figure_layouts.yaml - SINGLE SOURCE OF TRUTH
        try:
            import yaml

            project_root = Path(__file__).parent.parent.parent
            layouts_path = (
                project_root / "configs" / "VISUALIZATION" / "figure_layouts.yaml"
            )
            with open(layouts_path) as f:
                layouts_config = yaml.safe_load(f)
            formats = layouts_config.get("output_settings", {}).get("formats", ["png"])
        except Exception:
            formats = ["png"]  # Default: PNG only

    # Auto-detect synthetic mode if not explicitly specified
    if synthetic is None:
        synthetic = is_synthetic_mode()

    # Route to appropriate output directory
    if output_dir is None:
        output_dir = get_figures_dir_for_mode(synthetic=synthetic)
    elif synthetic and "synthetic" not in str(output_dir).lower():
        # User specified a directory but we're in synthetic mode - warn
        import warnings

        warnings.warn(
            f"Synthetic mode detected but output_dir={output_dir} doesn't contain 'synthetic'. "
            "This may cause production contamination. Consider using synthetic=False or "
            "passing a synthetic output directory."
        )

    # Ensure output_dir is a Path (callers may pass str)
    output_dir = Path(output_dir)

    output_dir.mkdir(parents=True, exist_ok=True)

    # Save figure in all formats
    primary_path = None
    for fmt in formats:
        path = output_dir / f"{name}.{fmt}"
        fig.savefig(path, dpi=300, bbox_inches="tight", facecolor="white")
        if fmt == "png":
            primary_path = path
        print(f"  Saved: {path}")

    # Save JSON data if provided
    if data is not None:
        # Add synthetic warning metadata if in synthetic mode
        if synthetic:
            data = dict(data)  # Make a copy
            data["_synthetic_warning"] = True
            data["_data_source"] = "synthetic"
            data["_do_not_publish"] = True

        data_dir = output_dir / "data"
        data_dir.mkdir(exist_ok=True)
        json_path = data_dir / f"{name}.json"
        with open(json_path, "w") as f:
            json.dump(data, f, indent=2, default=str)
        print(f"  Saved: {json_path}")

    return primary_path or output_dir / f"{name}.{formats[0]}"

get_combo_color

get_combo_color(combo_id: str) -> str

Get the color for a specific combo from config.

Resolves via plot_hyperparam_combos.yaml: combo.color_ref → color_definitions.

PARAMETER DESCRIPTION
combo_id

Combo identifier (e.g., 'ground_truth', 'best_single_fm')

TYPE: str

RETURNS DESCRIPTION
str

Hex color string

Source code in src/viz/plot_config.py
def get_combo_color(combo_id: str) -> str:
    """
    Get the color for a specific combo from config.

    Resolves via plot_hyperparam_combos.yaml: combo.color_ref → color_definitions.

    Parameters
    ----------
    combo_id : str
        Combo identifier (e.g., 'ground_truth', 'best_single_fm')

    Returns
    -------
    str
        Hex color string
    """
    try:
        loader = get_config_loader()
        combos_config = loader.get_combos()
        color_defs = combos_config.get("color_definitions", {})

        # Look up color_ref from combo config
        for combo_type in ["standard_combos", "extended_combos"]:
            for combo in combos_config.get(combo_type, []):
                if combo.get("id") == combo_id:
                    color_ref = combo.get("color_ref")
                    if color_ref and color_ref in color_defs:
                        return color_defs[color_ref]

        # Fallback to COLORS dict (e.g., for non-combo color lookups)
        return COLORS.get(combo_id, COLORS["neutral"])
    except (ConfigurationError, Exception):
        return COLORS.get(combo_id, COLORS["neutral"])

calibration_plot

Calibration plot visualization module.

Implements STRATOS-compliant smoothed calibration curves with LOESS smoothing. Based on Van Calster et al. 2024 guidelines.

COMPUTATION DECOUPLING: This module performs visualization ONLY. - LOESS smoothing and bootstrap CI are visualization rendering (acceptable). - Calibration metrics (slope, intercept, Brier, O:E) come from DuckDB. - The *_from_db functions read pre-computed metrics from DuckDB. - NO sklearn imports. NO src.stats imports.

compute_loess_calibration

compute_loess_calibration(
    y_true: ndarray,
    y_prob: ndarray,
    frac: float = 0.3,
    n_points: int = 100,
) -> Tuple[ndarray, ndarray]

Compute LOESS-smoothed calibration curve.

This is visualization rendering (smoothing for display), NOT metric computation.

PARAMETER DESCRIPTION
y_true

True binary labels

TYPE: array - like

y_prob

Predicted probabilities

TYPE: array - like

frac

Fraction of data used for LOESS smoothing (default 0.3)

TYPE: float DEFAULT: 0.3

n_points

Number of points for output curve

TYPE: int DEFAULT: 100

RETURNS DESCRIPTION
x_smooth

Sorted probability values

TYPE: ndarray

y_smooth

Smoothed calibration values (observed frequencies)

TYPE: ndarray

Source code in src/viz/calibration_plot.py
def compute_loess_calibration(
    y_true: np.ndarray, y_prob: np.ndarray, frac: float = 0.3, n_points: int = 100
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Compute LOESS-smoothed calibration curve.

    This is visualization rendering (smoothing for display), NOT metric computation.

    Parameters
    ----------
    y_true : array-like
        True binary labels
    y_prob : array-like
        Predicted probabilities
    frac : float
        Fraction of data used for LOESS smoothing (default 0.3)
    n_points : int
        Number of points for output curve

    Returns
    -------
    x_smooth : ndarray
        Sorted probability values
    y_smooth : ndarray
        Smoothed calibration values (observed frequencies)
    """
    try:
        from statsmodels.nonparametric.smoothers_lowess import lowess

        # LOESS smoothing
        smoothed = lowess(y_true, y_prob, frac=frac, return_sorted=True)
        return smoothed[:, 0], smoothed[:, 1]
    except ImportError:
        # Fallback: binned calibration
        return _compute_binned_calibration(y_true, y_prob, n_bins=n_points // 5)

compute_calibration_ci

compute_calibration_ci(
    y_true: ndarray,
    y_prob: ndarray,
    n_bootstrap: int = 200,
    frac: float = 0.3,
    alpha: float = 0.05,
) -> Tuple[ndarray, ndarray, ndarray]

Compute bootstrap confidence intervals for calibration curve.

This is visualization rendering (CI bands for display), NOT metric computation.

PARAMETER DESCRIPTION
y_true

True binary labels

TYPE: array - like

y_prob

Predicted probabilities

TYPE: array - like

n_bootstrap

Number of bootstrap iterations

TYPE: int DEFAULT: 200

frac

LOESS smoothing fraction

TYPE: float DEFAULT: 0.3

alpha

Significance level for CI (default 0.05 for 95% CI)

TYPE: float DEFAULT: 0.05

RETURNS DESCRIPTION
x_vals

Common x-axis values

TYPE: ndarray

y_lower

Lower confidence bound

TYPE: ndarray

y_upper

Upper confidence bound

TYPE: ndarray

Source code in src/viz/calibration_plot.py
def compute_calibration_ci(
    y_true: np.ndarray,
    y_prob: np.ndarray,
    n_bootstrap: int = 200,
    frac: float = 0.3,
    alpha: float = 0.05,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Compute bootstrap confidence intervals for calibration curve.

    This is visualization rendering (CI bands for display), NOT metric computation.

    Parameters
    ----------
    y_true : array-like
        True binary labels
    y_prob : array-like
        Predicted probabilities
    n_bootstrap : int
        Number of bootstrap iterations
    frac : float
        LOESS smoothing fraction
    alpha : float
        Significance level for CI (default 0.05 for 95% CI)

    Returns
    -------
    x_vals : ndarray
        Common x-axis values
    y_lower : ndarray
        Lower confidence bound
    y_upper : ndarray
        Upper confidence bound
    """
    y_true = np.asarray(y_true)
    y_prob = np.asarray(y_prob)
    n = len(y_true)

    # Common x-axis for interpolation
    x_common = np.linspace(0.05, 0.95, 50)
    bootstrap_curves = []

    rng = np.random.default_rng(42)
    for _ in range(n_bootstrap):
        idx = rng.choice(n, size=n, replace=True)
        y_t_boot = y_true[idx]
        y_p_boot = y_prob[idx]

        try:
            x_smooth, y_smooth = compute_loess_calibration(
                y_t_boot, y_p_boot, frac=frac
            )
            if len(x_smooth) > 1:
                # Interpolate to common x-axis
                f = interpolate.interp1d(
                    x_smooth, y_smooth, bounds_error=False, fill_value="extrapolate"
                )
                bootstrap_curves.append(f(x_common))
        except (ValueError, RuntimeError):
            continue

    if len(bootstrap_curves) < 10:
        # Not enough bootstrap samples, return wide CI
        return x_common, np.zeros_like(x_common), np.ones_like(x_common)

    bootstrap_array = np.array(bootstrap_curves)
    y_lower = np.percentile(bootstrap_array, alpha / 2 * 100, axis=0)
    y_upper = np.percentile(bootstrap_array, (1 - alpha / 2) * 100, axis=0)

    # Clip to valid range
    y_lower = np.clip(y_lower, 0, 1)
    y_upper = np.clip(y_upper, 0, 1)

    return x_common, y_lower, y_upper

plot_calibration_curve

plot_calibration_curve(
    y_true: ndarray,
    y_prob: ndarray,
    ax: Optional[Axes] = None,
    label: Optional[str] = None,
    color: Optional[str] = None,
    show_ci: bool = True,
    show_rug: bool = True,
    show_metrics: bool = True,
    metrics: Optional[Dict[str, float]] = None,
    frac: float = 0.3,
    ci_alpha: float = 0.2,
    n_bootstrap: int = 200,
    save_path: Optional[str] = None,
) -> Tuple[Figure, Axes]

Plot smoothed calibration curve with LOESS.

PARAMETER DESCRIPTION
y_true

True binary labels

TYPE: array - like

y_prob

Predicted probabilities

TYPE: array - like

ax

Axes to plot on

TYPE: Axes DEFAULT: None

label

Legend label for the model

TYPE: str DEFAULT: None

color

Line color

TYPE: str DEFAULT: None

show_ci

Whether to show confidence intervals

TYPE: bool DEFAULT: True

show_rug

Whether to show histogram rug at bottom

TYPE: bool DEFAULT: True

show_metrics

Whether to annotate with calibration metrics

TYPE: bool DEFAULT: True

metrics

Pre-computed calibration metrics from DuckDB. Expected keys: 'calibration_slope' (or 'slope'), 'calibration_intercept' (or 'intercept'). If show_metrics is True and metrics is None, the annotation is skipped.

TYPE: dict DEFAULT: None

frac

LOESS smoothing fraction

TYPE: float DEFAULT: 0.3

ci_alpha

Alpha for CI shading

TYPE: float DEFAULT: 0.2

n_bootstrap

Number of bootstrap samples for CI

TYPE: int DEFAULT: 200

save_path

If provided, saves JSON data for reproducibility

TYPE: str DEFAULT: None

RETURNS DESCRIPTION
fig, ax : matplotlib Figure and Axes
Source code in src/viz/calibration_plot.py
def plot_calibration_curve(
    y_true: np.ndarray,
    y_prob: np.ndarray,
    ax: Optional["plt.Axes"] = None,
    label: Optional[str] = None,
    color: Optional[str] = None,
    show_ci: bool = True,
    show_rug: bool = True,
    show_metrics: bool = True,
    metrics: Optional[Dict[str, float]] = None,
    frac: float = 0.3,
    ci_alpha: float = 0.2,
    n_bootstrap: int = 200,
    save_path: Optional[str] = None,
) -> Tuple["plt.Figure", "plt.Axes"]:
    """
    Plot smoothed calibration curve with LOESS.

    Parameters
    ----------
    y_true : array-like
        True binary labels
    y_prob : array-like
        Predicted probabilities
    ax : matplotlib.axes.Axes, optional
        Axes to plot on
    label : str, optional
        Legend label for the model
    color : str, optional
        Line color
    show_ci : bool
        Whether to show confidence intervals
    show_rug : bool
        Whether to show histogram rug at bottom
    show_metrics : bool
        Whether to annotate with calibration metrics
    metrics : dict, optional
        Pre-computed calibration metrics from DuckDB. Expected keys:
        'calibration_slope' (or 'slope'), 'calibration_intercept' (or 'intercept').
        If show_metrics is True and metrics is None, the annotation is skipped.
    frac : float
        LOESS smoothing fraction
    ci_alpha : float
        Alpha for CI shading
    n_bootstrap : int
        Number of bootstrap samples for CI
    save_path : str, optional
        If provided, saves JSON data for reproducibility

    Returns
    -------
    fig, ax : matplotlib Figure and Axes
    """
    import matplotlib.pyplot as plt

    try:
        from src.viz.plot_config import setup_style
    except ImportError:
        from plot_config import setup_style
    setup_style()

    y_true = np.asarray(y_true)
    y_prob = np.asarray(y_prob)

    if ax is None:
        fig, ax = plt.subplots(figsize=get_dimensions("calibration_small"))
    else:
        fig = ax.get_figure()

    # Reference line (perfect calibration)
    ax.plot([0, 1], [0, 1], "k--", alpha=0.5, label="Perfect calibration")

    # Compute LOESS smoothed curve (visualization rendering)
    x_smooth, y_smooth = compute_loess_calibration(y_true, y_prob, frac=frac)

    # Plot main curve
    ax.plot(x_smooth, y_smooth, color=color, label=label or "Model", linewidth=2)

    # Confidence intervals (visualization rendering)
    if show_ci:
        x_ci, y_lower, y_upper = compute_calibration_ci(
            y_true, y_prob, n_bootstrap=n_bootstrap, frac=frac
        )
        ax.fill_between(x_ci, y_lower, y_upper, alpha=ci_alpha, color=color)

    # Histogram rug
    if show_rug:
        # Show distribution of predictions at bottom
        ax.scatter(
            y_prob[y_true == 0],
            np.full(sum(y_true == 0), -0.02),
            marker="|",
            alpha=0.3,
            color=COLORS["control"],
            s=30,
            label="Controls",
        )
        ax.scatter(
            y_prob[y_true == 1],
            np.full(sum(y_true == 1), -0.04),
            marker="|",
            alpha=0.3,
            color=COLORS["glaucoma"],
            s=30,
            label="Cases",
        )

    # Calibration metrics annotation (from pre-computed metrics, NOT computed here)
    if show_metrics and metrics is not None:
        # Support both naming conventions (DuckDB uses calibration_slope,
        # older code used slope)
        slope = metrics.get("calibration_slope", metrics.get("slope"))
        intercept = metrics.get("calibration_intercept", metrics.get("intercept"))

        parts = []
        if slope is not None:
            parts.append(f"Slope: {slope:.2f}")
        if intercept is not None:
            parts.append(f"Intercept: {intercept:.2f}")

        if parts:
            metrics_text = "\n".join(parts)
            ax.text(
                0.05,
                0.95,
                metrics_text,
                transform=ax.transAxes,
                fontsize=10,
                verticalalignment="top",
                bbox=dict(boxstyle="round", facecolor=COLORS["background"], alpha=0.8),
            )

    # Labels and limits
    ax.set_xlabel("Predicted Probability")
    ax.set_ylabel("Observed Frequency")
    ax.set_xlim(-0.02, 1.02)
    ax.set_ylim(-0.08, 1.02)
    ax.legend(loc="lower right")
    ax.set_aspect("equal", adjustable="box")

    # Save JSON data for reproducibility
    if save_path:
        json_path = Path(save_path).with_suffix(".json")
        json_data = {
            "y_true": y_true.tolist(),
            "y_prob": y_prob.tolist(),
            "loess_frac": frac,
            "x_smooth": x_smooth.tolist(),
            "y_smooth": y_smooth.tolist(),
        }
        if show_ci:
            json_data["x_ci"] = x_ci.tolist()
            json_data["y_lower"] = y_lower.tolist()
            json_data["y_upper"] = y_upper.tolist()

        with open(json_path, "w") as f:
            json.dump(json_data, f, indent=2)

    return fig, ax

plot_calibration_multi_model

plot_calibration_multi_model(
    models_data: Dict[str, Dict],
    ax: Optional[Axes] = None,
    show_ci: bool = False,
    colors: Optional[List[str]] = None,
) -> Tuple[Figure, Axes]

Plot calibration curves for multiple models.

PARAMETER DESCRIPTION
models_data

Dictionary mapping model names to {'y_true': ..., 'y_prob': ...}

TYPE: dict

ax

TYPE: Axes DEFAULT: None

show_ci

Whether to show confidence intervals

TYPE: bool DEFAULT: False

colors

Colors for each model

TYPE: list DEFAULT: None

RETURNS DESCRIPTION
(fig, ax)
Source code in src/viz/calibration_plot.py
def plot_calibration_multi_model(
    models_data: Dict[str, Dict],
    ax: Optional["plt.Axes"] = None,
    show_ci: bool = False,
    colors: Optional[List[str]] = None,
) -> Tuple["plt.Figure", "plt.Axes"]:
    """
    Plot calibration curves for multiple models.

    Parameters
    ----------
    models_data : dict
        Dictionary mapping model names to {'y_true': ..., 'y_prob': ...}
    ax : matplotlib.axes.Axes, optional
    show_ci : bool
        Whether to show confidence intervals
    colors : list, optional
        Colors for each model

    Returns
    -------
    fig, ax
    """
    import matplotlib.pyplot as plt

    if ax is None:
        fig, ax = plt.subplots(figsize=get_dimensions("calibration_small"))
    else:
        fig = ax.get_figure()

    # Reference line
    ax.plot([0, 1], [0, 1], "k--", alpha=0.5, label="Perfect calibration")

    if colors is None:
        colors = plt.cm.tab10(np.linspace(0, 1, len(models_data)))

    for (model_name, data), color in zip(models_data.items(), colors):
        y_true = np.asarray(data["y_true"])
        y_prob = np.asarray(data["y_prob"])

        x_smooth, y_smooth = compute_loess_calibration(y_true, y_prob)
        ax.plot(x_smooth, y_smooth, label=model_name, color=color, linewidth=2)

        if show_ci:
            x_ci, y_lower, y_upper = compute_calibration_ci(y_true, y_prob)
            ax.fill_between(x_ci, y_lower, y_upper, alpha=0.15, color=color)

    ax.set_xlabel("Predicted Probability")
    ax.set_ylabel("Observed Frequency")
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.legend()
    ax.set_aspect("equal", adjustable="box")

    return fig, ax

save_calibration_extended_json_from_db

save_calibration_extended_json_from_db(
    run_id: str,
    output_path: str,
    db_path: Optional[str] = None,
) -> dict

Save extended calibration metrics to JSON by reading from DuckDB.

CRITICAL: This function reads PRE-COMPUTED metrics from DuckDB. It does NOT compute metrics - all computation happens during extraction.

PARAMETER DESCRIPTION
run_id

Run ID to load metrics for

TYPE: str

output_path

Path to save JSON file

TYPE: str

db_path

Path to DuckDB file

TYPE: str DEFAULT: None

RETURNS DESCRIPTION
dict

The JSON data structure

Source code in src/viz/calibration_plot.py
def save_calibration_extended_json_from_db(
    run_id: str,
    output_path: str,
    db_path: Optional[str] = None,
) -> dict:
    """
    Save extended calibration metrics to JSON by reading from DuckDB.

    CRITICAL: This function reads PRE-COMPUTED metrics from DuckDB.
    It does NOT compute metrics - all computation happens during extraction.

    Parameters
    ----------
    run_id : str
        Run ID to load metrics for
    output_path : str
        Path to save JSON file
    db_path : str, optional
        Path to DuckDB file

    Returns
    -------
    dict
        The JSON data structure
    """
    import duckdb

    # Find database
    if db_path is None:
        db_paths = [
            Path("data/public/foundation_plr_results.db"),
            Path("data/foundation_plr_results.db"),
        ]
        for p in db_paths:
            if p.exists():
                db_path = str(p)
                break
        if db_path is None:
            raise FileNotFoundError("DuckDB not found")

    conn = duckdb.connect(db_path, read_only=True)

    # Read pre-computed scalar metrics
    metrics_row = conn.execute(
        """
        SELECT
            auroc, brier, scaled_brier,
            calibration_slope, calibration_intercept, o_e_ratio
        FROM essential_metrics
        WHERE run_id = ?
    """,
        [run_id],
    ).fetchone()

    if metrics_row is None:
        conn.close()
        raise ValueError(f"Run {run_id} not found in essential_metrics")

    auroc, brier, scaled_brier, cal_slope, cal_intercept, o_e_ratio = metrics_row

    # Read pre-computed calibration curve
    curve_row = conn.execute(
        """
        SELECT x_smooth, y_smooth, ci_lower, ci_upper
        FROM calibration_curves
        WHERE run_id = ?
    """,
        [run_id],
    ).fetchone()

    conn.close()

    if curve_row is None:
        raise ValueError(f"Run {run_id} not found in calibration_curves")

    # Parse JSON arrays
    x_smooth = json.loads(curve_row[0])
    y_smooth = json.loads(curve_row[1])
    ci_lower = json.loads(curve_row[2])
    ci_upper = json.loads(curve_row[3])

    # Build JSON structure
    json_data = {
        "run_id": run_id,
        # STRATOS required metrics (pre-computed)
        "stratos_metrics": {
            "auroc": auroc,
            "brier_score": brier,
            "scaled_brier": scaled_brier,
            "calibration_slope": cal_slope,
            "calibration_intercept": cal_intercept,
            "o_e_ratio": o_e_ratio,
        },
        # Calibration curve for plotting (pre-computed)
        "calibration_curve": {
            "x_smooth": x_smooth,
            "y_smooth": y_smooth,
            "ci_lower": ci_lower,
            "ci_upper": ci_upper,
        },
    }

    # Save to file
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w") as f:
        json.dump(json_data, f, indent=2)

    return json_data

save_calibration_multi_combo_json_from_db

save_calibration_multi_combo_json_from_db(
    run_ids: List[str],
    output_path: str,
    db_path: Optional[str] = None,
) -> dict

Save calibration metrics for multiple runs to single JSON.

CRITICAL: Reads from DuckDB, does NOT compute metrics.

PARAMETER DESCRIPTION
run_ids

Run IDs to include

TYPE: list of str

output_path

Path to save JSON file

TYPE: str

db_path

Path to DuckDB file

TYPE: str DEFAULT: None

RETURNS DESCRIPTION
dict

The JSON data structure with all combos

Source code in src/viz/calibration_plot.py
def save_calibration_multi_combo_json_from_db(
    run_ids: List[str],
    output_path: str,
    db_path: Optional[str] = None,
) -> dict:
    """
    Save calibration metrics for multiple runs to single JSON.

    CRITICAL: Reads from DuckDB, does NOT compute metrics.

    Parameters
    ----------
    run_ids : list of str
        Run IDs to include
    output_path : str
        Path to save JSON file
    db_path : str, optional
        Path to DuckDB file

    Returns
    -------
    dict
        The JSON data structure with all combos
    """
    import duckdb

    # Find database
    if db_path is None:
        db_paths = [
            Path("data/public/foundation_plr_results.db"),
            Path("data/foundation_plr_results.db"),
        ]
        for p in db_paths:
            if p.exists():
                db_path = str(p)
                break
        if db_path is None:
            raise FileNotFoundError("DuckDB not found")

    conn = duckdb.connect(db_path, read_only=True)

    all_combos = {}
    for run_id in run_ids:
        # Read scalar metrics
        metrics_row = conn.execute(
            """
            SELECT
                outlier_method, imputation_method, classifier,
                calibration_slope, calibration_intercept, o_e_ratio, brier
            FROM essential_metrics
            WHERE run_id = ?
        """,
            [run_id],
        ).fetchone()

        if metrics_row is None:
            continue

        outlier, imputation, classifier, slope, intercept, oe, brier = metrics_row

        # Read curve data
        curve_row = conn.execute(
            """
            SELECT x_smooth, y_smooth
            FROM calibration_curves
            WHERE run_id = ?
        """,
            [run_id],
        ).fetchone()

        combo_key = f"{outlier}+{imputation}"
        all_combos[combo_key] = {
            "run_id": run_id,
            "outlier_method": outlier,
            "imputation_method": imputation,
            "classifier": classifier,
            "calibration_slope": slope,
            "calibration_intercept": intercept,
            "o_e_ratio": oe,
            "brier_score": brier,
        }

        if curve_row:
            all_combos[combo_key]["curve"] = {
                "x": json.loads(curve_row[0]),
                "y": json.loads(curve_row[1]),
            }

    conn.close()

    json_data = {
        "n_combos": len(all_combos),
        "combos": all_combos,
    }

    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w") as f:
        json.dump(json_data, f, indent=2)

    return json_data

generate_calibration_figure

generate_calibration_figure(
    y_true: ndarray,
    y_prob: ndarray,
    metrics: Optional[Dict[str, float]] = None,
    output_dir: Optional[Path] = None,
    filename: str = "fig_calibration_smoothed",
) -> Tuple[str, str]

Generate calibration plot and save to file.

PARAMETER DESCRIPTION
y_true

True binary labels

TYPE: array - like

y_prob

Predicted probabilities

TYPE: array - like

metrics

Pre-computed calibration metrics from DuckDB (e.g. calibration_slope, calibration_intercept). If None, metrics annotation is skipped.

TYPE: dict DEFAULT: None

output_dir

Output directory (default: uses save_figure default)

TYPE: Path DEFAULT: None

filename

Base filename (without extension)

TYPE: str DEFAULT: 'fig_calibration_smoothed'

RETURNS DESCRIPTION
png_path, json_path : paths to generated files
Source code in src/viz/calibration_plot.py
def generate_calibration_figure(
    y_true: np.ndarray,
    y_prob: np.ndarray,
    metrics: Optional[Dict[str, float]] = None,
    output_dir: Optional[Path] = None,
    filename: str = "fig_calibration_smoothed",
) -> Tuple[str, str]:
    """
    Generate calibration plot and save to file.

    Parameters
    ----------
    y_true : array-like
        True binary labels
    y_prob : array-like
        Predicted probabilities
    metrics : dict, optional
        Pre-computed calibration metrics from DuckDB (e.g. calibration_slope,
        calibration_intercept). If None, metrics annotation is skipped.
    output_dir : Path, optional
        Output directory (default: uses save_figure default)
    filename : str
        Base filename (without extension)

    Returns
    -------
    png_path, json_path : paths to generated files
    """
    import matplotlib.pyplot as plt

    # LOESS curve for JSON data (visualization rendering)
    x_smooth, y_smooth = compute_loess_calibration(y_true, y_prob)

    json_data = {
        "curve": {"x": x_smooth.tolist(), "y": y_smooth.tolist()},
    }
    if metrics is not None:
        json_data["metrics"] = metrics

    fig, ax = plot_calibration_curve(
        y_true,
        y_prob,
        show_ci=True,
        show_rug=True,
        show_metrics=metrics is not None,
        metrics=metrics,
    )

    # Save using figure system
    png_path = save_figure(fig, filename, data=json_data, output_dir=output_dir)
    plt.close(fig)
    json_path = png_path.parent / "data" / f"{filename}.json"

    return str(png_path), str(json_path)

dca_plot

Decision Curve Analysis (DCA) visualization module.

Implements STRATOS-compliant DCA plots for clinical utility assessment. Based on Vickers & Elkin 2006 and Van Calster et al. 2024 guidelines.

Architecture (CRITICAL-FAILURE-003 compliant):
  • Pure-math net benefit formulas (compute_net_benefit, compute_treat_all_nb, compute_treat_none_nb, compute_dca_curves) are ACCEPTABLE: they are simple TP/FP arithmetic with NO sklearn or src.stats imports.
  • DCA curve data for production figures is loaded from DuckDB via load_dca_curves_from_db(). All metric computation happens in extraction.
  • NO imports from src.stats. NO sklearn imports.

See: https://github.com/petteriTeikari/foundation_PLR/issues/13

compute_net_benefit

compute_net_benefit(
    y_true: ndarray, y_prob: ndarray, threshold: float
) -> float

Compute net benefit at a given threshold probability.

Net Benefit = TP/n - FP/n * (pt / (1-pt))

Where pt is the threshold probability.

This is a pure-math formula (no sklearn, no src.stats). Acceptable in viz.

PARAMETER DESCRIPTION
y_true

True binary labels (0 or 1)

TYPE: array - like

y_prob

Predicted probabilities

TYPE: array - like

threshold

Decision threshold probability

TYPE: float

RETURNS DESCRIPTION
float

Net benefit at the given threshold

Source code in src/viz/dca_plot.py
def compute_net_benefit(
    y_true: np.ndarray, y_prob: np.ndarray, threshold: float
) -> float:
    """
    Compute net benefit at a given threshold probability.

    Net Benefit = TP/n - FP/n * (pt / (1-pt))

    Where pt is the threshold probability.

    This is a pure-math formula (no sklearn, no src.stats). Acceptable in viz.

    Parameters
    ----------
    y_true : array-like
        True binary labels (0 or 1)
    y_prob : array-like
        Predicted probabilities
    threshold : float
        Decision threshold probability

    Returns
    -------
    float
        Net benefit at the given threshold
    """
    y_true = np.asarray(y_true)
    y_prob = np.asarray(y_prob)
    n = len(y_true)

    if n == 0:
        return np.nan

    # Avoid division by zero
    if threshold >= 1.0:
        return 0.0
    if threshold <= 0.0:
        threshold = 1e-10

    y_pred = (y_prob >= threshold).astype(int)

    tp = np.sum((y_pred == 1) & (y_true == 1))
    fp = np.sum((y_pred == 1) & (y_true == 0))

    odds = threshold / (1 - threshold)
    nb = tp / n - fp / n * odds

    return nb

compute_treat_all_nb

compute_treat_all_nb(
    prevalence: float, threshold: float
) -> float

Compute net benefit for treat-all strategy.

NB(treat-all) = prevalence - (1 - prevalence) * (pt / (1-pt))

PARAMETER DESCRIPTION
prevalence

Disease prevalence in the population

TYPE: float

threshold

Decision threshold probability

TYPE: float

RETURNS DESCRIPTION
float

Net benefit for treat-all strategy

Source code in src/viz/dca_plot.py
def compute_treat_all_nb(prevalence: float, threshold: float) -> float:
    """
    Compute net benefit for treat-all strategy.

    NB(treat-all) = prevalence - (1 - prevalence) * (pt / (1-pt))

    Parameters
    ----------
    prevalence : float
        Disease prevalence in the population
    threshold : float
        Decision threshold probability

    Returns
    -------
    float
        Net benefit for treat-all strategy
    """
    if threshold >= 1.0:
        return -np.inf
    if threshold <= 0.0:
        threshold = 1e-10

    odds = threshold / (1 - threshold)
    return prevalence - (1 - prevalence) * odds

compute_treat_none_nb

compute_treat_none_nb(threshold: float) -> float

Compute net benefit for treat-none strategy.

NB(treat-none) = 0 (always)

PARAMETER DESCRIPTION
threshold

Decision threshold probability (unused, for interface consistency)

TYPE: float

RETURNS DESCRIPTION
float

Always returns 0.0

Source code in src/viz/dca_plot.py
def compute_treat_none_nb(threshold: float) -> float:
    """
    Compute net benefit for treat-none strategy.

    NB(treat-none) = 0 (always)

    Parameters
    ----------
    threshold : float
        Decision threshold probability (unused, for interface consistency)

    Returns
    -------
    float
        Always returns 0.0
    """
    return 0.0

compute_dca_curves

compute_dca_curves(
    y_true: ndarray,
    y_prob: ndarray,
    thresholds: Optional[ndarray] = None,
    threshold_range: Tuple[float, float] = (0.01, 0.3),
    n_thresholds: int = 50,
) -> Dict[str, ndarray]

Compute DCA curves for model, treat-all, and treat-none strategies.

Uses only pure-math net benefit formulas (no sklearn, no src.stats).

PARAMETER DESCRIPTION
y_true

True binary labels

TYPE: array - like

y_prob

Predicted probabilities

TYPE: array - like

thresholds

Specific thresholds to evaluate. If None, uses threshold_range.

TYPE: array - like DEFAULT: None

threshold_range

(min, max) threshold range (default: 1-30% for glaucoma)

TYPE: tuple DEFAULT: (0.01, 0.3)

n_thresholds

Number of threshold points to evaluate

TYPE: int DEFAULT: 50

RETURNS DESCRIPTION
dict with keys:
  • thresholds: array of threshold probabilities
  • nb_model: net benefit for model at each threshold
  • nb_all: net benefit for treat-all at each threshold
  • nb_none: net benefit for treat-none at each threshold
  • prevalence: disease prevalence
Source code in src/viz/dca_plot.py
def compute_dca_curves(
    y_true: np.ndarray,
    y_prob: np.ndarray,
    thresholds: Optional[np.ndarray] = None,
    threshold_range: Tuple[float, float] = (0.01, 0.30),
    n_thresholds: int = 50,
) -> Dict[str, np.ndarray]:
    """
    Compute DCA curves for model, treat-all, and treat-none strategies.

    Uses only pure-math net benefit formulas (no sklearn, no src.stats).

    Parameters
    ----------
    y_true : array-like
        True binary labels
    y_prob : array-like
        Predicted probabilities
    thresholds : array-like, optional
        Specific thresholds to evaluate. If None, uses threshold_range.
    threshold_range : tuple
        (min, max) threshold range (default: 1-30% for glaucoma)
    n_thresholds : int
        Number of threshold points to evaluate

    Returns
    -------
    dict with keys:
        - thresholds: array of threshold probabilities
        - nb_model: net benefit for model at each threshold
        - nb_all: net benefit for treat-all at each threshold
        - nb_none: net benefit for treat-none at each threshold
        - prevalence: disease prevalence
    """
    y_true = np.asarray(y_true)
    y_prob = np.asarray(y_prob)

    if thresholds is None:
        thresholds = np.linspace(threshold_range[0], threshold_range[1], n_thresholds)

    prevalence = y_true.mean()

    nb_model = np.array([compute_net_benefit(y_true, y_prob, t) for t in thresholds])
    nb_all = np.array([compute_treat_all_nb(prevalence, t) for t in thresholds])
    nb_none = np.array([compute_treat_none_nb(t) for t in thresholds])

    return {
        "thresholds": thresholds,
        "nb_model": nb_model,
        "nb_all": nb_all,
        "nb_none": nb_none,
        "prevalence": prevalence,
    }

load_dca_curves_from_db

load_dca_curves_from_db(
    db_path: str, combo_ids: Optional[List[str]] = None
) -> Dict[str, Dict[str, ndarray]]

Load pre-computed DCA curves from DuckDB.

Supports two DuckDB schemas: 1. Streaming schema (config_id + per-row thresholds): joins essential_metrics to find matching configs by outlier/imputation/classifier. 2. Curve extraction schema (run_id + JSON arrays): joins essential_metrics by run_id to resolve combo matching.

PARAMETER DESCRIPTION
db_path

Path to DuckDB database

TYPE: str

combo_ids

Specific combo IDs to load. If None, loads standard combos from config.

TYPE: list of str DEFAULT: None

RETURNS DESCRIPTION
dict

Maps combo_id to dict with keys: - thresholds: np.ndarray of threshold values - nb_model: np.ndarray of model net benefits - nb_all: np.ndarray of treat-all net benefits - nb_none: np.ndarray of treat-none net benefits

Source code in src/viz/dca_plot.py
def load_dca_curves_from_db(
    db_path: str,
    combo_ids: Optional[List[str]] = None,
) -> Dict[str, Dict[str, np.ndarray]]:
    """
    Load pre-computed DCA curves from DuckDB.

    Supports two DuckDB schemas:
    1. Streaming schema (config_id + per-row thresholds): joins essential_metrics
       to find matching configs by outlier/imputation/classifier.
    2. Curve extraction schema (run_id + JSON arrays): joins essential_metrics
       by run_id to resolve combo matching.

    Parameters
    ----------
    db_path : str
        Path to DuckDB database
    combo_ids : list of str, optional
        Specific combo IDs to load. If None, loads standard combos from config.

    Returns
    -------
    dict
        Maps combo_id to dict with keys:
        - thresholds: np.ndarray of threshold values
        - nb_model: np.ndarray of model net benefits
        - nb_all: np.ndarray of treat-all net benefits
        - nb_none: np.ndarray of treat-none net benefits
    """
    import duckdb

    from src.viz.config_loader import get_config_loader

    db_path = Path(db_path)
    if not db_path.exists():
        raise FileNotFoundError(f"Database not found: {db_path}")

    conn = duckdb.connect(str(db_path), read_only=True)

    dca_data = {}

    try:
        # Determine schema by checking column names
        columns = {
            row[0]
            for row in conn.execute(
                "SELECT column_name FROM information_schema.columns "
                "WHERE table_name = 'dca_curves'"
            ).fetchall()
        }

        # Load standard combos from config if not specified
        if combo_ids is None:
            config = get_config_loader()
            standard_combos = config.get_standard_hyperparam_combos()
            combo_ids = [c["id"] for c in standard_combos]

        if "config_id" in columns and "threshold" in columns:
            # Streaming schema: one row per threshold per config_id
            dca_data = _load_dca_streaming_schema(conn, combo_ids)
        elif "run_id" in columns and "thresholds" in columns:
            # Curve extraction schema: JSON arrays per run_id
            dca_data = _load_dca_json_schema(conn, combo_ids)
        else:
            raise ValueError(f"Unrecognized dca_curves schema. Columns: {columns}")

    finally:
        conn.close()

    return dca_data

plot_dca

plot_dca(
    y_true: ndarray,
    y_prob: ndarray,
    ax: Optional[Axes] = None,
    threshold_range: Tuple[float, float] = (0.01, 0.3),
    n_thresholds: int = 50,
    model_label: str = "Model",
    model_color: Optional[str] = None,
    show_treat_all: bool = True,
    show_treat_none: bool = True,
    save_json_path: Optional[str] = None,
) -> Tuple[Figure, Axes]

Plot Decision Curve Analysis from raw predictions.

Uses pure-math compute_dca_curves (no sklearn, no src.stats).

PARAMETER DESCRIPTION
y_true

True binary labels

TYPE: array - like

y_prob

Predicted probabilities

TYPE: array - like

ax

Axes to plot on

TYPE: Axes DEFAULT: None

threshold_range

(min, max) threshold range (default: 1-30% for glaucoma screening)

TYPE: tuple DEFAULT: (0.01, 0.3)

n_thresholds

Number of threshold points

TYPE: int DEFAULT: 50

model_label

Label for model in legend

TYPE: str DEFAULT: 'Model'

model_color

Color for model line

TYPE: str DEFAULT: None

show_treat_all

Whether to show treat-all reference line

TYPE: bool DEFAULT: True

show_treat_none

Whether to show treat-none reference line

TYPE: bool DEFAULT: True

save_json_path

If provided, saves JSON data for reproducibility

TYPE: str DEFAULT: None

RETURNS DESCRIPTION
fig, ax : matplotlib Figure and Axes
Source code in src/viz/dca_plot.py
def plot_dca(
    y_true: np.ndarray,
    y_prob: np.ndarray,
    ax: Optional["plt.Axes"] = None,
    threshold_range: Tuple[float, float] = (0.01, 0.30),
    n_thresholds: int = 50,
    model_label: str = "Model",
    model_color: Optional[str] = None,
    show_treat_all: bool = True,
    show_treat_none: bool = True,
    save_json_path: Optional[str] = None,
) -> Tuple["plt.Figure", "plt.Axes"]:
    """
    Plot Decision Curve Analysis from raw predictions.

    Uses pure-math compute_dca_curves (no sklearn, no src.stats).

    Parameters
    ----------
    y_true : array-like
        True binary labels
    y_prob : array-like
        Predicted probabilities
    ax : matplotlib.axes.Axes, optional
        Axes to plot on
    threshold_range : tuple
        (min, max) threshold range (default: 1-30% for glaucoma screening)
    n_thresholds : int
        Number of threshold points
    model_label : str
        Label for model in legend
    model_color : str, optional
        Color for model line
    show_treat_all : bool
        Whether to show treat-all reference line
    show_treat_none : bool
        Whether to show treat-none reference line
    save_json_path : str, optional
        If provided, saves JSON data for reproducibility

    Returns
    -------
    fig, ax : matplotlib Figure and Axes
    """
    import matplotlib.pyplot as plt

    try:
        from src.viz.plot_config import setup_style
    except ImportError:
        from plot_config import setup_style
    setup_style()

    y_true = np.asarray(y_true)
    y_prob = np.asarray(y_prob)

    if ax is None:
        fig, ax = plt.subplots(figsize=get_dimensions("single"))
    else:
        fig = ax.get_figure()

    # Compute DCA curves (pure math only)
    dca_data = compute_dca_curves(
        y_true, y_prob, threshold_range=threshold_range, n_thresholds=n_thresholds
    )

    thresholds = dca_data["thresholds"]
    nb_model = dca_data["nb_model"]
    nb_all = dca_data["nb_all"]
    nb_none = dca_data["nb_none"]

    # Plot model curve
    ax.plot(thresholds, nb_model, label=model_label, color=model_color, linewidth=2)

    # Plot treat-all reference
    if show_treat_all:
        ax.plot(
            thresholds,
            nb_all,
            "--",
            color=COLORS["text_secondary"],
            alpha=0.7,
            label="Treat All",
            linewidth=1.5,
        )

    # Plot treat-none reference
    if show_treat_none:
        ax.plot(
            thresholds,
            nb_none,
            ":",
            color=COLORS["text_primary"],
            alpha=0.5,
            label="Treat None",
            linewidth=1.5,
        )

    # Labels and formatting
    ax.set_xlabel("Threshold Probability")
    ax.set_ylabel("Net Benefit")
    ax.set_xlim(threshold_range[0] - 0.01, threshold_range[1] + 0.01)

    # Set reasonable y-axis limits
    all_nb = np.concatenate([nb_model, nb_all[nb_all > -0.5]])
    if len(all_nb) > 0:
        y_min = max(-0.1, np.nanmin(all_nb) - 0.02)
        y_max = np.nanmax(all_nb) + 0.02
        ax.set_ylim(y_min, y_max)

    ax.legend(loc="upper right")
    ax.axhline(y=0, color=COLORS["grid_lines"], linestyle="-", alpha=0.3)

    # Add grid
    ax.grid(True, alpha=0.3)

    # Save JSON data for reproducibility
    if save_json_path:
        json_data = {
            "thresholds": thresholds.tolist(),
            "nb_model": nb_model.tolist(),
            "nb_all": nb_all.tolist(),
            "nb_none": nb_none.tolist(),
            "prevalence": float(dca_data["prevalence"]),
            "threshold_range": list(threshold_range),
            "y_true": y_true.tolist(),
            "y_prob": y_prob.tolist(),
        }
        with open(save_json_path, "w") as f:
            json.dump(json_data, f, indent=2)

    return fig, ax

plot_dca_multi_model

plot_dca_multi_model(
    models_data: Dict[str, Dict],
    ax: Optional[Axes] = None,
    threshold_range: Tuple[float, float] = (0.01, 0.3),
    n_thresholds: int = 50,
    colors: Optional[List[str]] = None,
) -> Tuple[Figure, Axes]

Plot DCA for multiple models on same axes from raw predictions.

Uses pure-math compute_net_benefit (no sklearn, no src.stats).

PARAMETER DESCRIPTION
models_data

Dictionary mapping model names to {'y_true': ..., 'y_prob': ...}

TYPE: dict

ax

TYPE: Axes DEFAULT: None

threshold_range

(min, max) threshold range

TYPE: tuple DEFAULT: (0.01, 0.3)

n_thresholds

Number of threshold points

TYPE: int DEFAULT: 50

colors

Colors for each model

TYPE: list DEFAULT: None

RETURNS DESCRIPTION
(fig, ax)
Source code in src/viz/dca_plot.py
def plot_dca_multi_model(
    models_data: Dict[str, Dict],
    ax: Optional["plt.Axes"] = None,
    threshold_range: Tuple[float, float] = (0.01, 0.30),
    n_thresholds: int = 50,
    colors: Optional[List[str]] = None,
) -> Tuple["plt.Figure", "plt.Axes"]:
    """
    Plot DCA for multiple models on same axes from raw predictions.

    Uses pure-math compute_net_benefit (no sklearn, no src.stats).

    Parameters
    ----------
    models_data : dict
        Dictionary mapping model names to {'y_true': ..., 'y_prob': ...}
    ax : matplotlib.axes.Axes, optional
    threshold_range : tuple
        (min, max) threshold range
    n_thresholds : int
        Number of threshold points
    colors : list, optional
        Colors for each model

    Returns
    -------
    fig, ax
    """
    import matplotlib.pyplot as plt

    if ax is None:
        fig, ax = plt.subplots(figsize=get_dimensions("single"))
    else:
        fig = ax.get_figure()

    if colors is None:
        colors = plt.cm.tab10(np.linspace(0, 1, len(models_data)))

    # Get prevalence from first model
    first_data = list(models_data.values())[0]
    y_true = np.asarray(first_data["y_true"])
    prevalence = y_true.mean()

    thresholds = np.linspace(threshold_range[0], threshold_range[1], n_thresholds)

    # Plot each model
    for (model_name, data), color in zip(models_data.items(), colors):
        y_t = np.asarray(data["y_true"])
        y_p = np.asarray(data["y_prob"])

        nb_model = np.array([compute_net_benefit(y_t, y_p, t) for t in thresholds])
        ax.plot(thresholds, nb_model, label=model_name, color=color, linewidth=2)

    # Plot references
    nb_all = np.array([compute_treat_all_nb(prevalence, t) for t in thresholds])
    nb_none = np.array([compute_treat_none_nb(t) for t in thresholds])

    ax.plot(
        thresholds,
        nb_all,
        "--",
        color=COLORS["text_secondary"],
        alpha=0.7,
        label="Treat All",
        linewidth=1.5,
    )
    ax.plot(
        thresholds,
        nb_none,
        ":",
        color=COLORS["text_primary"],
        alpha=0.5,
        label="Treat None",
        linewidth=1.5,
    )

    ax.set_xlabel("Threshold Probability")
    ax.set_ylabel("Net Benefit")
    ax.set_xlim(threshold_range[0] - 0.01, threshold_range[1] + 0.01)
    ax.legend()
    ax.grid(True, alpha=0.3)

    return fig, ax

plot_dca_from_db

plot_dca_from_db(
    db_path: str,
    combo_ids: Optional[List[str]] = None,
    ax: Optional[Axes] = None,
    output_dir: Optional[Path] = None,
    filename: str = "fig_dca_curves",
) -> Tuple[Figure, Axes]

Plot DCA curves from pre-computed data in DuckDB.

This is the PREFERRED method for production figures. Reads pre-computed DCA curves from the database (no on-the-fly computation).

PARAMETER DESCRIPTION
db_path

Path to DuckDB database

TYPE: str

combo_ids

Specific combo IDs to plot. If None, loads standard combos.

TYPE: list of str DEFAULT: None

ax

Axes to plot on

TYPE: Axes DEFAULT: None

output_dir

Output directory for saving figure

TYPE: Path DEFAULT: None

filename

Base filename for saving

TYPE: str DEFAULT: 'fig_dca_curves'

RETURNS DESCRIPTION
fig, ax : matplotlib Figure and Axes
Source code in src/viz/dca_plot.py
def plot_dca_from_db(
    db_path: str,
    combo_ids: Optional[List[str]] = None,
    ax: Optional["plt.Axes"] = None,
    output_dir: Optional[Path] = None,
    filename: str = "fig_dca_curves",
) -> Tuple["plt.Figure", "plt.Axes"]:
    """
    Plot DCA curves from pre-computed data in DuckDB.

    This is the PREFERRED method for production figures. Reads pre-computed
    DCA curves from the database (no on-the-fly computation).

    Parameters
    ----------
    db_path : str
        Path to DuckDB database
    combo_ids : list of str, optional
        Specific combo IDs to plot. If None, loads standard combos.
    ax : matplotlib.axes.Axes, optional
        Axes to plot on
    output_dir : Path, optional
        Output directory for saving figure
    filename : str
        Base filename for saving

    Returns
    -------
    fig, ax : matplotlib Figure and Axes
    """
    import matplotlib.pyplot as plt

    try:
        from src.viz.config_loader import get_config_loader
        from src.viz.plot_config import setup_style
    except ImportError:
        from config_loader import get_config_loader
        from plot_config import setup_style

    setup_style()

    if ax is None:
        fig, ax = plt.subplots(figsize=get_dimensions("dca"))
    else:
        fig = ax.get_figure()

    # Load pre-computed DCA curves from DB
    dca_data = load_dca_curves_from_db(db_path, combo_ids=combo_ids)

    if not dca_data:
        ax.text(
            0.5,
            0.5,
            "No DCA data found in database",
            transform=ax.transAxes,
            ha="center",
            va="center",
        )
        return fig, ax

    # Get colors from config
    config = get_config_loader()
    colors = config.get_colors().get("combo_colors", {})

    # Plot each combo
    for combo_id, curves in dca_data.items():
        color = colors.get(combo_id, None)

        # Get display name
        try:
            combo_config = config.get_combo_config(combo_id)
            display_name = combo_config.get("display_name", combo_id)
        except Exception:
            display_name = combo_id

        ax.plot(
            curves["thresholds"],
            curves["nb_model"],
            label=display_name,
            color=color,
            linewidth=2,
        )

    # Plot treat-all and treat-none from the first combo's data
    first_curves = list(dca_data.values())[0]
    ax.plot(
        first_curves["thresholds"],
        first_curves["nb_all"],
        "--",
        color=COLORS["text_secondary"],
        alpha=0.7,
        label="Treat All",
        linewidth=1.5,
    )
    ax.plot(
        first_curves["thresholds"],
        first_curves["nb_none"],
        ":",
        color=COLORS["text_primary"],
        alpha=0.5,
        label="Treat None",
        linewidth=1.5,
    )

    ax.set_xlabel("Threshold Probability")
    ax.set_ylabel("Net Benefit")
    thresholds = first_curves["thresholds"]
    ax.set_xlim(thresholds.min() - 0.01, thresholds.max() + 0.01)
    ax.legend(loc="upper right")
    ax.grid(True, alpha=0.3)
    ax.axhline(y=0, color=COLORS["grid_lines"], linestyle="-", alpha=0.3)

    # Set reasonable y-axis limits
    all_nb = []
    for curves in dca_data.values():
        all_nb.extend(curves["nb_model"].tolist())
    nb_all_vals = first_curves["nb_all"]
    all_nb.extend(nb_all_vals[nb_all_vals > -0.5].tolist())
    if all_nb:
        y_min = max(-0.1, np.nanmin(all_nb) - 0.02)
        y_max = np.nanmax(all_nb) + 0.02
        ax.set_ylim(y_min, y_max)

    # Prepare JSON data for reproducibility
    json_data = {
        "thresholds": thresholds.tolist(),
        "nb_all": nb_all_vals.tolist(),
        "nb_none": first_curves["nb_none"].tolist(),
        "combos": {
            combo_id: {
                "net_benefit": curves["nb_model"].tolist(),
            }
            for combo_id, curves in dca_data.items()
        },
    }

    # Save using figure system
    save_figure(fig, filename, data=json_data, output_dir=output_dir)

    return fig, ax

generate_dca_figure

generate_dca_figure(
    y_true: ndarray,
    y_prob: ndarray,
    output_dir: Optional[Path] = None,
    filename: str = "fig_dca_curves",
    threshold_range: Tuple[float, float] = (0.01, 0.3),
) -> Tuple[str, str]

Generate DCA plot from raw predictions and save to file.

For production figures, prefer plot_dca_from_db() which reads pre-computed data from DuckDB.

PARAMETER DESCRIPTION
y_true

True binary labels

TYPE: array - like

y_prob

Predicted probabilities

TYPE: array - like

output_dir

Output directory (default: uses save_figure default)

TYPE: Path DEFAULT: None

filename

Base filename (without extension)

TYPE: str DEFAULT: 'fig_dca_curves'

threshold_range

(min, max) threshold range

TYPE: tuple DEFAULT: (0.01, 0.3)

RETURNS DESCRIPTION
png_path, json_path : paths to generated files
Source code in src/viz/dca_plot.py
def generate_dca_figure(
    y_true: np.ndarray,
    y_prob: np.ndarray,
    output_dir: Optional[Path] = None,
    filename: str = "fig_dca_curves",
    threshold_range: Tuple[float, float] = (0.01, 0.30),
) -> Tuple[str, str]:
    """
    Generate DCA plot from raw predictions and save to file.

    For production figures, prefer plot_dca_from_db() which reads
    pre-computed data from DuckDB.

    Parameters
    ----------
    y_true : array-like
        True binary labels
    y_prob : array-like
        Predicted probabilities
    output_dir : Path, optional
        Output directory (default: uses save_figure default)
    filename : str
        Base filename (without extension)
    threshold_range : tuple
        (min, max) threshold range

    Returns
    -------
    png_path, json_path : paths to generated files
    """
    import matplotlib.pyplot as plt

    # Compute DCA curves for JSON data (pure math only)
    dca_curves_data = compute_dca_curves(
        y_true, y_prob, threshold_range=threshold_range
    )

    json_data = {
        "threshold_range": list(threshold_range),
        "thresholds": dca_curves_data["thresholds"].tolist(),
        "nb_model": dca_curves_data["nb_model"].tolist(),
        "nb_all": dca_curves_data["nb_all"].tolist(),
        "nb_none": dca_curves_data["nb_none"].tolist(),
    }

    fig, ax = plot_dca(y_true, y_prob, threshold_range=threshold_range)

    # Save using figure system
    png_path = save_figure(fig, filename, data=json_data, output_dir=output_dir)
    plt.close(fig)
    json_path = png_path.parent / "data" / f"{filename}.json"

    return str(png_path), str(json_path)

cd_diagram

Critical Difference Diagram for statistical method comparison.

Implements Demšar (2006) CD diagrams using Friedman test + Nemenyi post-hoc.

Cross-references: - planning/remaining-duckdb-stats-viz-tasks-plan.md (Figs 8-11)

References: - Demšar (2006). Statistical comparisons of classifiers. - Nemenyi (1963). Distribution-free multiple comparisons.

The diagram shows: 1. Methods ranked by average performance 2. Cliques (groups) of methods NOT significantly different 3. Critical Difference (CD) bar showing minimum significant difference

friedman_nemenyi_test

friedman_nemenyi_test(
    data: DataFrame, alpha: float = 0.05
) -> Dict

Perform Friedman test with Nemenyi post-hoc analysis.

PARAMETER DESCRIPTION
data

DataFrame with rows as "datasets" (e.g., preprocessing configs) and columns as "methods" (e.g., classifiers). Values are performance metrics (e.g., AUROC).

TYPE: DataFrame

alpha

Significance level

TYPE: float DEFAULT: 0.05

RETURNS DESCRIPTION
dict

Contains: - friedman_statistic: Chi-square statistic - friedman_pvalue: p-value for Friedman test - average_ranks: Dict of method -> average rank - critical_difference: CD value for Nemenyi test - pairwise_significant: Dict of (method1, method2) -> bool - cliques: List of sets of methods NOT significantly different

Examples:

>>> # Rows = configs, columns = classifiers
>>> data = pd.DataFrame({
...     'CatBoost': [0.91, 0.89, 0.93],
...     'XGBoost': [0.88, 0.87, 0.90],
...     'LogReg': [0.82, 0.81, 0.84]
... })
>>> result = friedman_nemenyi_test(data)
>>> print(f"Friedman p={result['friedman_pvalue']:.4f}")
Source code in src/viz/cd_diagram.py
def friedman_nemenyi_test(
    data: pd.DataFrame,
    alpha: float = 0.05,
) -> Dict:
    """
    Perform Friedman test with Nemenyi post-hoc analysis.

    Parameters
    ----------
    data : pd.DataFrame
        DataFrame with rows as "datasets" (e.g., preprocessing configs)
        and columns as "methods" (e.g., classifiers).
        Values are performance metrics (e.g., AUROC).
    alpha : float, default 0.05
        Significance level

    Returns
    -------
    dict
        Contains:
        - friedman_statistic: Chi-square statistic
        - friedman_pvalue: p-value for Friedman test
        - average_ranks: Dict of method -> average rank
        - critical_difference: CD value for Nemenyi test
        - pairwise_significant: Dict of (method1, method2) -> bool
        - cliques: List of sets of methods NOT significantly different

    Examples
    --------
    >>> # Rows = configs, columns = classifiers
    >>> data = pd.DataFrame({
    ...     'CatBoost': [0.91, 0.89, 0.93],
    ...     'XGBoost': [0.88, 0.87, 0.90],
    ...     'LogReg': [0.82, 0.81, 0.84]
    ... })
    >>> result = friedman_nemenyi_test(data)
    >>> print(f"Friedman p={result['friedman_pvalue']:.4f}")
    """
    n_datasets, n_methods = data.shape
    methods = data.columns.tolist()

    # Compute ranks for each row (dataset)
    # Higher performance = lower rank (rank 1 = best)
    ranks = data.rank(axis=1, ascending=False)
    average_ranks = ranks.mean().to_dict()

    # Friedman test
    # Using scipy's implementation
    friedman_stat, friedman_pvalue = stats.friedmanchisquare(
        *[data[col] for col in methods]
    )

    # Critical Difference (Nemenyi)
    cd = compute_critical_difference(n_methods, n_datasets, alpha)

    # Pairwise comparisons
    pairwise_significant = {}
    for i, m1 in enumerate(methods):
        for m2 in methods[i + 1 :]:
            rank_diff = abs(average_ranks[m1] - average_ranks[m2])
            pairwise_significant[(m1, m2)] = rank_diff > cd
            pairwise_significant[(m2, m1)] = rank_diff > cd

    # Identify cliques
    cliques = identify_cliques(average_ranks, cd)

    return {
        "friedman_statistic": float(friedman_stat),
        "friedman_pvalue": float(friedman_pvalue),
        "average_ranks": average_ranks,
        "critical_difference": float(cd),
        "pairwise_significant": pairwise_significant,
        "cliques": cliques,
        "n_datasets": n_datasets,
        "n_methods": n_methods,
        "alpha": alpha,
    }

compute_critical_difference

compute_critical_difference(
    n_methods: int, n_datasets: int, alpha: float = 0.05
) -> float

Compute Nemenyi critical difference.

CD = q_α × sqrt(k(k+1) / (6N))

where: - q_α: critical value from Studentized range distribution - k: number of methods - N: number of datasets

PARAMETER DESCRIPTION
n_methods

Number of methods being compared (k)

TYPE: int

n_datasets

Number of datasets/configurations (N)

TYPE: int

alpha

Significance level

TYPE: float DEFAULT: 0.05

RETURNS DESCRIPTION
float

Critical difference value

Source code in src/viz/cd_diagram.py
def compute_critical_difference(
    n_methods: int,
    n_datasets: int,
    alpha: float = 0.05,
) -> float:
    """
    Compute Nemenyi critical difference.

    CD = q_α × sqrt(k(k+1) / (6N))

    where:
    - q_α: critical value from Studentized range distribution
    - k: number of methods
    - N: number of datasets

    Parameters
    ----------
    n_methods : int
        Number of methods being compared (k)
    n_datasets : int
        Number of datasets/configurations (N)
    alpha : float
        Significance level

    Returns
    -------
    float
        Critical difference value
    """
    # Critical values for Nemenyi test (Studentized range / sqrt(2))
    # Values from Demšar (2006) Table 5
    # For α = 0.05
    q_alpha_05 = {
        2: 1.960,
        3: 2.343,
        4: 2.569,
        5: 2.728,
        6: 2.850,
        7: 2.949,
        8: 3.031,
        9: 3.102,
        10: 3.164,
        11: 3.219,
        12: 3.268,
        13: 3.313,
        14: 3.354,
        15: 3.391,
        16: 3.426,
        17: 3.458,
        18: 3.489,
        19: 3.517,
        20: 3.544,
    }
    # For α = 0.10
    q_alpha_10 = {
        2: 1.645,
        3: 2.052,
        4: 2.291,
        5: 2.459,
        6: 2.589,
        7: 2.693,
        8: 2.780,
        9: 2.855,
        10: 2.920,
        11: 2.978,
        12: 3.030,
        13: 3.077,
        14: 3.120,
        15: 3.159,
        16: 3.196,
        17: 3.230,
        18: 3.261,
        19: 3.291,
        20: 3.319,
    }

    if alpha == 0.05:
        q_alpha = q_alpha_05
    elif alpha == 0.10:
        q_alpha = q_alpha_10
    else:
        # Approximate using alpha=0.05 table
        q_alpha = q_alpha_05

    # Get q value (use max if n_methods > 20)
    k = min(n_methods, 20)
    q = q_alpha.get(k, q_alpha[20])

    cd = q * np.sqrt(n_methods * (n_methods + 1) / (6 * n_datasets))
    return cd

identify_cliques

identify_cliques(
    average_ranks: Dict[str, float], cd: float
) -> List[List[str]]

Identify cliques (groups of methods not significantly different).

A clique is a maximal set of methods where all pairs differ by < CD.

PARAMETER DESCRIPTION
average_ranks

Method -> average rank mapping

TYPE: dict

cd

Critical difference

TYPE: float

RETURNS DESCRIPTION
list of lists

Each inner list is a clique of method names

Source code in src/viz/cd_diagram.py
def identify_cliques(
    average_ranks: Dict[str, float],
    cd: float,
) -> List[List[str]]:
    """
    Identify cliques (groups of methods not significantly different).

    A clique is a maximal set of methods where all pairs differ by < CD.

    Parameters
    ----------
    average_ranks : dict
        Method -> average rank mapping
    cd : float
        Critical difference

    Returns
    -------
    list of lists
        Each inner list is a clique of method names
    """
    methods = list(average_ranks.keys())
    n = len(methods)

    # Sort by rank
    sorted_methods = sorted(methods, key=lambda m: average_ranks[m])

    # Find cliques using greedy approach
    cliques = []
    i = 0
    while i < n:
        clique = [sorted_methods[i]]
        j = i + 1

        # Extend clique while within CD
        while j < n:
            rank_diff = abs(
                average_ranks[sorted_methods[j]] - average_ranks[sorted_methods[i]]
            )
            if rank_diff <= cd:
                clique.append(sorted_methods[j])
                j += 1
            else:
                break

        # Only add if clique has 2+ methods
        if len(clique) >= 2:
            # Check if this clique is subsumed by existing
            is_new = True
            for existing in cliques:
                if set(clique).issubset(set(existing)):
                    is_new = False
                    break
            if is_new:
                cliques.append(clique)

        i += 1

    return cliques

draw_cd_diagram

draw_cd_diagram(
    data: Union[DataFrame, Dict],
    title: str = "Critical Difference Diagram",
    output_path: Optional[str] = None,
    save_data_path: Optional[str] = None,
    figure_id: Optional[str] = None,
    alpha: float = 0.05,
    figsize: Tuple[float, float] = (10, 5),
    text_fontsize: int = 10,
    line_width: float = 2.5,
    marker_size: int = 100,
    highlight_best: bool = True,
) -> Tuple[Figure, Axes]

Draw a Critical Difference diagram.

PARAMETER DESCRIPTION
data

If DataFrame: rows = datasets, columns = methods, values = metrics If dict: output from friedman_nemenyi_test()

TYPE: DataFrame or dict

title

Plot title

TYPE: str DEFAULT: 'Critical Difference Diagram'

output_path

If provided, save figure to this path

TYPE: str DEFAULT: None

save_data_path

If provided, save underlying data as JSON to this path

TYPE: str DEFAULT: None

figure_id

Figure identifier for data export (e.g., "fig08")

TYPE: str DEFAULT: None

alpha

Significance level for Nemenyi test

TYPE: float DEFAULT: 0.05

figsize

Figure size (width, height)

TYPE: tuple DEFAULT: (10, 5)

text_fontsize

Font size for method names

TYPE: int DEFAULT: 10

line_width

Width of clique bars

TYPE: float DEFAULT: 2.5

marker_size

Size of rank markers

TYPE: int DEFAULT: 100

highlight_best

Whether to highlight the best method

TYPE: bool DEFAULT: True

RETURNS DESCRIPTION
fig, ax : matplotlib figure and axes
Source code in src/viz/cd_diagram.py
def draw_cd_diagram(
    data: Union[pd.DataFrame, Dict],
    title: str = "Critical Difference Diagram",
    output_path: Optional[str] = None,
    save_data_path: Optional[str] = None,
    figure_id: Optional[str] = None,
    alpha: float = 0.05,
    figsize: Tuple[float, float] = (10, 5),
    text_fontsize: int = 10,
    line_width: float = 2.5,
    marker_size: int = 100,
    highlight_best: bool = True,
) -> Tuple[plt.Figure, plt.Axes]:
    """
    Draw a Critical Difference diagram.

    Parameters
    ----------
    data : pd.DataFrame or dict
        If DataFrame: rows = datasets, columns = methods, values = metrics
        If dict: output from friedman_nemenyi_test()
    title : str
        Plot title
    output_path : str, optional
        If provided, save figure to this path
    save_data_path : str, optional
        If provided, save underlying data as JSON to this path
    figure_id : str, optional
        Figure identifier for data export (e.g., "fig08")
    alpha : float
        Significance level for Nemenyi test
    figsize : tuple
        Figure size (width, height)
    text_fontsize : int
        Font size for method names
    line_width : float
        Width of clique bars
    marker_size : int
        Size of rank markers
    highlight_best : bool
        Whether to highlight the best method

    Returns
    -------
    fig, ax : matplotlib figure and axes
    """
    # Run statistical test if data is DataFrame
    if isinstance(data, pd.DataFrame):
        result = friedman_nemenyi_test(data, alpha=alpha)
    else:
        result = data

    avg_ranks = result["average_ranks"]
    cd = result["critical_difference"]
    cliques = result["cliques"]
    n_methods = result["n_methods"]

    # Sort methods by average rank
    sorted_methods = sorted(avg_ranks.keys(), key=lambda m: avg_ranks[m])

    # Create figure
    fig, ax = plt.subplots(figsize=figsize)

    # Set up axes
    rank_min = 1
    rank_max = n_methods
    ax.set_xlim(rank_min - 0.5, rank_max + 0.5)
    ax.set_ylim(-0.5, n_methods + 1)

    # Draw axis line
    ax.axhline(y=n_methods + 0.3, color="black", linewidth=1)

    # Draw tick marks and labels on top
    for rank in range(1, n_methods + 1):
        ax.plot([rank, rank], [n_methods + 0.2, n_methods + 0.4], "k-", lw=1)
        ax.text(
            rank,
            n_methods + 0.6,
            str(rank),
            ha="center",
            va="bottom",
            fontsize=text_fontsize - 2,
        )

    # Draw CD bar on right side
    cd_start = 1
    cd_end = 1 + cd
    ax.plot([cd_start, cd_end], [n_methods + 0.8, n_methods + 0.8], "k-", lw=2)
    ax.plot([cd_start, cd_start], [n_methods + 0.7, n_methods + 0.9], "k-", lw=2)
    ax.plot([cd_end, cd_end], [n_methods + 0.7, n_methods + 0.9], "k-", lw=2)
    ax.text(
        (cd_start + cd_end) / 2,
        n_methods + 1.0,
        f"CD = {cd:.2f}",
        ha="center",
        va="bottom",
        fontsize=text_fontsize - 1,
    )

    # Draw methods with their ranks
    y_positions = {}
    for i, method in enumerate(sorted_methods):
        rank = avg_ranks[method]
        y = n_methods - i - 0.5

        # Draw marker at rank position
        color = "gold" if (i == 0 and highlight_best) else "steelblue"
        ax.scatter([rank], [y], c=color, s=marker_size, zorder=5, edgecolors="black")

        # Draw line from method name to marker
        if rank < (n_methods + 1) / 2:
            # Left side - name on left
            ax.plot([rank_min - 0.3, rank], [y, y], "k-", lw=0.5, alpha=0.5)
            ax.text(
                rank_min - 0.4,
                y,
                method,
                ha="right",
                va="center",
                fontsize=text_fontsize,
            )
        else:
            # Right side - name on right
            ax.plot([rank, rank_max + 0.3], [y, y], "k-", lw=0.5, alpha=0.5)
            ax.text(
                rank_max + 0.4,
                y,
                method,
                ha="left",
                va="center",
                fontsize=text_fontsize,
            )

        # Store y position for clique bars
        y_positions[method] = y

    # Draw clique bars (horizontal lines connecting methods not significantly different)
    clique_colors = [
        COLORS["cd_rank1"],
        COLORS["cd_rank2"],
        COLORS["cd_rank3"],
        COLORS["cd_rank4"],
        COLORS["cd_rank5"],
    ]

    for idx, clique in enumerate(cliques):
        if len(clique) < 2:
            continue

        # Find y range for this clique
        clique_ranks = [avg_ranks[m] for m in clique]
        bar_left = min(clique_ranks) - 0.1
        bar_right = max(clique_ranks) + 0.1

        # Position bar slightly below the methods
        bar_y = min([y_positions[m] for m in clique]) - 0.3 - (idx * 0.15)

        color = clique_colors[idx % len(clique_colors)]

        # Draw horizontal bar
        ax.plot(
            [bar_left, bar_right],
            [bar_y, bar_y],
            color=color,
            linewidth=line_width,
            solid_capstyle="round",
        )

    # Style
    ax.set_title(title, fontsize=text_fontsize + 2, pad=20)
    ax.text(
        (rank_min + rank_max) / 2,
        n_methods + 1.3,
        "← better",
        ha="center",
        va="bottom",
        fontsize=text_fontsize - 2,
        style="italic",
    )

    # Remove frame
    ax.set_frame_on(False)
    ax.set_xticks([])
    ax.set_yticks([])

    plt.tight_layout()

    # Save figure if path provided
    if output_path:
        from .figure_export import save_figure_all_formats

        save_figure_all_formats(fig, output_path)
        plt.close(fig)

    # Save data as JSON if path provided
    if save_data_path:
        from .figure_data import save_figure_data

        # Convert cliques to serializable format
        cliques_serializable = [list(c) for c in cliques]

        data_dict = {
            "methods": sorted_methods,
            "average_ranks": [avg_ranks[m] for m in sorted_methods],
            "critical_difference": cd,
            "cliques": cliques_serializable,
            "friedman_statistic": result["friedman_statistic"],
            "friedman_pvalue": result["friedman_pvalue"],
        }

        save_figure_data(
            figure_id=figure_id or "cd_diagram",
            figure_title=title,
            data=data_dict,
            output_path=save_data_path,
            metadata={
                "n_datasets": result["n_datasets"],
                "n_methods": result["n_methods"],
                "alpha": alpha,
            },
        )

    return fig, ax

prepare_cd_data

prepare_cd_data(
    df: DataFrame,
    config_col: str,
    method_col: str,
    value_col: str,
) -> DataFrame

Prepare data for CD diagram from long-format DataFrame.

PARAMETER DESCRIPTION
df

Long-format data with config, method, and value columns

TYPE: DataFrame

config_col

Column name for configuration/dataset identifier

TYPE: str

method_col

Column name for method identifier

TYPE: str

value_col

Column name for performance metric

TYPE: str

RETURNS DESCRIPTION
DataFrame

Wide-format DataFrame suitable for friedman_nemenyi_test

Examples:

>>> df = pd.DataFrame({
...     'config': ['A', 'A', 'B', 'B'],
...     'classifier': ['Cat', 'XGB', 'Cat', 'XGB'],
...     'auroc': [0.9, 0.85, 0.88, 0.83]
... })
>>> wide_df = prepare_cd_data(df, 'config', 'classifier', 'auroc')
Source code in src/viz/cd_diagram.py
def prepare_cd_data(
    df: pd.DataFrame,
    config_col: str,
    method_col: str,
    value_col: str,
) -> pd.DataFrame:
    """
    Prepare data for CD diagram from long-format DataFrame.

    Parameters
    ----------
    df : pd.DataFrame
        Long-format data with config, method, and value columns
    config_col : str
        Column name for configuration/dataset identifier
    method_col : str
        Column name for method identifier
    value_col : str
        Column name for performance metric

    Returns
    -------
    pd.DataFrame
        Wide-format DataFrame suitable for friedman_nemenyi_test

    Examples
    --------
    >>> df = pd.DataFrame({
    ...     'config': ['A', 'A', 'B', 'B'],
    ...     'classifier': ['Cat', 'XGB', 'Cat', 'XGB'],
    ...     'auroc': [0.9, 0.85, 0.88, 0.83]
    ... })
    >>> wide_df = prepare_cd_data(df, 'config', 'classifier', 'auroc')
    """
    return df.pivot(index=config_col, columns=method_col, values=value_col)

factorial_matrix

factorial_matrix.py - Figure M3: Factorial Design Matrix

Visualizes the factorial experimental design structure: - 2 featurization methods - 7+ outlier detection methods (including ensembles) - 5+ imputation methods (including ensembles) - 5 classifiers

Total: 407 unique configurations

Usage: python src/viz/factorial_matrix.py

fetch_factorial_counts

fetch_factorial_counts() -> Tuple[
    Dict[str, List[str]], int
]

Fetch counts of configurations per factor combination.

Source code in src/viz/factorial_matrix.py
def fetch_factorial_counts() -> Tuple[Dict[str, List[str]], int]:
    """Fetch counts of configurations per factor combination."""
    conn = get_connection()

    # Get unique values per factor
    queries = {
        "outlier": "SELECT DISTINCT outlier_method FROM essential_metrics WHERE outlier_method IS NOT NULL AND outlier_method != 'Unknown'",
        "imputation": "SELECT DISTINCT imputation_method FROM essential_metrics WHERE imputation_method IS NOT NULL AND imputation_method != 'Unknown'",
        "featurization": "SELECT DISTINCT featurization FROM essential_metrics WHERE featurization IS NOT NULL AND featurization != 'Unknown'",
        "classifier": "SELECT DISTINCT classifier FROM essential_metrics WHERE classifier IS NOT NULL AND classifier != 'Unknown'",
    }

    factors = {}
    for name, query in queries.items():
        result = conn.execute(query).fetchall()
        factors[name] = [r[0] for r in result if r[0]]

    # Get total configurations
    total = conn.execute(
        "SELECT COUNT(*) FROM essential_metrics WHERE auroc IS NOT NULL"
    ).fetchone()[0]

    conn.close()

    return factors, total

create_figure

create_figure() -> Tuple[Figure, Dict[str, Any]]

Create the factorial design visualization.

Source code in src/viz/factorial_matrix.py
def create_figure() -> Tuple[plt.Figure, Dict[str, Any]]:
    """Create the factorial design visualization."""
    setup_style()

    factors, total_configs = fetch_factorial_counts()

    # Create figure
    fig, ax = plt.subplots(figsize=get_dimensions("matrix"))

    # Design as hierarchical boxes
    # Layout: centered pipeline stages from left to right

    box_height = 0.12
    stage_x = [0.1, 0.35, 0.6, 0.85]  # X positions for 4 stages
    stage_labels = [
        "Outlier\nDetection",
        "Imputation",
        "Featurization",
        "Classification",
    ]
    stage_keys = ["outlier", "imputation", "featurization", "classifier"]

    # Colors for different method types
    def get_method_color(method: str) -> str:
        method_lower = method.lower()
        if "ensemble" in method_lower:
            return COLORS["ensemble"]
        elif any(
            kw in method_lower
            for kw in ["moment", "units", "chronos", "timesnet", "saits", "csdi"]
        ):
            return COLORS["foundation_model"]
        elif "handcrafted" in method_lower:
            return COLORS["handcrafted"]
        elif "embed" in method_lower:
            return COLORS["embeddings"]
        elif "catboost" in method_lower:
            return COLORS["catboost"]
        elif "xgboost" in method_lower:
            return COLORS["xgboost"]
        elif "tabpfn" in method_lower:
            return COLORS["tabpfn"]
        else:
            return COLORS["traditional"]

    # Draw each stage
    for i, (x, label, key) in enumerate(zip(stage_x, stage_labels, stage_keys)):
        methods = factors.get(key, [])
        n_methods = len(methods)

        # Stage header box
        header_rect = mpatches.FancyBboxPatch(
            (x - 0.08, 0.85),
            0.16,
            0.1,
            boxstyle="round,pad=0.02,rounding_size=0.02",
            facecolor=COLORS["grid_lines"],
            edgecolor=COLORS["text_primary"],
            linewidth=2,
        )
        ax.add_patch(header_rect)
        ax.text(
            x, 0.90, label, ha="center", va="center", fontweight="bold", fontsize=10
        )
        ax.text(x, 0.86, f"({n_methods} methods)", ha="center", va="center", fontsize=8)

        # Method boxes (show up to 6, then "...")
        max_show = 6
        methods_to_show = methods[:max_show]
        if len(methods) > max_show:
            methods_to_show.append(f"... +{len(methods) - max_show} more")

        y_start = 0.75
        for j, method in enumerate(methods_to_show):
            y = y_start - j * (box_height + 0.02)

            # Shorten long names
            display_name = method
            if len(display_name) > 15:
                display_name = display_name[:13] + ".."

            color = get_method_color(method) if "..." not in method else "white"

            rect = mpatches.FancyBboxPatch(
                (x - 0.07, y - box_height / 2),
                0.14,
                box_height,
                boxstyle="round,pad=0.01,rounding_size=0.01",
                facecolor=color,
                edgecolor=COLORS["text_primary"],
                linewidth=1,
                alpha=0.8,
            )
            ax.add_patch(rect)
            ax.text(x, y, display_name, ha="center", va="center", fontsize=7)

    # Draw connecting arrows
    arrow_y = 0.90
    for i in range(len(stage_x) - 1):
        ax.annotate(
            "",
            xy=(stage_x[i + 1] - 0.09, arrow_y),
            xytext=(stage_x[i] + 0.09, arrow_y),
            arrowprops=dict(arrowstyle="->", color=COLORS["text_primary"], lw=2),
        )

    # Add multiplication symbols and counts
    mult_y = 0.78
    for i in range(len(stage_x) - 1):
        ax.text(
            (stage_x[i] + stage_x[i + 1]) / 2,
            mult_y,
            "×",
            ha="center",
            va="center",
            fontsize=16,
            fontweight="bold",
        )

    # Total configurations box
    total_rect = mpatches.FancyBboxPatch(
        (0.3, 0.05),
        0.4,
        0.12,
        boxstyle="round,pad=0.02,rounding_size=0.02",
        facecolor=COLORS["accent"],
        edgecolor=COLORS["text_primary"],
        linewidth=2,
        alpha=0.3,
    )
    ax.add_patch(total_rect)
    ax.text(
        0.5,
        0.11,
        f"Total: {total_configs} configurations",
        ha="center",
        va="center",
        fontweight="bold",
        fontsize=12,
    )
    ax.text(
        0.5,
        0.07,
        f"({len(factors['outlier'])} × {len(factors['imputation'])} × {len(factors['featurization'])} × {len(factors['classifier'])})",
        ha="center",
        va="center",
        fontsize=10,
    )

    # Legend - load display names from config (not hardcoded)
    legend_elements = [
        mpatches.Patch(
            facecolor=COLORS["foundation_model"],
            edgecolor=COLORS["text_primary"],
            label=get_category_display_name("foundation_model"),
        ),
        mpatches.Patch(
            facecolor=COLORS["traditional"],
            edgecolor=COLORS["text_primary"],
            label=get_category_display_name("traditional"),
        ),
        mpatches.Patch(
            facecolor=COLORS["ensemble"],
            edgecolor=COLORS["text_primary"],
            label=get_category_display_name("ensemble"),
        ),
        mpatches.Patch(
            facecolor=COLORS["handcrafted"],
            edgecolor=COLORS["text_primary"],
            label="Handcrafted",  # This is a featurization type, not a category
        ),
    ]
    ax.legend(handles=legend_elements, loc="lower left", fontsize=8, framealpha=0.9)

    # Title
    ax.set_title(
        "Factorial Experimental Design: Pipeline Configuration Space",
        fontweight="bold",
        fontsize=12,
        pad=20,
    )

    # Clean up axes
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.axis("off")

    plt.tight_layout()

    return fig, {
        "factors": {k: len(v) for k, v in factors.items()},
        "total_configurations": total_configs,
        "methods": factors,
    }

main

main() -> None

Generate and save the figure.

Source code in src/viz/factorial_matrix.py
def main() -> None:
    """Generate and save the figure."""
    print("Generating Figure M3: Factorial Design Matrix...")

    fig, data = create_figure()
    save_figure(fig, "fig_M3_factorial_matrix", data=data)

    plt.close(fig)
    print("Done!")

Usage Example

from src.viz.plot_config import setup_style, save_figure, COLORS

# Always setup style first
setup_style()

# Create figure using semantic colors
fig, ax = plt.subplots()
ax.plot(x, y, color=COLORS["ground_truth"])

# Save with JSON data for reproducibility
save_figure(fig, "fig_my_analysis", data={"x": x, "y": y})

Color System

Colors are loaded from configs/VISUALIZATION/combos.yaml:

from src.viz.plot_config import COLORS

# Available colors
COLORS["ground_truth"]   # #2E5B8C
COLORS["best_ensemble"]  # #932834
COLORS["traditional"]    # #666666

See Also