Skip to content

imputation

Signal reconstruction and imputation methods.

Overview

This module provides 7 imputation methods:

  • Ground Truth: Human-corrected signals
  • Deep Learning: SAITS, CSDI, TimesNet
  • Foundation Models: MOMENT
  • Traditional: Linear interpolation

Main Entry Point

flow_imputation

flow_imputation

flow_imputation(cfg: DictConfig) -> dict

Execute the PLR imputation flow with hyperparameter sweep.

Orchestrates the imputation pipeline by iterating over all combinations of data sources (from outlier detection) and hyperparameter configurations. Also handles ensembling of trained imputation models.

PARAMETER DESCRIPTION
cfg

Full Hydra configuration containing PREFECT flow names, MLFLOW settings, and hyperparameter configurations for imputation models.

TYPE: DictConfig

RETURNS DESCRIPTION
dict

Results from the imputation pipeline (implicitly via MLflow logging).

Notes

The flow performs the following steps: 1. Define hyperparameter groups from configuration 2. Download outlier detection outputs from MLflow 3. Define data sources (outlier detection + ground truth) 4. Run imputation for each source x hyperparameter combination 5. Recompute metrics for all submodels 6. Create ensemble models from submodels

Source code in src/imputation/flow_imputation.py
def flow_imputation(cfg: DictConfig) -> dict:
    """Execute the PLR imputation flow with hyperparameter sweep.

    Orchestrates the imputation pipeline by iterating over all combinations
    of data sources (from outlier detection) and hyperparameter configurations.
    Also handles ensembling of trained imputation models.

    Parameters
    ----------
    cfg : DictConfig
        Full Hydra configuration containing PREFECT flow names, MLFLOW settings,
        and hyperparameter configurations for imputation models.

    Returns
    -------
    dict
        Results from the imputation pipeline (implicitly via MLflow logging).

    Notes
    -----
    The flow performs the following steps:
    1. Define hyperparameter groups from configuration
    2. Download outlier detection outputs from MLflow
    3. Define data sources (outlier detection + ground truth)
    4. Run imputation for each source x hyperparameter combination
    5. Recompute metrics for all submodels
    6. Create ensemble models from submodels
    """
    # Flatten the hyperparameter groups to a dict
    hyperparams_group = define_hyperparam_group(cfg, task="imputation")

    # Download data from MLflow (output from outlier detection)
    prev_experiment_name = experiment_name_wrapper(
        experiment_name=cfg["PREFECT"]["FLOW_NAMES"]["OUTLIER_DETECTION"], cfg=cfg
    )

    # Define the "sources" for the flow, as in the outlier detection output along with the
    # ground truth of manually annotated imputation masks
    sources = define_sources_for_flow(
        prev_experiment_name=prev_experiment_name, cfg=cfg
    )

    # Set the experiment name
    experiment_name = experiment_name_wrapper(
        experiment_name=cfg["PREFECT"]["FLOW_NAMES"]["IMPUTATION"], cfg=cfg
    )
    init_mlflow_experiment(mlflow_cfg=cfg["MLFLOW"], experiment_name=experiment_name)

    no_of_runs = len(sources) * len(hyperparams_group)
    run_idx = 0
    for source_idx, (source_name, source_data) in enumerate(sources.items()):
        for idx, (cfg_group_name, cfg_group) in enumerate(hyperparams_group.items()):
            logger.info(f"Source #{source_idx + 1}/{len(sources)}: {source_name}")
            logger.info(
                f"Running pipeline for hyperparameter group #{idx + 1}/{len(hyperparams_group)}: {cfg_group_name}"
            )
            run_name = f"{cfg_group_name}__{source_name}"
            logger.info(f"Run name #{run_idx + 1}/{no_of_runs}: {run_name}")
            run_idx += 1
            imputation_PLR_workflow(
                cfg=cfg_group,
                source_name=source_name,
                source_data=source_data,
                experiment_name=experiment_name,
                run_name=run_name,
            )

    # The ensembling part will fetch the trained imputation models from MLflow

    # First re-computing metrics for all the submodels, and making sure that they are correct
    task_ensemble(cfg=cfg, task="imputation", sources=sources, recompute_metrics=True)

    # Then ensembling the submodels
    task_ensemble(cfg=cfg, task="imputation", sources=sources, recompute_metrics=False)

Imputation Core

imputation_main

setup_PLR_worklow

setup_PLR_worklow(
    cfg: DictConfig, run_name: str
) -> Tuple[DictConfig, str, str, bool, Optional[Any], str]

Set up the PLR imputation workflow.

Configures the imputation pipeline: 1. Extract model name from config 2. Apply debug settings if enabled 3. Check for existing model runs

PARAMETER DESCRIPTION
cfg

Full Hydra configuration.

TYPE: DictConfig

run_name

MLflow run name.

TYPE: str

RETURNS DESCRIPTION
tuple

(cfg, model_name, updated_name, train_ON, best_run, artifacts_dir) - cfg: Updated configuration - model_name: Imputation model name - updated_name: Updated run name - train_ON: Whether to retrain - best_run: Best existing run (if any) - artifacts_dir: Directory for artifacts

Source code in src/imputation/imputation_main.py
def setup_PLR_worklow(
    cfg: DictConfig, run_name: str
) -> Tuple[DictConfig, str, str, bool, Optional[Any], str]:
    """
    Set up the PLR imputation workflow.

    Configures the imputation pipeline:
    1. Extract model name from config
    2. Apply debug settings if enabled
    3. Check for existing model runs

    Parameters
    ----------
    cfg : DictConfig
        Full Hydra configuration.
    run_name : str
        MLflow run name.

    Returns
    -------
    tuple
        (cfg, model_name, updated_name, train_ON, best_run, artifacts_dir)
        - cfg: Updated configuration
        - model_name: Imputation model name
        - updated_name: Updated run name
        - train_ON: Whether to retrain
        - best_run: Best existing run (if any)
        - artifacts_dir: Directory for artifacts
    """
    # There should be only one model here atm, TO-OPTIMIZE how to reconcile this later TODO!
    assert len(cfg["MODELS"]) == 1, "Only one model should be trained at a time"

    # Refactor this later, as the model_name used to be looped here
    model_name = list(cfg["MODELS"].keys())[0]  # just pick the first here

    # Debug: set the epochs to 1
    if cfg["EXPERIMENT"]["debug"]:
        cfg = debug_train_only_for_one_epoch(cfg)
        cfg = fix_tree_learners_for_debug(cfg, model_name)

    # Check if you find older models, and if you want to retrain them
    updated_name = update_run_name(
        run_name=run_name, base_run_name=define_run_name(cfg=cfg)
    )
    train_ON, best_run = if_retrain_the_imputation_model(
        cfg=cfg, run_name=updated_name, model_type="imputation"
    )

    # get the artifacts dir
    artifacts_dir = get_artifacts_dir(service_name="imputation")

    return cfg, model_name, updated_name, train_ON, best_run, artifacts_dir

mlflow_log_of_source_for_imputation

mlflow_log_of_source_for_imputation(
    source_data: Dict[str, Any], cfg: DictConfig
) -> None

Log source data information to MLflow for imputation tracking.

Records the outlier detection run information that was used as input for the imputation step. This enables traceability of the preprocessing pipeline through MLflow.

PARAMETER DESCRIPTION
source_data

Source data dictionary containing 'mlflow' key with run metadata and outlier detection results. If 'mlflow' is None, logs ground truth parameters instead.

TYPE: dict

cfg

Full Hydra configuration containing OUTLIER_DETECTION settings.

TYPE: DictConfig

Notes

Logs either the outlier detection run ID and best metric, or None/0 for ground truth data where no upstream outlier detection was used.

Source code in src/imputation/imputation_main.py
def mlflow_log_of_source_for_imputation(
    source_data: Dict[str, Any], cfg: DictConfig
) -> None:
    """Log source data information to MLflow for imputation tracking.

    Records the outlier detection run information that was used as input
    for the imputation step. This enables traceability of the preprocessing
    pipeline through MLflow.

    Parameters
    ----------
    source_data : dict
        Source data dictionary containing 'mlflow' key with run metadata
        and outlier detection results. If 'mlflow' is None, logs ground
        truth parameters instead.
    cfg : DictConfig
        Full Hydra configuration containing OUTLIER_DETECTION settings.

    Notes
    -----
    Logs either the outlier detection run ID and best metric, or None/0
    for ground truth data where no upstream outlier detection was used.
    """
    if source_data["mlflow"] is not None:
        mlflow.log_param("Outlier_run_id", source_data["mlflow"]["run_id"])
        best_outlier_dict = get_best_dict("outlier_detection", cfg)
        best_outlier_string = best_outlier_dict["string"].replace("metrics.", "")
        col_name = get_best_imputation_col_name(best_metric_cfg=best_outlier_dict)
        try:
            best_value = source_data["mlflow"][col_name]
        except Exception as e:
            logger.error(f"Could not find {best_outlier_string} in {source_data}")
            logger.error(e)
            raise e
        mlflow.log_param(f"OutlierBest_{best_outlier_string}", best_value)
    else:
        # This is now the manually annotated data, so no id, and loss is 0 to itself (as it's the ground truth)
        # MSE (anomaly score) in practice
        mlflow.log_param("Outlier_run_id", None)
        mlflow.log_param("OutlierBest", 0)

imputation_model_selector

imputation_model_selector(
    source_data: Dict[str, Any],
    cfg: DictConfig,
    model_name: str,
    run_name: str,
    artifacts_dir: str,
    experiment_name: str,
) -> Tuple[Optional[Any], Optional[Dict[str, Any]]]

Select and execute imputation method.

Dispatches to the appropriate imputation implementation. Supports: - Deep learning: SAITS, CSDI, TimesNet (via PyPOTS) - Foundation models: MOMENT - Traditional: MissForest

PARAMETER DESCRIPTION
source_data

Data from outlier detection stage with signals and masks.

TYPE: dict

cfg

Full Hydra configuration.

TYPE: DictConfig

model_name

Imputation method name. One of: 'SAITS', 'CSDI', 'TimesNet', 'MISSFOREST', 'MOMENT'.

TYPE: str

run_name

MLflow run name.

TYPE: str

artifacts_dir

Directory for saving artifacts.

TYPE: str

experiment_name

MLflow experiment name.

TYPE: str

RETURNS DESCRIPTION
tuple

(model, imputation_artifacts) where: - model: Trained imputation model - imputation_artifacts: dict with imputed data and metrics

RAISES DESCRIPTION
NotImplementedError

If model_name is not supported.

Source code in src/imputation/imputation_main.py
def imputation_model_selector(
    source_data: Dict[str, Any],
    cfg: DictConfig,
    model_name: str,
    run_name: str,
    artifacts_dir: str,
    experiment_name: str,
) -> Tuple[Optional[Any], Optional[Dict[str, Any]]]:
    """
    Select and execute imputation method.

    Dispatches to the appropriate imputation implementation. Supports:
    - Deep learning: SAITS, CSDI, TimesNet (via PyPOTS)
    - Foundation models: MOMENT
    - Traditional: MissForest

    Parameters
    ----------
    source_data : dict
        Data from outlier detection stage with signals and masks.
    cfg : DictConfig
        Full Hydra configuration.
    model_name : str
        Imputation method name. One of:
        'SAITS', 'CSDI', 'TimesNet', 'MISSFOREST', 'MOMENT'.
    run_name : str
        MLflow run name.
    artifacts_dir : str
        Directory for saving artifacts.
    experiment_name : str
        MLflow experiment name.

    Returns
    -------
    tuple
        (model, imputation_artifacts) where:
        - model: Trained imputation model
        - imputation_artifacts: dict with imputed data and metrics

    Raises
    ------
    NotImplementedError
        If model_name is not supported.
    """
    # MLflow run (init only when training again, no point when reading precomputed results)
    init_mlflow_run(
        mlflow_cfg=cfg["MLFLOW"],
        run_name=run_name,
        cfg=cfg,
        experiment_name=experiment_name,
    )

    # MLflow parameters
    log_mlflow_params(
        mlflow_params=cfg["MODELS"][model_name]["MODEL"], model_name=model_name
    )
    mlflow_log_of_source_for_imputation(source_data, cfg)

    logger.info("Imputation with model {}".format(model_name))
    if model_name == "SAITS" or model_name == "CSDI" or model_name == "TimesNet":
        model, model_artifacts = pypots_wrapper(
            source_data=source_data,
            model_cfg=cfg["MODELS"][model_name],
            cfg=cfg,
            model_name=model_name,
            run_name=run_name,
        )
    elif model_name == "MISSFOREST":
        model, model_artifacts = missforest_main(
            source_data=source_data,
            model_cfg=cfg["MODELS"][model_name],
            cfg=cfg,
            model_name=model_name,
            run_name=run_name,
        )
    elif model_name == "MOMENT":
        model, model_artifacts = moment_imputation_main(
            data_dict=source_data,
            model_cfg=cfg["MODELS"][model_name],
            cfg=cfg,
            model_name=model_name,
            run_name=run_name,
        )
    else:
        logger.error("Model {} not implemented! Typo?".format(model_name))
        raise NotImplementedError("Model {} not implemented!".format(model_name))

    # Save and log all the artifacts created during the training to MLflow
    if model_artifacts is not None:
        model_artifacts["mlflow"] = get_mlflow_info()
        imputation_artifacts = {
            "source_data": source_data,
            "model_artifacts": model_artifacts,
        }

        # with PyPOTS, you only do this here, harmonize later this!
        # Moment computes the metrics already inside the Moment code for example
        imputation_artifacts["model_artifacts"]["metrics"] = compute_metrics_by_model(
            model_name, imputation_artifacts, cfg
        )

        # Log to MLflow
        save_and_log_imputer_artifacts(
            model, imputation_artifacts, artifacts_dir, cfg, model_name, run_name
        )
        return model, imputation_artifacts
    else:
        # This is None, when you hit like all-NaNs in your predictions and you abort the training
        mlflow.end_run()
        return None, None

imputation_PLR_workflow

imputation_PLR_workflow(
    cfg: DictConfig,
    source_name: str,
    source_data: Dict[str, Any],
    run_name: str,
    experiment_name: str,
    _visualize: bool = False,
) -> Optional[Dict[str, Any]]

Execute the PLR imputation workflow.

Main entry point for training or loading imputation models on PLR data. Handles workflow setup, model training/loading, and artifact management.

PARAMETER DESCRIPTION
cfg

Full Hydra configuration including MODELS, EXPERIMENT, and MLFLOW settings.

TYPE: DictConfig

source_name

Name identifier for the data source (e.g., outlier detection method name).

TYPE: str

source_data

Data dictionary from outlier detection containing signals, masks, and metadata.

TYPE: dict

run_name

MLflow run name for this imputation experiment.

TYPE: str

experiment_name

MLflow experiment name to log results under.

TYPE: str

_visualize

Whether to generate visualizations (currently unused). Default is False.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
dict

Imputation artifacts containing imputed signals, metrics, and MLflow info. Returns None if training is skipped and no pre-computed results are loaded.

Notes

The workflow checks for existing trained models and can skip retraining if train_ON is False based on configuration and existing MLflow runs.

Source code in src/imputation/imputation_main.py
def imputation_PLR_workflow(
    cfg: DictConfig,
    source_name: str,
    source_data: Dict[str, Any],
    run_name: str,
    experiment_name: str,
    _visualize: bool = False,
) -> Optional[Dict[str, Any]]:
    """Execute the PLR imputation workflow.

    Main entry point for training or loading imputation models on PLR data.
    Handles workflow setup, model training/loading, and artifact management.

    Parameters
    ----------
    cfg : DictConfig
        Full Hydra configuration including MODELS, EXPERIMENT, and MLFLOW settings.
    source_name : str
        Name identifier for the data source (e.g., outlier detection method name).
    source_data : dict
        Data dictionary from outlier detection containing signals, masks, and metadata.
    run_name : str
        MLflow run name for this imputation experiment.
    experiment_name : str
        MLflow experiment name to log results under.
    _visualize : bool, optional
        Whether to generate visualizations (currently unused). Default is False.

    Returns
    -------
    dict
        Imputation artifacts containing imputed signals, metrics, and MLflow info.
        Returns None if training is skipped and no pre-computed results are loaded.

    Notes
    -----
    The workflow checks for existing trained models and can skip retraining
    if `train_ON` is False based on configuration and existing MLflow runs.
    """
    # Set-up the workflow
    cfg, model_name, run_name, train_ON, best_run, artifacts_dir = setup_PLR_worklow(
        cfg, run_name
    )

    # Task 1) Train the model and impute the missing data
    if train_ON:
        # This often can be time-consuming, so this also saves the results to MLflow
        _, imputation_artifacts = imputation_model_selector(
            source_data=source_data,
            cfg=cfg,
            run_name=run_name,
            artifacts_dir=artifacts_dir,
            model_name=model_name,
            experiment_name=experiment_name,
        )

    else:
        # The time-consuming imputation results can be imported here, so you
        # don't have to re-run the same experiments again
        # logging.info("Reading imputation results from MLflow")
        # imputation_artifacts, _ = retrieve_mlflow_artifacts_from_best_run(
        #     best_run, cfg, model_name
        # )
        logger.info("Skipping the re-computation of the imputation metrics")
        logger.debug("Nothing atm when you skip training, implement later here:")
        logger.debug(
            "If you want to compute new metrics or something without running the whole training"
        )

imputation_utils

create_imputation_df

create_imputation_df(
    imputer_artifacts: dict,
    data_df: DataFrame,
    cfg: DictConfig,
)

Create a Polars DataFrame from imputation artifacts for visualization.

Combines baseline PLR data with imputed data and exports to DuckDB for downstream analysis and visualization.

PARAMETER DESCRIPTION
imputer_artifacts

Dictionary containing imputation results per model, with 'mlflow' metadata for each model.

TYPE: dict

data_df

Original PLR data DataFrame with subject codes and time series.

TYPE: DataFrame

cfg

Full Hydra configuration including DATA and MLFLOW settings.

TYPE: DictConfig

RETURNS DESCRIPTION
DataFrame

Combined DataFrame with imputed values, model identifiers, and original data columns.

Source code in src/imputation/imputation_utils.py
def create_imputation_df(
    imputer_artifacts: dict, data_df: pl.DataFrame, cfg: DictConfig
):
    """Create a Polars DataFrame from imputation artifacts for visualization.

    Combines baseline PLR data with imputed data and exports to DuckDB
    for downstream analysis and visualization.

    Parameters
    ----------
    imputer_artifacts : dict
        Dictionary containing imputation results per model, with 'mlflow'
        metadata for each model.
    data_df : pl.DataFrame
        Original PLR data DataFrame with subject codes and time series.
    cfg : DictConfig
        Full Hydra configuration including DATA and MLFLOW settings.

    Returns
    -------
    pl.DataFrame
        Combined DataFrame with imputed values, model identifiers, and
        original data columns.
    """
    # Combine the baseline PLR data with the imputed data
    data_for_features = data_for_featurization_wrapper(
        artifacts=imputer_artifacts, cfg=cfg
    )

    # Create Dataframes per subplot
    df = create_imputation_plot_df(data_for_features, data_df, cfg)

    # Export imputation dataframe as DuckDB, and log as artifact to MLflow as well
    mlflow_cfgs = get_mlflow_cfgs_from_imputation_artifacts(imputer_artifacts, cfg)
    export_imputation_df(df, mlflow_cfgs, cfg)

    return df

create_imputation_plot_df

create_imputation_plot_df(
    data_for_features: dict,
    data_df: DataFrame,
    cfg: DictConfig,
)

Create DataFrame containing imputation results per model, split, and subject.

Iterates through all combinations of models, splits, and split keys to construct a unified DataFrame suitable for plotting and analysis.

PARAMETER DESCRIPTION
data_for_features

Nested dictionary with structure {model: {split: {split_key: data}}}, containing imputed values and metadata.

TYPE: dict

data_df

Original PLR data DataFrame with subject information.

TYPE: DataFrame

cfg

Configuration containing DATA.PLR_length for validation.

TYPE: DictConfig

RETURNS DESCRIPTION
DataFrame

Combined DataFrame with all imputation results, reordered to match initial column order.

RAISES DESCRIPTION
AssertionError

If row count is not a multiple of PLR_length.

Source code in src/imputation/imputation_utils.py
def create_imputation_plot_df(
    data_for_features: dict, data_df: pl.DataFrame, cfg: DictConfig
):
    """Create DataFrame containing imputation results per model, split, and subject.

    Iterates through all combinations of models, splits, and split keys to
    construct a unified DataFrame suitable for plotting and analysis.

    Parameters
    ----------
    data_for_features : dict
        Nested dictionary with structure {model: {split: {split_key: data}}},
        containing imputed values and metadata.
    data_df : pl.DataFrame
        Original PLR data DataFrame with subject information.
    cfg : DictConfig
        Configuration containing DATA.PLR_length for validation.

    Returns
    -------
    pl.DataFrame
        Combined DataFrame with all imputation results, reordered to match
        initial column order.

    Raises
    ------
    AssertionError
        If row count is not a multiple of PLR_length.
    """
    logger.info("Creating the imputation dataframe")
    init_cols_saved = False
    df_imputation = pl.DataFrame()
    debug_dfs = {}
    for model in data_for_features.keys():
        for split in data_for_features[model].keys():
            for split_key in data_for_features[model][split].keys():
                logger.debug(f"Creating dataframe for {model} {split} {split_key}")
                df_tmp, size_debug = create_subjects_df(
                    subplot_dict=data_for_features[model][split][split_key],
                    data_df=data_df,
                    cfg=cfg,
                )
                debug_dfs[f"{model}_{split}_{split_key}"] = size_debug
                df_tmp = add_loop_keys(model, split, split_key, df_tmp)
                if not init_cols_saved:
                    init_cols_saved = True
                    colnames_init = df_tmp.columns
                df_imputation = concatenate_imputation_dfs(
                    df_list=[df_imputation, df_tmp]
                )

                assert df_imputation.shape[0] % cfg["DATA"]["PLR_length"] == 0, (
                    "The number of rows in the output DataFrame "
                    "is not correct, shoud be a multiple of the "
                    "length of the time vector (PLR_length={})"
                ).format(cfg["DATA"]["PLR_length"])

    # reorder the columns based on the first subject
    df_imputation = df_imputation.select(colnames_init)
    no_of_PLRs = df_imputation.shape[0] / cfg["DATA"]["PLR_length"]
    assert df_imputation.shape[0] % cfg["DATA"]["PLR_length"] == 0, (
        "The number of rows in the output DataFrame "
        "is not correct, shoud be a multiple of the "
        "length of the time vector (PLR_length={})"
    ).format(cfg["DATA"]["PLR_length"])

    logger.info(
        "Imputation dataframe created, shape: {} ({} of PLRs in total, from {} options)".format(
            df_imputation.shape, int(no_of_PLRs), len(debug_dfs)
        )
    )

    return df_imputation

concatenate_imputation_dfs

concatenate_imputation_dfs(df_list: list)

Concatenate imputation DataFrames with proper type casting.

Processes and combines DataFrames for imputation results, ensuring consistent column types and naming conventions.

PARAMETER DESCRIPTION
df_list

List of two Polars DataFrames [existing_df, new_df] to concatenate. The first may be empty, the second is processed before concatenation.

TYPE: list

RETURNS DESCRIPTION
DataFrame

Vertically concatenated DataFrame with consistent column types.

RAISES DESCRIPTION
Exception

If concatenation fails due to schema mismatches.

Source code in src/imputation/imputation_utils.py
def concatenate_imputation_dfs(df_list: list):
    """Concatenate imputation DataFrames with proper type casting.

    Processes and combines DataFrames for imputation results, ensuring
    consistent column types and naming conventions.

    Parameters
    ----------
    df_list : list
        List of two Polars DataFrames [existing_df, new_df] to concatenate.
        The first may be empty, the second is processed before concatenation.

    Returns
    -------
    pl.DataFrame
        Vertically concatenated DataFrame with consistent column types.

    Raises
    ------
    Exception
        If concatenation fails due to schema mismatches.
    """
    # Now operating only on the df to be added, and we don't need to touch the 1st as
    # it's empty in the beginning, and all the 2nd df's will be added through these operations
    df_list[1] = cast_numeric_polars_cols(df=df_list[1], cast_to="Float64")
    df_list[1] = rename_ci_cols(df=df_list[1])
    df_list[1] = df_list[1].select(sorted(df_list[1].columns))

    try:
        df_imputation = pl.concat(df_list, how="vertical")
    except Exception as e:
        logger.error("Error in concatenating the imputation dataframes: {}".format(e))
        logger.error("Trying to convert the dataframes to Float32")
        raise e

    return df_imputation

create_subjects_df

create_subjects_df(
    subplot_dict: dict, data_df: DataFrame, cfg: DictConfig
)

Create DataFrame for all subjects from imputation subplot data.

Combines time series data with subject metadata for all subjects in the subplot dictionary.

PARAMETER DESCRIPTION
subplot_dict

Dictionary containing 'data' with imputation arrays (mean, CI, etc.) and 'metadata' with subject codes.

TYPE: dict

data_df

Original PLR data DataFrame for looking up subject information.

TYPE: DataFrame

cfg

Configuration (unused but kept for interface consistency).

TYPE: DictConfig

RETURNS DESCRIPTION
tuple

(df_out, size_debug) where df_out is the combined DataFrame and size_debug is a dict with 'no_subjects' and 'no_timepoints'.

RAISES DESCRIPTION
AssertionError

If timepoint counts don't match expected values.

Source code in src/imputation/imputation_utils.py
def create_subjects_df(subplot_dict: dict, data_df: pl.DataFrame, cfg: DictConfig):
    """Create DataFrame for all subjects from imputation subplot data.

    Combines time series data with subject metadata for all subjects in
    the subplot dictionary.

    Parameters
    ----------
    subplot_dict : dict
        Dictionary containing 'data' with imputation arrays (mean, CI, etc.)
        and 'metadata' with subject codes.
    data_df : pl.DataFrame
        Original PLR data DataFrame for looking up subject information.
    cfg : DictConfig
        Configuration (unused but kept for interface consistency).

    Returns
    -------
    tuple
        (df_out, size_debug) where df_out is the combined DataFrame and
        size_debug is a dict with 'no_subjects' and 'no_timepoints'.

    Raises
    ------
    AssertionError
        If timepoint counts don't match expected values.
    """
    # see e.g. compute_features_from_dict() and combine these eventually
    no_subjects, no_timepoints, no_features = subplot_dict["data"]["mean"].shape
    df_out = pl.DataFrame()
    for idx in range(no_subjects):
        df_subject = pl.DataFrame()
        df_subject = add_ts_cols(subplot_dict, df_subject, idx, no_timepoints)
        assert df_subject.shape[0] == no_timepoints, (
            f"df_subject: {df_subject.shape[0]} time points for {idx}th subject "
        )
        subject_code = subplot_dict["metadata"]["metadata_df"]["subject_code"][idx]
        data_subject = get_subject_datadf(data_df, subject_code, no_timepoints)
        df_subject = pl.concat([df_subject, data_subject], how="horizontal")
        assert df_subject.shape[0] == no_timepoints, (
            f"{df_subject.shape[0]} time points for {idx} subject "
        )
        df_out = pandas_concat(df1=df_out, df2=df_subject)
        assert df_out.shape[0] == (idx + 1) * no_timepoints, (
            f"df_out: {df_out.shape[0]} time points for {idx} subject "
        )
        # The column lengths in the DataFrame are not equal.

    assert df_out.shape[0] == no_subjects * no_timepoints, (
        "The number of rows in the output DataFrame is not correct"
    )

    return df_out, {"no_subjects": no_subjects, "no_timepoints": no_timepoints}

get_subject_datadf

get_subject_datadf(
    data_df: DataFrame,
    subject_code: str,
    no_timepoints: int,
)

Extract time series data for a specific subject from the DataFrame.

PARAMETER DESCRIPTION
data_df

Full PLR data DataFrame containing all subjects.

TYPE: DataFrame

subject_code

Unique identifier for the subject to extract.

TYPE: str

no_timepoints

Expected number of timepoints for validation.

TYPE: int

RETURNS DESCRIPTION
DataFrame

DataFrame containing only the specified subject's time series.

RAISES DESCRIPTION
AssertionError

If the number of rows doesn't match expected timepoints.

Source code in src/imputation/imputation_utils.py
def get_subject_datadf(data_df: pl.DataFrame, subject_code: str, no_timepoints: int):
    """Extract time series data for a specific subject from the DataFrame.

    Parameters
    ----------
    data_df : pl.DataFrame
        Full PLR data DataFrame containing all subjects.
    subject_code : str
        Unique identifier for the subject to extract.
    no_timepoints : int
        Expected number of timepoints for validation.

    Returns
    -------
    pl.DataFrame
        DataFrame containing only the specified subject's time series.

    Raises
    ------
    AssertionError
        If the number of rows doesn't match expected timepoints.
    """
    # Pick the time series from Polars DataFrame matching the subject code
    data_subject = data_df.filter(data_df["subject_code"] == subject_code)
    assert data_subject.shape[0] == no_timepoints, (
        f"data_subject: {data_subject.shape[0]} time points for {subject_code} subject "
    )
    # Polars->Pandas->Polars to maybe catch the "The column lengths in the DataFrame are not equal."
    return pl.from_pandas(data_subject.to_pandas())

add_ts_cols

add_ts_cols(
    subplot_dict: dict,
    df_out: DataFrame,
    idx: int,
    no_timepoints: int,
    add_as_list: str = True,
)

Add time series columns from subplot data to a DataFrame.

Extracts and adds imputation data (mean, CI bounds, etc.) for a specific subject index to the output DataFrame.

PARAMETER DESCRIPTION
subplot_dict

Dictionary containing 'data' with arrays keyed by time series type.

TYPE: dict

df_out

Output DataFrame to add columns to.

TYPE: DataFrame

idx

Subject index to extract data for.

TYPE: int

no_timepoints

Expected number of timepoints for validation.

TYPE: int

add_as_list

If True, convert arrays to lists before adding (helps with Polars compatibility issues). Default is True.

TYPE: str DEFAULT: True

RETURNS DESCRIPTION
DataFrame

DataFrame with added time series columns.

RAISES DESCRIPTION
AssertionError

If the number of timepoints doesn't match expected value.

Source code in src/imputation/imputation_utils.py
def add_ts_cols(
    subplot_dict: dict,
    df_out: pl.DataFrame,
    idx: int,
    no_timepoints: int,
    add_as_list: str = True,
):
    """Add time series columns from subplot data to a DataFrame.

    Extracts and adds imputation data (mean, CI bounds, etc.) for a
    specific subject index to the output DataFrame.

    Parameters
    ----------
    subplot_dict : dict
        Dictionary containing 'data' with arrays keyed by time series type.
    df_out : pl.DataFrame
        Output DataFrame to add columns to.
    idx : int
        Subject index to extract data for.
    no_timepoints : int
        Expected number of timepoints for validation.
    add_as_list : str, optional
        If True, convert arrays to lists before adding (helps with Polars
        compatibility issues). Default is True.

    Returns
    -------
    pl.DataFrame
        DataFrame with added time series columns.

    Raises
    ------
    AssertionError
        If the number of timepoints doesn't match expected value.
    """
    for ts_key in subplot_dict["data"].keys():
        if subplot_dict["data"][ts_key] is not None:
            # add the array to Polars DataFrame
            array_tmp = subplot_dict["data"][ts_key][idx, :, :].flatten()
            assert len(array_tmp) == no_timepoints, (
                f"array tmp: {len(array_tmp)} time points for {idx} subject "
            )
            if add_as_list:
                # some weird Polars glitch after Numpy 1.25.2 downgrade, getting
                # "The column lengths in the DataFrame are not equal." Maybe list is better?
                list_tmp = list(array_tmp)
                df_out = df_out.with_columns(pl.Series(name=ts_key, values=list_tmp))
            else:
                df_out = df_out.with_columns(pl.lit(array_tmp).alias(ts_key))
        else:
            df_out = df_out.with_columns(pl.lit(None).alias(ts_key))
    assert df_out.shape[0] == no_timepoints, (
        f"df_out: {df_out.shape[0]} time points for {idx} subject "
    )

    return df_out

add_loop_keys

add_loop_keys(model, split, split_key, df_tmp)

Add model, split, and split_key identifiers as columns to DataFrame.

PARAMETER DESCRIPTION
model

Name of the imputation model.

TYPE: str

split

Data split identifier (e.g., 'train', 'test').

TYPE: str

split_key

Additional split key identifier.

TYPE: str

df_tmp

DataFrame to add identifier columns to.

TYPE: DataFrame

RETURNS DESCRIPTION
DataFrame

DataFrame with added 'model', 'split', and 'split_key' columns.

Source code in src/imputation/imputation_utils.py
def add_loop_keys(model, split, split_key, df_tmp):
    """Add model, split, and split_key identifiers as columns to DataFrame.

    Parameters
    ----------
    model : str
        Name of the imputation model.
    split : str
        Data split identifier (e.g., 'train', 'test').
    split_key : str
        Additional split key identifier.
    df_tmp : pl.DataFrame
        DataFrame to add identifier columns to.

    Returns
    -------
    pl.DataFrame
        DataFrame with added 'model', 'split', and 'split_key' columns.
    """
    df_tmp = df_tmp.with_columns(pl.lit(model).alias("model"))
    df_tmp = df_tmp.with_columns(pl.lit(split).alias("split"))
    df_tmp = df_tmp.with_columns(pl.lit(split_key).alias("split_key"))

    return df_tmp

rename_ci_cols

rename_ci_cols(df)

Rename confidence interval columns to standard names.

Normalizes column names for imputation confidence intervals to consistent 'ci_pos' and 'ci_neg' names.

PARAMETER DESCRIPTION
df

DataFrame with potentially inconsistent CI column names.

TYPE: DataFrame

RETURNS DESCRIPTION
DataFrame

DataFrame with standardized CI column names.

Notes

This is a temporary fix for column naming inconsistencies that should be harmonized upstream.

Source code in src/imputation/imputation_utils.py
def rename_ci_cols(df):
    """Rename confidence interval columns to standard names.

    Normalizes column names for imputation confidence intervals to
    consistent 'ci_pos' and 'ci_neg' names.

    Parameters
    ----------
    df : pl.DataFrame
        DataFrame with potentially inconsistent CI column names.

    Returns
    -------
    pl.DataFrame
        DataFrame with standardized CI column names.

    Notes
    -----
    This is a temporary fix for column naming inconsistencies that should
    be harmonized upstream.
    """
    # TODO! Hacky way to handle columns, harmonize these so you don't get mixed? or how is this actually happening?
    for col in df.columns:
        if "imputation_ci_pos" in col:
            df = df.rename({"imputation_ci_pos": "ci_pos"})
        elif "imputation_ci_neg" in col:
            df = df.rename({"imputation_ci_neg": "ci_neg"})
    return df

get_mlflow_cfgs_from_imputation_artifacts

get_mlflow_cfgs_from_imputation_artifacts(
    imputer_artifacts: dict, cfg: DictConfig
)

Extract MLflow configurations from imputation artifacts.

PARAMETER DESCRIPTION
imputer_artifacts

Dictionary of imputation results keyed by model name, each containing 'mlflow' metadata.

TYPE: dict

cfg

Configuration (unused but kept for interface consistency).

TYPE: DictConfig

RETURNS DESCRIPTION
dict

Dictionary mapping model names to their MLflow configurations.

Source code in src/imputation/imputation_utils.py
def get_mlflow_cfgs_from_imputation_artifacts(imputer_artifacts: dict, cfg: DictConfig):
    """Extract MLflow configurations from imputation artifacts.

    Parameters
    ----------
    imputer_artifacts : dict
        Dictionary of imputation results keyed by model name, each containing
        'mlflow' metadata.
    cfg : DictConfig
        Configuration (unused but kept for interface consistency).

    Returns
    -------
    dict
        Dictionary mapping model names to their MLflow configurations.
    """
    mlflow_cfgs = {}
    for model in imputer_artifacts.keys():
        mlflow_cfgs[model] = imputer_artifacts[model]["mlflow"]
    return mlflow_cfgs

export_imputation_df

export_imputation_df(
    df: DataFrame, mlflow_cfgs: dict, cfg: DictConfig
)

Export imputation DataFrame to DuckDB and log to MLflow.

Creates a DuckDB database for each model's imputation results and logs it as an artifact to the corresponding MLflow run.

PARAMETER DESCRIPTION
df

Combined imputation DataFrame with model column for filtering.

TYPE: DataFrame

mlflow_cfgs

Dictionary mapping model names to MLflow configurations.

TYPE: dict

cfg

Configuration for export settings.

TYPE: DictConfig

Source code in src/imputation/imputation_utils.py
def export_imputation_df(df: pl.DataFrame, mlflow_cfgs: dict, cfg: DictConfig):
    """Export imputation DataFrame to DuckDB and log to MLflow.

    Creates a DuckDB database for each model's imputation results and
    logs it as an artifact to the corresponding MLflow run.

    Parameters
    ----------
    df : pl.DataFrame
        Combined imputation DataFrame with model column for filtering.
    mlflow_cfgs : dict
        Dictionary mapping model names to MLflow configurations.
    cfg : DictConfig
        Configuration for export settings.
    """
    logger.info("Exporting the imputation dataframe (per model output) to DuckDB")
    for model in mlflow_cfgs.keys():
        db_name = f"imputation_{model}.db"
        # Pick the samples from the given model
        df_subset = df.filter(df["model"] == model)
        # Save as DuckDB database
        db_path = export_dataframe_to_duckdb(
            df=df_subset, db_name=db_name, cfg=cfg, name="imputation"
        )
        # Log as artifact to MLflow
        log_imputation_db_to_mlflow(
            db_path=db_path, mlflow_cfg=mlflow_cfgs[model], model=model, cfg=cfg
        )

get_imputation_results_from_mlflow_for_features

get_imputation_results_from_mlflow_for_features(
    experiment_name: str, cfg: DictConfig
)

Retrieve imputation results from MLflow for feature computation.

Fetches the best hyperparameter configurations and their corresponding imputation results from MLflow for use in downstream featurization.

PARAMETER DESCRIPTION
experiment_name

MLflow experiment name to search for imputation runs.

TYPE: str

cfg

Configuration for MLflow and model settings.

TYPE: DictConfig

RETURNS DESCRIPTION
dict

Dictionary mapping model names to their imputation results from MLflow.

Source code in src/imputation/imputation_utils.py
def get_imputation_results_from_mlflow_for_features(
    experiment_name: str, cfg: DictConfig
):
    """Retrieve imputation results from MLflow for feature computation.

    Fetches the best hyperparameter configurations and their corresponding
    imputation results from MLflow for use in downstream featurization.

    Parameters
    ----------
    experiment_name : str
        MLflow experiment name to search for imputation runs.
    cfg : DictConfig
        Configuration for MLflow and model settings.

    Returns
    -------
    dict
        Dictionary mapping model names to their imputation results from MLflow.
    """
    # Gets the best hyperparam
    best_unique_models = get_used_imputation_models_from_mlflow(
        experiment_name, cfg, exclude_ensemble=False
    )

    results_per_model = {}
    for i, model in enumerate(best_unique_models.keys()):
        logger.info(
            f"Getting the results of the model: {model} (#{i + 1}/{len(best_unique_models.keys())})"
        )
        results_per_model[model] = get_imputation_results_from_mlflow(
            mlflow_run=best_unique_models[model], model_name=model, cfg=cfg
        )

    return results_per_model

Model Training

impute_with_models

evaluate_pypots_model

evaluate_pypots_model(
    model, dataset_dict, split: str, cfg: DictConfig
)

Evaluate a PyPOTS model by imputing missing values in the dataset.

Runs the trained PyPOTS model on the provided dataset to impute missing values, handling both deterministic and probabilistic outputs.

PARAMETER DESCRIPTION
model

Trained PyPOTS imputation model (e.g., SAITS, CSDI, TimesNet).

TYPE: PyPOTS model

dataset_dict

Dataset dictionary containing 'X' array with NaN values to impute.

TYPE: dict

split

Data split name ('train' or 'test').

TYPE: str

cfg

Configuration (unused but kept for interface consistency).

TYPE: DictConfig

RETURNS DESCRIPTION
dict

Imputation results containing: - 'imputation_dict': Dict with 'imputation' (mean, CI bounds) and 'indicating_mask' (boolean mask of originally missing values) - 'timing': Elapsed time for imputation in seconds

RAISES DESCRIPTION
ValueError

If imputation output has unexpected shape (not 3D or 4D).

Notes

CSDI generates a 4D output with samples dimension, which is reduced to 3D by taking the first sample. Other models produce 3D output directly.

Source code in src/imputation/impute_with_models.py
def evaluate_pypots_model(model, dataset_dict, split: str, cfg: DictConfig):
    """Evaluate a PyPOTS model by imputing missing values in the dataset.

    Runs the trained PyPOTS model on the provided dataset to impute
    missing values, handling both deterministic and probabilistic outputs.

    Parameters
    ----------
    model : PyPOTS model
        Trained PyPOTS imputation model (e.g., SAITS, CSDI, TimesNet).
    dataset_dict : dict
        Dataset dictionary containing 'X' array with NaN values to impute.
    split : str
        Data split name ('train' or 'test').
    cfg : DictConfig
        Configuration (unused but kept for interface consistency).

    Returns
    -------
    dict
        Imputation results containing:
        - 'imputation_dict': Dict with 'imputation' (mean, CI bounds) and
          'indicating_mask' (boolean mask of originally missing values)
        - 'timing': Elapsed time for imputation in seconds

    Raises
    ------
    ValueError
        If imputation output has unexpected shape (not 3D or 4D).

    Notes
    -----
    CSDI generates a 4D output with samples dimension, which is reduced
    to 3D by taking the first sample. Other models produce 3D output directly.
    """
    start_time = time.time()
    imputation = model.impute(dataset_dict)
    end_time = time.time() - start_time

    # indicating mask for imputation error calculation
    indicating_mask = np.isnan(dataset_dict["X"])

    # Save the imputation results
    if len(imputation.shape) == 3:
        # deterministic imputation
        imputation_mean = imputation
        imputation_ci_neg = None
        imputation_ci_pos = None
    elif len(imputation.shape) == 4:
        # CSDI generates a new dimension for the imputation
        # but by default it has only value? how to get the "spread"?
        imputation_mean = imputation[:, 0, :, :]
        imputation_ci_neg = None
        imputation_ci_pos = None
    else:
        logger.error("Unknown shape of the imputation results")
        raise ValueError("Unknown shape of the imputation results")

    # See "create_imputation_dict()" in src/imputation/train_utils.py
    # Combine these later
    imputation_dict = {
        "imputation_dict": {
            "imputation": {
                "mean": imputation_mean,
                "imputation_ci_neg": imputation_ci_neg,
                "imputation_ci_pos": imputation_ci_pos,
            },
            "indicating_mask": indicating_mask,
        },
        "timing": end_time,
    }

    return imputation_dict

log_imputed_artifacts

log_imputed_artifacts(
    imputation: dict,
    model_name: str,
    cfg: DictConfig,
    run_id: str,
)

Save imputation results locally and log to MLflow.

PARAMETER DESCRIPTION
imputation

Imputation results dictionary to save.

TYPE: dict

model_name

Name of the imputation model for file naming.

TYPE: str

cfg

Configuration (unused but kept for interface consistency).

TYPE: DictConfig

run_id

MLflow run ID to log artifacts to.

TYPE: str

RETURNS DESCRIPTION
str

Path to the saved artifacts file.

Source code in src/imputation/impute_with_models.py
def log_imputed_artifacts(
    imputation: dict, model_name: str, cfg: DictConfig, run_id: str
):
    """Save imputation results locally and log to MLflow.

    Parameters
    ----------
    imputation : dict
        Imputation results dictionary to save.
    model_name : str
        Name of the imputation model for file naming.
    cfg : DictConfig
        Configuration (unused but kept for interface consistency).
    run_id : str
        MLflow run ID to log artifacts to.

    Returns
    -------
    str
        Path to the saved artifacts file.
    """
    output_dir, fname, artifacts_path = define_pypots_outputs(
        model_name=model_name, artifact_type="imputation"
    )
    save_results_dict(imputation, artifacts_path, name="imputation")
    # with mlflow.start_run(run_id=run_id):
    mlflow.log_artifact(artifacts_path, artifact_path="results")

    return artifacts_path

pypots_imputer_wrapper

pypots_imputer_wrapper(
    model, model_name, dataset_dicts, source_data, cfg
)

Wrapper to impute data across all splits using a PyPOTS model.

Iterates over data splits and applies the trained PyPOTS model to impute missing values in each split.

PARAMETER DESCRIPTION
model

Trained PyPOTS imputation model.

TYPE: PyPOTS model

model_name

Name of the model for logging.

TYPE: str

dataset_dicts

Dictionary of datasets keyed by split name (e.g., 'train', 'test').

TYPE: dict

source_data

Source data dictionary (unused but kept for interface consistency).

TYPE: dict

cfg

Configuration for imputation settings.

TYPE: DictConfig

RETURNS DESCRIPTION
dict

Dictionary mapping split names to their imputation results.

Source code in src/imputation/impute_with_models.py
def pypots_imputer_wrapper(model, model_name, dataset_dicts, source_data, cfg):
    """Wrapper to impute data across all splits using a PyPOTS model.

    Iterates over data splits and applies the trained PyPOTS model
    to impute missing values in each split.

    Parameters
    ----------
    model : PyPOTS model
        Trained PyPOTS imputation model.
    model_name : str
        Name of the model for logging.
    dataset_dicts : dict
        Dictionary of datasets keyed by split name (e.g., 'train', 'test').
    source_data : dict
        Source data dictionary (unused but kept for interface consistency).
    cfg : DictConfig
        Configuration for imputation settings.

    Returns
    -------
    dict
        Dictionary mapping split names to their imputation results.
    """
    logger.info("Evaluate (impute) the model on the data")
    imputed_dict = {}
    for i, split in enumerate(dataset_dicts.keys()):
        imputed_dict[split] = evaluate_pypots_model(
            model=model,
            dataset_dict=dataset_dicts[split],
            split=split,
            cfg=cfg,
        )

    return imputed_dict

train_utils

if_results_file_found

if_results_file_found(results_path: str) -> bool

Check if a results file exists at the specified path.

PARAMETER DESCRIPTION
results_path

Full path to the results file.

TYPE: str

RETURNS DESCRIPTION
bool

True if the file exists, False otherwise.

Source code in src/imputation/train_utils.py
def if_results_file_found(results_path: str) -> bool:
    """Check if a results file exists at the specified path.

    Parameters
    ----------
    results_path : str
        Full path to the results file.

    Returns
    -------
    bool
        True if the file exists, False otherwise.
    """
    if os.path.exists(results_path):
        return True
    else:
        dir, fname = os.path.split(results_path)
        logger.debug(
            "Could not find the results  file from the artifact_dir\n"
            "dir = {}, fname = {}".format(dir, fname)
        )
        return False

create_imputation_dict

create_imputation_dict(
    imputation_mean: ndarray,
    preprocess: Dict[str, Any],
    X_missing: ndarray,
    cfg: DictConfig,
    end_time: Optional[float] = None,
) -> Dict[str, Any]

Create standardized imputation result dictionary.

Formats imputation results into the common structure used across all imputation methods, with optional destandardization.

PARAMETER DESCRIPTION
imputation_mean

Imputed values array, shape (samples, timepoints) or (samples, timepoints, features).

TYPE: ndarray

preprocess

Preprocessing dictionary containing 'standardization' with mean and stdev.

TYPE: dict

X_missing

Original array with NaN values indicating missing points.

TYPE: ndarray

cfg

Configuration with PREPROCESS.standardize flag.

TYPE: DictConfig

end_time

Time taken for imputation in seconds. Default is None.

TYPE: float DEFAULT: None

RETURNS DESCRIPTION
dict

Standardized imputation dictionary containing: - 'imputation_dict': Dict with 'imputation' (mean, CI bounds) and 'indicating_mask' - 'timing': Elapsed time if provided

Notes

If input is 2D, a third dimension is added to match expected (samples, timepoints, features) format. Destandardization is applied if configured.

Source code in src/imputation/train_utils.py
def create_imputation_dict(
    imputation_mean: np.ndarray,
    preprocess: Dict[str, Any],
    X_missing: np.ndarray,
    cfg: DictConfig,
    end_time: Optional[float] = None,
) -> Dict[str, Any]:
    """Create standardized imputation result dictionary.

    Formats imputation results into the common structure used across
    all imputation methods, with optional destandardization.

    Parameters
    ----------
    imputation_mean : np.ndarray
        Imputed values array, shape (samples, timepoints) or (samples, timepoints, features).
    preprocess : dict
        Preprocessing dictionary containing 'standardization' with mean and stdev.
    X_missing : np.ndarray
        Original array with NaN values indicating missing points.
    cfg : DictConfig
        Configuration with PREPROCESS.standardize flag.
    end_time : float, optional
        Time taken for imputation in seconds. Default is None.

    Returns
    -------
    dict
        Standardized imputation dictionary containing:
        - 'imputation_dict': Dict with 'imputation' (mean, CI bounds) and 'indicating_mask'
        - 'timing': Elapsed time if provided

    Notes
    -----
    If input is 2D, a third dimension is added to match expected (samples, timepoints, features) format.
    Destandardization is applied if configured.
    """
    # Get boolean mask of missing (NaN) values
    indicating_mask = np.isnan(X_missing)

    if len(imputation_mean.shape) == 2:
        # Add the third dimension, as the downstream code expects 3D arrays
        imputation_mean = np.expand_dims(imputation_mean, axis=2)
        logger.debug("Adding the third dimension to the imputation_mean array")

    assert len(imputation_mean.shape) == 3

    # Destandardize the data (if needed)
    if cfg["PREPROCESS"]["standardize"]:
        logger.debug("Destandardizing the imputed data")
        imputation_mean = destandardize_numpy(
            imputation_mean,
            mean=preprocess["standardization"]["mean"],
            std=preprocess["standardization"]["stdev"],
        )

    imputation_dict = {
        "imputation_dict": {
            "imputation": {
                # (no_samples, no_timepoints, no_features)
                "mean": imputation_mean,
                "imputation_ci_neg": None,
                "imputation_ci_pos": None,
            },
            "indicating_mask": indicating_mask,
        },
        "timing": end_time,
    }

    return imputation_dict

create_imputation_dict_from_moment

create_imputation_dict_from_moment(
    imputation_mean: ndarray,
    indicating_mask: ndarray,
    imputation_time: float,
) -> Dict[str, Any]

Create imputation dictionary from MOMENT model outputs.

Formats MOMENT-specific outputs into the common imputation structure.

PARAMETER DESCRIPTION
imputation_mean

Imputed values array from MOMENT model.

TYPE: ndarray

indicating_mask

Boolean mask indicating originally missing values.

TYPE: ndarray

imputation_time

Time taken for imputation in seconds.

TYPE: float

RETURNS DESCRIPTION
dict

Standardized imputation dictionary with imputation values, mask, and timing information.

Source code in src/imputation/train_utils.py
def create_imputation_dict_from_moment(
    imputation_mean: np.ndarray, indicating_mask: np.ndarray, imputation_time: float
) -> Dict[str, Any]:
    """Create imputation dictionary from MOMENT model outputs.

    Formats MOMENT-specific outputs into the common imputation structure.

    Parameters
    ----------
    imputation_mean : np.ndarray
        Imputed values array from MOMENT model.
    indicating_mask : np.ndarray
        Boolean mask indicating originally missing values.
    imputation_time : float
        Time taken for imputation in seconds.

    Returns
    -------
    dict
        Standardized imputation dictionary with imputation values,
        mask, and timing information.
    """
    imputation_dict = {
        "imputation_dict": {
            "imputation": {
                # (no_samples, no_timepoints, no_features)
                "mean": imputation_mean,
                "imputation_ci_neg": None,
                "imputation_ci_pos": None,
            },
            "indicating_mask": indicating_mask,
        },
        "timing": imputation_time,
    }

    return imputation_dict

imputation_per_split_of_dict

imputation_per_split_of_dict(
    data_dicts: Dict[str, Any],
    df: DataFrame,
    preprocess: Dict[str, Any],
    model: Any,
    split: str,
    cfg: DictConfig,
) -> Dict[str, Any]

Apply imputation model to a single data split.

Transforms the input DataFrame using the trained model and creates a standardized imputation result dictionary.

PARAMETER DESCRIPTION
data_dicts

Data dictionaries (unused but kept for interface consistency).

TYPE: dict

df

DataFrame with missing values (NaN) to impute.

TYPE: DataFrame

preprocess

Preprocessing dictionary with standardization statistics.

TYPE: dict

model

Trained imputation model with transform() method.

TYPE: object

split

Split name for logging.

TYPE: str

cfg

Configuration for imputation settings.

TYPE: DictConfig

RETURNS DESCRIPTION
dict

Imputation result dictionary with imputed values and timing.

Source code in src/imputation/train_utils.py
def imputation_per_split_of_dict(
    data_dicts: Dict[str, Any],
    df: pd.DataFrame,
    preprocess: Dict[str, Any],
    model: Any,
    split: str,
    cfg: DictConfig,
) -> Dict[str, Any]:
    """Apply imputation model to a single data split.

    Transforms the input DataFrame using the trained model and creates
    a standardized imputation result dictionary.

    Parameters
    ----------
    data_dicts : dict
        Data dictionaries (unused but kept for interface consistency).
    df : pd.DataFrame
        DataFrame with missing values (NaN) to impute.
    preprocess : dict
        Preprocessing dictionary with standardization statistics.
    model : object
        Trained imputation model with transform() method.
    split : str
        Split name for logging.
    cfg : DictConfig
        Configuration for imputation settings.

    Returns
    -------
    dict
        Imputation result dictionary with imputed values and timing.
    """
    X_missing = df.to_numpy()
    logger.debug("Split = {}".format(split))
    start_time = time.time()
    imputation_mean = model.transform(x=df)
    dict_out = create_imputation_dict(
        imputation_mean=imputation_mean,
        preprocess=preprocess,
        X_missing=X_missing,
        end_time=time.time() - start_time,
        cfg=cfg,
    )

    return dict_out

train_torch_utils

create_torch_dataloader

create_torch_dataloader(
    data_dict_df: dict,
    task: str,
    model_cfg: DictConfig,
    split: str,
    cfg: DictConfig,
    model_name: str = None,
)

Create a PyTorch DataLoader for a specific data split.

Creates a TensorDataset from numpy arrays and wraps it in a DataLoader with the specified configuration.

PARAMETER DESCRIPTION
data_dict_df

Data dictionary containing arrays per split.

TYPE: dict

task

Task type ('imputation' or 'outlier_detection').

TYPE: str

model_cfg

Model configuration with TORCH.DATASET and TORCH.DATALOADER settings.

TYPE: DictConfig

split

Split name ('train', 'test', 'outlier_train', 'outlier_test').

TYPE: str

cfg

Full Hydra configuration.

TYPE: DictConfig

model_name

Model name for dataset creation. Default is None.

TYPE: str DEFAULT: None

RETURNS DESCRIPTION
DataLoader

PyTorch DataLoader configured for the specified split.

RAISES DESCRIPTION
NotImplementedError

If dataset_type is 'class' (not yet implemented).

ValueError

If dataset_type is unknown.

Source code in src/imputation/train_torch_utils.py
def create_torch_dataloader(
    data_dict_df: dict,
    task: str,
    model_cfg: DictConfig,
    split: str,
    cfg: DictConfig,
    model_name: str = None,
):
    """Create a PyTorch DataLoader for a specific data split.

    Creates a TensorDataset from numpy arrays and wraps it in a DataLoader
    with the specified configuration.

    Parameters
    ----------
    data_dict_df : dict
        Data dictionary containing arrays per split.
    task : str
        Task type ('imputation' or 'outlier_detection').
    model_cfg : DictConfig
        Model configuration with TORCH.DATASET and TORCH.DATALOADER settings.
    split : str
        Split name ('train', 'test', 'outlier_train', 'outlier_test').
    cfg : DictConfig
        Full Hydra configuration.
    model_name : str, optional
        Model name for dataset creation. Default is None.

    Returns
    -------
    DataLoader
        PyTorch DataLoader configured for the specified split.

    Raises
    ------
    NotImplementedError
        If dataset_type is 'class' (not yet implemented).
    ValueError
        If dataset_type is unknown.
    """
    # Create the dataset
    if model_cfg["TORCH"]["DATASET"]["dataset_type"] == "numpy":
        dataset = create_dataset_from_numpy(
            data_dict_df=data_dict_df,
            dataset_cfg=model_cfg["TORCH"]["DATASET"],
            model_cfg=model_cfg,
            split=split,
            task=task,
            model_name=model_name,
        )
    elif model_cfg["TORCH"]["DATASET"]["dataset_type"] == "class":
        raise NotImplementedError(
            "Class based dataset creation is not implemented yet, please use numpy"
        )
        # dataset = AnomalyDetectionPLRDataset(
        #     split, model_artifacts, dataset_cfg=model_cfg["TORCH"]["DATASET"]
        # )
    else:
        logger.error(
            "Unknown Torch Dataset creation method = {}".format(
                model_cfg["TORCH"]["DATASET"]["dataset_type"]
            )
        )
        raise ValueError("Unknown Torch Dataset creation method")

    # Create the dataloder from the dataset
    # Compare to moment-research/moment/data/dataloader.py#L98
    dataloader = DataLoader(
        dataset, **model_cfg["TORCH"]["DATALOADER"]
    )  # create your dataloader

    return dataloader

create_torch_dataloaders

create_torch_dataloaders(
    task: str,
    model_name: str,
    data_dict_df: dict,
    model_cfg: DictConfig,
    cfg: DictConfig,
    create_outlier_dataloaders: bool = True,
)

Create PyTorch DataLoaders for all required data splits.

Creates train and test dataloaders, with optional outlier-specific dataloaders for anomaly detection tasks.

PARAMETER DESCRIPTION
task

Task type ('imputation' or 'outlier_detection').

TYPE: str

model_name

Model name for dataset creation.

TYPE: str

data_dict_df

Data dictionary containing arrays per split.

TYPE: dict

model_cfg

Model configuration with TORCH settings.

TYPE: DictConfig

cfg

Full Hydra configuration.

TYPE: DictConfig

create_outlier_dataloaders

Whether to create outlier_train and outlier_test dataloaders. Default is True.

TYPE: bool DEFAULT: True

RETURNS DESCRIPTION
dict

Dictionary mapping split names to DataLoaders. Contains 'train' and 'test', plus 'outlier_train' and 'outlier_test' if requested.

Source code in src/imputation/train_torch_utils.py
def create_torch_dataloaders(
    task: str,
    model_name: str,
    data_dict_df: dict,
    model_cfg: DictConfig,
    cfg: DictConfig,
    create_outlier_dataloaders: bool = True,
):
    """Create PyTorch DataLoaders for all required data splits.

    Creates train and test dataloaders, with optional outlier-specific
    dataloaders for anomaly detection tasks.

    Parameters
    ----------
    task : str
        Task type ('imputation' or 'outlier_detection').
    model_name : str
        Model name for dataset creation.
    data_dict_df : dict
        Data dictionary containing arrays per split.
    model_cfg : DictConfig
        Model configuration with TORCH settings.
    cfg : DictConfig
        Full Hydra configuration.
    create_outlier_dataloaders : bool, optional
        Whether to create outlier_train and outlier_test dataloaders.
        Default is True.

    Returns
    -------
    dict
        Dictionary mapping split names to DataLoaders. Contains 'train'
        and 'test', plus 'outlier_train' and 'outlier_test' if requested.
    """
    # Create torch dataloader for zero-shot imputation
    logger.info("Creating torch dataloaders")
    train_dataloader = create_torch_dataloader(
        data_dict_df,
        task=task,
        model_cfg=model_cfg,
        split="train",
        cfg=cfg,
        model_name=model_name,
    )
    test_dataloader = create_torch_dataloader(
        data_dict_df,
        task=task,
        model_cfg=model_cfg,
        split="test",
        cfg=cfg,
        model_name=model_name,
    )
    if create_outlier_dataloaders:
        # if model_cfg["MODEL"]["train_on"] != 'pupil_orig_imputed':
        # No need for separate outlier dataloaders if the actual test/train are the same
        # A bit unconventional naming here, we are going to be training with the clean data
        # unsupervised, and MOMENT hopefully learns to reconstruct the denoised signal, and pick
        # up the anomalies from this outlier dataloader
        outlier_train_dataloader = create_torch_dataloader(
            data_dict_df,
            task=task,
            model_cfg=model_cfg,
            split="outlier_train",
            cfg=cfg,
            model_name=model_name,
        )

        outlier_test_dataloader = create_torch_dataloader(
            data_dict_df,
            task=task,
            model_cfg=model_cfg,
            split="outlier_test",
            cfg=cfg,
            model_name=model_name,
        )

        return {
            "train": train_dataloader,
            "test": test_dataloader,
            "outlier_train": outlier_train_dataloader,
            "outlier_test": outlier_test_dataloader,
        }
        # else:
        #     return {"train": train_dataloader, "test": test_dataloader}

    else:
        return {"train": train_dataloader, "test": test_dataloader}

MissForest

missforest_main

missforest_create_imputation_dicts

missforest_create_imputation_dicts(
    model, df_dict, source_data, cfg
)

Create imputation dictionaries from MissForest model outputs.

Transforms MissForest imputation results into the standardized format used by PyPOTS models for downstream processing compatibility.

PARAMETER DESCRIPTION
model

Trained MissForest model.

TYPE: MissForest

df_dict

Dictionary of DataFrames keyed by split name ('train', 'test').

TYPE: dict

source_data

Source data containing 'df' with data dictionaries per split and 'preprocess' with standardization statistics.

TYPE: dict

cfg

Configuration for imputation settings.

TYPE: DictConfig

RETURNS DESCRIPTION
dict

Dictionary mapping split names to imputation results in PyPOTS-compatible format.

Source code in src/imputation/missforest_main.py
def missforest_create_imputation_dicts(model, df_dict, source_data, cfg):
    """Create imputation dictionaries from MissForest model outputs.

    Transforms MissForest imputation results into the standardized format
    used by PyPOTS models for downstream processing compatibility.

    Parameters
    ----------
    model : MissForest
        Trained MissForest model.
    df_dict : dict
        Dictionary of DataFrames keyed by split name ('train', 'test').
    source_data : dict
        Source data containing 'df' with data dictionaries per split
        and 'preprocess' with standardization statistics.
    cfg : DictConfig
        Configuration for imputation settings.

    Returns
    -------
    dict
        Dictionary mapping split names to imputation results in
        PyPOTS-compatible format.
    """
    # Harmonize the output with the PyPOTS outputs, so that the downstream code works well
    # see pypots_imputer_wrapper()
    imputed_dict = {}
    for i, split in enumerate(source_data["df"]):
        imputed_dict[split] = {}
        data_dicts = source_data["df"][split]["data"]
        # metadata = source_data["df"][split]["metadata"]
        imputed_dict[split] = imputation_per_split_of_dict(
            data_dicts=data_dicts,
            df=df_dict[split],
            preprocess=source_data[
                "preprocess"
            ],  # needs only the standardization stats
            model=model,
            split=split,
            cfg=cfg,
        )

    return imputed_dict

check_df

check_df(df: DataFrame)

Validate and convert DataFrame for MissForest compatibility.

Logs the number of NaN values and ensures all columns are float type to prevent type errors during MissForest fitting.

PARAMETER DESCRIPTION
df

Input DataFrame with potential NaN values.

TYPE: DataFrame

RETURNS DESCRIPTION
DataFrame

DataFrame with all columns cast to float type.

Source code in src/imputation/missforest_main.py
def check_df(df: pd.DataFrame):
    """Validate and convert DataFrame for MissForest compatibility.

    Logs the number of NaN values and ensures all columns are float type
    to prevent type errors during MissForest fitting.

    Parameters
    ----------
    df : pd.DataFrame
        Input DataFrame with potential NaN values.

    Returns
    -------
    pd.DataFrame
        DataFrame with all columns cast to float type.
    """
    no_of_nans_cols = df.isnull().sum()
    no_of_nans = no_of_nans_cols.sum()
    logger.info("Number of NaNs in train_df: {}".format(no_of_nans))

    # model.fit(
    #   python3.11/site-packages/missforest/missforest.py", line 438, in fit
    #     nrmse_score.append(
    #   python3.11/site-packages/missforest/_array.py", line 47, in append
    #     raise ValueError(f"Datatype of new item must {self.dtype}.")
    # ValueError: Datatype of new item must <class 'float'>.
    # python-BaseException

    # if not isinstance(item, self.dtype):
    #     raise ValueError(f"Datatype of new item must {self.dtype}.")
    # when item: float32 instead of <class 'float'> in self.dtype

    # dtypes = df.dtypes
    df = df.astype(float)
    # logger.info(df.dtypes)

    return df

missforest_fit_script

missforest_fit_script(train_df: DataFrame, cfg: DictConfig)

Fit a MissForest model on training data.

PARAMETER DESCRIPTION
train_df

Training DataFrame with NaN values to learn imputation patterns from.

TYPE: DataFrame

cfg

Configuration containing MODELS.MISSFOREST.MODEL parameters.

TYPE: DictConfig

RETURNS DESCRIPTION
tuple

(model, results) where model is the fitted MissForest instance and results is a dict with 'train' timing in seconds.

Source code in src/imputation/missforest_main.py
def missforest_fit_script(train_df: pd.DataFrame, cfg: DictConfig):
    """Fit a MissForest model on training data.

    Parameters
    ----------
    train_df : pd.DataFrame
        Training DataFrame with NaN values to learn imputation patterns from.
    cfg : DictConfig
        Configuration containing MODELS.MISSFOREST.MODEL parameters.

    Returns
    -------
    tuple
        (model, results) where model is the fitted MissForest instance and
        results is a dict with 'train' timing in seconds.
    """
    train_df = check_df(df=train_df)
    logger.info("Fitting the MissForest model")
    start_time = time.time()
    # Default estimators are lgbm classifier and regressor
    params = cfg["MODELS"]["MISSFOREST"]["MODEL"]
    model = MissForest(**params)
    logger.info("MissForest | Model parameters: {}".format(params))
    model.fit(
        x=train_df,
    )
    results = {"train": time.time() - start_time}
    logger.info("Fitting done in {:.2f} seconds".format(results["train"]))

    return model, results

get_dataframes_from_dict_for_missforest

get_dataframes_from_dict_for_missforest(source_data: dict)

Convert source data dictionaries to DataFrames for MissForest.

Extracts arrays from source data, applies masks by setting masked values to NaN, and converts to pandas DataFrames.

PARAMETER DESCRIPTION
source_data

Source data containing 'df' with 'train' and 'test' splits, each having 'data' with 'X' arrays and 'mask' arrays.

TYPE: dict

RETURNS DESCRIPTION
tuple

(df_train, df_test) as pandas DataFrames with NaN values where mask indicates missing data.

RAISES DESCRIPTION
AssertionError

If input arrays contain unexpected NaN values or masking fails.

Source code in src/imputation/missforest_main.py
def get_dataframes_from_dict_for_missforest(source_data: dict):
    """Convert source data dictionaries to DataFrames for MissForest.

    Extracts arrays from source data, applies masks by setting masked
    values to NaN, and converts to pandas DataFrames.

    Parameters
    ----------
    source_data : dict
        Source data containing 'df' with 'train' and 'test' splits,
        each having 'data' with 'X' arrays and 'mask' arrays.

    Returns
    -------
    tuple
        (df_train, df_test) as pandas DataFrames with NaN values where
        mask indicates missing data.

    Raises
    ------
    AssertionError
        If input arrays contain unexpected NaN values or masking fails.
    """

    def get_df_per_split(split_dict):
        """Extract array from split dict, apply mask, and convert to DataFrame."""
        array = split_dict["data"]["X"]
        no_of_nans = np.sum(np.isnan(array))
        assert no_of_nans == 0, "There are NaNs in the data"
        mask = split_dict["data"]["mask"]
        mask_sum = np.sum(mask == 1)
        array[mask == 1] = np.nan
        masked_sum = np.sum(np.isnan(array))
        assert masked_sum == mask_sum, "Masking issue"
        df = pl.DataFrame(array)
        return df

    df_train = get_df_per_split(split_dict=deepcopy(source_data["df"]["train"]))
    df_test = get_df_per_split(split_dict=deepcopy(source_data["df"]["test"]))

    return df_train.to_pandas(), df_test.to_pandas()

missforest_main

missforest_main(
    source_data: dict,
    model_cfg: DictConfig,
    cfg: DictConfig,
    model_name: str = None,
    run_name: str = None,
)

See e.g. El Badisy et al. (2024) https://doi.org/10.1186/s12874-024-02305-3 Albu et al. (2024) https://arxiv.org/abs/2407.03379 for "missForestPredict" Original paper by Stekhoven and Bühlmann (2012) https://doi.org/10.1093/bioinformatics/btr597 Python https://github.com/yuenshingyan/MissForest / https://pypi.org/project/MissForest/

Source code in src/imputation/missforest_main.py
def missforest_main(
    source_data: dict,
    model_cfg: DictConfig,
    cfg: DictConfig,
    model_name: str = None,
    run_name: str = None,
):
    """
    See e.g. El Badisy et al. (2024) https://doi.org/10.1186/s12874-024-02305-3
    Albu et al. (2024) https://arxiv.org/abs/2407.03379 for "missForestPredict"
    Original paper by Stekhoven and Bühlmann (2012) https://doi.org/10.1093/bioinformatics/btr597
    Python https://github.com/yuenshingyan/MissForest / https://pypi.org/project/MissForest/
    """
    # MissForest is not learning from data, rather working on dataset-wise, so not even a good algorithm
    # to be used in production with new patients
    # MissForest not implemented - placeholder for future work.
    # See function docstring for references on implementation approach.
    raise NotImplementedError(
        "MissForest imputation not implemented. "
        "Check for the output scaling, test/train seems differently scaled/standardized"
    )

Artifacts

imputation_log_artifacts

pypots_model_logger

pypots_model_logger(
    model_obj: Any,
    model_name: str,
    model_info: dict[str, Any],
    artifacts_dir: str,
) -> None

Log a PyPOTS model and its training artifacts to MLflow.

Copies the saved PyPOTS model file and training directory (including TensorBoard logs) to MLflow artifacts.

PARAMETER DESCRIPTION
model_obj

Trained PyPOTS model instance with saving_path attribute.

TYPE: PyPOTS model

model_name

Name of the model for file naming.

TYPE: str

model_info

Model information dictionary containing 'num_params'.

TYPE: dict

artifacts_dir

Directory for artifacts (unused but kept for interface consistency).

TYPE: str

RAISES DESCRIPTION
FileNotFoundError

If the model file is not found at the expected path.

Source code in src/imputation/imputation_log_artifacts.py
def pypots_model_logger(
    model_obj: Any, model_name: str, model_info: dict[str, Any], artifacts_dir: str
) -> None:
    """Log a PyPOTS model and its training artifacts to MLflow.

    Copies the saved PyPOTS model file and training directory (including
    TensorBoard logs) to MLflow artifacts.

    Parameters
    ----------
    model_obj : PyPOTS model
        Trained PyPOTS model instance with saving_path attribute.
    model_name : str
        Name of the model for file naming.
    model_info : dict
        Model information dictionary containing 'num_params'.
    artifacts_dir : str
        Directory for artifacts (unused but kept for interface consistency).

    Raises
    ------
    FileNotFoundError
        If the model file is not found at the expected path.
    """
    model_path = Path(model_obj.saving_path) / f"{model_name}.pypots"
    if not model_path.exists():
        logger.error(f"Could not find the PyPOTS model from {model_path}")
        raise FileNotFoundError(f"Could not find the PyPOTS model from {model_path}")

    logger.debug(
        "Copying saved PyPOTS model to MLflow (from {})".format(model_obj.saving_path)
    )
    mlflow.log_artifact(str(model_path), artifact_path="model")
    logger.debug("Copying PyPOTS directory to MLflow (contains e.g. tensorboard logs)")
    mlflow.log_artifact(model_obj.saving_path, artifact_path="pyPOTS")
    # Log the number parameters of the model
    try:
        mlflow.log_param("num_params", model_info["num_params"])
    except Exception as e:
        logger.warning(f"Could not log the number of parameters to MLflow: {e}")

generic_pickled_model_logger

generic_pickled_model_logger(
    model_obj: Any, model_name: str, artifacts_dir: str
) -> None

Save a model as pickle and log to MLflow.

Generic model logger for models that don't have specialized saving methods.

PARAMETER DESCRIPTION
model_obj

Trained model instance to pickle.

TYPE: object

model_name

Name of the model for file naming.

TYPE: str

artifacts_dir

Directory to save the pickle file.

TYPE: str

Source code in src/imputation/imputation_log_artifacts.py
def generic_pickled_model_logger(
    model_obj: Any, model_name: str, artifacts_dir: str
) -> None:
    """Save a model as pickle and log to MLflow.

    Generic model logger for models that don't have specialized saving methods.

    Parameters
    ----------
    model_obj : object
        Trained model instance to pickle.
    model_name : str
        Name of the model for file naming.
    artifacts_dir : str
        Directory to save the pickle file.
    """
    logger.debug("Logging the MissForest model to local disk")
    fname = get_imputation_pickle_name(model_name)
    path = Path(artifacts_dir) / fname
    save_object_to_pickle(model_obj, str(path))
    logger.debug("Copying saved MissForest model to MLflow")
    mlflow.log_artifact(str(path), artifact_path="model")

log_imputer_model

log_imputer_model(
    model_obj: Any,
    model_name: str,
    artifacts: dict[str, Any],
    artifacts_dir: str,
) -> None

Log an imputation model to MLflow using the appropriate method.

Dispatches to the correct logging method based on model type (PyPOTS, MissForest, MOMENT, etc.).

PARAMETER DESCRIPTION
model_obj

Trained imputation model instance.

TYPE: object

model_name

Name of the model.

TYPE: str

artifacts

Artifacts dictionary containing 'model_artifacts' with 'model_info'.

TYPE: dict

artifacts_dir

Directory for saving artifacts.

TYPE: str

Notes

PyPOTS models use their specialized save format. MissForest uses pickle. MOMENT models are not currently logged (only results are logged).

Source code in src/imputation/imputation_log_artifacts.py
def log_imputer_model(
    model_obj: Any, model_name: str, artifacts: dict[str, Any], artifacts_dir: str
) -> None:
    """Log an imputation model to MLflow using the appropriate method.

    Dispatches to the correct logging method based on model type
    (PyPOTS, MissForest, MOMENT, etc.).

    Parameters
    ----------
    model_obj : object
        Trained imputation model instance.
    model_name : str
        Name of the model.
    artifacts : dict
        Artifacts dictionary containing 'model_artifacts' with 'model_info'.
    artifacts_dir : str
        Directory for saving artifacts.

    Notes
    -----
    PyPOTS models use their specialized save format. MissForest uses pickle.
    MOMENT models are not currently logged (only results are logged).
    """
    # Log the model
    logger.debug("Logging the model to MLflow")

    if "model_info" in artifacts["model_artifacts"]:
        if "PyPOTS" in artifacts["model_artifacts"]["model_info"]:
            if artifacts["model_artifacts"]["model_info"]["PyPOTS"]:
                # This is now a PyPOTS model
                pypots_model_logger(
                    model_obj=model_obj,
                    model_name=model_name,
                    model_info=artifacts["model_artifacts"]["model_info"],
                    artifacts_dir=artifacts_dir,
                )
            else:
                logger.warning(
                    "Figure out how to log the new model to MLflow, where is the model located?"
                )
                logger.warning("Or is the model_obj still not saved to disk at all?")
                # raise NotImplementedError(
                #     "No non-PyPOTS evaluation/imputation implemented yet!"
                # )
        else:
            logger.warning(
                "Figure out how to log the new model to MLflow, where is the model located?"
            )
            logger.warning("Or is the model_obj still not saved to disk at all?")
            # raise NotImplementedError(
            #     "No non-PyPOTS evaluation/imputation implemented yet!"
            # )
    elif "MISSFOREST" in model_name:
        generic_pickled_model_logger(model_obj, model_name, artifacts_dir)

    elif "MOMENT" in model_name:
        # generic_pickled_model_logger(model_obj, model_name, artifacts_dir)
        logger.warning("Moment model is not logged now, only the results are logged")

    else:
        logger.warning(
            "Figure out how to log the new model to MLflow, where is the model located?"
        )
        logger.warning("Or is the model_obj still not saved to disk at all?")

log_the_imputation_results

log_the_imputation_results(
    imputation_artifacts: dict[str, Any],
    model_name: str,
    artifacts_dir: str,
    cfg: DictConfig,
    run_name: str,
) -> None

Save imputation results locally and log to MLflow.

PARAMETER DESCRIPTION
imputation_artifacts

Dictionary containing imputation results to save.

TYPE: dict

model_name

Name of the model for file naming.

TYPE: str

artifacts_dir

Directory to save the pickle file.

TYPE: str

cfg

Configuration (unused but kept for interface consistency).

TYPE: DictConfig

run_name

Run name (unused but kept for interface consistency).

TYPE: str

Source code in src/imputation/imputation_log_artifacts.py
def log_the_imputation_results(
    imputation_artifacts: dict[str, Any],
    model_name: str,
    artifacts_dir: str,
    cfg: DictConfig,
    run_name: str,
) -> None:
    """Save imputation results locally and log to MLflow.

    Parameters
    ----------
    imputation_artifacts : dict
        Dictionary containing imputation results to save.
    model_name : str
        Name of the model for file naming.
    artifacts_dir : str
        Directory to save the pickle file.
    cfg : DictConfig
        Configuration (unused but kept for interface consistency).
    run_name : str
        Run name (unused but kept for interface consistency).
    """
    # Log first to disk
    results_path = Path(artifacts_dir) / get_imputation_pickle_name(model_name)
    save_results_dict(
        results_dict=imputation_artifacts,
        results_path=str(results_path),
        name="imputation",
    )

    # And then copy this to MLflow
    logger.debug("Copying the imputation results to MLflow")
    mlflow.log_artifact(str(results_path), artifact_path="imputation")

save_and_log_imputer_artifacts

save_and_log_imputer_artifacts(
    model: Any,
    imputation_artifacts: dict[str, Any],
    artifacts_dir: str,
    cfg: DictConfig,
    model_name: str,
    run_name: str,
) -> None

Save and log all imputation artifacts to MLflow.

Orchestrates the logging of model, results, and Hydra configuration artifacts to the associated MLflow run.

PARAMETER DESCRIPTION
model

Trained imputation model instance.

TYPE: object

imputation_artifacts

Dictionary containing 'model_artifacts' with MLflow info and results.

TYPE: dict

artifacts_dir

Directory for saving artifacts locally.

TYPE: str

cfg

Full Hydra configuration.

TYPE: DictConfig

model_name

Name of the imputation model.

TYPE: str

run_name

MLflow run name.

TYPE: str

Notes

Ends any active MLflow run, then starts a new run context to log artifacts. The run is ended after all artifacts are logged.

Source code in src/imputation/imputation_log_artifacts.py
def save_and_log_imputer_artifacts(
    model: Any,
    imputation_artifacts: dict[str, Any],
    artifacts_dir: str,
    cfg: DictConfig,
    model_name: str,
    run_name: str,
) -> None:
    """Save and log all imputation artifacts to MLflow.

    Orchestrates the logging of model, results, and Hydra configuration
    artifacts to the associated MLflow run.

    Parameters
    ----------
    model : object
        Trained imputation model instance.
    imputation_artifacts : dict
        Dictionary containing 'model_artifacts' with MLflow info and results.
    artifacts_dir : str
        Directory for saving artifacts locally.
    cfg : DictConfig
        Full Hydra configuration.
    model_name : str
        Name of the imputation model.
    run_name : str
        MLflow run name.

    Notes
    -----
    Ends any active MLflow run, then starts a new run context to log
    artifacts. The run is ended after all artifacts are logged.
    """
    logger.info("Logging the imputer artifacts to MLflow")

    # Log the metrics MLflow
    if mlflow.active_run() is not None:
        mlflow.end_run()

    mlflow_info = get_mlflow_info_from_model_dict(
        imputation_artifacts["model_artifacts"]
    )
    experiment_id, run_id = get_mlflow_params(mlflow_info)

    with mlflow.start_run(run_id):
        # Log the model
        log_imputer_model(
            model_obj=model,
            model_name=model_name,
            artifacts=imputation_artifacts,
            artifacts_dir=artifacts_dir,
        )

        # Log the imputation results (forward passes, and other data)
        log_the_imputation_results(
            imputation_artifacts, model_name, artifacts_dir, cfg, run_name
        )

        # Log the Hydra artifacts to MLflow
        log_hydra_artifacts_to_mlflow(artifacts_dir, model_name, cfg, run_name)

        # End the MLflow run, and you can still log to the same run later when evaluating inputation,
        # computing metrics, logging the artifacts with the run_id
        logger.debug(
            "MLflow | Ending MLflow run named: {}".format(
                mlflow.active_run().info.run_name
            )
        )
        mlflow.end_run()