Skip to content

data_io

Data loading, preprocessing, and export utilities.

Overview

This module handles:

  • Loading PLR data from DuckDB/SQLite
  • Data preprocessing and validation
  • Export to various formats
  • Stratification for cross-validation

Data Import

data_import

import_PLR_data_wrapper

import_PLR_data_wrapper(
    cfg: DictConfig, data_dir: str = None
)

Import and preprocess PLR data from individual CSV files.

Main import wrapper that handles the complete data loading pipeline: importing raw data, combining with metadata, preparing for imputation, granularizing outlier labels, and stratifying splits.

PARAMETER DESCRIPTION
cfg

Configuration dictionary containing DATA and METADATA settings.

TYPE: DictConfig

data_dir

Directory for data files, by default None.

TYPE: str DEFAULT: None

RETURNS DESCRIPTION
tuple

Tuple containing (df_train, df_test) as Polars dataframes.

Source code in src/data_io/data_import.py
def import_PLR_data_wrapper(cfg: DictConfig, data_dir: str = None):
    """Import and preprocess PLR data from individual CSV files.

    Main import wrapper that handles the complete data loading pipeline:
    importing raw data, combining with metadata, preparing for imputation,
    granularizing outlier labels, and stratifying splits.

    Parameters
    ----------
    cfg : DictConfig
        Configuration dictionary containing DATA and METADATA settings.
    data_dir : str, optional
        Directory for data files, by default None.

    Returns
    -------
    tuple
        Tuple containing (df_train, df_test) as Polars dataframes.
    """
    # Import data
    df_raw = import_data(cfg, data_dir)

    # import the metadata
    df_metadata = metadata_wrapper(metadata_cfg=cfg["METADATA"])

    # Combine with the time series data
    df_raw, code_stats = combine_metadata_with_df_splits(df_raw, df_metadata)

    # Pick the relevant columns for the task at hand
    df_raw = prepare_dataframe_for_imputation(df_raw, cfg)

    # Check that nothing funky happened to the time vector
    check_for_unique_timepoints(df_raw, cfg, col="time", assert_on_error=True)
    check_for_unique_timepoints(df_raw, cfg, col="time_orig", assert_on_error=False)

    # "pupil_orig" is a bit tricky probably as it is the raw data with outliers _and_ NaNs
    # with the NaN possibly causing some algorithms problems
    # Let's create a novel column that has the missing data imputed
    df_raw = impute_orig_for_training(df_raw, cfg)

    # Automatic split of outliers to easy (blinks) and hard (mostly pupil segmentation algorithm noise)
    # that is closer to true trend
    df_raw = granularize_outlier_labels(df_raw, cfg)

    # Update no_outliers per subject
    df_raw = update_number_of_outliers(df_raw, cfg)

    # Stratify data to train and validation sets
    df_train, df_test = stratify_splits(df_raw, cfg)

    # Export the subset of data to DuckDB
    _ = export_dataframes_to_duckdb(
        df_train, df_test, db_name=cfg["DATA"]["filename_DuckDB"], data_dir=data_dir
    )

    return df_train, df_test

import_data

import_data(
    cfg: DictConfig, data_dir: str = None
) -> DataFrame

Import raw PLR data from individual subject CSV files.

PARAMETER DESCRIPTION
cfg

Configuration dictionary containing DATA settings.

TYPE: DictConfig

data_dir

Directory for output data files, by default None.

TYPE: str DEFAULT: None

RETURNS DESCRIPTION
DataFrame

Polars dataframe containing all imported subject data.

Source code in src/data_io/data_import.py
def import_data(cfg: DictConfig, data_dir: str = None) -> pl.DataFrame:
    """Import raw PLR data from individual subject CSV files.

    Parameters
    ----------
    cfg : DictConfig
        Configuration dictionary containing DATA settings.
    data_dir : str, optional
        Directory for output data files, by default None.

    Returns
    -------
    pl.DataFrame
        Polars dataframe containing all imported subject data.
    """
    # if you have access to the raw files, you can import them here (otherwise import the DuckDB)
    dir_in = get_repo_root() / cfg["DATA"]["individual_subjects_path"]
    df_raw = create_csvs_from_individual_subjects(
        individual_subjects_dir=str(dir_in), data_dir=data_dir, cfg=cfg
    )
    return pl.from_pandas(df_raw)

create_csvs_from_individual_subjects

create_csvs_from_individual_subjects(
    individual_subjects_dir: str,
    data_dir: str,
    no_of_timepoints: int = 1981,
    cfg: DictConfig = None,
) -> DataFrame

Create a combined dataframe from individual subject CSV files.

PARAMETER DESCRIPTION
individual_subjects_dir

Directory containing individual subject CSV files.

TYPE: str

data_dir

Directory for output data files.

TYPE: str

no_of_timepoints

Expected number of timepoints per subject, by default 1981.

TYPE: int DEFAULT: 1981

cfg

Configuration dictionary, by default None.

TYPE: DictConfig DEFAULT: None

RETURNS DESCRIPTION
DataFrame

Combined pandas dataframe with all subjects.

RAISES DESCRIPTION
FileNotFoundError

If the individual subjects directory does not exist.

Source code in src/data_io/data_import.py
def create_csvs_from_individual_subjects(
    individual_subjects_dir: str,
    data_dir: str,
    no_of_timepoints: int = 1981,
    cfg: DictConfig = None,
) -> pd.DataFrame:
    """Create a combined dataframe from individual subject CSV files.

    Parameters
    ----------
    individual_subjects_dir : str
        Directory containing individual subject CSV files.
    data_dir : str
        Directory for output data files.
    no_of_timepoints : int, optional
        Expected number of timepoints per subject, by default 1981.
    cfg : DictConfig, optional
        Configuration dictionary, by default None.

    Returns
    -------
    pd.DataFrame
        Combined pandas dataframe with all subjects.

    Raises
    ------
    FileNotFoundError
        If the individual subjects directory does not exist.
    """
    individual_subjects_path = Path(individual_subjects_dir)
    if not individual_subjects_path.exists():
        logger.error(
            "Individual subjects directory does not exist: {}".format(
                individual_subjects_dir
            )
        )
        raise FileNotFoundError
    else:
        logger.info(
            'Import individual subjects to CSV(s) from "{}"'.format(
                individual_subjects_dir
            )
        )

    files = list(individual_subjects_path.glob("*.csv"))
    logger.info("Found {} files".format(len(files)))  # Found 507 files

    list_of_dfs, outliers, subject_codes = [], [], []
    for i, file in enumerate(tqdm(files, desc="Importing individual subjects")):
        no_outliers, csv_data, subject_code = import_master_csv(
            i=i, csv_path=file, cfg=cfg
        )
        if csv_data.shape[0] == no_of_timepoints:
            list_of_dfs.append(csv_data)
            outliers.append(no_outliers)
            subject_codes.append(subject_code)
        else:
            logger.warning(
                "Subject {} has {} timepoints instead of {}".format(
                    subject_code, csv_data.shape[0], no_of_timepoints
                )
            )

    # Create dataframe from the list of dataframes
    df_raw = convert_list_of_dfs_to_df(list_of_dfs, outliers, subject_codes)
    max_no_outliers = max(outliers)  # 749 (out of 1981)
    logger.info("Max number of outliers per subject: {}".format(max_no_outliers))
    return df_raw

convert_list_of_dfs_to_df

convert_list_of_dfs_to_df(
    list_of_dfs, outliers, subject_codes
)

Convert a list of subject dataframes into a single combined dataframe.

PARAMETER DESCRIPTION
list_of_dfs

List of pandas dataframes, one per subject.

TYPE: list

outliers

List of outlier counts per subject.

TYPE: list

subject_codes

List of subject code identifiers.

TYPE: list

RETURNS DESCRIPTION
DataFrame

Combined dataframe with subject_code and no_outliers columns added.

Source code in src/data_io/data_import.py
def convert_list_of_dfs_to_df(list_of_dfs, outliers, subject_codes):
    """Convert a list of subject dataframes into a single combined dataframe.

    Parameters
    ----------
    list_of_dfs : list
        List of pandas dataframes, one per subject.
    outliers : list
        List of outlier counts per subject.
    subject_codes : list
        List of subject code identifiers.

    Returns
    -------
    pd.DataFrame
        Combined dataframe with subject_code and no_outliers columns added.
    """
    df_out = pd.DataFrame()
    for i, (df, no_outliers, code) in enumerate(
        tqdm(
            zip(list_of_dfs, outliers, subject_codes),
            desc="Converting to single dataframe",
            total=len(list_of_dfs),
        )
    ):
        # add scalars to the dataframe
        df["subject_code"] = code
        df["no_outliers"] = no_outliers
        if i == 0:
            df_out = df
        else:
            df_out = pd.concat([df_out, df])

    return df_out

export_split_dataframes

export_split_dataframes(
    df_train: DataFrame, df_val: DataFrame, data_dir: str
)

Export train and validation dataframes to CSV files.

PARAMETER DESCRIPTION
df_train

Training data dataframe.

TYPE: DataFrame

df_val

Validation data dataframe.

TYPE: DataFrame

data_dir

Directory to save the CSV files.

TYPE: str

Source code in src/data_io/data_import.py
def export_split_dataframes(
    df_train: pd.DataFrame, df_val: pd.DataFrame, data_dir: str
):
    """Export train and validation dataframes to CSV files.

    Parameters
    ----------
    df_train : pd.DataFrame
        Training data dataframe.
    df_val : pd.DataFrame
        Validation data dataframe.
    data_dir : str
        Directory to save the CSV files.
    """
    data_path = Path(data_dir)
    if not data_path.exists():
        logger.info("Create output directory: {}".format(data_dir))
        data_path.mkdir(parents=True, exist_ok=True)

    train_path, val_path = define_split_csv_paths(data_dir=data_dir)

    logger.info("Export train split to {}".format(train_path))
    df_train.to_csv(train_path, index=False)

    logger.info("Export val split to {}".format(val_path))
    df_val.to_csv(val_path, index=False)

define_split

define_split(
    subject_codes: list,
    csv_subsets: list,
    indices: list,
    split: str,
    drop_raw_pupil_values: bool = False,
)

Create a combined dataframe for a specific train/test split.

PARAMETER DESCRIPTION
subject_codes

List of all subject code identifiers.

TYPE: list

csv_subsets

List of dataframes for all subjects.

TYPE: list

indices

Indices of subjects to include in this split.

TYPE: list

split

Name of the split ("train" or "test").

TYPE: str

drop_raw_pupil_values

Whether to drop rows with NaN pupil values, by default False.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
DataFrame

Combined dataframe for the split.

Source code in src/data_io/data_import.py
def define_split(
    subject_codes: list,
    csv_subsets: list,
    indices: list,
    split: str,
    drop_raw_pupil_values: bool = False,
):
    """Create a combined dataframe for a specific train/test split.

    Parameters
    ----------
    subject_codes : list
        List of all subject code identifiers.
    csv_subsets : list
        List of dataframes for all subjects.
    indices : list
        Indices of subjects to include in this split.
    split : str
        Name of the split ("train" or "test").
    drop_raw_pupil_values : bool, optional
        Whether to drop rows with NaN pupil values, by default False.

    Returns
    -------
    pd.DataFrame
        Combined dataframe for the split.
    """
    codes = np.array(subject_codes)[indices]
    list_of_df = [csv_subsets[i] for i in indices]
    assert len(codes) == len(list_of_df)
    logger.info('Split "{}" contains {} subjects'.format(split, len(codes)))

    for i, code in enumerate(codes):
        df_with_code = list_of_df[i]
        no_of_timepoints1 = df_with_code.shape[0]
        df_with_code["subject_code"] = code
        if drop_raw_pupil_values:
            df_with_code = df_with_code.dropna()
        no_of_timepoints2 = df_with_code.shape[0]
        if no_of_timepoints1 != no_of_timepoints2:
            logger.info(
                "Subject {} had {} timepoints dropped due to pupil raw NaNs".format(
                    code, no_of_timepoints1 - no_of_timepoints2
                )
            )

        if i == 0:
            df_out = df_with_code
        else:
            df_out = pd.concat([df_out, df_with_code])

    logger.info(
        'Total of {} timepoints ({}x{}) in split "{}"'.format(
            df_out.shape[0], len(codes), no_of_timepoints1, split
        )
    )

    return df_out

import_master_csv

import_master_csv(i: int, csv_path: str, cfg: DictConfig)

Import and preprocess a single subject's CSV file.

Performs column selection, time vector quality checks, column renaming, and linear interpolation of color channels.

PARAMETER DESCRIPTION
i

Subject index (for first-subject logging).

TYPE: int

csv_path

Path to the subject's CSV file.

TYPE: str

cfg

Configuration dictionary with DATA settings.

TYPE: DictConfig

RETURNS DESCRIPTION
tuple

Tuple containing (no_outliers, csv_subset, subject_code).

Source code in src/data_io/data_import.py
def import_master_csv(i: int, csv_path: str, cfg: DictConfig):
    """Import and preprocess a single subject's CSV file.

    Performs column selection, time vector quality checks, column renaming,
    and linear interpolation of color channels.

    Parameters
    ----------
    i : int
        Subject index (for first-subject logging).
    csv_path : str
        Path to the subject's CSV file.
    cfg : DictConfig
        Configuration dictionary with DATA settings.

    Returns
    -------
    tuple
        Tuple containing (no_outliers, csv_subset, subject_code).
    """
    subject_code = Path(csv_path).name.split(".")[0].split("_")[0]
    csv_raw = pd.read_csv(csv_path)  # e.g. (1981, 94) # TODO! direct Polars import
    csv_subset = csv_raw[cfg["DATA"]["COLUMNS_TO_KEEP"]]
    if i == 0:
        logger.info(
            "Keeping the following column subset: {}".format(
                cfg["DATA"]["COLUMNS_TO_KEEP"]
            )
        )
        logger.info(
            f"{csv_subset.shape[1]} columns out of total of {csv_raw.shape[1]} columns"
        )

    # Check the time vector
    # Just to be safe, use the same time vector for all subjects ("time"), but store the original time vector
    # to a new column ("time_orig") in case you want to have a look at it later, or they are actually really irregular,
    # and your modeling approach could exploit the multiscale nature of the data?
    time_orig, time_ideal, time_checks = check_time_vector_quality(
        subject_code, csv_subset, cfg
    )
    csv_subset = csv_subset.assign(time=pd.Series(time_ideal))
    csv_subset = csv_subset.assign(time_orig=pd.Series(time_orig))
    if not time_checks["OK"]:
        logger.warning(
            "Time vector quality checks failed for subject {}".format(subject_code)
        )
        for key, value in time_checks.items():
            logger.warning(f"{key}: {value}")

    # rename color columns for something less ambiguous
    csv_subset = csv_subset.rename(columns={"R": "Red", "B": "Blue"})
    csv_subset = csv_subset.rename(
        columns={
            "denoised": "pupil_gt",
            "pupil_raw": "pupil_orig",
            "pupil_toBeImputed": "pupil_raw",
        }
    )

    if i == 0:
        logger.info("Renamed the columns to: {}".format(list(csv_subset.columns)))

    # The color columns have NaNs for outliers?
    # the light was on obviously even during the blinks and there is no ambiguity there
    csv_subset["Red"] = linear_interpolation_of_col(column=csv_subset["Red"])
    csv_subset["Blue"] = linear_interpolation_of_col(column=csv_subset["Blue"])

    # if first or last value is NaN, the interpolation will not work
    csv_subset = fix_for_orphaned_nans(
        subject_code, csv_subset, cfg, cols=("Red", "Blue")
    )

    no_outliers = csv_subset["outlier_labels"].sum()

    return no_outliers, csv_subset, subject_code

linear_interpolation_of_col

linear_interpolation_of_col(column: Series)

Apply linear interpolation to fill NaN values in a pandas Series.

PARAMETER DESCRIPTION
column

Series with potential NaN values.

TYPE: Series

RETURNS DESCRIPTION
Series

Series with NaN values linearly interpolated.

Source code in src/data_io/data_import.py
def linear_interpolation_of_col(column: pd.Series):
    """Apply linear interpolation to fill NaN values in a pandas Series.

    Parameters
    ----------
    column : pd.Series
        Series with potential NaN values.

    Returns
    -------
    pd.Series
        Series with NaN values linearly interpolated.
    """
    return column.interpolate()

import_data_from_duckdb

import_data_from_duckdb(
    data_cfg: DictConfig,
    data_dir: str,
    use_demo_data: bool = False,
)

Import PLR data from a DuckDB database file.

PARAMETER DESCRIPTION
data_cfg

Data configuration dictionary with filename_DuckDB setting.

TYPE: DictConfig

data_dir

Directory containing the DuckDB file.

TYPE: str

use_demo_data

Whether to use demo data instead of full dataset, by default False.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
tuple

Tuple containing (df_train, df_val) as Polars dataframes.

RAISES DESCRIPTION
FileNotFoundError

If the DuckDB file does not exist.

Source code in src/data_io/data_import.py
def import_data_from_duckdb(
    data_cfg: DictConfig, data_dir: str, use_demo_data: bool = False
):
    """Import PLR data from a DuckDB database file.

    Parameters
    ----------
    data_cfg : DictConfig
        Data configuration dictionary with filename_DuckDB setting.
    data_dir : str
        Directory containing the DuckDB file.
    use_demo_data : bool, optional
        Whether to use demo data instead of full dataset, by default False.

    Returns
    -------
    tuple
        Tuple containing (df_train, df_val) as Polars dataframes.

    Raises
    ------
    FileNotFoundError
        If the DuckDB file does not exist.
    """
    db_path = Path(get_duckdb_file(data_cfg, use_demo_data))
    if not db_path.exists():
        logger.error("DuckDB file not found: {}".format(db_path))
        logger.error("Typo with the filename, or you simply do not have this .db file?")
        logger.error(
            'Set cfg["DATA"]["import_from_DuckDB"] = False so you can import from .CSV files'
        )
        raise FileNotFoundError
    else:
        logger.info(
            "Importing the PLR data from DuckDB database as Polars dataframes ({})".format(
                data_cfg["filename_DuckDB"]
            )
        )

    df_train, df_val = import_duckdb_as_dataframes(str(db_path))
    check_data_import(df_train, df_val)

    return df_train, df_val

Data Utilities

data_utils

convert_sec_to_date

convert_sec_to_date(
    df: DataFrame,
    time_col: str = "ds",
    seconds_offset: float = 1,
) -> DataFrame

Convert seconds to datetime format in a dataframe column.

PARAMETER DESCRIPTION
df

Input dataframe containing the time column.

TYPE: DataFrame

time_col

Name of the time column to convert, by default "ds".

TYPE: str DEFAULT: 'ds'

seconds_offset

Offset in seconds to add before conversion, by default 1.

TYPE: float DEFAULT: 1

RETURNS DESCRIPTION
DataFrame

Dataframe with the time column converted to datetime.

Source code in src/data_io/data_utils.py
def convert_sec_to_date(
    df: pd.DataFrame, time_col: str = "ds", seconds_offset: float = 1
) -> pd.DataFrame:
    """Convert seconds to datetime format in a dataframe column.

    Parameters
    ----------
    df : pd.DataFrame
        Input dataframe containing the time column.
    time_col : str, optional
        Name of the time column to convert, by default "ds".
    seconds_offset : float, optional
        Offset in seconds to add before conversion, by default 1.

    Returns
    -------
    pd.DataFrame
        Dataframe with the time column converted to datetime.
    """
    df[time_col] = pd.to_datetime(
        1000 * (df[time_col] + seconds_offset), unit="ms", errors="coerce"
    )
    # date = pd.date_range('2004-01-01', '2018-01-01', freq="AS")
    return df

convert_sec_to_millisec

convert_sec_to_millisec(
    df: DataFrame,
    time_col: str = "ds",
    seconds_offset: float = 1,
) -> DataFrame

Convert seconds to milliseconds in a dataframe column.

PARAMETER DESCRIPTION
df

Input dataframe containing the time column.

TYPE: DataFrame

time_col

Name of the time column to convert, by default "ds".

TYPE: str DEFAULT: 'ds'

seconds_offset

Offset in seconds to add before conversion, by default 1.

TYPE: float DEFAULT: 1

RETURNS DESCRIPTION
DataFrame

Dataframe with the time column converted to milliseconds.

Source code in src/data_io/data_utils.py
def convert_sec_to_millisec(
    df: pd.DataFrame, time_col: str = "ds", seconds_offset: float = 1
) -> pd.DataFrame:
    """Convert seconds to milliseconds in a dataframe column.

    Parameters
    ----------
    df : pd.DataFrame
        Input dataframe containing the time column.
    time_col : str, optional
        Name of the time column to convert, by default "ds".
    seconds_offset : float, optional
        Offset in seconds to add before conversion, by default 1.

    Returns
    -------
    pd.DataFrame
        Dataframe with the time column converted to milliseconds.
    """
    df[time_col] += seconds_offset
    df[time_col] *= 1000
    return df

split_df_to_samples

split_df_to_samples(
    df: DataFrame,
    split: str = "train",
    subject_col_name: str = "unique_id",
) -> dict[str, DataFrame]

Split a dataframe into a dictionary of single-subject dataframes.

PARAMETER DESCRIPTION
df

Input dataframe containing multiple subjects.

TYPE: DataFrame

split

Name of the data split (for logging), by default "train".

TYPE: str DEFAULT: 'train'

subject_col_name

Name of the column containing subject identifiers, by default "unique_id".

TYPE: str DEFAULT: 'unique_id'

RETURNS DESCRIPTION
dict

Dictionary mapping subject codes to their respective dataframes.

Source code in src/data_io/data_utils.py
def split_df_to_samples(
    df: pd.DataFrame, split: str = "train", subject_col_name: str = "unique_id"
) -> dict[str, pd.DataFrame]:
    """Split a dataframe into a dictionary of single-subject dataframes.

    Parameters
    ----------
    df : pd.DataFrame
        Input dataframe containing multiple subjects.
    split : str, optional
        Name of the data split (for logging), by default "train".
    subject_col_name : str, optional
        Name of the column containing subject identifiers, by default "unique_id".

    Returns
    -------
    dict
        Dictionary mapping subject codes to their respective dataframes.
    """
    subject_codes = df[subject_col_name].unique()
    no_of_unique_subjects = len(subject_codes)

    logger.info(
        "Splitting Pandas Dataframe data into {} single-subject Dataframes".format(
            no_of_unique_subjects
        )
    )

    dict_of_dfs = {}
    # TOOPTIMIZE: There is probably some more efficient way to do this
    for i, code in enumerate(
        tqdm(subject_codes, desc='Split "{}", df to dict of dfs'.format(split))
    ):
        df_sample = df.loc[df[subject_col_name] == code]
        df_sample = df_sample.drop(subject_col_name, axis=1)
        dict_of_dfs[code] = df_sample

    assert len(dict_of_dfs) == no_of_unique_subjects, (
        "Number of subjects does not match!"
    )

    return dict_of_dfs

get_subset_of_data

get_subset_of_data(
    df_subset: DataFrame,
    t0: float = 18.0,
    t1: float = 19.05,
) -> DataFrame

Extract a time-windowed subset of data from a dataframe.

Filters data to keep only rows within the specified time range and limits to the first 3 subjects (96 rows).

PARAMETER DESCRIPTION
df_subset

Input dataframe with a 'ds' time column.

TYPE: DataFrame

t0

Start time for filtering, by default 18.0.

TYPE: float DEFAULT: 18.0

t1

End time for filtering, by default 19.05.

TYPE: float DEFAULT: 19.05

RETURNS DESCRIPTION
DataFrame

Filtered dataframe containing only data within the time window.

Source code in src/data_io/data_utils.py
def get_subset_of_data(
    df_subset: pd.DataFrame, t0: float = 18.0, t1: float = 19.05
) -> pd.DataFrame:
    """Extract a time-windowed subset of data from a dataframe.

    Filters data to keep only rows within the specified time range
    and limits to the first 3 subjects (96 rows).

    Parameters
    ----------
    df_subset : pd.DataFrame
        Input dataframe with a 'ds' time column.
    t0 : float, optional
        Start time for filtering, by default 18.0.
    t1 : float, optional
        End time for filtering, by default 19.05.

    Returns
    -------
    pd.DataFrame
        Filtered dataframe containing only data within the time window.
    """
    df_subset.drop(df_subset[df_subset.ds > t1].index, inplace=True)
    df_subset.drop(df_subset[df_subset.ds < t0].index, inplace=True)
    df_subset.reset_index(drop=True, inplace=True)

    # Manually get 3 first subjects
    df_subset = df_subset.iloc[:96]

    return df_subset

define_split_csv_paths

define_split_csv_paths(
    data_dir: str, suffix: str = ""
) -> tuple[Path, Path]

Define file paths for train and validation CSV files.

PARAMETER DESCRIPTION
data_dir

Directory containing the data files.

TYPE: str

suffix

Suffix to append to filenames, by default "".

TYPE: str DEFAULT: ''

RETURNS DESCRIPTION
tuple of Path

Tuple containing (train_path, val_path).

Source code in src/data_io/data_utils.py
def define_split_csv_paths(data_dir: str, suffix: str = "") -> tuple[Path, Path]:
    """Define file paths for train and validation CSV files.

    Parameters
    ----------
    data_dir : str
        Directory containing the data files.
    suffix : str, optional
        Suffix to append to filenames, by default "".

    Returns
    -------
    tuple of Path
        Tuple containing (train_path, val_path).
    """
    data_path = Path(data_dir)
    train_path = data_path / f"train_PLR{suffix}.csv"
    val_path = data_path / f"val_PLR{suffix}.csv"

    return train_path, val_path

import_nonnan_data_from_csv

import_nonnan_data_from_csv(
    data_dir: str, suffix: str = "_nonNan"
) -> tuple[DataFrame, DataFrame]

Import PLR data from CSV files with NaN/outlier rows removed.

PARAMETER DESCRIPTION
data_dir

Directory containing the CSV files.

TYPE: str

suffix

Suffix for the CSV filenames, by default "_nonNan".

TYPE: str DEFAULT: '_nonNan'

RETURNS DESCRIPTION
tuple of pl.DataFrame

Tuple containing (df_train, df_val) as Polars dataframes.

Source code in src/data_io/data_utils.py
def import_nonnan_data_from_csv(
    data_dir: str, suffix: str = "_nonNan"
) -> tuple[pl.DataFrame, pl.DataFrame]:
    """Import PLR data from CSV files with NaN/outlier rows removed.

    Parameters
    ----------
    data_dir : str
        Directory containing the CSV files.
    suffix : str, optional
        Suffix for the CSV filenames, by default "_nonNan".

    Returns
    -------
    tuple of pl.DataFrame
        Tuple containing (df_train, df_val) as Polars dataframes.
    """
    logger.info("Import PLR data with the NaN/outlier rows removed")
    train_path, val_path = define_split_csv_paths(data_dir=data_dir, suffix="_nonNan")
    logger.info("TRAIN split path = {}".format(train_path))
    df_train = pl.read_csv(train_path)
    logger.info("VAL split path = {}".format(val_path))
    df_val = pl.read_csv(val_path)

    return df_train, df_val

import_PLR_data_from_CSV

import_PLR_data_from_CSV(
    data_dir: str,
) -> tuple[DataFrame, DataFrame]

Import PLR data from train and validation CSV files.

PARAMETER DESCRIPTION
data_dir

Directory containing the CSV files.

TYPE: str

RETURNS DESCRIPTION
tuple of pl.DataFrame

Tuple containing (df_train, df_val) as Polars dataframes.

Source code in src/data_io/data_utils.py
def import_PLR_data_from_CSV(data_dir: str) -> tuple[pl.DataFrame, pl.DataFrame]:
    """Import PLR data from train and validation CSV files.

    Parameters
    ----------
    data_dir : str
        Directory containing the CSV files.

    Returns
    -------
    tuple of pl.DataFrame
        Tuple containing (df_train, df_val) as Polars dataframes.
    """
    logger.info('Import data from CSVs in "{}"'.format(data_dir))
    train_path, val_path = define_split_csv_paths(data_dir=data_dir)

    logger.info("Import train split from {}".format(train_path))
    df_train = pl.read_csv(train_path)

    logger.info("Import val split from {}".format(val_path))
    df_val = pl.read_csv(val_path)

    return df_train, df_val

export_dataframe_to_duckdb

export_dataframe_to_duckdb(
    df: DataFrame,
    db_name: str,
    cfg: DictConfig,
    name: Optional[str] = None,
    service_name: str = "duckdb",
    debug_DuckDBWrite: bool = True,
    copy_orig_db: bool = False,
) -> str

Export a Polars dataframe to a DuckDB database.

PARAMETER DESCRIPTION
df

Polars dataframe to export.

TYPE: DataFrame

db_name

Name of the output database file.

TYPE: str

cfg

Configuration dictionary containing DATA settings.

TYPE: DictConfig

name

Name identifier for the export, by default None.

TYPE: str DEFAULT: None

service_name

Service name for artifact directory, by default "duckdb".

TYPE: str DEFAULT: 'duckdb'

debug_DuckDBWrite

Whether to verify the write by reading back, by default True.

TYPE: bool DEFAULT: True

copy_orig_db

Whether to copy the original database instead of writing new, by default False.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
str

Path to the created DuckDB database file.

Source code in src/data_io/data_utils.py
def export_dataframe_to_duckdb(
    df: pl.DataFrame,
    db_name: str,
    cfg: DictConfig,
    name: Optional[str] = None,
    service_name: str = "duckdb",
    debug_DuckDBWrite: bool = True,
    copy_orig_db: bool = False,
) -> str:
    """Export a Polars dataframe to a DuckDB database.

    Parameters
    ----------
    df : pl.DataFrame
        Polars dataframe to export.
    db_name : str
        Name of the output database file.
    cfg : DictConfig
        Configuration dictionary containing DATA settings.
    name : str, optional
        Name identifier for the export, by default None.
    service_name : str, optional
        Service name for artifact directory, by default "duckdb".
    debug_DuckDBWrite : bool, optional
        Whether to verify the write by reading back, by default True.
    copy_orig_db : bool, optional
        Whether to copy the original database instead of writing new, by default False.

    Returns
    -------
    str
        Path to the created DuckDB database file.
    """
    dir_out = get_artifacts_dir(service_name=service_name)
    dir_out.mkdir(parents=True, exist_ok=True)
    db_path_out = dir_out / db_name
    if db_path_out.exists():
        db_path_out.unlink()

    if copy_orig_db:
        # TODO! debug code, eliminate when you figure out what to do with the anomaly detection
        # the one that flow_import_data() uses
        db_path_in = get_duckdb_file(data_cfg=cfg["DATA"])

        shutil.copyfile(db_path_in, db_path_out)
        logger.info("Copying the DuckDB Database, from {}".format(db_path_in))
        logger.info("to: {}".format(db_path_out))
    else:
        logger.info("Writing dataframe to DuckDB Database: {}".format(db_path_out))
        logger.info(
            "Shape of the dataframe written to disk as DuckDB {}".format(df.shape)
        )
        if db_path_out.exists():
            logger.warning("DuckDB Database already exists, removing the old one")
            db_path_out.unlink()
        with duckdb.connect(database=str(db_path_out), read_only=False) as con:
            con.execute("""
                        CREATE TABLE IF NOT EXISTS 'train' AS SELECT * FROM df;
                    """)

    if debug_DuckDBWrite:
        logger.debug("Reading back from DuckDB (to test that stuff was written)")
        if copy_orig_db:
            df_train = load_from_duckdb_as_dataframe(
                db_path=db_path_out, cfg=cfg, split="train"
            )
            df_val = load_from_duckdb_as_dataframe(
                db_path=db_path_out, cfg=cfg, split="val"
            )
            df_back = pl.concat([df_train, df_val])
        else:
            df_back = load_from_duckdb_as_dataframe(db_path=db_path_out, cfg=cfg)
        # TODO! figure out why this happens? :o
        #  Saved dataframe shape (1004367, 15) does not match the shape read back: (1004367, 13)
        #  does not write the "new cols" for some reason, "_imputed" suffix ones
        assert df.shape == df_back.shape, (
            f"Saved dataframe shape {df.shape} does not match the "
            f"shape read back: {df_back.shape} (Samples (time points), Features)"
        )
        logger.debug("Read successful!")

    return db_path_out

load_both_splits_from_duckdb

load_both_splits_from_duckdb(
    db_path: str, cfg: DictConfig
) -> DataFrame

Load and concatenate both train and validation splits from DuckDB.

PARAMETER DESCRIPTION
db_path

Path to the DuckDB database file.

TYPE: str

cfg

Configuration dictionary.

TYPE: DictConfig

RETURNS DESCRIPTION
DataFrame

Concatenated Polars dataframe containing both splits.

Source code in src/data_io/data_utils.py
def load_both_splits_from_duckdb(db_path: str, cfg: DictConfig) -> pl.DataFrame:
    """Load and concatenate both train and validation splits from DuckDB.

    Parameters
    ----------
    db_path : str
        Path to the DuckDB database file.
    cfg : DictConfig
        Configuration dictionary.

    Returns
    -------
    pl.DataFrame
        Concatenated Polars dataframe containing both splits.
    """
    df_train = load_from_duckdb_as_dataframe(db_path=db_path, cfg=cfg, split="train")
    df_val = load_from_duckdb_as_dataframe(db_path=db_path, cfg=cfg, split="val")
    return pl.concat([df_train, df_val])

load_from_duckdb_as_dataframe

load_from_duckdb_as_dataframe(
    db_path: str, cfg: DictConfig, split: str = "train"
) -> DataFrame

Load a data split from DuckDB as a Polars dataframe.

PARAMETER DESCRIPTION
db_path

Path to the DuckDB database file.

TYPE: str

cfg

Configuration dictionary.

TYPE: DictConfig

split

Name of the split to load ("train" or "test"), by default "train".

TYPE: str DEFAULT: 'train'

RETURNS DESCRIPTION
DataFrame

Polars dataframe containing the requested split.

RAISES DESCRIPTION
Exception

If there is an error reading from DuckDB.

Source code in src/data_io/data_utils.py
def load_from_duckdb_as_dataframe(
    db_path: str, cfg: DictConfig, split: str = "train"
) -> pl.DataFrame:
    """Load a data split from DuckDB as a Polars dataframe.

    Parameters
    ----------
    db_path : str
        Path to the DuckDB database file.
    cfg : DictConfig
        Configuration dictionary.
    split : str, optional
        Name of the split to load ("train" or "test"), by default "train".

    Returns
    -------
    pl.DataFrame
        Polars dataframe containing the requested split.

    Raises
    ------
    Exception
        If there is an error reading from DuckDB.
    """
    try:
        with duckdb.connect(database=db_path, read_only=True) as con:
            df_load = con.query(f"SELECT * FROM {split}").pl()
    except Exception as e:
        logger.error("Error in reading DuckDB: {}".format(e))
        raise e
    return df_load

export_dataframes_to_duckdb

export_dataframes_to_duckdb(
    df_train: Union[DataFrame, DataFrame],
    df_test: Union[DataFrame, DataFrame],
    db_name: str = "SERI_PLR_GLAUCOMA.db",
    data_dir: Optional[str] = None,
    debug_DuckDBWrite: bool = True,
) -> str

Export train and test dataframes to a DuckDB database.

Creates separate tables for train and test splits in the database.

PARAMETER DESCRIPTION
df_train

Training data dataframe.

TYPE: DataFrame or DataFrame

df_test

Test data dataframe.

TYPE: DataFrame or DataFrame

db_name

Name of the database file, by default "SERI_PLR_GLAUCOMA.db".

TYPE: str DEFAULT: 'SERI_PLR_GLAUCOMA.db'

data_dir

Directory to save the database, by default None.

TYPE: str DEFAULT: None

debug_DuckDBWrite

Whether to verify the write by reading back, by default True.

TYPE: bool DEFAULT: True

RETURNS DESCRIPTION
str

Path to the created DuckDB database file.

Source code in src/data_io/data_utils.py
def export_dataframes_to_duckdb(
    df_train: Union[pl.DataFrame, pd.DataFrame],
    df_test: Union[pl.DataFrame, pd.DataFrame],
    db_name: str = "SERI_PLR_GLAUCOMA.db",
    data_dir: Optional[str] = None,
    debug_DuckDBWrite: bool = True,
) -> str:
    """Export train and test dataframes to a DuckDB database.

    Creates separate tables for train and test splits in the database.

    Parameters
    ----------
    df_train : pl.DataFrame or pd.DataFrame
        Training data dataframe.
    df_test : pl.DataFrame or pd.DataFrame
        Test data dataframe.
    db_name : str, optional
        Name of the database file, by default "SERI_PLR_GLAUCOMA.db".
    data_dir : str, optional
        Directory to save the database, by default None.
    debug_DuckDBWrite : bool, optional
        Whether to verify the write by reading back, by default True.

    Returns
    -------
    str
        Path to the created DuckDB database file.
    """
    # https://duckdb.org/docs/api/python/overview.html#persistent-storage
    db_path = Path(data_dir) / db_name
    logger.info("Writing dataframes to DuckDB Database: {}".format(db_path))

    if db_path.exists():
        logger.warning("DuckDB Database already exists, removing the old one")
        db_path.unlink()

    with duckdb.connect(database=str(db_path), read_only=False) as con:
        # TOOPTIMIZE! Blue and Red are now written as double (could be just uint8)
        con.execute("""
                    CREATE TABLE IF NOT EXISTS 'train' AS SELECT * FROM df_train;
                """)
        con.execute("""
                                CREATE TABLE IF NOT EXISTS 'test' AS SELECT * FROM df_test;
                            """)
    logger.info("Write finished".format())

    if debug_DuckDBWrite:
        logger.info("Reading back from DuckDB")
        logger.info("TRAIN split")
        with duckdb.connect(database=str(db_path), read_only=True) as con:
            con.query("SELECT * FROM train").show()
        logger.info("TEST split")
        with duckdb.connect(database=str(db_path), read_only=True) as con:
            con.query("SELECT * FROM test").show()
        logger.info("Read successful!")

    return str(db_path)

import_duckdb_as_dataframes

import_duckdb_as_dataframes(
    db_path: str,
) -> tuple[DataFrame, DataFrame]

Import train and test dataframes from a DuckDB database.

PARAMETER DESCRIPTION
db_path

Path to the DuckDB database file.

TYPE: str

RETURNS DESCRIPTION
tuple of pl.DataFrame

Tuple containing (df_train, df_test) as Polars dataframes.

RAISES DESCRIPTION
Exception

If there is an error reading from DuckDB.

Source code in src/data_io/data_utils.py
def import_duckdb_as_dataframes(db_path: str) -> tuple[pl.DataFrame, pl.DataFrame]:
    """Import train and test dataframes from a DuckDB database.

    Parameters
    ----------
    db_path : str
        Path to the DuckDB database file.

    Returns
    -------
    tuple of pl.DataFrame
        Tuple containing (df_train, df_test) as Polars dataframes.

    Raises
    ------
    Exception
        If there is an error reading from DuckDB.
    """
    logger.info("Reading data from DuckDB Database: {}".format(db_path))
    try:
        with duckdb.connect(database=db_path, read_only=True) as con:
            train = con.query("SELECT * FROM train")
            df_train = train.pl()
            test = con.query("SELECT * FROM test")
            df_test = test.pl()
    except Exception as e:
        logger.error("Error in reading DuckDB: {}".format(e))
        raise e
    logger.info("Done with the read from DuckDb to Polars dataframes".format())

    return df_train, df_test

check_data_import

check_data_import(
    df_train: DataFrame,
    df_val: DataFrame,
    display_outliers: bool = True,
) -> None

Validate imported data splits for data leakage and display statistics.

Checks that no subject appears in both train and validation splits, and optionally displays outlier statistics.

PARAMETER DESCRIPTION
df_train

Training data dataframe.

TYPE: DataFrame

df_val

Validation data dataframe.

TYPE: DataFrame

display_outliers

Whether to display outlier statistics, by default True.

TYPE: bool DEFAULT: True

RAISES DESCRIPTION
ValueError

If data leakage is detected (same subject in both splits).

Source code in src/data_io/data_utils.py
def check_data_import(
    df_train: pl.DataFrame, df_val: pl.DataFrame, display_outliers: bool = True
) -> None:
    """Validate imported data splits for data leakage and display statistics.

    Checks that no subject appears in both train and validation splits,
    and optionally displays outlier statistics.

    Parameters
    ----------
    df_train : pl.DataFrame
        Training data dataframe.
    df_val : pl.DataFrame
        Validation data dataframe.
    display_outliers : bool, optional
        Whether to display outlier statistics, by default True.

    Raises
    ------
    ValueError
        If data leakage is detected (same subject in both splits).
    """
    unique_subjects_train = get_unique_polars_rows(
        df_train,
        unique_col="subject_code",
        value_col="pupil_raw",
        split="train",
        df_string="PLR",
    )
    unique_subjects_val = get_unique_polars_rows(
        df_val,
        unique_col="subject_code",
        value_col="pupil_raw",
        split="val",
        df_string="PLR",
    )

    # Check for leakage, you cannot have the same subject both in train and validation
    matching_codes = unique_subjects_train.join(
        unique_subjects_val, on="subject_code", how="inner"
    )
    if len(matching_codes) > 0:
        logger.error(
            "Data leakage detected! The same subject code is found in both train and validation splits. Redefine the splits!"
        )
        raise ValueError(
            "Data leakage detected! The same subject code is found in both train and validation"
        )

    if display_outliers:
        # When you are importing from CSVs, this can be confusing as the outliers are found from both "pupil_faw"
        # and "outlier_labels" so as a quick fix, skip the confusing display, or define the outlier percentage
        # correctly for CSV import, see e.g. prepare_dataframe_for_imputation()
        # TODO! Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future.
        #  Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
        #  data_utils.py:266: DeprecationWarning:
        no_outliers = int(df_train.select(pl.col("pupil_raw").null_count()).to_numpy())
        train_outlier_percentage = no_outliers / df_train.shape[0] * 100

        no_outliers = int(df_val.select(pl.col("pupil_raw").null_count()).to_numpy())
        val_outlier_percentage = no_outliers / df_val.shape[0] * 100

        # TODO! % of outliers
        logger.info(
            "Train split shape: {}, unique subjects = {} ({:.2f}% missing)".format(
                df_train.shape, len(unique_subjects_train), train_outlier_percentage
            )
        )
        logger.info(
            "Val split shape: {}, unique subjects = {} ({:.2f}% missing)".format(
                df_val.shape, len(unique_subjects_val), val_outlier_percentage
            )
        )

prepare_dataframe_for_imputation

prepare_dataframe_for_imputation(
    df: DataFrame, cfg: DictConfig
) -> DataFrame

Prepare a dataframe for imputation by fixing light stimuli and setting outliers.

PARAMETER DESCRIPTION
df

Input dataframe containing PLR data.

TYPE: DataFrame

cfg

Configuration dictionary.

TYPE: DictConfig

RETURNS DESCRIPTION
DataFrame

Prepared dataframe ready for imputation.

Source code in src/data_io/data_utils.py
def prepare_dataframe_for_imputation(df: pl.DataFrame, cfg: DictConfig) -> pl.DataFrame:
    """Prepare a dataframe for imputation by fixing light stimuli and setting outliers.

    Parameters
    ----------
    df : pl.DataFrame
        Input dataframe containing PLR data.
    cfg : DictConfig
        Configuration dictionary.

    Returns
    -------
    pl.DataFrame
        Prepared dataframe ready for imputation.
    """
    # Fix light stimuli vector
    df = fix_light_stimuli_vector(df, cfg)

    # Set pupil_raw outliers to null
    df = set_outliers_to_null(df, cfg)

    return df

fix_light_stimuli_vector

fix_light_stimuli_vector(
    df: DataFrame,
    cfg: DictConfig,
    drop_colors: bool = False,
) -> DataFrame

Combine Red and Blue channels into a single light stimuli column.

PARAMETER DESCRIPTION
df

Input dataframe with Red and Blue columns.

TYPE: DataFrame

cfg

Configuration dictionary.

TYPE: DictConfig

drop_colors

Whether to drop the original Red and Blue columns, by default False.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
DataFrame

Dataframe with combined light_stimuli column.

Source code in src/data_io/data_utils.py
def fix_light_stimuli_vector(
    df: pl.DataFrame, cfg: DictConfig, drop_colors: bool = False
) -> pl.DataFrame:
    """Combine Red and Blue channels into a single light stimuli column.

    Parameters
    ----------
    df : pl.DataFrame
        Input dataframe with Red and Blue columns.
    cfg : DictConfig
        Configuration dictionary.
    drop_colors : bool, optional
        Whether to drop the original Red and Blue columns, by default False.

    Returns
    -------
    pl.DataFrame
        Dataframe with combined light_stimuli column.
    """
    logger.info("Combining Red and Blue into a single light stimuli column")
    df = df.with_columns(light_stimuli=pl.Series(df["Blue"] + df["Red"]))
    df = interpolate_missing_light_stimuli_values(df, cfg, col_name="light_stimuli")
    # remove the original columns
    if drop_colors:
        logger.debug("Dropping the original Red and Blue columns")
        df = df.drop(["Blue", "Red"])
    else:
        logger.debug("Keeping the original Red and Blue columns")
        df = interpolate_missing_light_stimuli_values(df, cfg, col_name="Blue")
        df = interpolate_missing_light_stimuli_values(df, cfg, col_name="Red")

    return df

interpolate_missing_light_stimuli_values

interpolate_missing_light_stimuli_values(
    df: DataFrame,
    cfg: DictConfig,
    col_name: str = "light_stimuli",
) -> DataFrame

Interpolate missing values in a light stimuli column.

PARAMETER DESCRIPTION
df

Input dataframe.

TYPE: DataFrame

cfg

Configuration dictionary.

TYPE: DictConfig

col_name

Name of the column to interpolate, by default "light_stimuli".

TYPE: str DEFAULT: 'light_stimuli'

RETURNS DESCRIPTION
DataFrame

Dataframe with interpolated values.

Source code in src/data_io/data_utils.py
def interpolate_missing_light_stimuli_values(
    df: pl.DataFrame, cfg: DictConfig, col_name: str = "light_stimuli"
) -> pl.DataFrame:
    """Interpolate missing values in a light stimuli column.

    Parameters
    ----------
    df : pl.DataFrame
        Input dataframe.
    cfg : DictConfig
        Configuration dictionary.
    col_name : str, optional
        Name of the column to interpolate, by default "light_stimuli".

    Returns
    -------
    pl.DataFrame
        Dataframe with interpolated values.
    """
    no_nulls = int(df.select(pl.col(col_name).null_count()).to_numpy())
    if no_nulls > 0:
        logger.debug(
            "Interpolating missing {} values (n = {})".format(col_name, no_nulls)
        )
        df = df.with_columns(
            light_stimuli=pl.when(pl.col(col_name).is_null())
            .then(pl.col(col_name).interpolate())
            .otherwise(pl.col(col_name))
        )

    return df

set_outliers_to_null

set_outliers_to_null(
    df: DataFrame, cfg: DictConfig
) -> DataFrame

Set outlier values to null in the pupil_raw column based on outlier labels.

PARAMETER DESCRIPTION
df

Input dataframe with pupil_raw and outlier_labels columns.

TYPE: DataFrame

cfg

Configuration dictionary.

TYPE: DictConfig

RETURNS DESCRIPTION
DataFrame

Dataframe with outliers set to null in pupil_raw column.

Source code in src/data_io/data_utils.py
def set_outliers_to_null(df: pl.DataFrame, cfg: DictConfig) -> pl.DataFrame:
    """Set outlier values to null in the pupil_raw column based on outlier labels.

    Parameters
    ----------
    df : pl.DataFrame
        Input dataframe with pupil_raw and outlier_labels columns.
    cfg : DictConfig
        Configuration dictionary.

    Returns
    -------
    pl.DataFrame
        Dataframe with outliers set to null in pupil_raw column.
    """
    # Raw should have this correct already, whereas "_orig" not so much necessarily

    # Set 'null' values to the outliers in the pupil_raw (input data
    df = df.with_columns(
        pupil_raw=pl.when(pl.col("outlier_labels") == 1)
        .then(pl.lit(None))
        .otherwise(pl.col("pupil_raw"))
    )

    # Drop again unnecessary column as this is encoded to the pupil_raw column
    df = df.drop(["outlier_labels"])

    # Count the amount of outliers (null in Polars)
    no_outliers = int(df.select(pl.col("pupil_raw").null_count()).to_numpy())
    outlier_percentage = no_outliers / df.shape[0] * 100
    logger.info(
        "Number of outliers set to null = {} ({:.2f}% of all samples)".format(
            no_outliers, outlier_percentage
        )
    )

    return df

set_missing_in_data

set_missing_in_data(
    df: DataFrame,
    X: ndarray,
    _missingness_cfg: DictConfig,
    col_name: str = "pupil_raw",
    split: str = "train",
) -> ndarray

Set missing values in numpy array based on dataframe null values.

Transfers the missingness pattern from a dataframe column to a numpy array.

PARAMETER DESCRIPTION
df

Input dataframe containing the column with missing values.

TYPE: DataFrame

X

Numpy array to apply missingness pattern to.

TYPE: ndarray

col_name

Name of the column to get missingness from, by default "pupil_raw".

TYPE: str DEFAULT: 'pupil_raw'

split

Name of the data split (for logging), by default "train".

TYPE: str DEFAULT: 'train'

RETURNS DESCRIPTION
ndarray

Array with missing values set to NaN where the dataframe has nulls.

Source code in src/data_io/data_utils.py
def set_missing_in_data(
    df: pl.DataFrame,
    X: np.ndarray,
    _missingness_cfg: DictConfig,
    col_name: str = "pupil_raw",
    split: str = "train",
) -> np.ndarray:
    """Set missing values in numpy array based on dataframe null values.

    Transfers the missingness pattern from a dataframe column to a numpy array.

    Parameters
    ----------
    df : pl.DataFrame
        Input dataframe containing the column with missing values.
    X : np.ndarray
        Numpy array to apply missingness pattern to.
    col_name : str, optional
        Name of the column to get missingness from, by default "pupil_raw".
    split : str, optional
        Name of the data split (for logging), by default "train".

    Returns
    -------
    np.ndarray
        Array with missing values set to NaN where the dataframe has nulls.
    """
    raw = df.select(col_name)
    raw = raw.to_numpy().reshape(X.shape[0], X.shape[1], X.shape[2])
    X[np.isnan(raw)] = np.nan

    mask_no_missing = np.isnan(raw).sum()
    out_no_missing = np.isnan(X).sum()
    assert mask_no_missing == out_no_missing, (
        "Number of missing values "
        "in the mask and the output data do not match! (for some weird reason)"
    )

    logger.info(
        "Percentage of missing values in the data ({}) = {:.2f}%".format(
            split, mask_no_missing / raw.size * 100
        )
    )

    return X

combine_metadata_with_df_splits

combine_metadata_with_df_splits(
    df_raw: DataFrame, df_metadata: DataFrame
) -> tuple[DataFrame, dict]

Combine metadata dataframe with the PLR data splits.

PARAMETER DESCRIPTION
df_raw

Raw PLR data dataframe.

TYPE: DataFrame

df_metadata

Metadata dataframe containing subject information.

TYPE: DataFrame

RETURNS DESCRIPTION
tuple

Tuple containing (combined_df, code_stats) where code_stats contains information about matching, extra, and missing subject codes.

Source code in src/data_io/data_utils.py
def combine_metadata_with_df_splits(
    df_raw: pl.DataFrame, df_metadata: pl.DataFrame
) -> tuple[pl.DataFrame, dict]:
    """Combine metadata dataframe with the PLR data splits.

    Parameters
    ----------
    df_raw : pl.DataFrame
        Raw PLR data dataframe.
    df_metadata : pl.DataFrame
        Metadata dataframe containing subject information.

    Returns
    -------
    tuple
        Tuple containing (combined_df, code_stats) where code_stats contains
        information about matching, extra, and missing subject codes.
    """
    logger.info("Combining metadata with the data splits")
    df, code_stats = combine_metadata_with_df(
        df=df_raw, df_metadata=df_metadata, split="all data"
    )
    return df, code_stats

combine_metadata_with_df

combine_metadata_with_df(
    df: DataFrame, df_metadata: DataFrame, split: str
) -> tuple[DataFrame, dict]

Combine metadata with a PLR dataframe for a specific split.

PARAMETER DESCRIPTION
df

PLR data dataframe.

TYPE: DataFrame

df_metadata

Metadata dataframe containing subject information.

TYPE: DataFrame

split

Name of the data split.

TYPE: str

RETURNS DESCRIPTION
tuple

Tuple containing (combined_df, code_stats).

Source code in src/data_io/data_utils.py
def combine_metadata_with_df(
    df: pl.DataFrame, df_metadata: pl.DataFrame, split: str
) -> tuple[pl.DataFrame, dict]:
    """Combine metadata with a PLR dataframe for a specific split.

    Parameters
    ----------
    df : pl.DataFrame
        PLR data dataframe.
    df_metadata : pl.DataFrame
        Metadata dataframe containing subject information.
    split : str
        Name of the data split.

    Returns
    -------
    tuple
        Tuple containing (combined_df, code_stats).
    """
    unique_PLR = get_unique_polars_rows(
        df, unique_col="subject_code", value_col="time", split=split, df_string="PLR"
    )

    unique_metadata = get_unique_polars_rows(
        df_metadata,
        unique_col="subject_code",
        value_col="class_label",
        split=split,
        df_string="metadata",
    )

    code_stats = get_missing_labels(unique_PLR, unique_metadata, split)
    # Note! When training the imputation we don't need the class labels (in metadata), but when analyzing the
    # downstream effects of the imputation we do need the class labels

    # Loop through the values in metadata (maybe there is a more efficient way to do this)
    df = add_labels_for_matching_codes(
        df, df_metadata, matching_codes=code_stats["matching"]
    )

    return df, code_stats

get_missing_labels

get_missing_labels(
    unique_PLR: DataFrame,
    unique_metadata: DataFrame,
    split: str,
) -> dict

Identify matching, extra, and missing subject codes between PLR and metadata.

PARAMETER DESCRIPTION
unique_PLR

Dataframe with unique PLR subject codes.

TYPE: DataFrame

unique_metadata

Dataframe with unique metadata subject codes.

TYPE: DataFrame

split

Name of the data split (for logging).

TYPE: str

RETURNS DESCRIPTION
dict

Dictionary containing lists of 'matching', 'extra_metadata', and 'missing_from_PLR' subject codes.

Source code in src/data_io/data_utils.py
def get_missing_labels(
    unique_PLR: pl.DataFrame, unique_metadata: pl.DataFrame, split: str
) -> dict:
    """Identify matching, extra, and missing subject codes between PLR and metadata.

    Parameters
    ----------
    unique_PLR : pl.DataFrame
        Dataframe with unique PLR subject codes.
    unique_metadata : pl.DataFrame
        Dataframe with unique metadata subject codes.
    split : str
        Name of the data split (for logging).

    Returns
    -------
    dict
        Dictionary containing lists of 'matching', 'extra_metadata', and
        'missing_from_PLR' subject codes.
    """

    def check_code_segment(df: pl.DataFrame, string: str) -> list[str]:
        list_of_codes: list[str] = []
        for row in df.rows(named=True):
            list_of_codes.append(row["subject_code"])
        logger.info(f"{string} ({split} split), number of labels: {len(list_of_codes)}")
        for code in list_of_codes:
            logger.debug(code)
        return sorted(list_of_codes)

    codes_PLR = unique_PLR.select(pl.col("subject_code"))
    # codes_metadata = unique_metadata.select(pl.col("subject_code"))

    # That you could actually do some classification
    matching_codes = unique_PLR.join(unique_metadata, on="subject_code", how="inner")
    # What codes you have in metadata XLSX but are not found as PLR recordings:
    extra_metadata = unique_metadata.join(matching_codes, on="subject_code", how="anti")
    # What codes should have class_labels added to the XLSX
    missing_from_PLR = unique_PLR.join(matching_codes, on="subject_code", how="anti")

    assert len(matching_codes) + len(extra_metadata) == len(unique_metadata)
    assert len(matching_codes) + len(missing_from_PLR) == len(codes_PLR)

    codes = {}
    codes["matching"] = check_code_segment(df=matching_codes, string="Matching codes")
    codes["extra_metadata"] = check_code_segment(
        df=extra_metadata, string="Extra metadata"
    )
    codes["missing_from_PLR"] = check_code_segment(
        df=missing_from_PLR, string="Missing from PLR"
    )

    return codes

add_labels_for_matching_codes

add_labels_for_matching_codes(
    df: DataFrame,
    df_metadata: DataFrame,
    matching_codes: list[str],
) -> DataFrame

Add metadata labels to dataframe for subjects with matching codes.

PARAMETER DESCRIPTION
df

PLR data dataframe.

TYPE: DataFrame

df_metadata

Metadata dataframe containing subject information.

TYPE: DataFrame

matching_codes

List of subject codes that exist in both dataframes.

TYPE: list

RETURNS DESCRIPTION
DataFrame

Dataframe with metadata columns added for matching subjects.

Source code in src/data_io/data_utils.py
def add_labels_for_matching_codes(
    df: pl.DataFrame, df_metadata: pl.DataFrame, matching_codes: list[str]
) -> pl.DataFrame:
    """Add metadata labels to dataframe for subjects with matching codes.

    Parameters
    ----------
    df : pl.DataFrame
        PLR data dataframe.
    df_metadata : pl.DataFrame
        Metadata dataframe containing subject information.
    matching_codes : list
        List of subject codes that exist in both dataframes.

    Returns
    -------
    pl.DataFrame
        Dataframe with metadata columns added for matching subjects.
    """

    def add_empty_cols(df: pl.DataFrame, df_metadata: pl.DataFrame) -> pl.DataFrame:
        for col in df_metadata.columns:
            if col not in df.columns:
                logger.debug(f"Adding empty column: {col}")
                df = df.with_columns(pl.lit(None).alias(col))
        return df

    df = add_empty_cols(df, df_metadata)

    for i, code in enumerate(matching_codes):
        logger.debug(f"Adding metadata for code {code}")
        df = add_label_per_code(code, df, df_metadata)

    check_post_metadata_add(df)

    return df

add_label_per_code

add_label_per_code(
    code: str,
    df: DataFrame,
    df_metadata: DataFrame,
    code_col: str = "subject_code",
    length_PLR: int = 1981,
) -> DataFrame

Add metadata labels for a single subject code.

PARAMETER DESCRIPTION
code

Subject code to add labels for.

TYPE: str

df

PLR data dataframe.

TYPE: DataFrame

df_metadata

Metadata dataframe.

TYPE: DataFrame

code_col

Name of the subject code column, by default "subject_code".

TYPE: str DEFAULT: 'subject_code'

length_PLR

Expected number of timepoints per subject, by default 1981.

TYPE: int DEFAULT: 1981

RETURNS DESCRIPTION
DataFrame

Dataframe with metadata added for the specified subject.

Source code in src/data_io/data_utils.py
def add_label_per_code(
    code: str,
    df: pl.DataFrame,
    df_metadata: pl.DataFrame,
    code_col: str = "subject_code",
    length_PLR: int = 1981,
) -> pl.DataFrame:
    """Add metadata labels for a single subject code.

    Parameters
    ----------
    code : str
        Subject code to add labels for.
    df : pl.DataFrame
        PLR data dataframe.
    df_metadata : pl.DataFrame
        Metadata dataframe.
    code_col : str, optional
        Name of the subject code column, by default "subject_code".
    length_PLR : int, optional
        Expected number of timepoints per subject, by default 1981.

    Returns
    -------
    pl.DataFrame
        Dataframe with metadata added for the specified subject.
    """
    df_per_code = df.filter(pl.col(code_col) == code)
    # Obviosuly this will break if you start doing some custom recordings
    assert len(df_per_code) == length_PLR, (
        f"Length of the PLR data for {code} is not {length_PLR}"
    )

    # Loop through all the metadata columns that you have and them based on the subject code
    for col in df_metadata.columns:
        # Don't add the subject_code column to the subject column (replace it basically
        if col is not code_col:
            value = df_metadata.filter(pl.col(code_col) == code)[col].to_numpy()[0]

            df = df.with_columns(
                (
                    pl.when(pl.col(code_col) == code)
                    .then(pl.lit(value))
                    .otherwise(pl.col(col))
                ).alias(col)
            )

    return df

get_unique_labels

get_unique_labels(
    df: DataFrame,
    unique_col: str = "class_label",
    value_col: str = "time",
) -> list[str]

Get list of unique non-null labels from a dataframe column.

PARAMETER DESCRIPTION
df

Input dataframe.

TYPE: DataFrame

unique_col

Column to get unique values from, by default "class_label".

TYPE: str DEFAULT: 'class_label'

value_col

Column to use for selecting representative rows, by default "time".

TYPE: str DEFAULT: 'time'

RETURNS DESCRIPTION
list

List of unique label values.

Source code in src/data_io/data_utils.py
def get_unique_labels(
    df: pl.DataFrame, unique_col: str = "class_label", value_col: str = "time"
) -> list[str]:
    """Get list of unique non-null labels from a dataframe column.

    Parameters
    ----------
    df : pl.DataFrame
        Input dataframe.
    unique_col : str, optional
        Column to get unique values from, by default "class_label".
    value_col : str, optional
        Column to use for selecting representative rows, by default "time".

    Returns
    -------
    list
        List of unique label values.
    """
    unique_labels = get_unique_polars_rows(
        df, unique_col=unique_col, value_col=value_col
    )
    unique_labels = unique_labels.filter(
        pl.col(unique_col).is_not_null()
    )  # drop null rows
    return list(unique_labels[:, unique_col].to_numpy())

pick_per_label

pick_per_label(
    df: DataFrame, label: str, cfg: DictConfig
) -> DataFrame

Filter dataframe to keep only rows with a specific class label.

PARAMETER DESCRIPTION
df

Input dataframe with class_label column.

TYPE: DataFrame

label

Class label to filter for.

TYPE: str

cfg

Configuration dictionary.

TYPE: DictConfig

RETURNS DESCRIPTION
DataFrame

Filtered dataframe containing only rows with the specified label.

Source code in src/data_io/data_utils.py
def pick_per_label(df: pl.DataFrame, label: str, cfg: DictConfig) -> pl.DataFrame:
    """Filter dataframe to keep only rows with a specific class label.

    Parameters
    ----------
    df : pl.DataFrame
        Input dataframe with class_label column.
    label : str
        Class label to filter for.
    cfg : DictConfig
        Configuration dictionary.

    Returns
    -------
    pl.DataFrame
        Filtered dataframe containing only rows with the specified label.
    """
    # Polars alternative due to a weird Polars issue
    df_pd = df.to_pandas()
    df_label = df_pd[df_pd["class_label"] == label]
    df = pl.DataFrame(df_label)
    check_for_data_lengths(df, cfg)
    return pl.DataFrame(df_label)

get_outlier_count_per_code

get_outlier_count_per_code(
    unique_value: ndarray, df_pd: DataFrame
) -> tuple[ndarray, ndarray]

Get outlier counts per subject code, sorted by count.

PARAMETER DESCRIPTION
unique_value

Array of unique subject codes.

TYPE: ndarray

df_pd

Dataframe with subject_code and no_outliers columns.

TYPE: DataFrame

RETURNS DESCRIPTION
tuple

Tuple containing (sorted_counts, sorted_codes) for subjects with outlier counts above the median.

Source code in src/data_io/data_utils.py
def get_outlier_count_per_code(
    unique_value: np.ndarray, df_pd: pd.DataFrame
) -> tuple[np.ndarray, np.ndarray]:
    """Get outlier counts per subject code, sorted by count.

    Parameters
    ----------
    unique_value : np.ndarray
        Array of unique subject codes.
    df_pd : pd.DataFrame
        Dataframe with subject_code and no_outliers columns.

    Returns
    -------
    tuple
        Tuple containing (sorted_counts, sorted_codes) for subjects with
        outlier counts above the median.
    """
    no_outliers = np.zeros_like(unique_value)
    for i in range(len(unique_value)):
        unique_code = unique_value[i]
        df_code = df_pd[df_pd["subject_code"] == unique_code]
        no_outliers[i] = df_code["no_outliers"].iloc[0]

    sorted_indices = no_outliers.argsort()
    outliers_count_sorted = no_outliers[sorted_indices]
    codes_sorted = unique_value[sorted_indices]

    median_outlier = np.median(no_outliers)
    over_median = outliers_count_sorted > median_outlier
    counts_left = outliers_count_sorted[over_median]
    codes_left = codes_sorted[over_median]

    return counts_left, codes_left

pick_random_subjects_with_outlier_no_cutoff

pick_random_subjects_with_outlier_no_cutoff(
    unique_value: ndarray, df_pd: DataFrame, n: int
) -> ndarray

Pick random subjects from those with above-median outlier counts.

PARAMETER DESCRIPTION
unique_value

Array of unique subject codes.

TYPE: ndarray

df_pd

Dataframe with subject_code and no_outliers columns.

TYPE: DataFrame

n

Number of subjects to pick.

TYPE: int

RETURNS DESCRIPTION
ndarray

Array of randomly selected subject codes.

Source code in src/data_io/data_utils.py
def pick_random_subjects_with_outlier_no_cutoff(
    unique_value: np.ndarray, df_pd: pd.DataFrame, n: int
) -> np.ndarray:
    """Pick random subjects from those with above-median outlier counts.

    Parameters
    ----------
    unique_value : np.ndarray
        Array of unique subject codes.
    df_pd : pd.DataFrame
        Dataframe with subject_code and no_outliers columns.
    n : int
        Number of subjects to pick.

    Returns
    -------
    np.ndarray
        Array of randomly selected subject codes.
    """
    outlier_counts, codes_left = get_outlier_count_per_code(unique_value, df_pd)
    random_idx = random.sample(range(0, len(codes_left) - 1), n)
    random_codes = codes_left[random_idx]

    return random_codes

pick_n_subjects_per_label_pandas

pick_n_subjects_per_label_pandas(
    df: DataFrame,
    n: int,
    PLR_length: int = 1981,
    col_select: str = "subject_code",
    pick_random: bool = False,
) -> DataFrame

Pick n subjects from a dataframe, optionally with outlier-based selection.

PARAMETER DESCRIPTION
df

Input dataframe.

TYPE: DataFrame

n

Number of subjects to pick.

TYPE: int

PLR_length

Expected number of timepoints per subject, by default 1981.

TYPE: int DEFAULT: 1981

col_select

Column containing subject identifiers, by default "subject_code".

TYPE: str DEFAULT: 'subject_code'

pick_random

If True, pick first n subjects; if False, pick from high-outlier subjects, by default False.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
DataFrame

Dataframe containing data for the selected n subjects.

Source code in src/data_io/data_utils.py
def pick_n_subjects_per_label_pandas(
    df: pl.DataFrame,
    n: int,
    PLR_length: int = 1981,
    col_select: str = "subject_code",
    pick_random: bool = False,
) -> pl.DataFrame:
    """Pick n subjects from a dataframe, optionally with outlier-based selection.

    Parameters
    ----------
    df : pl.DataFrame
        Input dataframe.
    n : int
        Number of subjects to pick.
    PLR_length : int, optional
        Expected number of timepoints per subject, by default 1981.
    col_select : str, optional
        Column containing subject identifiers, by default "subject_code".
    pick_random : bool, optional
        If True, pick first n subjects; if False, pick from high-outlier subjects,
        by default False.

    Returns
    -------
    pl.DataFrame
        Dataframe containing data for the selected n subjects.
    """
    # Pandas alternative due to a weird Polars issue
    df_pd = df.to_pandas()
    unique_value = df_pd[col_select].unique()
    if pick_random:
        first_n_subjects = unique_value[:n]
    else:
        first_n_subjects = pick_random_subjects_with_outlier_no_cutoff(
            unique_value, df_pd, n
        )

    df_out = pd.DataFrame()
    for i, code in enumerate(first_n_subjects):
        df_code = df_pd[df_pd[col_select] == code]
        assert df_code.shape[0] == PLR_length
        df_out = pd.concat([df_out, df_code])

    assert (
        df_out.shape[0] == n * PLR_length
    )  # TODO! check with other n values than 8 for debug
    return pl.DataFrame(df_out)

pick_n_subjects_per_label

pick_n_subjects_per_label(
    df: DataFrame,
    label: str,
    n: int,
    PLR_length: int = 1981,
) -> DataFrame

Pick n subjects with a specific class label from a dataframe.

PARAMETER DESCRIPTION
df

Input dataframe with class_label column.

TYPE: DataFrame

label

Class label to filter for.

TYPE: str

n

Number of subjects to pick.

TYPE: int

PLR_length

Expected number of timepoints per subject, by default 1981.

TYPE: int DEFAULT: 1981

RETURNS DESCRIPTION
DataFrame

Dataframe containing data for the selected n subjects with the given label.

Source code in src/data_io/data_utils.py
def pick_n_subjects_per_label(
    df: pl.DataFrame, label: str, n: int, PLR_length: int = 1981
) -> pl.DataFrame:
    """Pick n subjects with a specific class label from a dataframe.

    Parameters
    ----------
    df : pl.DataFrame
        Input dataframe with class_label column.
    label : str
        Class label to filter for.
    n : int
        Number of subjects to pick.
    PLR_length : int, optional
        Expected number of timepoints per subject, by default 1981.

    Returns
    -------
    pl.DataFrame
        Dataframe containing data for the selected n subjects with the given label.
    """
    df_label = df.filter(pl.col("class_label") == label)
    unique_subjects = get_list_of_unique_subjects(df_label)
    first_n_subjects = unique_subjects[:n]
    df_out = pl.DataFrame()
    for i, code in enumerate(first_n_subjects):
        df_code = df_label.filter(pl.col("subject_code") == code)
        assert len(df_code) == PLR_length, (
            "Length of the PLR data for {} is not {}".format(code, PLR_length)
        )
        df_out = pl.concat([df_out, df_code])
        assert df_out.shape[0] == (i + 1) * PLR_length, (
            "Seems like subject was not added with pl.concat?\n"
            "{}, {}, {} (i, df_code, df_out)"
        ).format(i, df_code.shape, df_out.shape)
        # due to Polars issue?
        # note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
        # thread '<unnamed>' panicked at crates/polars-core/src/fmt.rs:567:13:
        # The column lengths in the DataFrame are not equal.

    df_n_subjects = int(df_out.shape[0] / PLR_length)
    assert df_out.shape[0] == n * PLR_length, (
        f"Number of rows ({df_n_subjects}) in the output dataframe "
        f"is not equal to the number of subjects ({n}) requested"
    )
    unique_subjects_out = get_list_of_unique_subjects(df_out)
    assert len(unique_subjects_out) == n, (
        "Number of subjects in the output dataframe is not equal to the number of subjects requested"
    )

    return df_out

get_list_of_unique_subjects

get_list_of_unique_subjects(
    df_label: DataFrame, unique_col: str = "subject_code"
) -> list[str]

Get a list of unique subject codes from a dataframe.

PARAMETER DESCRIPTION
df_label

Input dataframe.

TYPE: DataFrame

unique_col

Column containing subject identifiers, by default "subject_code".

TYPE: str DEFAULT: 'subject_code'

RETURNS DESCRIPTION
list

List of unique subject codes.

Source code in src/data_io/data_utils.py
def get_list_of_unique_subjects(
    df_label: pl.DataFrame, unique_col: str = "subject_code"
) -> list[str]:
    """Get a list of unique subject codes from a dataframe.

    Parameters
    ----------
    df_label : pl.DataFrame
        Input dataframe.
    unique_col : str, optional
        Column containing subject identifiers, by default "subject_code".

    Returns
    -------
    list
        List of unique subject codes.
    """
    unique_rows = get_unique_polars_rows(df_label, unique_col=unique_col)
    return list(unique_rows[:, unique_col].to_numpy())

get_unique_polars_rows

get_unique_polars_rows(
    df: DataFrame,
    unique_col: str = "subject_code",
    value_col: str = "time",
    split: Optional[str] = None,
    df_string: Optional[str] = None,
    pandas_fix: bool = True,
) -> DataFrame

Get unique rows from a Polars dataframe based on a column.

Deduplicates the dataframe to get one row per unique value in the specified column.

PARAMETER DESCRIPTION
df

Input Polars dataframe.

TYPE: DataFrame

unique_col

Column to use for identifying unique rows, by default "subject_code".

TYPE: str DEFAULT: 'subject_code'

value_col

Column to use for selecting representative rows, by default "time".

TYPE: str DEFAULT: 'time'

split

Name of the data split (for logging), by default None.

TYPE: str DEFAULT: None

df_string

Description string for logging, by default None.

TYPE: str DEFAULT: None

pandas_fix

Whether to use pandas for deduplication (more reliable), by default True.

TYPE: bool DEFAULT: True

RETURNS DESCRIPTION
DataFrame

Dataframe with one row per unique value in unique_col.

Source code in src/data_io/data_utils.py
def get_unique_polars_rows(
    df: pl.DataFrame,
    unique_col: str = "subject_code",
    value_col: str = "time",
    split: Optional[str] = None,
    df_string: Optional[str] = None,
    pandas_fix: bool = True,
) -> pl.DataFrame:
    """Get unique rows from a Polars dataframe based on a column.

    Deduplicates the dataframe to get one row per unique value in the
    specified column.

    Parameters
    ----------
    df : pl.DataFrame
        Input Polars dataframe.
    unique_col : str, optional
        Column to use for identifying unique rows, by default "subject_code".
    value_col : str, optional
        Column to use for selecting representative rows, by default "time".
    split : str, optional
        Name of the data split (for logging), by default None.
    df_string : str, optional
        Description string for logging, by default None.
    pandas_fix : bool, optional
        Whether to use pandas for deduplication (more reliable), by default True.

    Returns
    -------
    pl.DataFrame
        Dataframe with one row per unique value in unique_col.
    """
    # Get one value per subject code to check if the metadata is there
    try:
        unique_values = list(df[unique_col].unique())
        if pandas_fix:
            df = df.to_pandas()
            df = df.drop_duplicates(subset=[unique_col])
            df = pl.DataFrame(df)
        else:
            # polars.exceptions.ShapeError: series used as keys should have the same length as the DataFrame
            df = df.select(
                pl.all()
                .top_k_by(value_col, k=1)
                .over(unique_col, mapping_strategy="explode")
            )
    except KeyError as e:
        logger.error(f"Unique values = {unique_values}")
        logger.error(f"Number of subjects (in df): {int(df.shape[0] / 1981)}")
        logger.error("Error in getting unique rows from the dataframe: {}".format(e))
        raise e

    logger.debug(
        f"{split} split: number of unique subjects ({df_string}) = {df.shape[0]}"
    )

    return df

check_post_metadata_add

check_post_metadata_add(
    df: DataFrame, length_PLR: int = 1981
) -> None

Validate dataframe after metadata addition.

Checks that the number of rows with and without class labels are multiples of the PLR length.

PARAMETER DESCRIPTION
df

Dataframe to validate.

TYPE: DataFrame

length_PLR

Expected number of timepoints per subject, by default 1981.

TYPE: int DEFAULT: 1981

RAISES DESCRIPTION
AssertionError

If row counts are not multiples of PLR length.

Source code in src/data_io/data_utils.py
def check_post_metadata_add(df: pl.DataFrame, length_PLR: int = 1981) -> None:
    """Validate dataframe after metadata addition.

    Checks that the number of rows with and without class labels
    are multiples of the PLR length.

    Parameters
    ----------
    df : pl.DataFrame
        Dataframe to validate.
    length_PLR : int, optional
        Expected number of timepoints per subject, by default 1981.

    Raises
    ------
    AssertionError
        If row counts are not multiples of PLR length.
    """
    no_nonnull_rows = df.select(pl.count("class_label")).to_numpy()[0]
    no_nonnull_subjects = float(no_nonnull_rows / length_PLR)
    assert no_nonnull_subjects.is_integer(), (
        "Number of non-null rows is not a multiple of the PLR length, "
        "no_nonnull_subjects = {}"
    ).format(no_nonnull_subjects)

    # would be bizarre if this failed if the one above was ok
    no_null_rows = df.select(pl.col("class_label").is_null().sum()).to_numpy()[0]
    no_null_subjects = float(no_null_rows / length_PLR)
    assert no_null_subjects.is_integer(), (
        "Number of null rows is not a multiple of the PLR length, no_null_subjects = {}"
    ).format(no_null_subjects)

    logger.info(
        "After adding metadata to the PLR data, "
        "we have {} subjects with a class label (control vs glaucoma), and {} with no class labels".format(
            int(no_nonnull_subjects), int(no_null_subjects)
        )
    )

pick_debug_data

pick_debug_data(
    df: DataFrame,
    string: str,
    cfg: DictConfig,
    n: int = 4,
    pick_random: bool = False,
) -> DataFrame

Pick a small subset of data for debugging purposes.

Selects n subjects per unique label for faster debugging runs.

PARAMETER DESCRIPTION
df

Input dataframe.

TYPE: DataFrame

string

Description string for logging.

TYPE: str

cfg

Configuration dictionary.

TYPE: DictConfig

n

Number of subjects to pick per label, by default 4.

TYPE: int DEFAULT: 4

pick_random

Whether to pick subjects randomly, by default False.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
DataFrame

Subset dataframe for debugging.

Source code in src/data_io/data_utils.py
def pick_debug_data(
    df: pl.DataFrame,
    string: str,
    cfg: DictConfig,
    n: int = 4,
    pick_random: bool = False,
) -> pl.DataFrame:
    """Pick a small subset of data for debugging purposes.

    Selects n subjects per unique label for faster debugging runs.

    Parameters
    ----------
    df : pl.DataFrame
        Input dataframe.
    string : str
        Description string for logging.
    cfg : DictConfig
        Configuration dictionary.
    n : int, optional
        Number of subjects to pick per label, by default 4.
    pick_random : bool, optional
        Whether to pick subjects randomly, by default False.

    Returns
    -------
    pl.DataFrame
        Subset dataframe for debugging.
    """
    logger.warning(
        'You have a debug mode on, picking a subset of the "{}" data!'.format(string)
    )
    logger.warning("Number of subjects to pick per label = {}".format(n))
    # Use a smaller data subset so things run faster (this is useful for debugging and testing)
    unique_labels = get_unique_labels(df)
    check_for_data_lengths(df, cfg)
    df_out = pl.DataFrame()
    for idx, label in enumerate(unique_labels):
        # df_label = pick_n_subjects_per_label(df, label, n)
        df_label = pick_per_label(df, label, cfg)
        df_label = pick_n_subjects_per_label_pandas(
            df_label, n, pick_random=pick_random
        )
        get_list_of_unique_subjects(df_label)
        df_out = pandas_concat(df_out, df_label)
        logger.info(
            f'{idx} ({label}): {int(df_out.shape[0] / cfg["DATA"]["PLR_length"])} subjects for the "{label}" label'
        )
        logger.info(get_list_of_unique_subjects(df_label))

    check_for_data_lengths(df_out, cfg)
    return df_out

combine_split_dataframes

combine_split_dataframes(
    df_train: DataFrame,
    df_val: DataFrame,
    cfg: DictConfig,
    debug_mode: bool = False,
    debug_n: int = 4,
    pick_random: bool = False,
    demo_mode: bool = False,
) -> DataFrame

Combine train and validation dataframes with a split indicator column.

PARAMETER DESCRIPTION
df_train

Training dataframe.

TYPE: DataFrame

df_val

Validation dataframe.

TYPE: DataFrame

cfg

Configuration dictionary.

TYPE: DictConfig

debug_mode

Whether to use debug mode (subset of data), by default False.

TYPE: bool DEFAULT: False

debug_n

Number of subjects per label in debug mode, by default 4.

TYPE: int DEFAULT: 4

pick_random

Whether to pick subjects randomly in debug mode, by default False.

TYPE: bool DEFAULT: False

demo_mode

Whether demo mode is enabled, by default False.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
DataFrame

Combined dataframe with 'split' column indicating train/test.

Source code in src/data_io/data_utils.py
def combine_split_dataframes(
    df_train: pl.DataFrame,
    df_val: pl.DataFrame,
    cfg: DictConfig,
    debug_mode: bool = False,
    debug_n: int = 4,
    pick_random: bool = False,
    demo_mode: bool = False,
) -> pl.DataFrame:
    """Combine train and validation dataframes with a split indicator column.

    Parameters
    ----------
    df_train : pl.DataFrame
        Training dataframe.
    df_val : pl.DataFrame
        Validation dataframe.
    cfg : DictConfig
        Configuration dictionary.
    debug_mode : bool, optional
        Whether to use debug mode (subset of data), by default False.
    debug_n : int, optional
        Number of subjects per label in debug mode, by default 4.
    pick_random : bool, optional
        Whether to pick subjects randomly in debug mode, by default False.
    demo_mode : bool, optional
        Whether demo mode is enabled, by default False.

    Returns
    -------
    pl.DataFrame
        Combined dataframe with 'split' column indicating train/test.
    """

    def add_split_column(
        df,
        string,
        debug_mode=False,
        debug_n=debug_n,
        pick_random: bool = False,
        demo_mode: bool = False,
    ):
        check_for_data_lengths(df, cfg)
        df_pd = df.to_pandas()
        df_pd["split"] = string
        df = pl.DataFrame(df_pd)
        check_for_data_lengths(df, cfg)
        if not demo_mode:
            if debug_mode:
                df = pick_debug_data(
                    df, string, cfg, n=debug_n, pick_random=pick_random
                )
        else:
            logger.info("Demo mode is on, not picking any debug subjects here")
        return df

    df_train = add_split_column(
        df=df_train,
        string="train",
        debug_mode=debug_mode,
        debug_n=debug_n,
        pick_random=pick_random,
        demo_mode=demo_mode,
    )
    check_for_data_lengths(df_train, cfg)

    df_test = add_split_column(
        df=df_val,
        string="test",
        debug_mode=debug_mode,
        debug_n=debug_n,
        pick_random=pick_random,
        demo_mode=demo_mode,
    )
    logger.info(
        "Combining the train ({}) and test ({}) splits".format(
            df_train.shape, df_test.shape
        )
    )
    check_for_data_lengths(df_test, cfg)

    df = pl.concat([df_train, df_test])
    check_for_data_lengths(df, cfg)
    logger.info("Combined dataframe shape = {}".format(df.shape))
    logger.info(f"Number of time points = {df.shape[0]:,}")

    return df

define_desired_timevector

define_desired_timevector(
    PLR_length: int = 1981, fps: int = 30
) -> ndarray

Generate an ideal time vector for PLR recordings.

PARAMETER DESCRIPTION
PLR_length

Number of timepoints in the recording, by default 1981.

TYPE: int DEFAULT: 1981

fps

Frames per second of the recording, by default 30.

TYPE: int DEFAULT: 30

RETURNS DESCRIPTION
ndarray

Time vector in seconds.

Source code in src/data_io/data_utils.py
def define_desired_timevector(PLR_length: int = 1981, fps: int = 30) -> np.ndarray:
    """Generate an ideal time vector for PLR recordings.

    Parameters
    ----------
    PLR_length : int, optional
        Number of timepoints in the recording, by default 1981.
    fps : int, optional
        Frames per second of the recording, by default 30.

    Returns
    -------
    np.ndarray
        Time vector in seconds.
    """
    time_vector = np.linspace(0, (PLR_length - 1) / fps, PLR_length)
    return time_vector

check_time_similarity

check_time_similarity(
    time_vec_in: ndarray, time_vec_ideal: ndarray
) -> dict[str, bool | float]

Check if two time vectors are similar.

PARAMETER DESCRIPTION
time_vec_in

Input time vector to check.

TYPE: ndarray

time_vec_ideal

Ideal/reference time vector.

TYPE: ndarray

RETURNS DESCRIPTION
dict

Dictionary containing check results including 'allclose', 'min_same', 'max_same', and overall 'OK' status.

Source code in src/data_io/data_utils.py
def check_time_similarity(
    time_vec_in: np.ndarray, time_vec_ideal: np.ndarray
) -> dict[str, bool | float]:
    """Check if two time vectors are similar.

    Parameters
    ----------
    time_vec_in : np.ndarray
        Input time vector to check.
    time_vec_ideal : np.ndarray
        Ideal/reference time vector.

    Returns
    -------
    dict
        Dictionary containing check results including 'allclose', 'min_same',
        'max_same', and overall 'OK' status.
    """
    time_checks = {}
    # check if the time vectors are similar (within a tolerance), picks up rounding off issues
    # and if there is some jitter in the original recording?
    time_checks["allclose"] = np.allclose(time_vec_in, time_vec_ideal, atol=0)
    # check that min and max are the same
    time_checks["min_in"] = np.min(time_vec_in)
    time_checks["min_same"] = np.min(time_vec_in) == np.min(time_vec_ideal)
    time_checks["max_in"] = np.max(time_vec_in)
    time_checks["max_same"] = np.max(time_vec_in) == np.max(time_vec_ideal)
    time_checks["OK"] = (
        time_checks["allclose"] and time_checks["min_same"] and time_checks["max_same"]
    )
    return time_checks

check_time_vector_quality

check_time_vector_quality(
    subject_code: str,
    csv_subset: DataFrame,
    cfg: DictConfig,
) -> tuple[ndarray, ndarray, dict[str, bool | float]]

Check the quality of a subject's time vector against the ideal.

PARAMETER DESCRIPTION
subject_code

Subject identifier.

TYPE: str

csv_subset

Subject's data containing a 'time' column.

TYPE: DataFrame

cfg

Configuration dictionary with PLR_length setting.

TYPE: DictConfig

RETURNS DESCRIPTION
tuple

Tuple containing (time_vec_in, time_vec_ideal, time_checks).

Source code in src/data_io/data_utils.py
def check_time_vector_quality(
    subject_code: str, csv_subset: pd.DataFrame, cfg: DictConfig
) -> tuple[np.ndarray, np.ndarray, dict[str, bool | float]]:
    """Check the quality of a subject's time vector against the ideal.

    Parameters
    ----------
    subject_code : str
        Subject identifier.
    csv_subset : pd.DataFrame
        Subject's data containing a 'time' column.
    cfg : DictConfig
        Configuration dictionary with PLR_length setting.

    Returns
    -------
    tuple
        Tuple containing (time_vec_in, time_vec_ideal, time_checks).
    """
    time_vec_in = csv_subset["time"].to_numpy()
    time_vec_ideal = define_desired_timevector(PLR_length=cfg["DATA"]["PLR_length"])
    assert time_vec_in.shape[0] == time_vec_ideal.shape[0], (
        "Time vector length does not match"
    )
    time_checks = check_time_similarity(time_vec_in, time_vec_ideal)

    return time_vec_in, time_vec_ideal, time_checks

check_for_unique_timepoints

check_for_unique_timepoints(
    df: DataFrame,
    cfg: DictConfig,
    col: str = "time",
    assert_on_error: bool = True,
) -> None

Check that all subjects have the same time vector.

PARAMETER DESCRIPTION
df

Input dataframe.

TYPE: DataFrame

cfg

Configuration dictionary with PLR_length setting.

TYPE: DictConfig

col

Name of the time column, by default "time".

TYPE: str DEFAULT: 'time'

assert_on_error

Whether to raise an error if check fails, by default True.

TYPE: bool DEFAULT: True

RAISES DESCRIPTION
AssertionError

If subjects have different time vectors and assert_on_error is True.

Source code in src/data_io/data_utils.py
def check_for_unique_timepoints(
    df: pl.DataFrame, cfg: DictConfig, col: str = "time", assert_on_error: bool = True
) -> None:
    """Check that all subjects have the same time vector.

    Parameters
    ----------
    df : pl.DataFrame
        Input dataframe.
    cfg : DictConfig
        Configuration dictionary with PLR_length setting.
    col : str, optional
        Name of the time column, by default "time".
    assert_on_error : bool, optional
        Whether to raise an error if check fails, by default True.

    Raises
    ------
    AssertionError
        If subjects have different time vectors and assert_on_error is True.
    """
    time_df = get_unique_polars_rows(df, unique_col=col)
    time_col = time_df["time"].to_numpy()
    if assert_on_error:
        assert len(time_col) == cfg["DATA"]["PLR_length"], (
            f"Number of unique time points {len(time_col)} is "
            f"not {cfg['DATA']['PLR_length']} in the imported dataframe "
            f"(all the subjects)\nWhich means that subjects had different"
            f"timevectors"
        )
    else:
        logger.warning(
            f"Number of unique time points {len(time_col)} is not {cfg['DATA']['PLR_length']} in the"
        )
        logger.warning(
            "imported dataframe (all the subjects)\nWhich means that subjects had different timevectors"
        )
        logger.warning(
            'This is still ok for "time_raw" col as we are using "ideal time vector" for the modeling and'
        )
        logger.warning("visualization to account for small rounding off errors")

fix_for_orphaned_nans

fix_for_orphaned_nans(
    subject_code: str,
    csv_subset: DataFrame,
    cfg: DictConfig,
    cols: tuple = ("Red", "Blue"),
)

Fix orphaned NaN values by replacing with zeros.

Orphaned NaNs are NaN values that remain after interpolation, typically at the edges of the data.

PARAMETER DESCRIPTION
subject_code

Subject identifier for logging.

TYPE: str

csv_subset

Subject's data containing the columns to fix.

TYPE: DataFrame

cfg

Configuration dictionary with PLR_length setting.

TYPE: DictConfig

cols

Columns to check and fix, by default ("Red", "Blue").

TYPE: tuple DEFAULT: ('Red', 'Blue')

RETURNS DESCRIPTION
DataFrame

Dataframe with orphaned NaNs replaced by zeros.

RAISES DESCRIPTION
ValueError

If NaNs remain after the fix.

Source code in src/data_io/data_utils.py
def fix_for_orphaned_nans(
    subject_code: str,
    csv_subset: pd.DataFrame,
    cfg: DictConfig,
    cols: tuple = ("Red", "Blue"),
):
    """Fix orphaned NaN values by replacing with zeros.

    Orphaned NaNs are NaN values that remain after interpolation,
    typically at the edges of the data.

    Parameters
    ----------
    subject_code : str
        Subject identifier for logging.
    csv_subset : pd.DataFrame
        Subject's data containing the columns to fix.
    cfg : DictConfig
        Configuration dictionary with PLR_length setting.
    cols : tuple, optional
        Columns to check and fix, by default ("Red", "Blue").

    Returns
    -------
    pd.DataFrame
        Dataframe with orphaned NaNs replaced by zeros.

    Raises
    ------
    ValueError
        If NaNs remain after the fix.
    """
    # Check for orphaned NaNs
    for col in cols:
        no_orphaned_nans = csv_subset[col].isnull().sum()
        if no_orphaned_nans > 0:
            # e.g. PLR4199, PLR4195, PLR4194, PLR1127, PLR4204, PLR1081, PLR4140, PLR4164
            logger.warning(
                f"Subject {subject_code} has {no_orphaned_nans} orphaned NaNs in the {col} column after "
                f"interpolation (replacing with 0)"
            )
            csv_subset[col] = csv_subset[col].fillna(0)
            no_orphaned_nans_after = csv_subset[col].isnull().sum()
            if no_orphaned_nans_after > 0:
                logger.error(
                    f"Subject {subject_code} still has {no_orphaned_nans_after} orphaned NaNs in the {col} column"
                )
                raise ValueError(
                    f"Subject {subject_code} still has {no_orphaned_nans_after} orphaned NaNs in the {col} column"
                )

    assert csv_subset.shape[0] == cfg["DATA"]["PLR_length"], (
        'Length of the PLR data for "{}" is not {}'.format(
            subject_code, cfg["DATA"]["PLR_length"]
        )
    )

    return csv_subset

check_for_data_lengths

check_for_data_lengths(
    df: DataFrame, cfg: DictConfig
) -> None

Verify that all subjects have the expected PLR data length.

PARAMETER DESCRIPTION
df

Input dataframe.

TYPE: DataFrame

cfg

Configuration dictionary with PLR_length setting.

TYPE: DictConfig

RAISES DESCRIPTION
AssertionError

If any subject has a different number of timepoints than expected.

Source code in src/data_io/data_utils.py
def check_for_data_lengths(df: pl.DataFrame, cfg: DictConfig) -> None:
    """Verify that all subjects have the expected PLR data length.

    Parameters
    ----------
    df : pl.DataFrame
        Input dataframe.
    cfg : DictConfig
        Configuration dictionary with PLR_length setting.

    Raises
    ------
    AssertionError
        If any subject has a different number of timepoints than expected.
    """
    unique_codes = get_unique_polars_rows(df)
    unique_codes = list(unique_codes["subject_code"].to_numpy())

    for code in unique_codes:
        df_code = df.filter(pl.col("subject_code") == code)
        assert len(df_code) == cfg["DATA"]["PLR_length"], (
            'Length ({}) of the PLR data for "{}" is not {}'.format(
                df_code.shape[0], code, cfg["DATA"]["PLR_length"]
            )
        )

transform_data_for_momentfm

transform_data_for_momentfm(
    X: ndarray,
    mask: ndarray,
    dataset_cfg: DictConfig,
    model_name: str,
) -> tuple[ndarray, ndarray, ndarray]

Transform data arrays for MOMENT foundation model input.

Applies trimming, padding, and downsampling to prepare PLR data for the MOMENT time series foundation model.

PARAMETER DESCRIPTION
X

Input data array of shape (n_subjects, n_timepoints).

TYPE: ndarray

mask

Outlier mask array of shape (n_subjects, n_timepoints).

TYPE: ndarray

dataset_cfg

Dataset configuration with transform parameters.

TYPE: DictConfig

model_name

Name of the model (e.g., "MOMENT", "UniTS", "TimesNet").

TYPE: str

RETURNS DESCRIPTION
tuple

Tuple containing (X_transformed, mask_transformed, input_mask).

Source code in src/data_io/data_utils.py
def transform_data_for_momentfm(
    X: np.ndarray, mask: np.ndarray, dataset_cfg: DictConfig, model_name: str
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Transform data arrays for MOMENT foundation model input.

    Applies trimming, padding, and downsampling to prepare PLR data
    for the MOMENT time series foundation model.

    Parameters
    ----------
    X : np.ndarray
        Input data array of shape (n_subjects, n_timepoints).
    mask : np.ndarray
        Outlier mask array of shape (n_subjects, n_timepoints).
    dataset_cfg : DictConfig
        Dataset configuration with transform parameters.
    model_name : str
        Name of the model (e.g., "MOMENT", "UniTS", "TimesNet").

    Returns
    -------
    tuple
        Tuple containing (X_transformed, mask_transformed, input_mask).
    """
    logger.debug("Trimming the data for MomentFM")
    logger.debug(f"Trimming to size = {dataset_cfg.trim_to_size}")
    logger.debug(f"Downsample factor = {dataset_cfg.downsample_factor}")

    # Input data, e.g. standardized pupil size (PLR)
    X = transform_for_moment_fm_length(
        data_array=X,
        trim_to_size=dataset_cfg.trim_to_size,
        pad_ts=dataset_cfg.pad_ts,
        downsample_factor=dataset_cfg.downsample_factor,
        resample_method=dataset_cfg.resample_method,
        split_subjects_to_windows=dataset_cfg.split_subjects_to_windows,
        fill_na=dataset_cfg.fill_na,
        model_name=model_name,
    )

    # Mask data, e.g. what you have labeled as being outliers
    # no_of_outliers = mask.sum()
    mask = transform_for_moment_fm_length(
        data_array=mask,
        trim_to_size=dataset_cfg.trim_to_size,
        pad_ts=dataset_cfg.pad_ts,
        downsample_factor=dataset_cfg.downsample_factor,
        resample_method=dataset_cfg.resample_method,
        split_subjects_to_windows=dataset_cfg.split_subjects_to_windows,
        fill_na="0",
        binarize_output=True,
        model_name=model_name,
    )
    assert mask.shape == X.shape, "Mask and data shapes do not match"

    # Input mask data, e.g. as we have some NaNs padded, we can tell MomentFM to ignore these
    # and attend to the parts where mask is 1
    # "The input mask is utilized to regulate the time steps or patches that the model should attend to.
    #  For instance, in the case of shorter time series, you may opt not to attend to padding. To implement this,
    #  you can provide an input mask with zeros in the padded locations."
    input_mask = np.zeros((X.shape[0], X.shape[1]))
    input_mask[~np.isnan(X)] = 1

    return X.copy(), mask.copy(), input_mask.copy()

fill_na_in_array_before_windowing

fill_na_in_array_before_windowing(
    array: ndarray,
    fill_na: Optional[str],
    trim_to_size: int,
    model_name: Optional[str],
) -> ndarray

Fill NaN values in array before splitting into windows.

Different models have different requirements for handling NaN values. This function applies model-specific filling strategies.

PARAMETER DESCRIPTION
array

Input array of shape (batch_size, time_points).

TYPE: ndarray

fill_na

Strategy for filling NaN values ("median", "0", or None).

TYPE: str or None

trim_to_size

Target window size after trimming.

TYPE: int

model_name

Name of the model ("TimesNet", "UniTS", etc.).

TYPE: str

RETURNS DESCRIPTION
ndarray

Array with NaN values filled according to the strategy.

RAISES DESCRIPTION
ValueError

If model_name is unknown.

NotImplementedError

If fill_na strategy is not implemented.

Source code in src/data_io/data_utils.py
def fill_na_in_array_before_windowing(
    array: np.ndarray,
    fill_na: Optional[str],
    trim_to_size: int,
    model_name: Optional[str],
) -> np.ndarray:
    """Fill NaN values in array before splitting into windows.

    Different models have different requirements for handling NaN values.
    This function applies model-specific filling strategies.

    Parameters
    ----------
    array : np.ndarray
        Input array of shape (batch_size, time_points).
    fill_na : str or None
        Strategy for filling NaN values ("median", "0", or None).
    trim_to_size : int
        Target window size after trimming.
    model_name : str
        Name of the model ("TimesNet", "UniTS", etc.).

    Returns
    -------
    np.ndarray
        Array with NaN values filled according to the strategy.

    Raises
    ------
    ValueError
        If model_name is unknown.
    NotImplementedError
        If fill_na strategy is not implemented.
    """
    logger.debug(f"Filling NaNs in the array before windowing with {fill_na}")

    def fill_na_per_subject(
        array_subj: np.ndarray,
        fill_na="median",
        model_name: str = "UniTS",
        start_idxs=(9, 12),
        end_idxs=(1987, 1990),
        trim_to_size=trim_to_size,
    ):
        if fill_na == "median":
            fillna_start = np.nanmedian(array_subj[start_idxs[0] : start_idxs[1]])
            fillna_end = np.nanmedian(array_subj[end_idxs[0] : end_idxs[1]])
        elif fill_na == "0":
            fillna_start = 0
            fillna_end = 0
        else:
            logger.error(f"fill_na = {fill_na} not implemented")
            raise NotImplementedError(f"fill_na = {fill_na} not implemented")
        array_subj[: start_idxs[0]] = fillna_start
        array_subj[end_idxs[1] :] = fillna_end
        return array_subj

    if model_name is not None:
        if fill_na is not None:
            if model_name == "TimesNet" or model_name == "UniTS":
                # TimesNet does not like NaNs in the input data so you can have this quick hacky fix
                # Assumed that the padding is now a multiple of 100 (or 500) -> giving 2,000 as PLR length
                # Momemnt used 512*4=2048 in contrast
                no_subjects = array.shape[0]
                for subj_idx in range(no_subjects):
                    array[subj_idx, :] = fill_na_per_subject(
                        array_subj=array[subj_idx, :],
                        fill_na=fill_na,
                        model_name=model_name,
                    )
            else:
                # e.g. Moment is okay with NaNs in the data as you can use the input_mask to mask out invalid
                # (e.g. padded or missing points) in the input data
                logger.warning("Unknown model_name = {}".format(model_name))
                raise ValueError(f"Unknown model_name = {model_name}")

            no_of_nans = np.sum(np.isnan(array))
            assert no_of_nans == 0, "No NaNs detected, after padding and filling"

    return array

transform_for_moment_fm_length

transform_for_moment_fm_length(
    data_array: ndarray,
    trim_to_size: int = 512,
    pad_ts: bool = True,
    downsample_factor: int = 4,
    resample_method: str = "cubic",
    split_subjects_to_windows: bool = True,
    _binarize_output: bool = False,
    fill_na: Optional[str] = None,
    model_name: Optional[str] = None,
) -> ndarray

Transform data array to required length for foundation models.

Applies padding or trimming, optional downsampling, and optional window splitting to prepare data for time series foundation models.

PARAMETER DESCRIPTION
data_array

Input array of shape (n_subjects, n_timepoints).

TYPE: ndarray

trim_to_size

Target size for trimming/padding, by default 512.

TYPE: int DEFAULT: 512

pad_ts

Whether to pad the time series, by default True.

TYPE: bool DEFAULT: True

downsample_factor

Factor for downsampling, by default 4.

TYPE: int DEFAULT: 4

resample_method

Interpolation method for resampling, by default "cubic".

TYPE: str DEFAULT: 'cubic'

split_subjects_to_windows

Whether to split into fixed-size windows, by default True.

TYPE: bool DEFAULT: True

fill_na

Strategy for filling NaN values, by default None.

TYPE: str DEFAULT: None

model_name

Name of the target model, by default None.

TYPE: str DEFAULT: None

RETURNS DESCRIPTION
ndarray

Transformed data array.

Source code in src/data_io/data_utils.py
def transform_for_moment_fm_length(
    data_array: np.ndarray,
    trim_to_size: int = 512,
    pad_ts: bool = True,
    downsample_factor: int = 4,
    resample_method: str = "cubic",
    split_subjects_to_windows: bool = True,
    _binarize_output: bool = False,
    fill_na: Optional[str] = None,
    model_name: Optional[str] = None,
) -> np.ndarray:
    """Transform data array to required length for foundation models.

    Applies padding or trimming, optional downsampling, and optional
    window splitting to prepare data for time series foundation models.

    Parameters
    ----------
    data_array : np.ndarray
        Input array of shape (n_subjects, n_timepoints).
    trim_to_size : int, optional
        Target size for trimming/padding, by default 512.
    pad_ts : bool, optional
        Whether to pad the time series, by default True.
    downsample_factor : int, optional
        Factor for downsampling, by default 4.
    resample_method : str, optional
        Interpolation method for resampling, by default "cubic".
    split_subjects_to_windows : bool, optional
        Whether to split into fixed-size windows, by default True.
    fill_na : str, optional
        Strategy for filling NaN values, by default None.
    model_name : str, optional
        Name of the target model, by default None.

    Returns
    -------
    np.ndarray
        Transformed data array.
    """
    if pad_ts:
        # Pad to the next multiple of 512, e.g. (1981,) -> (2048,) with NaNs for the padding
        array = pad_glaucoma_PLR(data_array=data_array, trim_to_size=trim_to_size)
    else:
        # Trim to the multiple of trim_to_size (e.g. 96): (1981,) -> (1920) = 20*96
        array = trim_to_multiple_of(data_array=data_array, window_size=trim_to_size)
        assert array.shape[1] % trim_to_size == 0, (
            "Something funky happened with the trim? "
            "Length ({}) should be a multiple of trim_to_size ({})".format(
                array.shape[1], trim_to_size
            )
        )

    # Make new pseudosubjects
    if split_subjects_to_windows:
        # e.g. (355,1981) -> (7100,100) for TimesNet
        array = fill_na_in_array_before_windowing(
            array, fill_na, trim_to_size, model_name
        )
        array = split_subjects_to_windows_PLR(array=array, window_size=trim_to_size)

    else:
        if downsample_factor is not None:
            array = downsample_PLR(
                array=array,
                downsample_factor=downsample_factor,
                resample_method=resample_method,
            )

    return array

trim_to_multiple_of

trim_to_multiple_of(
    data_array: ndarray, window_size: int = 96
) -> ndarray

Trim array length to a multiple of window_size by removing edge samples.

PARAMETER DESCRIPTION
data_array

Input array of shape (n_subjects, n_timepoints).

TYPE: ndarray

window_size

Target multiple size, by default 96.

TYPE: int DEFAULT: 96

RETURNS DESCRIPTION
ndarray

Trimmed array with length as a multiple of window_size.

Source code in src/data_io/data_utils.py
def trim_to_multiple_of(data_array: np.ndarray, window_size: int = 96) -> np.ndarray:
    """Trim array length to a multiple of window_size by removing edge samples.

    Parameters
    ----------
    data_array : np.ndarray
        Input array of shape (n_subjects, n_timepoints).
    window_size : int, optional
        Target multiple size, by default 96.

    Returns
    -------
    np.ndarray
        Trimmed array with length as a multiple of window_size.
    """
    length_in = data_array.shape[1]
    no_of_windows = np.floor(length_in / window_size).astype(int)
    new_length = no_of_windows * window_size
    to_trim = length_in - new_length
    if (to_trim % 2) == 0:  # for odd trim
        i1 = to_trim // 2
        i2 = length_in - i1
    else:  # for even, trim 1 point more from beginning with less important data
        i1 = to_trim // 2
        i2 = length_in - i1
        i1 += 1

    return data_array[:, i1:i2]

get_no_of_windows

get_no_of_windows(
    length_PLR: int = 1981, window_size: int = 512
)

Calculate the number of windows needed to cover the PLR signal.

PARAMETER DESCRIPTION
length_PLR

Length of the PLR signal, by default 1981.

TYPE: int DEFAULT: 1981

window_size

Size of each window, by default 512.

TYPE: int DEFAULT: 512

RETURNS DESCRIPTION
int

Number of windows needed (rounded up).

Source code in src/data_io/data_utils.py
def get_no_of_windows(length_PLR: int = 1981, window_size: int = 512):
    """Calculate the number of windows needed to cover the PLR signal.

    Parameters
    ----------
    length_PLR : int, optional
        Length of the PLR signal, by default 1981.
    window_size : int, optional
        Size of each window, by default 512.

    Returns
    -------
    int
        Number of windows needed (rounded up).
    """
    return np.ceil(length_PLR / window_size).astype(int)

split_subjects_to_windows_PLR

split_subjects_to_windows_PLR(
    array: ndarray, window_size: int = 512
)

Split subject data into fixed-size windows.

Reshapes the array so each subject's time series is split into multiple windows, creating pseudo-subjects.

PARAMETER DESCRIPTION
array

Input array of shape (n_subjects, n_timepoints).

TYPE: ndarray

window_size

Size of each window, by default 512.

TYPE: int DEFAULT: 512

RETURNS DESCRIPTION
ndarray

Reshaped array of shape (n_subjects * windows_per_subject, window_size).

Source code in src/data_io/data_utils.py
def split_subjects_to_windows_PLR(array: np.ndarray, window_size: int = 512):
    """Split subject data into fixed-size windows.

    Reshapes the array so each subject's time series is split into
    multiple windows, creating pseudo-subjects.

    Parameters
    ----------
    array : np.ndarray
        Input array of shape (n_subjects, n_timepoints).
    window_size : int, optional
        Size of each window, by default 512.

    Returns
    -------
    np.ndarray
        Reshaped array of shape (n_subjects * windows_per_subject, window_size).
    """
    windows_per_subject = array.shape[1] // window_size
    no_subjects = array.shape[0]
    array_out = np.reshape(array, (no_subjects * windows_per_subject, window_size))
    return array_out

downsample_PLR

downsample_PLR(
    array: ndarray,
    downsample_factor: int = 4,
    resample_method: str = "cubic",
)

Downsample PLR signals by a given factor using interpolation.

PARAMETER DESCRIPTION
array

Input array of shape (n_subjects, n_timepoints).

TYPE: ndarray

downsample_factor

Factor by which to reduce the number of samples, by default 4.

TYPE: int DEFAULT: 4

resample_method

Interpolation method ("cubic", "linear", etc.), by default "cubic".

TYPE: str DEFAULT: 'cubic'

RETURNS DESCRIPTION
ndarray

Downsampled array of shape (n_subjects, n_timepoints // downsample_factor).

RAISES DESCRIPTION
AssertionError

If NaN ratio increases significantly after resampling or all values are NaN.

Source code in src/data_io/data_utils.py
def downsample_PLR(
    array: np.ndarray, downsample_factor: int = 4, resample_method: str = "cubic"
):
    """Downsample PLR signals by a given factor using interpolation.

    Parameters
    ----------
    array : np.ndarray
        Input array of shape (n_subjects, n_timepoints).
    downsample_factor : int, optional
        Factor by which to reduce the number of samples, by default 4.
    resample_method : str, optional
        Interpolation method ("cubic", "linear", etc.), by default "cubic".

    Returns
    -------
    np.ndarray
        Downsampled array of shape (n_subjects, n_timepoints // downsample_factor).

    Raises
    ------
    AssertionError
        If NaN ratio increases significantly after resampling or all values are NaN.
    """

    def downsample_subject(x, y, downsample_factor, resample_method):
        nan_ratio = np.isnan(y).sum() / len(y)
        x_new = np.linspace(x[0], x[-1], len(x) // downsample_factor)
        f = interpolate.interp1d(x, y, kind=resample_method)
        y_resampled = f(x_new)
        nan_ratio_resampled = np.isnan(y_resampled).sum() / len(y_resampled)
        # we know assume that you only have NaN padding and no NaNs in the signal, so if you get some new
        # NaNs, it is an issue. And this might happen also with non-nice downsample factors?
        safety_factor = 1.5
        assert nan_ratio_resampled < nan_ratio * safety_factor, (
            f"NaN ratio before ({nan_ratio}) and after ({nan_ratio_resampled}) resampling do not match"
        )
        # import matplotlib.pyplot as plt
        # plt.plot(x, y)
        # plt.show()

        return y_resampled

    samples_out: int = int(array.shape[1] // downsample_factor)
    no_subjects, no_timepoints = array.shape
    for i in range(no_subjects):
        x = np.linspace(0, no_timepoints, no_timepoints)
        y = array[i, :]
        y_resampled = downsample_subject(x, y, downsample_factor, resample_method)
        if i == 0:
            y_out = y_resampled
        else:
            y_out = np.vstack((y_out, y_resampled))

    assert y_out.shape[1] == samples_out, (
        f"Downsampled array length is not {samples_out}"
    )

    nan_ratio = np.isnan(y_out).sum() / y_out.size
    assert nan_ratio != 1, "All your values seem NaN now"

    return y_out

unpad_glaucoma_PLR

unpad_glaucoma_PLR(array: ndarray, length_PLR: int = 1981)

Remove padding from PLR array to restore original length.

PARAMETER DESCRIPTION
array

Padded array of shape (n_subjects, padded_length).

TYPE: ndarray

length_PLR

Original PLR signal length, by default 1981.

TYPE: int DEFAULT: 1981

RETURNS DESCRIPTION
ndarray

Unpadded array of shape (n_subjects, length_PLR).

Source code in src/data_io/data_utils.py
def unpad_glaucoma_PLR(array: np.ndarray, length_PLR: int = 1981):
    """Remove padding from PLR array to restore original length.

    Parameters
    ----------
    array : np.ndarray
        Padded array of shape (n_subjects, padded_length).
    length_PLR : int, optional
        Original PLR signal length, by default 1981.

    Returns
    -------
    np.ndarray
        Unpadded array of shape (n_subjects, length_PLR).
    """
    start_idx, end_idx = get_padding_indices(
        length_orig=length_PLR, length_padded=array.shape[1]
    )
    array_out = array[:, start_idx:end_idx]

    return array_out

get_padding_indices

get_padding_indices(
    length_orig: int = 1981, length_padded: int = 2048
)

Calculate start and end indices for centered padding/unpadding.

PARAMETER DESCRIPTION
length_orig

Original signal length, by default 1981.

TYPE: int DEFAULT: 1981

length_padded

Padded signal length, by default 2048.

TYPE: int DEFAULT: 2048

RETURNS DESCRIPTION
tuple

Tuple containing (start_idx, end_idx) for slicing.

Source code in src/data_io/data_utils.py
def get_padding_indices(length_orig: int = 1981, length_padded: int = 2048):
    """Calculate start and end indices for centered padding/unpadding.

    Parameters
    ----------
    length_orig : int, optional
        Original signal length, by default 1981.
    length_padded : int, optional
        Padded signal length, by default 2048.

    Returns
    -------
    tuple
        Tuple containing (start_idx, end_idx) for slicing.
    """
    no_points_pad = length_padded - length_orig  # 67
    start_idx = no_points_pad // 2  # 33
    end_idx = start_idx + length_orig  # 2014
    return start_idx, end_idx

pad_glaucoma_PLR

pad_glaucoma_PLR(
    data_array: ndarray, trim_to_size: int = 512
)

Pad PLR array with NaN values to reach a multiple of trim_to_size.

Centers the original data within the padded array.

PARAMETER DESCRIPTION
data_array

Input array of shape (n_subjects, n_timepoints).

TYPE: ndarray

trim_to_size

Target multiple for the padded length, by default 512.

TYPE: int DEFAULT: 512

RETURNS DESCRIPTION
ndarray

Padded array of shape (n_subjects, ceil(n_timepoints/trim_to_size) * trim_to_size).

RAISES DESCRIPTION
AssertionError

If the padded array contains only NaN values.

Source code in src/data_io/data_utils.py
def pad_glaucoma_PLR(data_array: np.ndarray, trim_to_size: int = 512):
    """Pad PLR array with NaN values to reach a multiple of trim_to_size.

    Centers the original data within the padded array.

    Parameters
    ----------
    data_array : np.ndarray
        Input array of shape (n_subjects, n_timepoints).
    trim_to_size : int, optional
        Target multiple for the padded length, by default 512.

    Returns
    -------
    np.ndarray
        Padded array of shape (n_subjects, ceil(n_timepoints/trim_to_size) * trim_to_size).

    Raises
    ------
    AssertionError
        If the padded array contains only NaN values.
    """
    new_length = int(np.ceil(data_array.shape[1] / trim_to_size)) * trim_to_size  # 2048
    length_in = data_array.shape[1]  # 1981
    start_idx, end_idx = get_padding_indices(length_in, new_length)

    # Pad the input array with NaNs
    array_out = np.zeros((data_array.shape[0], new_length))
    array_out[:] = np.nan
    nan_sum = np.isnan(array_out).sum()
    array_out[:, start_idx:end_idx] = data_array

    assert nan_sum != np.isnan(array_out).sum(), "it seems that all you have is NaNs?"
    assert array_out.shape[1] == new_length, f"Padded array length is not {new_length}"
    return array_out

Data Wrangling

data_wrangler

convert_datadict_to_dict_arrays

convert_datadict_to_dict_arrays(
    data_dict: dict[str, Any], cls_model_cfg: DictConfig
) -> dict[str, Any]

Convert hierarchical data dictionary to flat arrays structure.

Needs to be this flat structure, used with some models. See data_transform_wrapper() -> create_dmatrices_and_dict_arrays().

PARAMETER DESCRIPTION
data_dict

Hierarchical data dictionary with train/test splits.

TYPE: dict

cls_model_cfg

Classification model configuration.

TYPE: DictConfig

RETURNS DESCRIPTION
dict

Flat dictionary with x_train, y_train, x_test, y_test, etc.

Source code in src/data_io/data_wrangler.py
def convert_datadict_to_dict_arrays(
    data_dict: dict[str, Any], cls_model_cfg: DictConfig
) -> dict[str, Any]:
    """Convert hierarchical data dictionary to flat arrays structure.

    Needs to be this flat structure, used with some models.
    See data_transform_wrapper() -> create_dmatrices_and_dict_arrays().

    Parameters
    ----------
    data_dict : dict
        Hierarchical data dictionary with train/test splits.
    cls_model_cfg : DictConfig
        Classification model configuration.

    Returns
    -------
    dict
        Flat dictionary with x_train, y_train, x_test, y_test, etc.
    """
    dict_arrays = {
        "x_train": data_dict["train"]["data"]["X"],
        "x_train_w": np.ones_like((data_dict["train"]["data"]["X"])),
        "y_train": data_dict["train"]["labels"]["class_label"][:, 0],
        "x_test": data_dict["test"]["data"]["X"],
        "x_test_w": np.ones_like((data_dict["test"]["data"]["X"])),
        "y_test": data_dict["test"]["labels"]["class_label"][:, 0],
        "feature_names": None,
        "subject_codes_train": data_dict["train"]["metadata"]["subject_code"][:, 0],
        "subject_codes_test": data_dict["test"]["metadata"]["subject_code"][:, 0],
    }

    return dict_arrays

fix_pl_schema

fix_pl_schema(df_metadata: DataFrame) -> DataFrame

Cast object types in a Polars dataframe to appropriate types.

Handles conversion of decimal.Decimal to float and ensures string types are properly cast.

PARAMETER DESCRIPTION
df_metadata

Polars dataframe with potential Object dtype columns.

TYPE: DataFrame

RETURNS DESCRIPTION
DataFrame

Dataframe with Object types cast to appropriate types.

See Also

convert_object_type : Similar function for numpy arrays.

References

https://docs.pola.rs/user-guide/expressions/casting/#basic-example

Source code in src/data_io/data_wrangler.py
def fix_pl_schema(df_metadata: pl.DataFrame) -> pl.DataFrame:
    """Cast object types in a Polars dataframe to appropriate types.

    Handles conversion of decimal.Decimal to float and ensures string types
    are properly cast.

    Parameters
    ----------
    df_metadata : pl.DataFrame
        Polars dataframe with potential Object dtype columns.

    Returns
    -------
    pl.DataFrame
        Dataframe with Object types cast to appropriate types.

    See Also
    --------
    convert_object_type : Similar function for numpy arrays.

    References
    ----------
    https://docs.pola.rs/user-guide/expressions/casting/#basic-example
    """

    def get_sample_value(sample_col: pl.Series) -> Any:
        # get first non-None value from Polars Series
        for sample in sample_col:
            if sample is not None:
                return sample
        return None

    for col, dtype in zip(df_metadata.columns, df_metadata.dtypes):
        if dtype == pl.Object:
            sample_value = get_sample_value(sample_col=df_metadata[col])
            if isinstance(sample_value, decimal.Decimal):
                # e.g. Age type comes out like this
                # print(1, col)
                numpy_array = df_metadata[col].to_numpy().astype(float)
                df_metadata = df_metadata.with_columns(
                    pl.Series(name=col, values=numpy_array)
                )
            elif isinstance(sample_value, str):
                # e.g. "subject_code" type comes out like this
                # print(2, col)
                numpy_array = df_metadata[col].to_numpy().astype(str)
                df_metadata = df_metadata.with_columns(
                    pl.Series(name=col, values=numpy_array)
                )
            else:
                logger.warning("Casting issue with dtype = {}".format(dtype))

    return df_metadata

convert_subject_dict_of_arrays_to_df

convert_subject_dict_of_arrays_to_df(
    subject_dict: dict[str, dict[str, ndarray]],
    wildcard_categories: list[str] | None = None,
) -> DataFrame

Convert a subject dictionary of arrays to a Polars dataframe.

PARAMETER DESCRIPTION
subject_dict

Dictionary with category names as keys containing sub-dictionaries with array data.

TYPE: dict

wildcard_categories

If provided, only include categories in this list, by default None.

TYPE: list DEFAULT: None

RETURNS DESCRIPTION
DataFrame

Polars dataframe with arrays as columns.

RAISES DESCRIPTION
AssertionError

If any array is not 1D.

Source code in src/data_io/data_wrangler.py
def convert_subject_dict_of_arrays_to_df(
    subject_dict: dict[str, dict[str, np.ndarray]],
    wildcard_categories: list[str] | None = None,
) -> pl.DataFrame:
    """Convert a subject dictionary of arrays to a Polars dataframe.

    Parameters
    ----------
    subject_dict : dict
        Dictionary with category names as keys containing sub-dictionaries
        with array data.
    wildcard_categories : list, optional
        If provided, only include categories in this list, by default None.

    Returns
    -------
    pl.DataFrame
        Polars dataframe with arrays as columns.

    Raises
    ------
    AssertionError
        If any array is not 1D.
    """
    df = pl.DataFrame()
    for category_name, category_dict in subject_dict.items():
        if wildcard_categories is None:
            for subkey, array in category_dict.items():
                assert len(array.shape) == 1, f"Array shape is not 1D: {array.shape}"
                array = convert_object_type(
                    array
                )  # if possible Object types, causing downstream issues
                df = df.with_columns(pl.Series(name=subkey, values=array))
        else:
            if category_name in wildcard_categories:
                for subkey, array in category_dict.items():
                    assert len(array.shape) == 1, (
                        f"Array shape is not 1D: {array.shape}"
                    )
                    array = convert_object_type(
                        array
                    )  # if possible Object types, causing downstream issues
                    df = df.with_columns(pl.Series(name=subkey, values=array))
    return df

get_subject_dict_for_featurization

get_subject_dict_for_featurization(
    split_dict: dict[str, dict[str, ndarray]],
    i: int,
    cfg: DictConfig,
    return_1st_value: bool = False,
) -> dict[str, dict[str, Any]]

Extract a single subject's data from a split dictionary.

PARAMETER DESCRIPTION
split_dict

Dictionary containing data for all subjects in a split.

TYPE: dict

i

Subject index to extract.

TYPE: int

cfg

Configuration dictionary.

TYPE: DictConfig

return_1st_value

If True, return only the first value; otherwise return the full row, by default False.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
dict

Dictionary containing only the specified subject's data.

Source code in src/data_io/data_wrangler.py
def get_subject_dict_for_featurization(
    split_dict: dict[str, dict[str, np.ndarray]],
    i: int,
    cfg: DictConfig,
    return_1st_value: bool = False,
) -> dict[str, dict[str, Any]]:
    """Extract a single subject's data from a split dictionary.

    Parameters
    ----------
    split_dict : dict
        Dictionary containing data for all subjects in a split.
    i : int
        Subject index to extract.
    cfg : DictConfig
        Configuration dictionary.
    return_1st_value : bool, optional
        If True, return only the first value; otherwise return the full row,
        by default False.

    Returns
    -------
    dict
        Dictionary containing only the specified subject's data.
    """
    subject_dict = deepcopy(split_dict)
    for category_name, category_dict in split_dict.items():
        for subkey, array in category_dict.items():
            if return_1st_value:
                subject_dict[category_name][subkey] = array[i, 0]
            else:
                subject_dict[category_name][subkey] = array[i, :]
    return subject_dict

pick_correct_data_and_label_for_experiment

pick_correct_data_and_label_for_experiment(
    data_dict: dict[str, Any],
    cfg: DictConfig,
    task: str,
    _task_cfg: DictConfig,
) -> None

Select appropriate data columns for a specific experiment task.

PARAMETER DESCRIPTION
data_dict

Full data dictionary.

TYPE: dict

cfg

Configuration dictionary.

TYPE: DictConfig

task

Task name.

TYPE: str

_task_cfg

Task-specific configuration (currently unused).

TYPE: DictConfig

Notes

Currently a placeholder function.

Source code in src/data_io/data_wrangler.py
def pick_correct_data_and_label_for_experiment(
    data_dict: dict[str, Any], cfg: DictConfig, task: str, _task_cfg: DictConfig
) -> None:
    """Select appropriate data columns for a specific experiment task.

    Parameters
    ----------
    data_dict : dict
        Full data dictionary.
    cfg : DictConfig
        Configuration dictionary.
    task : str
        Task name.
    _task_cfg : DictConfig
        Task-specific configuration (currently unused).

    Notes
    -----
    Currently a placeholder function.
    """
    print("placeholder")

get_dict_with_wildcard

get_dict_with_wildcard(
    df: DataFrame, wildcard: str = "pupil"
) -> dict[str, Series]

Extract columns matching a wildcard pattern as a dictionary.

PARAMETER DESCRIPTION
df

Input Polars dataframe.

TYPE: DataFrame

wildcard

Pattern to match in column names, by default "pupil".

TYPE: str DEFAULT: 'pupil'

RETURNS DESCRIPTION
dict

Dictionary mapping column names to Polars Series.

Source code in src/data_io/data_wrangler.py
def get_dict_with_wildcard(
    df: pl.DataFrame, wildcard: str = "pupil"
) -> dict[str, pl.Series]:
    """Extract columns matching a wildcard pattern as a dictionary.

    Parameters
    ----------
    df : pl.DataFrame
        Input Polars dataframe.
    wildcard : str, optional
        Pattern to match in column names, by default "pupil".

    Returns
    -------
    dict
        Dictionary mapping column names to Polars Series.
    """
    # Get all columns that have the wildcard
    cols = [col for col in df.columns if wildcard in col]

    # Create a dictionary with the wildcard columns
    data_dict = {}
    for col in cols:
        data_dict[col] = df[col]

    return data_dict

get_dict_with_list_of_cols

get_dict_with_list_of_cols(
    df: DataFrame, cols: list[str]
) -> dict[str, Series]

Extract specific columns as a dictionary.

PARAMETER DESCRIPTION
df

Input Polars dataframe.

TYPE: DataFrame

cols

List of column names to extract.

TYPE: list

RETURNS DESCRIPTION
dict

Dictionary mapping column names to Polars Series.

Source code in src/data_io/data_wrangler.py
def get_dict_with_list_of_cols(
    df: pl.DataFrame, cols: list[str]
) -> dict[str, pl.Series]:
    """Extract specific columns as a dictionary.

    Parameters
    ----------
    df : pl.DataFrame
        Input Polars dataframe.
    cols : list
        List of column names to extract.

    Returns
    -------
    dict
        Dictionary mapping column names to Polars Series.
    """
    # Create a dictionary with the wildcard columns
    data_dict = {}
    for col in cols:
        data_dict[col] = df[col]

    return data_dict

get_dict_with_remaining_cols

get_dict_with_remaining_cols(
    df_split: DataFrame,
    data_dict: dict[str, dict[str, Any]],
) -> dict[str, Series]

Extract columns not already present in data_dict as a dictionary.

PARAMETER DESCRIPTION
df_split

Input Polars dataframe.

TYPE: DataFrame

data_dict

Existing data dictionary to check for used columns.

TYPE: dict

RETURNS DESCRIPTION
dict

Dictionary mapping remaining column names to Polars Series.

Source code in src/data_io/data_wrangler.py
def get_dict_with_remaining_cols(
    df_split: pl.DataFrame, data_dict: dict[str, dict[str, Any]]
) -> dict[str, pl.Series]:
    """Extract columns not already present in data_dict as a dictionary.

    Parameters
    ----------
    df_split : pl.DataFrame
        Input Polars dataframe.
    data_dict : dict
        Existing data dictionary to check for used columns.

    Returns
    -------
    dict
        Dictionary mapping remaining column names to Polars Series.
    """
    # Get all the columns that are not in the data_dict
    used_cols = []
    for key1, dict in data_dict.items():
        for key2 in dict.keys():
            used_cols.append(key2)

    remaining_cols = [col for col in df_split.columns if col not in used_cols]

    # Create a dictionary with the wildcard columns
    data_dict_remaining = {}
    for col in remaining_cols:
        data_dict_remaining[col] = df_split[col]

    return data_dict_remaining

convert_object_type

convert_object_type(array_tmp: ndarray) -> ndarray

Convert numpy object dtype arrays to appropriate types.

Handles conversion of decimal.Decimal to float and str to string dtype.

PARAMETER DESCRIPTION
array_tmp

Array potentially with object dtype.

TYPE: ndarray

RETURNS DESCRIPTION
ndarray

Array with appropriate dtype (float or str).

Source code in src/data_io/data_wrangler.py
def convert_object_type(array_tmp: np.ndarray) -> np.ndarray:
    """Convert numpy object dtype arrays to appropriate types.

    Handles conversion of decimal.Decimal to float and str to string dtype.

    Parameters
    ----------
    array_tmp : np.ndarray
        Array potentially with object dtype.

    Returns
    -------
    np.ndarray
        Array with appropriate dtype (float or str).
    """
    # these might cause weird unintuitive issues downstream
    if array_tmp.dtype == "object":
        first_value = array_tmp[0]
        if isinstance(first_value, decimal.Decimal):
            # e.g. Age type comes out like this
            array_tmp = array_tmp.astype(float)
        elif isinstance(first_value, str):
            # e.g. "subject_code" type comes out like this
            array_tmp = array_tmp.astype(str)

    return array_tmp

reshape_flat_series_to_2d_arrays

reshape_flat_series_to_2d_arrays(
    dict_series: dict[str, dict[str, Series]],
    length_PLR: int = 1981,
) -> dict[str, dict[str, ndarray]]

Reshape flat Polars Series to 2D numpy arrays.

Converts a dictionary of Series to 2D arrays with shape (n_subjects, n_timepoints).

PARAMETER DESCRIPTION
dict_series

Dictionary of dictionaries containing Polars Series.

TYPE: dict

length_PLR

Number of timepoints per subject, by default 1981.

TYPE: int DEFAULT: 1981

RETURNS DESCRIPTION
dict

Dictionary with same structure but 2D numpy arrays.

Source code in src/data_io/data_wrangler.py
def reshape_flat_series_to_2d_arrays(
    dict_series: dict[str, dict[str, pl.Series]], length_PLR: int = 1981
) -> dict[str, dict[str, np.ndarray]]:
    """Reshape flat Polars Series to 2D numpy arrays.

    Converts a dictionary of Series to 2D arrays with shape
    (n_subjects, n_timepoints).

    Parameters
    ----------
    dict_series : dict
        Dictionary of dictionaries containing Polars Series.
    length_PLR : int, optional
        Number of timepoints per subject, by default 1981.

    Returns
    -------
    dict
        Dictionary with same structure but 2D numpy arrays.
    """
    dict_arrays = {}
    for key1, dict in dict_series.items():
        dict_arrays[key1] = {}
        for key2, series in dict.items():
            array_tmp = convert_object_type(array_tmp=series.to_numpy())
            dict_arrays[key1][key2] = array_tmp.reshape(-1, length_PLR)

    return dict_arrays

split_df_to_dict

split_df_to_dict(
    df_split: DataFrame, cfg: DictConfig, split: str
) -> dict[str, dict[str, ndarray]]

Convert a split dataframe to a hierarchical dictionary of 2D arrays.

Creates a structured dictionary with categories: time, data, labels, light, and metadata. All values are reshaped to (n_subjects, n_timepoints).

PARAMETER DESCRIPTION
df_split

Polars dataframe for a single split.

TYPE: DataFrame

cfg

Configuration dictionary with DATA settings.

TYPE: DictConfig

split

Name of the split (for logging).

TYPE: str

RETURNS DESCRIPTION
dict

Hierarchical dictionary with 2D numpy arrays.

Source code in src/data_io/data_wrangler.py
def split_df_to_dict(
    df_split: pl.DataFrame, cfg: DictConfig, split: str
) -> dict[str, dict[str, np.ndarray]]:
    """Convert a split dataframe to a hierarchical dictionary of 2D arrays.

    Creates a structured dictionary with categories: time, data, labels,
    light, and metadata. All values are reshaped to (n_subjects, n_timepoints).

    Parameters
    ----------
    df_split : pl.DataFrame
        Polars dataframe for a single split.
    cfg : DictConfig
        Configuration dictionary with DATA settings.
    split : str
        Name of the split (for logging).

    Returns
    -------
    dict
        Hierarchical dictionary with 2D numpy arrays.
    """
    # Hierarchical dictionary, so you can easily add stuff later on
    # These will be pl.Series
    data_dict = {}
    data_dict["time"] = get_dict_with_wildcard(df_split, wildcard="time")
    data_dict["data"] = get_dict_with_wildcard(df_split, wildcard="pupil")
    data_dict["labels"] = get_dict_with_list_of_cols(
        df_split,
        cols=[
            "class_label",
            "outlier_mask",
            "imputation_mask",
            "outlier_mask_easy",
            "outlier_mask_medium",
        ],
    )
    data_dict["light"] = get_dict_with_list_of_cols(
        df_split, cols=["Red", "Blue", "light_stimuli"]
    )
    data_dict["metadata"] = get_dict_with_remaining_cols(df_split, data_dict)

    # Convert pl.Series to 2D numpy arrays (no_subjects, no_timepoints)
    data_dict = reshape_flat_series_to_2d_arrays(
        dict_series=data_dict, length_PLR=cfg["DATA"]["PLR_length"]
    )

    return data_dict

convert_df_to_dict

convert_df_to_dict(
    data_df: DataFrame, cfg: DictConfig
) -> dict[str, Any]

Convert a Polars dataframe to a hierarchical dictionary for model input.

Converts the combined dataframe into a structured dictionary that can be used with various ML frameworks: - sklearn: (X_train, X_val, y_train, y_val) - PyTorch: (dataloader, dataset)

PARAMETER DESCRIPTION
data_df

Combined Polars dataframe with 'split' column.

TYPE: DataFrame

cfg

Configuration dictionary with DATA settings.

TYPE: DictConfig

RETURNS DESCRIPTION
dict

Dictionary with 'df' key containing split dictionaries and 'preprocess' key with preprocessing parameters.

Source code in src/data_io/data_wrangler.py
def convert_df_to_dict(data_df: pl.DataFrame, cfg: DictConfig) -> dict[str, Any]:
    """Convert a Polars dataframe to a hierarchical dictionary for model input.

    Converts the combined dataframe into a structured dictionary that can
    be used with various ML frameworks:
    - sklearn: (X_train, X_val, y_train, y_val)
    - PyTorch: (dataloader, dataset)

    Parameters
    ----------
    data_df : pl.DataFrame
        Combined Polars dataframe with 'split' column.
    cfg : DictConfig
        Configuration dictionary with DATA settings.

    Returns
    -------
    dict
        Dictionary with 'df' key containing split dictionaries and
        'preprocess' key with preprocessing parameters.
    """
    data_dicts = {}
    data_dicts["df"] = {}
    for i, split in enumerate(data_df["split"].unique().to_list()):
        # Dataframe into a nested dictionary (categories) with 2D numpy arrays
        # (no_subjects, no_timepoints)
        data_dicts["df"][split] = split_df_to_dict(
            df_split=data_df.filter(pl.col("split") == split), cfg=cfg, split=split
        )

    # Preprocess if desired
    data_dicts = preprocess_data_dicts(data_dicts=data_dicts, cfg=cfg)

    return data_dicts

Flow Data

flow_data

flow_import_data

flow_import_data(cfg: DictConfig) -> DataFrame

Import PLR data from either raw CSVs or DuckDB database.

Main data import flow that handles loading, combining splits, optional debug subsetting, and visualization of PLR pupillometry data.

PARAMETER DESCRIPTION
cfg

Configuration dictionary containing DATA, DEBUG, EXPERIMENT, and PREFECT settings.

TYPE: DictConfig

RETURNS DESCRIPTION
DataFrame

Combined Polars dataframe with train and test splits indicated by the 'split' column.

RAISES DESCRIPTION
NotImplementedError

If debug mode is on but no subject subset is specified.

Source code in src/data_io/flow_data.py
def flow_import_data(cfg: DictConfig) -> pl.DataFrame:
    """Import PLR data from either raw CSVs or DuckDB database.

    Main data import flow that handles loading, combining splits, optional
    debug subsetting, and visualization of PLR pupillometry data.

    Parameters
    ----------
    cfg : DictConfig
        Configuration dictionary containing DATA, DEBUG, EXPERIMENT,
        and PREFECT settings.

    Returns
    -------
    pl.DataFrame
        Combined Polars dataframe with train and test splits indicated
        by the 'split' column.

    Raises
    ------
    NotImplementedError
        If debug mode is on but no subject subset is specified.
    """
    experiment_name = experiment_name_wrapper(
        experiment_name=cfg["PREFECT"]["FLOW_NAMES"]["DATA_IMPORT"], cfg=cfg
    )
    logger.info("FLOW | Name: {}".format(experiment_name))
    logger.info("=====================")
    data_dir = get_data_dir(data_path=cfg["DATA"]["data_path"])

    if cfg["DATA"]["import_from_DuckDB"]:
        # TODO! Change this a bit later, when you know if this DuckDB can be somewhere online
        df_train, df_test = import_data_from_duckdb(
            data_cfg=cfg["DATA"],
            data_dir=data_dir,
            use_demo_data=cfg["EXPERIMENT"]["use_demo_data"],
        )
        # Compute granularized outlier masks if not in database
        # (these columns are computed on-the-fly from outlier_mask)
        if "outlier_mask_easy" not in df_train.columns:
            df_train = granularize_outlier_labels(df_train, cfg)
            df_test = granularize_outlier_labels(df_test, cfg)
    else:
        # Task 1) Import the Polars dataframe
        df_train, df_test = import_PLR_data_wrapper(cfg, data_dir=data_dir)

    # check that each subject has good data lengths
    check_for_data_lengths(df_train, cfg)
    check_for_data_lengths(df_test, cfg)

    # If you have DEBUG MODE on, and you only want to use a subset of data
    # Combine splits to one dataframe with the column indicating the split
    if cfg["DEBUG"]["debug_n_subjects"] is not None:
        if get_demo_string_to_add() in experiment_name:
            demo_mode = True
        else:
            demo_mode = False

        df = combine_split_dataframes(
            df_train,
            df_test,
            cfg,
            debug_mode=cfg["EXPERIMENT"]["debug"],
            debug_n=cfg["DEBUG"]["debug_n_subjects"],
            pick_random=cfg["DEBUG"].get("pick_random", False),
            demo_mode=demo_mode,
        )
    else:
        logger.warning("DEBUG MODE is ON, but you did not take a subset of the data")
        logger.warning(
            "This typically reserved for Github Actions (or something) to run an end-to-end test"
        )
        raise NotImplementedError

    # Check that all the subjects that have equal number of samples
    check_for_data_lengths(df, cfg)

    if cfg["DATA"]["VISUALIZE"]["visualize_input_subjects"]:
        # Whether to visualize the input data or not
        # The "import_PLR_data_wrapper" contains some heuristics for data quality, but you could obviously
        # implement something more sophisticated here as well
        logger.info("Visualization of the input data")

        # Task 2) Visualize the data
        fig_paths = visualize_input_data(df=df, cfg=cfg)

        # Task 3) Create a MP4 video from the figures
        create_video_from_figures_on_disk(fig_paths=fig_paths, cfg=cfg)

    else:
        logger.info("Skipping the visualization of the input data")

    return df

DuckDB Export

duckdb_export

Memory-efficient export of features and classifier results to DuckDB.

This module creates shareable DuckDB databases that enable: 1. Reproducibility without raw clinical data access 2. Memory-efficient analysis (target: <16GB RAM) 3. Continuation from intermediate artifacts

Cross-references: - planning/share-features-and-classifier-outputs.md - planning/statistics-implementation.md (Memory Optimization section)

Output Files: - foundation_plr_features.db: Hand-crafted PLR features (shareable) - foundation_plr_results.db: All classifier outputs (shareable)

Usage: # Export from mlruns python -m src.data_io.duckdb_export export --mlruns ./mlruns

# Continue analysis from features.db
python -m src.data_io.duckdb_export analyze --from-features features.db

# Continue analysis from results.db (re-run only statistics)
python -m src.data_io.duckdb_export analyze --from-results results.db

DuckDBAnalysisPipeline dataclass

DuckDBAnalysisPipeline(
    features_db: Optional[Path] = None,
    results_db: Optional[Path] = None,
    output_dir: Path = (lambda: Path("outputs/analysis"))(),
    _features: Optional[ndarray] = None,
    _labels: Optional[ndarray] = None,
    _feature_names: Optional[List[str]] = None,
    _predictions_df: Optional[DataFrame] = None,
    _metrics_df: Optional[DataFrame] = None,
)

Pipeline for running analysis from DuckDB artifacts.

Supports continuation from: 1. features.db - re-run classification + statistics 2. results.db - re-run only statistics

Usage: # From features (re-run classification) pipeline = DuckDBAnalysisPipeline.from_features("features.db") pipeline.run_classification() pipeline.run_statistics()

# From results (re-run only statistics)
pipeline = DuckDBAnalysisPipeline.from_results("results.db")
pipeline.run_statistics()

from_features classmethod

from_features(
    features_db: Union[str, Path],
    output_dir: Optional[Union[str, Path]] = None,
) -> DuckDBAnalysisPipeline

Create pipeline from features database (will run classification).

Source code in src/data_io/duckdb_export.py
@classmethod
def from_features(
    cls,
    features_db: Union[str, Path],
    output_dir: Optional[Union[str, Path]] = None,
) -> "DuckDBAnalysisPipeline":
    """Create pipeline from features database (will run classification)."""
    pipeline = cls(
        features_db=Path(features_db),
        output_dir=Path(output_dir) if output_dir else Path("outputs/analysis"),
    )
    pipeline._load_features()
    return pipeline

from_results classmethod

from_results(
    results_db: Union[str, Path],
    output_dir: Optional[Union[str, Path]] = None,
) -> DuckDBAnalysisPipeline

Create pipeline from results database (statistics only).

Source code in src/data_io/duckdb_export.py
@classmethod
def from_results(
    cls, results_db: Union[str, Path], output_dir: Optional[Union[str, Path]] = None
) -> "DuckDBAnalysisPipeline":
    """Create pipeline from results database (statistics only)."""
    pipeline = cls(
        results_db=Path(results_db),
        output_dir=Path(output_dir) if output_dir else Path("outputs/analysis"),
    )
    pipeline._load_results()
    return pipeline

can_run_classification

can_run_classification() -> bool

Check if classification can be run (requires features).

Source code in src/data_io/duckdb_export.py
def can_run_classification(self) -> bool:
    """Check if classification can be run (requires features)."""
    return self._features is not None

can_run_statistics

can_run_statistics() -> bool

Check if statistics can be run (requires results).

Source code in src/data_io/duckdb_export.py
def can_run_statistics(self) -> bool:
    """Check if statistics can be run (requires results)."""
    return self._metrics_df is not None or self._predictions_df is not None

run_classification

run_classification(
    classifiers: Optional[List[str]] = None,
    n_folds: int = 5,
) -> DataFrame

Run classification on loaded features.

PARAMETER DESCRIPTION
classifiers

Classifier names to use (default: LogReg, XGBoost, CatBoost)

TYPE: List[str] DEFAULT: None

n_folds

Number of CV folds

TYPE: int DEFAULT: 5

RETURNS DESCRIPTION
DataFrame

Classification results

Source code in src/data_io/duckdb_export.py
def run_classification(
    self,
    classifiers: Optional[List[str]] = None,
    n_folds: int = 5,
) -> pd.DataFrame:
    """
    Run classification on loaded features.

    Parameters
    ----------
    classifiers : List[str], optional
        Classifier names to use (default: LogReg, XGBoost, CatBoost)
    n_folds : int, default 5
        Number of CV folds

    Returns
    -------
    pd.DataFrame
        Classification results
    """
    if not self.can_run_classification():
        raise ValueError(
            "Cannot run classification: features not loaded. "
            "Use from_features() to create pipeline."
        )

    logger.info("Running classification...")

    # Import here to avoid circular deps
    from sklearn.linear_model import LogisticRegression
    from sklearn.metrics import roc_auc_score
    from sklearn.model_selection import StratifiedKFold

    try:
        from xgboost import XGBClassifier

        has_xgboost = True
    except ImportError:
        has_xgboost = False
        logger.warning("XGBoost not available")

    if classifiers is None:
        classifiers = ["LogisticRegression"]
        if has_xgboost:
            classifiers.append("XGBoost")

    results = []
    skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)

    for clf_name in classifiers:
        logger.info(f"  Training {clf_name}...")

        for fold, (train_idx, test_idx) in enumerate(
            skf.split(self._features, self._labels)
        ):
            X_train = self._features[train_idx]
            X_test = self._features[test_idx]
            y_train = self._labels[train_idx]
            y_test = self._labels[test_idx]

            # Initialize classifier
            if clf_name == "LogisticRegression":
                clf = LogisticRegression(max_iter=1000, random_state=42)
            elif clf_name == "XGBoost" and has_xgboost:
                clf = XGBClassifier(
                    n_estimators=100,
                    use_label_encoder=False,
                    eval_metric="logloss",
                    random_state=42,
                )
            else:
                logger.warning(f"Unknown classifier: {clf_name}, skipping")
                continue

            # Train and predict
            clf.fit(X_train, y_train)
            y_prob = clf.predict_proba(X_test)[:, 1]
            y_pred = (y_prob >= 0.5).astype(int)

            # Compute AUROC (stored but not used directly - predictions are logged below)
            _ = roc_auc_score(y_test, y_prob)

            for i, (idx, prob, pred, true) in enumerate(
                zip(test_idx, y_prob, y_pred, y_test)
            ):
                results.append(
                    {
                        "prediction_id": len(results),
                        "subject_id": f"S{idx:03d}",
                        "eye": "OD",  # Placeholder
                        "fold": fold,
                        "bootstrap_iter": 0,
                        "outlier_method": "unknown",
                        "imputation_method": "unknown",
                        "featurization": "unknown",
                        "classifier": clf_name,
                        "source_name": "duckdb_pipeline",
                        "y_true": int(true),
                        "y_pred": int(pred),
                        "y_prob": float(prob),
                        "mlflow_run_id": None,
                    }
                )

    self._predictions_df = pd.DataFrame(results)
    logger.info(f"Classification complete: {len(self._predictions_df)} predictions")

    # Compute aggregate metrics
    self._compute_aggregate_metrics()

    return self._predictions_df

run_statistics

run_statistics(
    output_dir: Optional[Path] = None,
) -> Dict[str, Any]

Run statistical analysis on loaded/computed results.

PARAMETER DESCRIPTION
output_dir

Override output directory

TYPE: Path DEFAULT: None

RETURNS DESCRIPTION
Dict[str, Any]

Statistical results

Source code in src/data_io/duckdb_export.py
def run_statistics(self, output_dir: Optional[Path] = None) -> Dict[str, Any]:
    """
    Run statistical analysis on loaded/computed results.

    Parameters
    ----------
    output_dir : Path, optional
        Override output directory

    Returns
    -------
    Dict[str, Any]
        Statistical results
    """
    if not self.can_run_statistics():
        raise ValueError(
            "Cannot run statistics: results not loaded. "
            "Use from_results() or run_classification() first."
        )

    output_dir = output_dir or self.output_dir
    output_dir.mkdir(parents=True, exist_ok=True)

    logger.info("Running statistical analysis...")

    results = {
        "metrics_summary": self._metrics_df.to_dict()
        if self._metrics_df is not None
        else {},
        "n_predictions": len(self._predictions_df)
        if self._predictions_df is not None
        else 0,
    }

    # Add calibration analysis if predictions available
    if self._predictions_df is not None:
        from ..stats.calibration_extended import calibration_slope_intercept

        for clf_name, group in self._predictions_df.groupby("classifier"):
            try:
                cal_result = calibration_slope_intercept(
                    group["y_true"].values, group["y_prob"].values
                )
                results[f"calibration_{clf_name}"] = {
                    "slope": cal_result.slope,
                    "intercept": cal_result.intercept,
                    "e_o_ratio": cal_result.e_o_ratio,
                    "brier_score": cal_result.brier_score,
                }
            except Exception as e:
                logger.warning(f"Calibration failed for {clf_name}: {e}")

    # Save results
    results_path = output_dir / "statistics_results.json"
    with open(results_path, "w") as f:
        json.dump(results, f, indent=2, default=str)

    logger.info(f"Statistics saved to {results_path}")
    return results

export_for_reproduction

export_for_reproduction(
    output_path: Union[str, Path],
) -> Path

Export current state to DuckDB for future reproduction.

Source code in src/data_io/duckdb_export.py
def export_for_reproduction(
    self,
    output_path: Union[str, Path],
) -> Path:
    """Export current state to DuckDB for future reproduction."""
    output_path = Path(output_path)

    if self._predictions_df is None:
        raise ValueError("No predictions to export. Run classification first.")

    return export_results_to_duckdb(
        predictions_df=self._predictions_df,
        metrics_per_fold=None,  # Could compute from predictions
        metrics_aggregate=self._metrics_df,
        output_path=output_path,
    )

load_artifact_safe

load_artifact_safe(
    artifact_path: Path,
) -> Generator[Any, None, None]

Context manager for safe artifact loading with cleanup.

Usage: with load_artifact_safe(path) as artifact: # Use artifact # Artifact is automatically deleted and garbage collected

Source code in src/data_io/duckdb_export.py
@contextmanager
def load_artifact_safe(artifact_path: Path) -> Generator[Any, None, None]:
    """
    Context manager for safe artifact loading with cleanup.

    Usage:
        with load_artifact_safe(path) as artifact:
            # Use artifact
        # Artifact is automatically deleted and garbage collected
    """
    artifact = None
    try:
        with open(artifact_path, "rb") as f:
            artifact = pickle.load(f)
        yield artifact
    finally:
        del artifact
        gc.collect()

iter_artifacts_chunked

iter_artifacts_chunked(
    artifact_paths: List[Path], batch_size: int = 5
) -> Generator[List[Any], None, None]

Iterate over artifacts in batches with explicit cleanup.

Prevents memory accumulation when processing many artifacts.

Source code in src/data_io/duckdb_export.py
def iter_artifacts_chunked(
    artifact_paths: List[Path], batch_size: int = 5
) -> Generator[List[Any], None, None]:
    """
    Iterate over artifacts in batches with explicit cleanup.

    Prevents memory accumulation when processing many artifacts.
    """
    for i in range(0, len(artifact_paths), batch_size):
        batch_paths = artifact_paths[i : i + batch_size]
        batch_artifacts = []

        for path in batch_paths:
            with open(path, "rb") as f:
                batch_artifacts.append(pickle.load(f))

        yield batch_artifacts

        # Explicit cleanup
        del batch_artifacts
        gc.collect()
        logger.debug(f"Processed batch {i // batch_size + 1}, memory cleaned")

concat_dataframes_efficient

concat_dataframes_efficient(
    dfs: List[DataFrame],
) -> DataFrame

Efficiently concatenate multiple DataFrames.

Uses pandas concat with copy=False for O(n) performance instead of O(n^2) that would occur with iterative concatenation.

PARAMETER DESCRIPTION
dfs

List of DataFrames to concatenate.

TYPE: List[DataFrame]

RETURNS DESCRIPTION
DataFrame

Concatenated DataFrame with reset index.

Source code in src/data_io/duckdb_export.py
def concat_dataframes_efficient(dfs: List[pd.DataFrame]) -> pd.DataFrame:
    """Efficiently concatenate multiple DataFrames.

    Uses pandas concat with copy=False for O(n) performance instead of O(n^2)
    that would occur with iterative concatenation.

    Parameters
    ----------
    dfs : List[pd.DataFrame]
        List of DataFrames to concatenate.

    Returns
    -------
    pd.DataFrame
        Concatenated DataFrame with reset index.
    """
    if not dfs:
        return pd.DataFrame()
    return pd.concat(dfs, ignore_index=True, copy=False)

export_features_to_duckdb

export_features_to_duckdb(
    features_data: Dict[str, DataFrame],
    metadata: DataFrame,
    output_path: Union[str, Path],
    provenance: Optional[Dict[str, Dict]] = None,
    chunk_size: int = 1000,
) -> Path

Export hand-crafted features to DuckDB.

PARAMETER DESCRIPTION
features_data

Mapping from source_name to features DataFrame

TYPE: Dict[str, DataFrame]

metadata

Subject metadata (subject_id, eye, split, has_glaucoma)

TYPE: DataFrame

output_path

Output .db file path

TYPE: str or Path

provenance

Mapping from source_name to provenance info

TYPE: Dict[str, Dict] DEFAULT: None

chunk_size

Rows per insert batch

TYPE: int DEFAULT: 1000

RETURNS DESCRIPTION
Path

Path to created database

Source code in src/data_io/duckdb_export.py
def export_features_to_duckdb(
    features_data: Dict[str, pd.DataFrame],
    metadata: pd.DataFrame,
    output_path: Union[str, Path],
    provenance: Optional[Dict[str, Dict]] = None,
    chunk_size: int = 1000,
) -> Path:
    """
    Export hand-crafted features to DuckDB.

    Parameters
    ----------
    features_data : Dict[str, pd.DataFrame]
        Mapping from source_name to features DataFrame
    metadata : pd.DataFrame
        Subject metadata (subject_id, eye, split, has_glaucoma)
    output_path : str or Path
        Output .db file path
    provenance : Dict[str, Dict], optional
        Mapping from source_name to provenance info
    chunk_size : int, default 1000
        Rows per insert batch

    Returns
    -------
    Path
        Path to created database
    """
    output_path = Path(output_path)
    logger.info(f"Exporting features to {output_path}")

    # Remove existing file if present
    if output_path.exists():
        output_path.unlink()

    with duckdb.connect(str(output_path)) as con:
        # Create schema
        con.execute(FEATURES_SCHEMA)

        # Insert metadata
        if len(metadata) > 0:
            logger.info(f"Inserting {len(metadata)} metadata rows")
            # Register DataFrame and insert
            con.register("metadata_df", metadata)
            con.execute("INSERT INTO feature_metadata SELECT * FROM metadata_df")
            con.unregister("metadata_df")

        # Get expected column order from schema
        features_columns = [
            "subject_id",
            "eye",
            "source_name",
            "baseline_diameter",
            "constriction_amplitude",
            "constriction_amplitude_rel",
            "max_constriction_diameter",
            "latency_to_constriction",
            "latency_75pct",
            "time_to_redilation",
            "max_constriction_velocity",
            "mean_constriction_velocity",
            "max_redilation_velocity",
            "pipr_6s",
            "pipr_10s",
            "recovery_time",
            "constriction_duration",
        ]

        # Insert features
        total_rows = 0
        for source_name, df in features_data.items():
            if len(df) == 0:
                continue

            df = df.copy()
            df["source_name"] = source_name

            # Reorder columns to match schema, adding missing columns as NULL
            ordered_df = pd.DataFrame()
            for col in features_columns:
                if col in df.columns:
                    ordered_df[col] = df[col]
                else:
                    ordered_df[col] = None

            # Insert in chunks
            for i in range(0, len(ordered_df), chunk_size):
                chunk = ordered_df.iloc[i : i + chunk_size]
                con.register("chunk_df", chunk)
                con.execute("INSERT INTO plr_features SELECT * FROM chunk_df")
                con.unregister("chunk_df")

            total_rows += len(df)
            gc.collect()

            logger.debug(f"  Inserted {len(df)} rows for {source_name}")

        # Insert provenance
        if provenance:
            prov_df = pd.DataFrame(
                [{"source_name": k, **v} for k, v in provenance.items()]
            )
            con.register("prov_df", prov_df)
            con.execute("INSERT INTO feature_provenance SELECT * FROM prov_df")
            con.unregister("prov_df")

    file_size = output_path.stat().st_size / 1024 / 1024
    logger.info(
        f"Features export complete: {output_path} ({file_size:.1f} MB, {total_rows} rows)"
    )
    return output_path

export_results_to_duckdb

export_results_to_duckdb(
    predictions_df: DataFrame,
    metrics_per_fold: DataFrame,
    metrics_aggregate: DataFrame,
    output_path: Union[str, Path],
    calibration_curves: Optional[DataFrame] = None,
    dca_curves: Optional[DataFrame] = None,
    mlflow_runs: Optional[DataFrame] = None,
    chunk_size: int = 1000,
) -> Path

Export classifier results to DuckDB.

PARAMETER DESCRIPTION
predictions_df

All predictions with columns matching schema

TYPE: DataFrame

metrics_per_fold

Metrics per fold

TYPE: DataFrame

metrics_aggregate

Aggregated metrics (mean, CI)

TYPE: DataFrame

output_path

Output .db file path

TYPE: str or Path

calibration_curves

Calibration curve data

TYPE: DataFrame DEFAULT: None

dca_curves

Decision curve analysis data

TYPE: DataFrame DEFAULT: None

mlflow_runs

MLflow run metadata

TYPE: DataFrame DEFAULT: None

chunk_size

Rows per insert batch

TYPE: int DEFAULT: 1000

RETURNS DESCRIPTION
Path

Path to created database

Source code in src/data_io/duckdb_export.py
def export_results_to_duckdb(
    predictions_df: pd.DataFrame,
    metrics_per_fold: pd.DataFrame,
    metrics_aggregate: pd.DataFrame,
    output_path: Union[str, Path],
    calibration_curves: Optional[pd.DataFrame] = None,
    dca_curves: Optional[pd.DataFrame] = None,
    mlflow_runs: Optional[pd.DataFrame] = None,
    chunk_size: int = 1000,
) -> Path:
    """
    Export classifier results to DuckDB.

    Parameters
    ----------
    predictions_df : pd.DataFrame
        All predictions with columns matching schema
    metrics_per_fold : pd.DataFrame
        Metrics per fold
    metrics_aggregate : pd.DataFrame
        Aggregated metrics (mean, CI)
    output_path : str or Path
        Output .db file path
    calibration_curves : pd.DataFrame, optional
        Calibration curve data
    dca_curves : pd.DataFrame, optional
        Decision curve analysis data
    mlflow_runs : pd.DataFrame, optional
        MLflow run metadata
    chunk_size : int, default 1000
        Rows per insert batch

    Returns
    -------
    Path
        Path to created database
    """
    output_path = Path(output_path)
    logger.info(f"Exporting results to {output_path}")

    if output_path.exists():
        output_path.unlink()

    with duckdb.connect(str(output_path)) as con:
        con.execute(RESULTS_SCHEMA)

        # Insert predictions in chunks
        if predictions_df is not None and len(predictions_df) > 0:
            logger.info(f"Inserting {len(predictions_df)} predictions")
            for i in range(0, len(predictions_df), chunk_size):
                chunk = predictions_df.iloc[i : i + chunk_size]
                con.register("chunk_df", chunk)
                con.execute("INSERT INTO predictions SELECT * FROM chunk_df")
                con.unregister("chunk_df")
            gc.collect()

        # Insert metrics
        if metrics_per_fold is not None and len(metrics_per_fold) > 0:
            logger.info(f"Inserting {len(metrics_per_fold)} fold metrics")
            con.register("fold_df", metrics_per_fold)
            con.execute("INSERT INTO metrics_per_fold SELECT * FROM fold_df")
            con.unregister("fold_df")

        if metrics_aggregate is not None and len(metrics_aggregate) > 0:
            logger.info(f"Inserting {len(metrics_aggregate)} aggregate metrics")
            con.register("agg_df", metrics_aggregate)
            con.execute("INSERT INTO metrics_aggregate SELECT * FROM agg_df")
            con.unregister("agg_df")

        # Insert optional tables
        if calibration_curves is not None and len(calibration_curves) > 0:
            con.register("cal_df", calibration_curves)
            con.execute("INSERT INTO calibration_curves SELECT * FROM cal_df")
            con.unregister("cal_df")

        if dca_curves is not None and len(dca_curves) > 0:
            con.register("dca_df", dca_curves)
            con.execute("INSERT INTO dca_curves SELECT * FROM dca_df")
            con.unregister("dca_df")

        if mlflow_runs is not None and len(mlflow_runs) > 0:
            con.register("runs_df", mlflow_runs)
            con.execute("INSERT INTO mlflow_runs SELECT * FROM runs_df")
            con.unregister("runs_df")

    file_size = output_path.stat().st_size / 1024 / 1024
    logger.info(f"Results export complete: {output_path} ({file_size:.1f} MB)")
    return output_path

load_features_from_duckdb

load_features_from_duckdb(
    db_path: Union[str, Path],
    source_name: Optional[str] = None,
    split: Optional[str] = None,
) -> Tuple[ndarray, ndarray, List[str]]

Load features from DuckDB for classification.

PARAMETER DESCRIPTION
db_path

Path to features.db

TYPE: str or Path

source_name

Filter for specific pipeline configuration

TYPE: str DEFAULT: None

split

Filter for 'train', 'val', or 'test'

TYPE: str DEFAULT: None

RETURNS DESCRIPTION
X

Feature matrix (n_samples, n_features)

TYPE: ndarray

y

Labels (n_samples,)

TYPE: ndarray

feature_names

Feature column names

TYPE: List[str]

Source code in src/data_io/duckdb_export.py
def load_features_from_duckdb(
    db_path: Union[str, Path],
    source_name: Optional[str] = None,
    split: Optional[str] = None,
) -> Tuple[np.ndarray, np.ndarray, List[str]]:
    """
    Load features from DuckDB for classification.

    Parameters
    ----------
    db_path : str or Path
        Path to features.db
    source_name : str, optional
        Filter for specific pipeline configuration
    split : str, optional
        Filter for 'train', 'val', or 'test'

    Returns
    -------
    X : np.ndarray
        Feature matrix (n_samples, n_features)
    y : np.ndarray
        Labels (n_samples,)
    feature_names : List[str]
        Feature column names
    """
    db_path = Path(db_path)
    logger.info(f"Loading features from {db_path}")

    with duckdb.connect(str(db_path), read_only=True) as con:
        query = """
            SELECT f.*, m.has_glaucoma, m.split
            FROM plr_features f
            JOIN feature_metadata m
                ON f.subject_id = m.subject_id AND f.eye = m.eye
            WHERE 1=1
        """

        if source_name:
            query += f" AND f.source_name = '{source_name}'"
        if split:
            query += f" AND m.split = '{split}'"

        df = con.execute(query).df()

    # Feature columns (exclude metadata)
    metadata_cols = ["subject_id", "eye", "source_name", "has_glaucoma", "split"]
    feature_cols = [c for c in df.columns if c not in metadata_cols]

    X = df[feature_cols].values
    y = df["has_glaucoma"].values.astype(int)

    logger.info(f"Loaded {X.shape[0]} samples, {X.shape[1]} features")
    return X, y, feature_cols

load_results_from_duckdb

load_results_from_duckdb(
    db_path: Union[str, Path],
    table: str = "metrics_aggregate",
) -> DataFrame

Load results from DuckDB.

PARAMETER DESCRIPTION
db_path

Path to results.db

TYPE: str or Path

table

Table to load: "predictions", "metrics_per_fold", "metrics_aggregate", "calibration_curves", "dca_curves", "mlflow_runs"

TYPE: str DEFAULT: "metrics_aggregate"

RETURNS DESCRIPTION
DataFrame

Requested data

Source code in src/data_io/duckdb_export.py
def load_results_from_duckdb(
    db_path: Union[str, Path],
    table: str = "metrics_aggregate",
) -> pd.DataFrame:
    """
    Load results from DuckDB.

    Parameters
    ----------
    db_path : str or Path
        Path to results.db
    table : str, default "metrics_aggregate"
        Table to load: "predictions", "metrics_per_fold", "metrics_aggregate",
        "calibration_curves", "dca_curves", "mlflow_runs"

    Returns
    -------
    pd.DataFrame
        Requested data
    """
    db_path = Path(db_path)

    with duckdb.connect(str(db_path), read_only=True) as con:
        df = con.execute(f"SELECT * FROM {table}").df()

    logger.info(f"Loaded {len(df)} rows from {table}")
    return df

extract_mlflow_classification_runs

extract_mlflow_classification_runs(
    mlruns_dir: Union[str, Path],
    experiment_id: Optional[str] = None,
    batch_size: int = 10,
) -> Tuple[
    DataFrame, DataFrame, DataFrame, DataFrame, DataFrame
]

Extract classification results from MLflow runs.

Extracts all available metrics and computes STRATOS-required metrics: - Calibration slope, intercept, O:E ratio - Net benefit at 5%, 10%, 20% thresholds - Full DCA curves (50 threshold points from 1% to 50%)

PARAMETER DESCRIPTION
mlruns_dir

Path to mlruns directory

TYPE: str or Path

experiment_id

Specific experiment ID (defaults to classification experiment)

TYPE: str DEFAULT: None

batch_size

Number of runs to process per batch

TYPE: int DEFAULT: 10

RETURNS DESCRIPTION
Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]

(predictions_df, metrics_per_fold_df, metrics_aggregate_df, dca_curves_df, mlflow_runs_df)

Source code in src/data_io/duckdb_export.py
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
def extract_mlflow_classification_runs(
    mlruns_dir: Union[str, Path],
    experiment_id: Optional[str] = None,
    batch_size: int = 10,
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """
    Extract classification results from MLflow runs.

    Extracts all available metrics and computes STRATOS-required metrics:
    - Calibration slope, intercept, O:E ratio
    - Net benefit at 5%, 10%, 20% thresholds
    - Full DCA curves (50 threshold points from 1% to 50%)

    Parameters
    ----------
    mlruns_dir : str or Path
        Path to mlruns directory
    experiment_id : str, optional
        Specific experiment ID (defaults to classification experiment)
    batch_size : int
        Number of runs to process per batch

    Returns
    -------
    Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]
        (predictions_df, metrics_per_fold_df, metrics_aggregate_df, dca_curves_df, mlflow_runs_df)
    """
    import yaml

    mlruns_dir = Path(mlruns_dir)

    # Find classification experiment if not specified
    if experiment_id is None:
        # Look for experiment with most runs (likely classification)
        experiment_id = _find_classification_experiment(mlruns_dir)

    exp_dir = mlruns_dir / experiment_id
    if not exp_dir.exists():
        raise ValueError(f"Experiment directory not found: {exp_dir}")

    logger.info(f"Extracting from experiment {experiment_id}")

    # Collect all run directories
    run_dirs = [
        d for d in exp_dir.iterdir() if d.is_dir() and (d / "meta.yaml").exists()
    ]
    logger.info(f"Found {len(run_dirs)} runs")

    all_predictions = []
    all_metrics = []
    all_aggregate = []
    all_dca_curves = []
    mlflow_runs = []

    prediction_id = 0
    metric_id = 0
    dca_id = 0

    # Process in batches
    for batch_start in range(0, len(run_dirs), batch_size):
        batch_dirs = run_dirs[batch_start : batch_start + batch_size]

        for run_dir in batch_dirs:
            try:
                # Load run metadata
                with open(run_dir / "meta.yaml") as f:
                    meta = yaml.safe_load(f)

                run_id = meta.get("run_id", "")
                run_name = meta.get("run_name", "")

                # Skip non-classification runs (case-insensitive matching)
                run_name_upper = run_name.upper()
                if not any(
                    clf.upper() in run_name_upper
                    for clf in [
                        "XGBOOST",
                        "LogisticRegression",
                        "CatBoost",
                        "TabM",
                        "TabPFN",
                    ]
                ):
                    continue

                # Parse run configuration from name
                config = _parse_run_name(run_name)

                # Store MLflow run info
                mlflow_runs.append(
                    {
                        "run_id": run_id,
                        "experiment_name": f"exp_{experiment_id}",
                        "run_name": run_name,
                        "status": meta.get("status", ""),
                        "start_time": meta.get("start_time", ""),
                        "end_time": meta.get("end_time", ""),
                        "params_json": json.dumps(config),
                        "metrics_json": "{}",
                        "tags_json": "{}",
                    }
                )

                # Load metrics pickle
                metrics_path = _find_artifact(run_dir, "metrics", "*.pickle")
                if metrics_path:
                    with load_artifact_safe(metrics_path) as metrics_data:
                        # Extract aggregate metrics
                        if "metrics_stats" in metrics_data:
                            for split_name, split_data in metrics_data[
                                "metrics_stats"
                            ].items():
                                if (
                                    "metrics" not in split_data
                                    or "scalars" not in split_data["metrics"]
                                ):
                                    continue

                                scalars = split_data["metrics"]["scalars"]
                                for metric_name, metric_vals in scalars.items():
                                    if (
                                        isinstance(metric_vals, dict)
                                        and "mean" in metric_vals
                                    ):
                                        all_aggregate.append(
                                            {
                                                "aggregate_id": len(all_aggregate),
                                                "source_name": config.get(
                                                    "source_name", run_name
                                                ),
                                                "classifier": config.get(
                                                    "classifier", "unknown"
                                                ),
                                                "metric_name": f"{split_name}/{metric_name}",
                                                "mean": metric_vals.get("mean", np.nan),
                                                "std": metric_vals.get("std", np.nan),
                                                "ci_lower": metric_vals.get(
                                                    "ci", [np.nan, np.nan]
                                                )[0]
                                                if isinstance(
                                                    metric_vals.get("ci"),
                                                    (list, np.ndarray),
                                                )
                                                else np.nan,
                                                "ci_upper": metric_vals.get(
                                                    "ci", [np.nan, np.nan]
                                                )[1]
                                                if isinstance(
                                                    metric_vals.get("ci"),
                                                    (list, np.ndarray),
                                                )
                                                else np.nan,
                                                "median": np.nan,
                                                "q25": np.nan,
                                                "q75": np.nan,
                                                "n_observations": metric_vals.get(
                                                    "n", 0
                                                ),
                                            }
                                        )

                        # Extract per-iteration metrics (all available, not just AUROC)
                        if "metrics_iter" in metrics_data:
                            for split_name, split_data in metrics_data[
                                "metrics_iter"
                            ].items():
                                if (
                                    "metrics" not in split_data
                                    or "scalars" not in split_data["metrics"]
                                ):
                                    continue

                                scalars = split_data["metrics"]["scalars"]

                                # Map MLflow metric names to schema column names
                                METRIC_MAPPING = {
                                    "AUROC": "auroc",
                                    "AUPR": "aupr",
                                    "Brier": "brier_score",
                                    "sensitivity": "sensitivity",
                                    "specificity": "specificity",
                                    "PPV": "ppv",
                                    "NPV": "npv",
                                    "F1": "f1_score",
                                    "accuracy": "accuracy",
                                }

                                # Determine number of folds from AUROC (or any available metric)
                                n_folds = 0
                                for metric_name in scalars:
                                    vals = scalars[metric_name]
                                    if isinstance(vals, (list, np.ndarray)):
                                        n_folds = min(len(vals), 10)  # Cap at 10
                                        break

                                # Extract all metrics per fold
                                for fold in range(n_folds):
                                    fold_metrics = {
                                        "metric_id": metric_id,
                                        "source_name": config.get(
                                            "source_name", run_name
                                        ),
                                        "classifier": config.get(
                                            "classifier", "unknown"
                                        ),
                                        "fold": fold,
                                        # Initialize all columns
                                        "auroc": None,
                                        "aupr": None,
                                        "brier_score": None,
                                        "calibration_slope": None,
                                        "calibration_intercept": None,
                                        "e_o_ratio": None,
                                        "sensitivity": None,
                                        "specificity": None,
                                        "ppv": None,
                                        "npv": None,
                                        "f1_score": None,
                                        "accuracy": None,
                                        "net_benefit_5pct": None,
                                        "net_benefit_10pct": None,
                                        "net_benefit_20pct": None,
                                    }

                                    # Extract available metrics
                                    for (
                                        mlflow_name,
                                        schema_col,
                                    ) in METRIC_MAPPING.items():
                                        if mlflow_name in scalars:
                                            vals = scalars[mlflow_name]
                                            if isinstance(
                                                vals, (list, np.ndarray)
                                            ) and fold < len(vals):
                                                val = vals[fold]
                                                fold_metrics[schema_col] = (
                                                    float(val)
                                                    if np.isfinite(val)
                                                    else None
                                                )

                                    all_metrics.append(fold_metrics)
                                    metric_id += 1

                # Load dict_arrays pickle for predictions
                arrays_path = _find_artifact(run_dir, "dict_arrays", "*.pickle")
                if arrays_path:
                    with load_artifact_safe(arrays_path) as arrays_data:
                        # Extract subject codes if available
                        subject_codes_test = arrays_data.get("subject_codes_test", [])
                        y_test = arrays_data.get("y_test", np.array([]))

                        # Note: Full predictions require subjectwise_stats from metrics
                        # For now, store aggregated info
                        if metrics_path and "subjectwise_stats" in metrics_data:
                            subj_stats = metrics_data.get("subjectwise_stats", {})
                            if "test" in subj_stats and "preds" in subj_stats["test"]:
                                preds = subj_stats["test"]["preds"]
                                y_pred_proba = preds.get("y_pred_proba", {})
                                y_pred_mean = y_pred_proba.get("mean", np.array([]))

                                for i, (code, y_true_val) in enumerate(
                                    zip(subject_codes_test, y_test)
                                ):
                                    prob = (
                                        y_pred_mean[i] if i < len(y_pred_mean) else 0.5
                                    )
                                    all_predictions.append(
                                        {
                                            "prediction_id": prediction_id,
                                            "subject_id": str(code),
                                            "eye": "OD",  # Default
                                            "fold": 0,  # Bootstrap aggregated
                                            "bootstrap_iter": 0,
                                            "outlier_method": config.get(
                                                "outlier_method", ""
                                            ),
                                            "imputation_method": config.get(
                                                "imputation_method", ""
                                            ),
                                            "featurization": config.get(
                                                "featurization", ""
                                            ),
                                            "classifier": config.get("classifier", ""),
                                            "source_name": config.get(
                                                "source_name", run_name
                                            ),
                                            "y_true": int(y_true_val),
                                            "y_pred": int(prob >= 0.5),
                                            "y_prob": float(prob),
                                            "mlflow_run_id": run_id,
                                        }
                                    )
                                    prediction_id += 1

                                # Compute STRATOS metrics from predictions
                                y_true_arr = np.array(y_test)
                                y_prob_arr = (
                                    np.array(y_pred_mean)
                                    if len(y_pred_mean) > 0
                                    else np.array([])
                                )

                                if len(y_true_arr) > 10 and len(y_prob_arr) == len(
                                    y_true_arr
                                ):
                                    stratos_metrics = _compute_stratos_metrics(
                                        y_true_arr, y_prob_arr
                                    )

                                    # Update the last fold metric with STRATOS values
                                    # (Store as aggregate since we have one y_true/y_prob set)
                                    if all_metrics:
                                        # Find metrics for this run and update them
                                        source = config.get("source_name", run_name)
                                        for m in all_metrics:
                                            if m["source_name"] == source:
                                                m["calibration_slope"] = (
                                                    stratos_metrics.get(
                                                        "calibration_slope"
                                                    )
                                                )
                                                m["calibration_intercept"] = (
                                                    stratos_metrics.get(
                                                        "calibration_intercept"
                                                    )
                                                )
                                                m["e_o_ratio"] = stratos_metrics.get(
                                                    "e_o_ratio"
                                                )
                                                m["net_benefit_5pct"] = (
                                                    stratos_metrics.get(
                                                        "net_benefit_5pct"
                                                    )
                                                )
                                                m["net_benefit_10pct"] = (
                                                    stratos_metrics.get(
                                                        "net_benefit_10pct"
                                                    )
                                                )
                                                m["net_benefit_20pct"] = (
                                                    stratos_metrics.get(
                                                        "net_benefit_20pct"
                                                    )
                                                )

                                    # Compute DCA curves
                                    try:
                                        from ..stats.clinical_utility import (
                                            decision_curve_analysis,
                                        )

                                        dca_df = decision_curve_analysis(
                                            y_true_arr,
                                            y_prob_arr,
                                            threshold_range=(0.01, 0.50),
                                            n_thresholds=50,
                                        )

                                        source = config.get("source_name", run_name)
                                        classifier = config.get("classifier", "unknown")

                                        for _, row in dca_df.iterrows():
                                            all_dca_curves.append(
                                                {
                                                    "dca_id": dca_id,
                                                    "source_name": source,
                                                    "classifier": classifier,
                                                    "threshold": float(
                                                        row["threshold"]
                                                    ),
                                                    "net_benefit_model": float(
                                                        row["nb_model"]
                                                    ),
                                                    "net_benefit_all": float(
                                                        row["nb_all"]
                                                    ),
                                                    "net_benefit_none": float(
                                                        row["nb_none"]
                                                    ),
                                                    "sensitivity": float(
                                                        row["sensitivity"]
                                                    ),
                                                    "specificity": float(
                                                        row["specificity"]
                                                    ),
                                                }
                                            )
                                            dca_id += 1

                                    except Exception as e:
                                        logger.warning(
                                            f"Failed to compute DCA curves: {e}"
                                        )

            except Exception as e:
                logger.warning(f"Error processing run {run_dir.name}: {e}")
                continue

        # Cleanup after each batch
        gc.collect()
        logger.info(
            f"Processed {min(batch_start + batch_size, len(run_dirs))}/{len(run_dirs)} runs"
        )

    predictions_df = (
        pd.DataFrame(all_predictions) if all_predictions else pd.DataFrame()
    )
    metrics_per_fold_df = pd.DataFrame(all_metrics) if all_metrics else pd.DataFrame()
    metrics_aggregate_df = (
        pd.DataFrame(all_aggregate) if all_aggregate else pd.DataFrame()
    )
    dca_curves_df = pd.DataFrame(all_dca_curves) if all_dca_curves else pd.DataFrame()
    mlflow_runs_df = pd.DataFrame(mlflow_runs) if mlflow_runs else pd.DataFrame()

    logger.info(
        f"Extracted {len(predictions_df)} predictions, "
        f"{len(metrics_per_fold_df)} fold metrics, "
        f"{len(metrics_aggregate_df)} aggregate metrics, "
        f"{len(dca_curves_df)} DCA curve points"
    )

    return (
        predictions_df,
        metrics_per_fold_df,
        metrics_aggregate_df,
        dca_curves_df,
        mlflow_runs_df,
    )

export_mlflow_to_duckdb

export_mlflow_to_duckdb(
    mlruns_dir: Union[str, Path],
    output_path: Union[str, Path],
    experiment_id: Optional[str] = None,
) -> Path

Export MLflow classification results to DuckDB.

PARAMETER DESCRIPTION
mlruns_dir

Path to mlruns directory

TYPE: str or Path

output_path

Output .db file path

TYPE: str or Path

experiment_id

Specific experiment ID

TYPE: str DEFAULT: None

RETURNS DESCRIPTION
Path

Path to created database

Source code in src/data_io/duckdb_export.py
def export_mlflow_to_duckdb(
    mlruns_dir: Union[str, Path],
    output_path: Union[str, Path],
    experiment_id: Optional[str] = None,
) -> Path:
    """
    Export MLflow classification results to DuckDB.

    Parameters
    ----------
    mlruns_dir : str or Path
        Path to mlruns directory
    output_path : str or Path
        Output .db file path
    experiment_id : str, optional
        Specific experiment ID

    Returns
    -------
    Path
        Path to created database
    """
    (
        predictions_df,
        metrics_per_fold_df,
        metrics_aggregate_df,
        dca_curves_df,
        mlflow_runs_df,
    ) = extract_mlflow_classification_runs(mlruns_dir, experiment_id)

    return export_results_to_duckdb(
        predictions_df=predictions_df,
        metrics_per_fold=metrics_per_fold_df,
        metrics_aggregate=metrics_aggregate_df,
        output_path=output_path,
        dca_curves=dca_curves_df,
        mlflow_runs=mlflow_runs_df,
    )

main

main()

Command-line interface for DuckDB export/analysis.

Supports two main commands: - export: Export MLflow runs to a DuckDB database - analyze: Run analysis from existing features or results databases

RETURNS DESCRIPTION
None

Executes CLI commands and writes output to files.

Source code in src/data_io/duckdb_export.py
def main():
    """Command-line interface for DuckDB export/analysis.

    Supports two main commands:
    - export: Export MLflow runs to a DuckDB database
    - analyze: Run analysis from existing features or results databases

    Returns
    -------
    None
        Executes CLI commands and writes output to files.
    """
    import argparse

    parser = argparse.ArgumentParser(
        description="Export features/results to DuckDB or run analysis from artifacts"
    )
    subparsers = parser.add_subparsers(dest="command", help="Command to run")

    # Export command
    export_parser = subparsers.add_parser("export", help="Export from mlruns")
    export_parser.add_argument(
        "--mlruns", required=True, help="Path to mlruns directory"
    )
    export_parser.add_argument("--experiment-id", help="Specific experiment ID")
    export_parser.add_argument(
        "--output", default="foundation_plr_results.db", help="Output database file"
    )

    # Analyze command
    analyze_parser = subparsers.add_parser(
        "analyze", help="Run analysis from artifacts"
    )
    analyze_parser.add_argument("--from-features", help="Path to features.db")
    analyze_parser.add_argument("--from-results", help="Path to results.db")
    analyze_parser.add_argument("--output-dir", default="outputs/analysis")

    args = parser.parse_args()

    if args.command == "export":
        logger.info(f"Export from {args.mlruns}")
        output_path = export_mlflow_to_duckdb(
            args.mlruns,
            args.output,
            args.experiment_id,
        )
        logger.info(f"Exported to {output_path}")

    elif args.command == "analyze":
        if args.from_features:
            pipeline = DuckDBAnalysisPipeline.from_features(
                args.from_features, output_dir=args.output_dir
            )
            pipeline.run_classification()
            pipeline.run_statistics()

        elif args.from_results:
            pipeline = DuckDBAnalysisPipeline.from_results(
                args.from_results, output_dir=args.output_dir
            )
            pipeline.run_statistics()

        else:
            logger.error("Specify --from-features or --from-results")

    else:
        parser.print_help()

Stratification

stratification_utils

add_split_col_to_dataframe

add_split_col_to_dataframe(
    df_raw: DataFrame, split_codes: dict
)

Add a 'split' column to dataframe based on subject code assignments.

PARAMETER DESCRIPTION
df_raw

Raw data dataframe.

TYPE: DataFrame

split_codes

Dictionary mapping split names to lists of subject codes.

TYPE: dict

RETURNS DESCRIPTION
DataFrame

Dataframe with added 'split' column.

Source code in src/data_io/stratification_utils.py
def add_split_col_to_dataframe(df_raw: pl.DataFrame, split_codes: dict):
    """Add a 'split' column to dataframe based on subject code assignments.

    Parameters
    ----------
    df_raw : pl.DataFrame
        Raw data dataframe.
    split_codes : dict
        Dictionary mapping split names to lists of subject codes.

    Returns
    -------
    pl.DataFrame
        Dataframe with added 'split' column.
    """
    if isinstance(df_raw, pd.DataFrame):
        df_raw = pl.DataFrame(df_raw)  # if a Pandas, convert to Polars

    # Add the split column to the raw data
    df_raw = df_raw.with_columns(pl.lit(None).alias("split"))

    for split, codes in split_codes.items():
        for i, code in enumerate(codes):
            df_raw = df_raw.with_columns(
                pl.when(pl.col("subject_code") == code)
                .then(pl.lit(split))
                .otherwise("split")
                .alias("split")
            )

    return df_raw

multicol_stratification

multicol_stratification(
    df_tmp: DataFrame,
    test_size: float,
    stratify_columns: list,
    cfg: DictConfig,
    col_to_return: str = "subject_code",
) -> dict

Perform multi-column iterative stratification for train/test split.

Custom iterative train test split which 'maintains balanced representation with respect to order-th label combinations.'

PARAMETER DESCRIPTION
df_tmp

Temporary dataframe with columns to stratify on.

TYPE: DataFrame

test_size

Proportion of data to use for test set (0-1).

TYPE: float

stratify_columns

List of column names to use for stratification.

TYPE: list

cfg

Configuration dictionary.

TYPE: DictConfig

col_to_return

Column name to return for each split, by default "subject_code".

TYPE: str DEFAULT: 'subject_code'

RETURNS DESCRIPTION
dict

Dictionary with 'train' and 'test' keys mapping to lists of values.

References
  • https://www.abzu.ai/data-science/stratified-data-splitting-part-2/
  • https://madewithml.com/courses/mlops/splitting/#stratified-split
Source code in src/data_io/stratification_utils.py
def multicol_stratification(
    df_tmp: pd.DataFrame,
    test_size: float,
    stratify_columns: list,
    cfg: DictConfig,
    col_to_return: str = "subject_code",
) -> dict:
    """Perform multi-column iterative stratification for train/test split.

    Custom iterative train test split which 'maintains balanced representation
    with respect to order-th label combinations.'

    Parameters
    ----------
    df_tmp : pd.DataFrame
        Temporary dataframe with columns to stratify on.
    test_size : float
        Proportion of data to use for test set (0-1).
    stratify_columns : list
        List of column names to use for stratification.
    cfg : DictConfig
        Configuration dictionary.
    col_to_return : str, optional
        Column name to return for each split, by default "subject_code".

    Returns
    -------
    dict
        Dictionary with 'train' and 'test' keys mapping to lists of values.

    References
    ----------
    - https://www.abzu.ai/data-science/stratified-data-splitting-part-2/
    - https://madewithml.com/courses/mlops/splitting/#stratified-split
    """
    # One-hot encode the stratify columns and concatenate them
    if isinstance(df_tmp, pl.DataFrame):
        df_tmp = df_tmp.to_pandas()
    one_hot_cols = [pd.get_dummies(df_tmp[col]) for col in stratify_columns]
    one_hot_cols = pd.concat(one_hot_cols, axis=1).to_numpy()
    stratifier = IterativeStratification(
        n_splits=2,
        order=len(stratify_columns),
        sample_distribution_per_fold=[test_size, 1 - test_size],
    )
    train_indices, test_indices = next(
        stratifier.split(df_tmp.to_numpy(), one_hot_cols)
    )
    # Return the train and test set dataframes
    train, test = (
        df_tmp.iloc[train_indices][col_to_return],
        df_tmp.iloc[test_indices][col_to_return],
    )
    return {"train": list(train), "test": list(test)}

create_tmp_stratification_df

create_tmp_stratification_df(
    df_raw: DataFrame, stratify_columns: ListConfig
)

Create a temporary dataframe for stratification with binned features.

PARAMETER DESCRIPTION
df_raw

Raw data dataframe.

TYPE: DataFrame

stratify_columns

List of columns to use for stratification.

TYPE: ListConfig

RETURNS DESCRIPTION
DataFrame

Temporary pandas dataframe with subject_code, no_outliers_bins, and class_label columns.

Source code in src/data_io/stratification_utils.py
def create_tmp_stratification_df(df_raw: pl.DataFrame, stratify_columns: ListConfig):
    """Create a temporary dataframe for stratification with binned features.

    Parameters
    ----------
    df_raw : pl.DataFrame
        Raw data dataframe.
    stratify_columns : ListConfig
        List of columns to use for stratification.

    Returns
    -------
    pd.DataFrame
        Temporary pandas dataframe with subject_code, no_outliers_bins,
        and class_label columns.
    """
    # Stratify the data based on both the class_label and the missingness_ratio
    codes = get_unique_polars_rows(df_raw, "subject_code")
    subject_codes = codes["subject_code"].to_numpy()
    no_outliers = codes["no_outliers"].to_numpy()
    no_outliers_bins = bin_outlier_counts(
        no_outliers
    )  # Bin the missingness (or no of outliers) into n bins
    class_labels = codes["class_label"].to_numpy()
    df_tmp = pd.DataFrame(
        {
            "subject_code": subject_codes,
            "no_outliers_bins": no_outliers_bins,
            "class_label": class_labels,
        }
    )

    return df_tmp

bin_outlier_counts

bin_outlier_counts(outliers: list, no_of_bins: int = 5)

Bin outlier counts into quantile-based categories.

PARAMETER DESCRIPTION
outliers

List of outlier counts per subject.

TYPE: list

no_of_bins

Number of bins to create, by default 5.

TYPE: int DEFAULT: 5

RETURNS DESCRIPTION
ndarray

Array of bin labels for each subject.

Source code in src/data_io/stratification_utils.py
def bin_outlier_counts(outliers: list, no_of_bins: int = 5):
    """Bin outlier counts into quantile-based categories.

    Parameters
    ----------
    outliers : list
        List of outlier counts per subject.
    no_of_bins : int, optional
        Number of bins to create, by default 5.

    Returns
    -------
    np.ndarray
        Array of bin labels for each subject.
    """
    labels = np.linspace(0, no_of_bins - 1, no_of_bins)
    bins = pd.qcut(outliers, no_of_bins, labels=labels)
    return bins.to_numpy()

create_splits_to_df

create_splits_to_df(df_raw, cfg: DictConfig)

Create stratified splits and add split column to dataframe.

PARAMETER DESCRIPTION
df_raw

Raw data dataframe.

TYPE: DataFrame

cfg

Configuration with STRATIFICATION settings.

TYPE: DictConfig

RETURNS DESCRIPTION
DataFrame

Dataframe with added 'split' column.

RAISES DESCRIPTION
ValueError

If any data points have missing split assignments.

Source code in src/data_io/stratification_utils.py
def create_splits_to_df(df_raw, cfg: DictConfig):
    """Create stratified splits and add split column to dataframe.

    Parameters
    ----------
    df_raw : pl.DataFrame
        Raw data dataframe.
    cfg : DictConfig
        Configuration with STRATIFICATION settings.

    Returns
    -------
    pl.DataFrame
        Dataframe with added 'split' column.

    Raises
    ------
    ValueError
        If any data points have missing split assignments.
    """
    # Create three-column df for two-column stratification to get bac the subject codes per split
    df_tmp = create_tmp_stratification_df(
        df_raw, stratify_columns=cfg["STRATIFICATION"]["test_size"]
    )
    # Get subject codes belonging to the training and validation sets
    split_codes = multicol_stratification(
        df_tmp=df_tmp,
        test_size=cfg["STRATIFICATION"]["test_size"],
        stratify_columns=list(cfg["STRATIFICATION"]["stratify_columns"]),
        cfg=cfg,
    )

    # Add the split column to the raw data
    df_raw = add_split_col_to_dataframe(df_raw, split_codes)

    # Check that all the data has a split
    if df_raw["split"].null_count() > 0:
        logger.error(f"Data has {df_raw['split'].null_count()} missing splits")
        raise ValueError(f"Data has {df_raw['split'].null_count()} missing splits")

    return df_raw

stratify_splits

stratify_splits(df_raw: DataFrame, cfg: DictConfig)

Main function to stratify data into train and test splits.

Performs multi-column stratification based on class labels and outlier counts, ensuring balanced representation in both splits.

PARAMETER DESCRIPTION
df_raw

Raw data dataframe with all subjects.

TYPE: DataFrame

cfg

Configuration with STRATIFICATION settings.

TYPE: DictConfig

RETURNS DESCRIPTION
tuple

Tuple containing (df_train, df_test) as Polars dataframes.

Source code in src/data_io/stratification_utils.py
def stratify_splits(
    df_raw: pl.DataFrame,
    cfg: DictConfig,
):
    """Main function to stratify data into train and test splits.

    Performs multi-column stratification based on class labels and outlier
    counts, ensuring balanced representation in both splits.

    Parameters
    ----------
    df_raw : pl.DataFrame
        Raw data dataframe with all subjects.
    cfg : DictConfig
        Configuration with STRATIFICATION settings.

    Returns
    -------
    tuple
        Tuple containing (df_train, df_test) as Polars dataframes.
    """
    logger.info("Create splits (train/test)")
    df = create_splits_to_df(df_raw, cfg)

    # Split the data into training and validation Polars dataframes
    df_train = df.filter(pl.col("split") == "train")
    df_test = df.filter(pl.col("split") == "test")

    # Check against data leakage
    check_data_import(df_train, df_test, display_outliers=False)

    return df_train, df_test

Source Definitions

define_sources_for_flow

get_best_mlflow_col_for_imputation

get_best_mlflow_col_for_imputation(
    cfg: DictConfig, string: str = "MAE"
) -> str

Get the MLflow column name for the best imputation metric.

PARAMETER DESCRIPTION
cfg

Configuration dictionary containing IMPUTATION_METRICS settings.

TYPE: DictConfig

string

Metric identifier (e.g., "MAE"), by default "MAE".

TYPE: str DEFAULT: 'MAE'

RETURNS DESCRIPTION
str

MLflow column name for the specified metric.

Source code in src/data_io/define_sources_for_flow.py
def get_best_mlflow_col_for_imputation(cfg: DictConfig, string: str = "MAE") -> str:
    """Get the MLflow column name for the best imputation metric.

    Parameters
    ----------
    cfg : DictConfig
        Configuration dictionary containing IMPUTATION_METRICS settings.
    string : str, optional
        Metric identifier (e.g., "MAE"), by default "MAE".

    Returns
    -------
    str
        MLflow column name for the specified metric.
    """
    best_metric: dict = cfg["IMPUTATION_METRICS"]["best_metric"]
    return best_metric[string]

get_best_string_for_imputation

get_best_string_for_imputation(
    cfg: DictConfig, split: str = "test"
) -> dict

Get the best metric configuration dictionary for imputation.

PARAMETER DESCRIPTION
cfg

Configuration dictionary containing IMPUTATION_METRICS settings.

TYPE: DictConfig

split

Data split name, by default "test".

TYPE: str DEFAULT: 'test'

RETURNS DESCRIPTION
dict

Best metric configuration dictionary.

Source code in src/data_io/define_sources_for_flow.py
def get_best_string_for_imputation(cfg: DictConfig, split: str = "test") -> dict:
    """Get the best metric configuration dictionary for imputation.

    Parameters
    ----------
    cfg : DictConfig
        Configuration dictionary containing IMPUTATION_METRICS settings.
    split : str, optional
        Data split name, by default "test".

    Returns
    -------
    dict
        Best metric configuration dictionary.
    """
    best_metric = cfg["IMPUTATION_METRICS"]["best_metric"]
    return best_metric

get_best_string_for_outlier_detection

get_best_string_for_outlier_detection(
    cfg: DictConfig,
) -> dict

Get the best metric configuration for outlier detection.

PARAMETER DESCRIPTION
cfg

Configuration dictionary containing OUTLIER_DETECTION settings.

TYPE: DictConfig

RETURNS DESCRIPTION
dict

Best metric configuration for outlier detection.

Source code in src/data_io/define_sources_for_flow.py
def get_best_string_for_outlier_detection(cfg: DictConfig) -> dict:
    """Get the best metric configuration for outlier detection.

    Parameters
    ----------
    cfg : DictConfig
        Configuration dictionary containing OUTLIER_DETECTION settings.

    Returns
    -------
    dict
        Best metric configuration for outlier detection.
    """
    what_is_best = cfg["OUTLIER_DETECTION"]["what_is_best"]
    return cfg["OUTLIER_DETECTION"][what_is_best]

get_best_string_for_classification

get_best_string_for_classification(cfg: DictConfig) -> dict

Get the best metric configuration for classification.

PARAMETER DESCRIPTION
cfg

Configuration dictionary containing CLASSIFICATION_SETTINGS.

TYPE: DictConfig

RETURNS DESCRIPTION
dict

Best metric configuration for classification.

Source code in src/data_io/define_sources_for_flow.py
def get_best_string_for_classification(cfg: DictConfig) -> dict:
    """Get the best metric configuration for classification.

    Parameters
    ----------
    cfg : DictConfig
        Configuration dictionary containing CLASSIFICATION_SETTINGS.

    Returns
    -------
    dict
        Best metric configuration for classification.
    """
    return cfg["CLASSIFICATION_SETTINGS"]["BEST_METRIC"]

get_best_dict

get_best_dict(task: str, cfg: DictConfig) -> Optional[dict]

Get the best metric dictionary for a given task.

PARAMETER DESCRIPTION
task

Task name ("outlier_detection", "imputation", "featurization", or "classification").

TYPE: str

cfg

Configuration dictionary.

TYPE: DictConfig

RETURNS DESCRIPTION
dict or None

Best metric configuration dictionary, or None for featurization.

RAISES DESCRIPTION
NotImplementedError

If the task is unknown.

Source code in src/data_io/define_sources_for_flow.py
def get_best_dict(task: str, cfg: DictConfig) -> Optional[dict]:
    """Get the best metric dictionary for a given task.

    Parameters
    ----------
    task : str
        Task name ("outlier_detection", "imputation", "featurization", or "classification").
    cfg : DictConfig
        Configuration dictionary.

    Returns
    -------
    dict or None
        Best metric configuration dictionary, or None for featurization.

    Raises
    ------
    NotImplementedError
        If the task is unknown.
    """
    if task == "outlier_detection":
        best_dict = get_best_string_for_outlier_detection(cfg)
    elif task == "imputation":
        best_dict = get_best_string_for_imputation(cfg)
    elif task == "featurization":
        best_dict = None
    elif task == "classification":
        best_dict = get_best_string_for_classification(cfg)
    else:
        logger.error(f"Unknown task: {task}")
        raise NotImplementedError(f"Unknown task: {task}")
    return best_dict

get_best_run_dict

get_best_run_dict(
    run_df: DataFrame, best_dict: dict, task: str
) -> DataFrame

Sort MLflow runs by the best metric and return the sorted dataframe.

PARAMETER DESCRIPTION
run_df

DataFrame of MLflow runs.

TYPE: DataFrame

best_dict

Configuration specifying the metric column and sort direction.

TYPE: dict

task

Task name for metric column selection.

TYPE: str

RETURNS DESCRIPTION
DataFrame

Sorted DataFrame with best runs first.

RAISES DESCRIPTION
ValueError

If task is unknown or metric column not found.

Source code in src/data_io/define_sources_for_flow.py
def get_best_run_dict(run_df: pd.DataFrame, best_dict: dict, task: str) -> pd.DataFrame:
    """Sort MLflow runs by the best metric and return the sorted dataframe.

    Parameters
    ----------
    run_df : pd.DataFrame
        DataFrame of MLflow runs.
    best_dict : dict
        Configuration specifying the metric column and sort direction.
    task : str
        Task name for metric column selection.

    Returns
    -------
    pd.DataFrame
        Sorted DataFrame with best runs first.

    Raises
    ------
    ValueError
        If task is unknown or metric column not found.
    """
    if task == "outlier_detection":
        col_name = get_col_for_for_best_anomaly_detection_metric(best_dict, task)
    elif task == "imputation":
        col_name = get_best_imputation_col_name(best_dict)
    elif task == "classification":
        col_name = get_best_imputation_col_name(best_dict)
    else:
        logger.error(f"Unknown task: {task}")
        raise ValueError(f"Unknown task: {task}")

    if col_name not in run_df.columns:
        logger.error(f"Unknown string: {best_dict['string']}")
        logger.error(f"Available columns: {run_df.columns}")
        raise ValueError(f"Unknown string: {best_dict['string']}")

    if best_dict["direction"] == "ASC":
        run_df = run_df.sort_values(by=col_name, ascending=True)
    elif best_dict["direction"] == "DESC":
        run_df = run_df.sort_values(by=col_name, ascending=False)
    else:
        logger.error(f"Unknown direction: {best_dict['direction']}")
        raise ValueError(f"Unknown direction: {best_dict['direction']}")

    return run_df

drop_ensemble_runs

drop_ensemble_runs(runs_model: DataFrame) -> DataFrame

Remove ensemble runs from a MLflow runs dataframe.

PARAMETER DESCRIPTION
runs_model

DataFrame of MLflow runs.

TYPE: DataFrame

RETURNS DESCRIPTION
DataFrame

DataFrame with ensemble runs removed.

Source code in src/data_io/define_sources_for_flow.py
def drop_ensemble_runs(runs_model: pd.DataFrame) -> pd.DataFrame:
    """Remove ensemble runs from a MLflow runs dataframe.

    Parameters
    ----------
    runs_model : pd.DataFrame
        DataFrame of MLflow runs.

    Returns
    -------
    pd.DataFrame
        DataFrame with ensemble runs removed.
    """
    runs_model_out = pd.DataFrame()
    logger.info("Dropping ensemble runs")
    for idx, row in runs_model.iterrows():
        if "ensemble" not in row["tags.mlflow.runName"]:
            runs_model_out = pd.concat([runs_model_out, pd.DataFrame(row).T])
    return runs_model_out

foundation_model_filter

foundation_model_filter(
    mlflow_runs: DataFrame,
    best_dict: dict,
    model_name: str,
    task: str,
) -> Optional[DataFrame]

Filter foundation model runs to get best zeroshot and finetuned variants.

The idea is to get both the zeroshot and finetuned model with the foundation models (if available) whereas with the more traditional models, the zeroshot option is not there typically (or does not perform so well at all).

PARAMETER DESCRIPTION
mlflow_runs

DataFrame of all MLflow runs.

TYPE: DataFrame

best_dict

Configuration specifying the best metric and direction.

TYPE: dict

model_name

Name of the foundation model to filter for.

TYPE: str

task

Task name for metric selection.

TYPE: str

RETURNS DESCRIPTION
DataFrame or None

DataFrame with filtered runs, or None if no runs found.

Source code in src/data_io/define_sources_for_flow.py
def foundation_model_filter(
    mlflow_runs: pd.DataFrame, best_dict: dict, model_name: str, task: str
) -> Optional[pd.DataFrame]:
    """Filter foundation model runs to get best zeroshot and finetuned variants.

    The idea is to get both the zeroshot and finetuned model with the foundation
    models (if available) whereas with the more traditional models, the zeroshot
    option is not there typically (or does not perform so well at all).

    Parameters
    ----------
    mlflow_runs : pd.DataFrame
        DataFrame of all MLflow runs.
    best_dict : dict
        Configuration specifying the best metric and direction.
    model_name : str
        Name of the foundation model to filter for.
    task : str
        Task name for metric selection.

    Returns
    -------
    pd.DataFrame or None
        DataFrame with filtered runs, or None if no runs found.
    """
    df = pd.DataFrame()
    runs_model = mlflow_runs[
        mlflow_runs["tags.mlflow.runName"].str.contains(model_name)
    ]
    runs_model = drop_ensemble_runs(runs_model)

    if df.shape[0] > 0:
        criteria = ["zeroshot", "finetune"]
        data_sources = ["gt", "orig"]
        # You should get 3 (or 4 runs) as the zeroshot would be evaluate always on the "orig", split
        # so it does not matter if the source is "gt" or "orig"
        for criterion in criteria:
            for data_source in data_sources:
                try:
                    runs_criterion = runs_model[
                        runs_model["tags.mlflow.runName"].str.contains(criterion)
                    ]
                except Exception as e:
                    logger.error(
                        f"Could not filter the runs with criterion: {criterion}, model: {model_name}"
                    )
                    raise e
                runs_criterion_source = runs_criterion[
                    runs_model["tags.mlflow.runName"].str.contains(data_source)
                ]
                if runs_criterion_source.shape[0] > 0:
                    run_df = get_best_run_dict(runs_criterion_source, best_dict, task)
                    df = pd.concat([df, run_df.iloc[0:1]])

        return df

    else:
        logger.warning(f"No runs found for foundation model: {model_name}")
        return None

get_best_imputation_runs

get_best_imputation_runs(
    mlflow_runs: DataFrame, task: str, cfg: DictConfig
) -> DataFrame

Get the best run for each unique imputer+outlier combination.

Unlike best outlier_runs, we now have added 3 new fields to the mlflow_runs: 1. imputer_model 2. outlier_source 3. unique_combo

And the unique_combo defines how many imputer+outlier_source combos we have for featurization.

PARAMETER DESCRIPTION
mlflow_runs

DataFrame of MLflow runs with unique_combo column.

TYPE: DataFrame

task

Task name for metric selection.

TYPE: str

cfg

Configuration dictionary.

TYPE: DictConfig

RETURNS DESCRIPTION
DataFrame

DataFrame with one best run per unique combination.

Source code in src/data_io/define_sources_for_flow.py
def get_best_imputation_runs(
    mlflow_runs: pd.DataFrame, task: str, cfg: DictConfig
) -> pd.DataFrame:
    """Get the best run for each unique imputer+outlier combination.

    Unlike best outlier_runs, we now have added 3 new fields to the mlflow_runs:
    1. imputer_model
    2. outlier_source
    3. unique_combo

    And the unique_combo defines how many imputer+outlier_source combos we have
    for featurization.

    Parameters
    ----------
    mlflow_runs : pd.DataFrame
        DataFrame of MLflow runs with unique_combo column.
    task : str
        Task name for metric selection.
    cfg : DictConfig
        Configuration dictionary.

    Returns
    -------
    pd.DataFrame
        DataFrame with one best run per unique combination.
    """
    unique_combos = mlflow_runs["unique_combo"].unique()
    best_dict = get_best_dict(task, cfg)
    runs_out = pd.DataFrame()
    for i, unique_combo in enumerate(unique_combos):
        run_df = mlflow_runs[mlflow_runs["unique_combo"] == unique_combo]
        best_run = get_best_run_dict(run_df, best_dict, task)[0:1]
        assert best_run.shape[0] == 1, f"Expected 1 run, got {best_run.shape[0]}"
        logger.debug(f"Unique combo: {unique_combo}")
        runs_out = pd.concat([runs_out, best_run])

    assert len(runs_out) == len(unique_combos), (
        f"Expected {len(unique_combos)} runs, got {len(runs_out)}"
    )

    return runs_out

drop_foundational_models

drop_foundational_models(
    mlflow_runs: DataFrame,
    foundation_model_names: list[str],
) -> DataFrame

Remove foundation model runs from MLflow runs dataframe.

PARAMETER DESCRIPTION
mlflow_runs

DataFrame of MLflow runs.

TYPE: DataFrame

foundation_model_names

List of foundation model name strings to filter out.

TYPE: list

RETURNS DESCRIPTION
DataFrame

DataFrame with foundation model runs removed.

Source code in src/data_io/define_sources_for_flow.py
def drop_foundational_models(
    mlflow_runs: pd.DataFrame, foundation_model_names: list[str]
) -> pd.DataFrame:
    """Remove foundation model runs from MLflow runs dataframe.

    Parameters
    ----------
    mlflow_runs : pd.DataFrame
        DataFrame of MLflow runs.
    foundation_model_names : list
        List of foundation model name strings to filter out.

    Returns
    -------
    pd.DataFrame
        DataFrame with foundation model runs removed.
    """

    def check_name(foundation_model_names: list[str], run_name: str) -> bool:
        is_foundational_run = False
        for name in foundation_model_names:
            if name in run_name:
                is_foundational_run = True
        return is_foundational_run

    runs_out = pd.DataFrame()
    for i, row in mlflow_runs.iterrows():
        if not check_name(foundation_model_names, run_name=row["tags.mlflow.runName"]):
            runs_out = pd.concat([runs_out, pd.DataFrame(row).T])

    return runs_out

get_best_model_runs

get_best_model_runs(
    mlflow_runs: DataFrame, task: str, cfg: DictConfig
) -> DataFrame

Get the best runs for each model type (foundation and traditional).

Manual definition of what you want to compare. You could also simply use all MLflow runs as source, and put the return_subset to False. This artisanal selection is here just to reduce the number of combos to return.

PARAMETER DESCRIPTION
mlflow_runs

DataFrame of all MLflow runs.

TYPE: DataFrame

task

Task name for metric selection.

TYPE: str

cfg

Configuration dictionary.

TYPE: DictConfig

RETURNS DESCRIPTION
DataFrame

DataFrame with best runs for each model type.

Source code in src/data_io/define_sources_for_flow.py
def get_best_model_runs(
    mlflow_runs: pd.DataFrame, task: str, cfg: DictConfig
) -> pd.DataFrame:
    """Get the best runs for each model type (foundation and traditional).

    Manual definition of what you want to compare. You could also simply use
    all MLflow runs as source, and put the return_subset to False. This
    artisanal selection is here just to reduce the number of combos to return.

    Parameters
    ----------
    mlflow_runs : pd.DataFrame
        DataFrame of all MLflow runs.
    task : str
        Task name for metric selection.
    cfg : DictConfig
        Configuration dictionary.

    Returns
    -------
    pd.DataFrame
        DataFrame with best runs for each model type.
    """
    best_dict = get_best_dict(task, cfg)

    # Get the best foundational model runs
    foundation_model_names = get_foundation_model_names()
    runs_out = pd.DataFrame()
    for model_name in foundation_model_names:
        mlflow_runs_foundational = foundation_model_filter(
            mlflow_runs, best_dict, model_name=model_name, task=task
        )
        if mlflow_runs_foundational is not None:
            runs_out = pd.concat([runs_out, mlflow_runs_foundational])

    # Drop foundational models
    mlflow_others = drop_foundational_models(mlflow_runs, foundation_model_names)
    runs_out = pd.concat([runs_out, mlflow_others])

    # Get ensemble runs
    mlflow_ensemble = mlflow_runs[
        mlflow_runs["tags.mlflow.runName"].str.contains("ensemble")
    ]
    runs_out = pd.concat([runs_out, mlflow_ensemble])

    logger.info(f"Found {len(runs_out)} runs")

    return runs_out

get_unique_outlier_runs

get_unique_outlier_runs(
    mlflow_runs: DataFrame, cfg: DictConfig, task: str
) -> DataFrame

Get one best run per unique run name from MLflow runs.

PARAMETER DESCRIPTION
mlflow_runs

DataFrame of MLflow runs.

TYPE: DataFrame

cfg

Configuration dictionary.

TYPE: DictConfig

task

Task name for metric selection.

TYPE: str

RETURNS DESCRIPTION
DataFrame

DataFrame with one best run per unique run name.

Source code in src/data_io/define_sources_for_flow.py
def get_unique_outlier_runs(
    mlflow_runs: pd.DataFrame, cfg: DictConfig, task: str
) -> pd.DataFrame:
    """Get one best run per unique run name from MLflow runs.

    Parameters
    ----------
    mlflow_runs : pd.DataFrame
        DataFrame of MLflow runs.
    cfg : DictConfig
        Configuration dictionary.
    task : str
        Task name for metric selection.

    Returns
    -------
    pd.DataFrame
        DataFrame with one best run per unique run name.
    """
    best_dict = get_best_dict(task, cfg)
    best_runs = pd.DataFrame()
    unique_run_names = mlflow_runs["tags.mlflow.runName"].unique()
    for unique_run in unique_run_names:
        run_df = mlflow_runs[mlflow_runs["tags.mlflow.runName"] == unique_run]
        if best_dict is not None:
            run_df = get_best_run_dict(run_df, best_dict, task)
        else:
            # e.g. featurization does not have any metrics, so sort by latest
            run_df = run_df.sort_values(by="start_time", ascending=False)
        best_runs = pd.concat([best_runs, run_df.iloc[0:1]])

    assert len(best_runs) == len(unique_run_names), (
        f"Expected {len(unique_run_names)} runs, got {len(best_runs)}"
    )

    # Drop rows with NaN values in the best metric
    # best_runs = best_runs.dropna(subset=[best_string])

    return best_runs

parse_run_name_for_two_model_names

parse_run_name_for_two_model_names(
    run_name: str, delimiter: str = "__"
) -> tuple[str, str]

Parse a run name to extract imputer model and outlier source names.

PARAMETER DESCRIPTION
run_name

MLflow run name in format "imputer__outlier_source".

TYPE: str

delimiter

Delimiter separating model names, by default "__".

TYPE: str DEFAULT: '__'

RETURNS DESCRIPTION
tuple

Tuple containing (imputer_model, outlier_source) as simplified names.

RAISES DESCRIPTION
Exception

If the run name cannot be parsed.

Source code in src/data_io/define_sources_for_flow.py
def parse_run_name_for_two_model_names(
    run_name: str, delimiter: str = "__"
) -> tuple[str, str]:
    """Parse a run name to extract imputer model and outlier source names.

    Parameters
    ----------
    run_name : str
        MLflow run name in format "imputer__outlier_source".
    delimiter : str, optional
        Delimiter separating model names, by default "__".

    Returns
    -------
    tuple
        Tuple containing (imputer_model, outlier_source) as simplified names.

    Raises
    ------
    Exception
        If the run name cannot be parsed.
    """
    try:
        imputer_model, outlier_source = run_name.split(delimiter)
    except Exception:
        try:
            # how did the extra delimiter appear here?
            imputer_model, outlier_source, extra = run_name.split(delimiter)
            outlier_source = outlier_source + "_" + extra
        except Exception as e:
            logger.error('Could not parse run name "{}"'.format(run_name))
            raise e
    imputer_model = simplify_model_name(imputer_model)
    outlier_source = simplify_model_name(outlier_source)
    if outlier_source == "TimesNet":
        outlier_source = "TimesNet-gt"
    return imputer_model, outlier_source

simplify_model_name

simplify_model_name(
    model_name: str, delimiter: str = "_"
) -> str

Simplify a model name by extracting the core identifier.

Handles special cases for ensemble and MOMENT model naming conventions.

PARAMETER DESCRIPTION
model_name

Full model name from MLflow run.

TYPE: str

delimiter

Delimiter in the model name, by default "_".

TYPE: str DEFAULT: '_'

RETURNS DESCRIPTION
str

Simplified model name.

Source code in src/data_io/define_sources_for_flow.py
def simplify_model_name(model_name: str, delimiter: str = "_") -> str:
    """Simplify a model name by extracting the core identifier.

    Handles special cases for ensemble and MOMENT model naming conventions.

    Parameters
    ----------
    model_name : str
        Full model name from MLflow run.
    delimiter : str, optional
        Delimiter in the model name, by default "_".

    Returns
    -------
    str
        Simplified model name.
    """
    model_name = model_name.replace("pupil_", "pupil-")
    model_name_out = model_name.split(delimiter)[0]
    # if "zeroshot" in model_name:
    #     model_name_out = model_name_out + "-zeroshot"
    # elif "finetune" in model_name:
    #     model_name_out = model_name_out + "-finetune"
    if "ensemble" in model_name:
        if "gt_thresholded" not in model_name:
            model_name_out = model_name_out.replace("ensembleThresholded", "ensemble")
    else:
        if "MOMENT" in model_name:
            model_name_fields = model_name.split("_")
            if model_name_fields[3] == "gt" or model_name_fields[3] == "orig":
                # Outlier naming
                model_name_out = model_name_out = (
                    model_name_fields[0]
                    + "-"
                    + model_name_fields[3]
                    + "-"
                    + model_name_fields[1]
                )
            else:
                # Imputation naming
                model_name_out = model_name_fields[0] + "-" + model_name_fields[1]

    model_name_out = model_name_out.replace("UniTS-Outlier", "UniTS")

    return model_name_out

get_unique_combo_runs

get_unique_combo_runs(
    mlflow_runs_in: DataFrame,
    cfg: DictConfig,
    task: str,
    delimiter: str = "__",
) -> DataFrame

Add unique combo columns to MLflow runs for imputer+outlier tracking.

Parses run names to extract imputer_model and outlier_source, creating a unique_combo identifier for each combination.

PARAMETER DESCRIPTION
mlflow_runs_in

Input MLflow runs DataFrame.

TYPE: DataFrame

cfg

Configuration dictionary.

TYPE: DictConfig

task

Task name.

TYPE: str

delimiter

Delimiter separating model names, by default "__".

TYPE: str DEFAULT: '__'

RETURNS DESCRIPTION
DataFrame

DataFrame with added imputer_model, outlier_source, and unique_combo columns.

RAISES DESCRIPTION
ValueError

If any run has a NaN run_id.

Source code in src/data_io/define_sources_for_flow.py
def get_unique_combo_runs(
    mlflow_runs_in: pd.DataFrame, cfg: DictConfig, task: str, delimiter: str = "__"
) -> pd.DataFrame:
    """Add unique combo columns to MLflow runs for imputer+outlier tracking.

    Parses run names to extract imputer_model and outlier_source, creating
    a unique_combo identifier for each combination.

    Parameters
    ----------
    mlflow_runs_in : pd.DataFrame
        Input MLflow runs DataFrame.
    cfg : DictConfig
        Configuration dictionary.
    task : str
        Task name.
    delimiter : str, optional
        Delimiter separating model names, by default "__".

    Returns
    -------
    pd.DataFrame
        DataFrame with added imputer_model, outlier_source, and unique_combo columns.

    Raises
    ------
    ValueError
        If any run has a NaN run_id.
    """
    # Add empty columns to the Pandas DataFrame
    mlflow_runs = deepcopy(mlflow_runs_in)
    mlflow_runs = mlflow_runs.reset_index(drop=True)
    mlflow_runs["imputer_model"] = ""
    mlflow_runs["outlier_source"] = ""
    mlflow_runs["unique_combo"] = ""

    for i, run_df in enumerate(mlflow_runs.iterrows()):
        run_name = run_df[1]["tags.mlflow.runName"]
        imputer_model, outlier_source = parse_run_name_for_two_model_names(
            run_name, delimiter=delimiter
        )
        mlflow_runs.at[i, "imputer_model"] = imputer_model
        mlflow_runs.at[i, "outlier_source"] = outlier_source
        unique_combo_string = f"{imputer_model}{delimiter}{outlier_source}"
        mlflow_runs.at[i, "unique_combo"] = unique_combo_string
        logger.debug(
            f"{i + 1}/{mlflow_runs.shape[0]}: {imputer_model}, {outlier_source}, {run_name}"
        )
        logger.info(f"{i + 1}/{mlflow_runs.shape[0]}: {unique_combo_string}")

    assert len(mlflow_runs) == len(mlflow_runs_in), (
        "input mlflow_run has different number of runs ({}) "
        "than the output mlflow_run ({})".format(len(mlflow_runs_in), len(mlflow_runs))
    )

    run_ids = mlflow_runs["run_id"].tolist()
    isnan = False
    for run_id in run_ids:
        if isinstance(run_id, float):
            isnan = np.isnan(run_id)

    if isnan:
        logger.error("You have NaN runs?")
        logger.error(run_ids)
        logger.error(mlflow_runs)
        raise ValueError("You have NaN runs?")

    # unique_combos = sorted(list(set(mlflow_runs["unique_combo"])))
    return mlflow_runs

get_previous_best_mlflow_runs

get_previous_best_mlflow_runs(
    experiment_name: str,
    cfg: DictConfig,
    task: str = "outlier_detection",
    return_subset: bool = True,
) -> Optional[DataFrame]

Get the best MLflow runs from a previous experiment for use as data sources.

PARAMETER DESCRIPTION
experiment_name

Name of the MLflow experiment.

TYPE: str

cfg

Configuration dictionary.

TYPE: DictConfig

task

Task name to determine filtering strategy, by default "outlier_detection".

TYPE: str DEFAULT: 'outlier_detection'

return_subset

Whether to return a curated subset or all runs, by default True.

TYPE: bool DEFAULT: True

RETURNS DESCRIPTION
DataFrame or None

DataFrame of best MLflow runs, or None if no runs found.

RAISES DESCRIPTION
ValueError

If the task is unknown.

NotImplementedError

If featurization task is requested.

Source code in src/data_io/define_sources_for_flow.py
def get_previous_best_mlflow_runs(
    experiment_name: str,
    cfg: DictConfig,
    task: str = "outlier_detection",
    return_subset: bool = True,
) -> Optional[pd.DataFrame]:
    """Get the best MLflow runs from a previous experiment for use as data sources.

    Parameters
    ----------
    experiment_name : str
        Name of the MLflow experiment.
    cfg : DictConfig
        Configuration dictionary.
    task : str, optional
        Task name to determine filtering strategy, by default "outlier_detection".
    return_subset : bool, optional
        Whether to return a curated subset or all runs, by default True.

    Returns
    -------
    pd.DataFrame or None
        DataFrame of best MLflow runs, or None if no runs found.

    Raises
    ------
    ValueError
        If the task is unknown.
    NotImplementedError
        If featurization task is requested.
    """
    # Gives all the runs per experiment
    mlflow_runs = mlflow.search_runs(experiment_names=[experiment_name])

    # Gives one run per unique run_name (with the best metric/loss returned)
    if mlflow_runs.shape[0] == 0:
        return None
    else:
        mlflow_runs = get_unique_outlier_runs(mlflow_runs, cfg, task=task)

        # Manual selection
        # Note! The task names refer to previous tasks
        if task == "outlier_detection":
            # Outlier detection results -> imputation source
            if return_subset:
                # i.e. best zero-shot and best finetuned per model, and not all the
                # different hyperparam combos you tried (and were reflected in the run_name)
                mlflow_runs = get_best_model_runs(mlflow_runs, task, cfg)
                logger.info(
                    "Returning subset of the outlier detection runs, {} runs".format(
                        len(mlflow_runs)
                    )
                )
        elif task == "imputation":
            # Imputation results -> featurization source
            mlflow_runs_combo = get_unique_combo_runs(mlflow_runs, cfg, task=task)
            if return_subset:
                # Note! with the MOMENT models, the logic is the same as for outlier detection,
                # i.e. keep best zeroshot and finetuned model
                mlflow_runs = get_best_imputation_runs(mlflow_runs_combo, task, cfg)
                logger.info(
                    "Returning subset of the imputation runs, {} runs".format(
                        len(mlflow_runs)
                    )
                )

        elif task == "featurization":
            # PLR features (handcrafted or embeddings) -> classification source
            logger.info("Returning all the featurization runs")
            raise NotImplementedError("Featurization not implemented")

        elif task == "classification":
            best_unique_models = get_grouped_classification_runs(
                {}, experiment_name, cfg, task
            )
            # TODO! you could use this as a sanity check
            no_of_mlflow_runs = mlflow_runs.shape[0]
            no_of_unique_submodels_in_ensemble = 0
            for ensemble_name, ensemble_dict in best_unique_models.items():
                no_of_unique_submodels_in_ensemble += len(ensemble_dict)
            if no_of_mlflow_runs != no_of_unique_submodels_in_ensemble:
                logger.error(
                    "The number of unique submodels in the ensemble does not match the number of MLflow runs"
                )
                logger.error(
                    f"No of MLflow runs: {no_of_mlflow_runs}, "
                    f"no of unique submodels in ensemble: {no_of_unique_submodels_in_ensemble}"
                )

        else:
            logger.error(f"Unknown task: {task}")
            raise ValueError(f"Unknown task: {task}")

        return mlflow_runs

get_arrays_for_splits_from_imputer_artifacts

get_arrays_for_splits_from_imputer_artifacts(
    artifacts: dict, run_name: str
) -> dict[str, dict[str, ndarray]]

Extract imputation arrays from MLflow imputer artifacts.

PARAMETER DESCRIPTION
artifacts

MLflow artifacts containing imputation results.

TYPE: dict

run_name

MLflow run name for logging.

TYPE: str

RETURNS DESCRIPTION
dict

Dictionary with train/test splits containing X, CI_pos, CI_neg, and mask arrays.

Source code in src/data_io/define_sources_for_flow.py
def get_arrays_for_splits_from_imputer_artifacts(
    artifacts: dict, run_name: str
) -> dict[str, dict[str, np.ndarray]]:
    """Extract imputation arrays from MLflow imputer artifacts.

    Parameters
    ----------
    artifacts : dict
        MLflow artifacts containing imputation results.
    run_name : str
        MLflow run name for logging.

    Returns
    -------
    dict
        Dictionary with train/test splits containing X, CI_pos, CI_neg, and mask arrays.
    """
    dict_out = {}
    imputation = artifacts["model_artifacts"]["imputation"]
    for split in imputation.keys():
        imputation_dict = imputation[split]["imputation_dict"]
        imputation_mean = imputation_dict["imputation"]["mean"]
        mask = imputation_dict["indicating_mask"]
        ci_pos = imputation_dict["imputation"]["imputation_ci_pos"]
        ci_neg = imputation_dict["imputation"]["imputation_ci_neg"]
        no_dims = len(imputation_mean.shape)
        if no_dims == 3:
            # this is in "PyPOTS space", make it 2D
            imputation_mean = imputation_mean[:, :, 0]
            if len(mask.shape) == 3:
                mask = mask[:, :, 0]
            if ci_pos is not None:
                if len(ci_pos.shape) == 3:
                    ci_pos = ci_pos[:, :, 0]
                    ci_neg = ci_neg[:, :, 0]

        if ci_pos is None:
            ci_pos = np.ones_like(imputation_mean)
            ci_pos[:] = np.nan
            ci_neg = np.ones_like(imputation_mean)
            ci_neg[:] = np.nan

        dict_out[split] = {
            # Same as reconstruction
            "X": imputation_mean,
            "CI_pos": ci_pos,
            "CI_neg": ci_neg,
            "mask": mask,
        }

    return dict_out

check_arrays

check_arrays(
    splits_dicts: dict[str, dict[str, ndarray]], task: str
) -> None

Validate that X and mask arrays have matching shapes.

PARAMETER DESCRIPTION
splits_dicts

Dictionary of split dictionaries containing X and mask arrays.

TYPE: dict

task

Task name for error messages.

TYPE: str

RAISES DESCRIPTION
AssertionError

If X and mask arrays have different shapes.

Source code in src/data_io/define_sources_for_flow.py
def check_arrays(splits_dicts: dict[str, dict[str, np.ndarray]], task: str) -> None:
    """Validate that X and mask arrays have matching shapes.

    Parameters
    ----------
    splits_dicts : dict
        Dictionary of split dictionaries containing X and mask arrays.
    task : str
        Task name for error messages.

    Raises
    ------
    AssertionError
        If X and mask arrays have different shapes.
    """
    for split, split_dict in splits_dicts.items():
        assert split_dict["X"].shape == split_dict["mask"].shape, (
            "{} X and mask have different sizes, X: {}, mask: {}".format(
                split, split_dict["X"].shape, split_dict["mask"].shape
            )
        )

get_best_epoch

get_best_epoch(
    outlier_artifacts: dict,
) -> tuple[dict, bool]

Extract the best epoch results from outlier detection artifacts.

Handles multiple artifact formats from different outlier detection methods.

PARAMETER DESCRIPTION
outlier_artifacts

Dictionary of outlier detection artifacts from MLflow.

TYPE: dict

RETURNS DESCRIPTION
tuple

Tuple containing (results_best, simple_format) where simple_format indicates the artifact structure type.

Source code in src/data_io/define_sources_for_flow.py
def get_best_epoch(outlier_artifacts: dict) -> tuple[dict, bool]:
    """Extract the best epoch results from outlier detection artifacts.

    Handles multiple artifact formats from different outlier detection methods.

    Parameters
    ----------
    outlier_artifacts : dict
        Dictionary of outlier detection artifacts from MLflow.

    Returns
    -------
    tuple
        Tuple containing (results_best, simple_format) where simple_format
        indicates the artifact structure type.
    """
    # TODO! There is no need to have all these options, have all the outlier detection method output the same format
    simple_format = True
    if "outlier_results" in outlier_artifacts:
        # if you logged results at each epoch
        if outlier_artifacts["metadata"]["best_epoch"] is not None:
            results_best = outlier_artifacts["outlier_results"][
                outlier_artifacts["metadata"]["best_epoch"]
            ]
        else:
            last_key = list(outlier_artifacts["outlier_results"].keys())[-1]
            results_best = outlier_artifacts["outlier_results"][last_key]
        simple_format = False
    elif "best_arrays" in outlier_artifacts:
        results_best = outlier_artifacts["best_arrays"]
    elif "results" in outlier_artifacts:
        results_best = outlier_artifacts["results"]
    else:
        # e.g. sklearn, Prophet
        results_best = outlier_artifacts
    return results_best, simple_format

if_pick_the_split

if_pick_the_split(run_name: str, split: str) -> bool

Determine if a given split should be processed based on run name.

PARAMETER DESCRIPTION
run_name

MLflow run name.

TYPE: str

split

Split name to check.

TYPE: str

RETURNS DESCRIPTION
bool

True if the split should be picked, False otherwise.

Source code in src/data_io/define_sources_for_flow.py
def if_pick_the_split(run_name: str, split: str) -> bool:
    """Determine if a given split should be processed based on run name.

    Parameters
    ----------
    run_name : str
        MLflow run name.
    split : str
        Split name to check.

    Returns
    -------
    bool
        True if the split should be picked, False otherwise.
    """
    pick_split = False
    if "outlier" in split:
        return True

    # these did not have any outlier split
    simple_detectors = get_simple_outlier_detectors()
    for name in simple_detectors:
        if name in run_name:
            return True

    return pick_split

get_arrays_for_splits_from_outlier_artifacts

get_arrays_for_splits_from_outlier_artifacts(
    outlier_artifacts: dict, run_name: str
) -> dict[str, dict[str, ndarray]]

Extract reconstruction and mask arrays from outlier detection artifacts.

Handles multiple artifact formats from different outlier detection methods (MOMENT, TimesNet, LOF, Prophet, etc.).

PARAMETER DESCRIPTION
outlier_artifacts

Dictionary of outlier detection artifacts from MLflow.

TYPE: dict

run_name

MLflow run name for logging and method detection.

TYPE: str

RETURNS DESCRIPTION
dict

Dictionary with train/test splits containing X (reconstruction) and mask arrays. Format: {split: {'X': np.array, 'mask': np.array}}

RAISES DESCRIPTION
ValueError

If arrays cannot be extracted from the artifacts.

AssertionError

If no splits were selected for analysis.

Source code in src/data_io/define_sources_for_flow.py
def get_arrays_for_splits_from_outlier_artifacts(
    outlier_artifacts: dict, run_name: str
) -> dict[str, dict[str, np.ndarray]]:
    """Extract reconstruction and mask arrays from outlier detection artifacts.

    Handles multiple artifact formats from different outlier detection methods
    (MOMENT, TimesNet, LOF, Prophet, etc.).

    Parameters
    ----------
    outlier_artifacts : dict
        Dictionary of outlier detection artifacts from MLflow.
    run_name : str
        MLflow run name for logging and method detection.

    Returns
    -------
    dict
        Dictionary with train/test splits containing X (reconstruction) and mask arrays.
        Format: {split: {'X': np.array, 'mask': np.array}}

    Raises
    ------
    ValueError
        If arrays cannot be extracted from the artifacts.
    AssertionError
        If no splits were selected for analysis.
    """
    # best_arrays_format = "best_arrays" in outlier_artifacts
    results_best, simple_format = get_best_epoch(outlier_artifacts)
    dict_out = {}
    for split in results_best.keys():
        if if_pick_the_split(run_name, split):
            # Remember that the "vanilla train and test" were used for reconstruction learning, and did not
            # contain any outlier labels (unsupervised learning), and the outlier detection capability was
            # evaluated using the "pupil_orig" data (from both "train" and "test")
            split_fields = split.split("_")
            if len(split_fields) == 2:
                split_out = split_fields[1]
            elif len(split_fields) == 1:
                split_out = split_fields[0]
            else:
                logger.error("What is this split = {}".format(split))
                raise ValueError(f"Unknown split {split}")

            try:
                if simple_format:
                    # TimesNet (well not specific to TimesNet, simpler structure, move on to this?)
                    try:
                        dict_out[split_out] = {
                            "X": results_best[split]["preds"],
                            "mask": results_best[split]["pred_mask"],
                        }
                    except Exception:
                        try:
                            dict_out[split_out] = {
                                "X": results_best[split]["arrays"]["preds"],
                                "mask": results_best[split]["arrays"]["pred_mask"],
                            }
                        except Exception:
                            try:
                                # e.g. LOF, Prophet, etc. do not reconstruct
                                X_nan = np.zeros_like(
                                    results_best[split]["arrays"]["pred_mask"]
                                ).astype(float)
                                X_nan[:] = np.nan
                                dict_out[split_out] = {
                                    "X": X_nan,
                                    "mask": results_best[split]["arrays"]["pred_mask"],
                                }
                            except Exception as e:
                                logger.error(
                                    "Could not get best results!, error = {}".format(e)
                                )
                                raise e

                    assert len(dict_out[split_out]["X"].shape) == 2, (
                        "reconstructed signal array needs to be 2D"
                    )
                    assert len(dict_out[split_out]["mask"].shape) == 2, (
                        "mask needs to be 2D"
                    )

                else:
                    # MOMENT
                    # Finetuned
                    dict_out[split_out] = {
                        # Same as reconstruction
                        "X": results_best[split]["results_dict"]["split_results"][
                            "arrays"
                        ]["preds"],
                        # When anomaly score is used with the adaptive f1 to get timepoint-wise labels
                        "mask": results_best[split]["results_dict"]["preds"]["arrays"][
                            "pred_mask"
                        ],
                    }
            except Exception as e:
                logger.error(f"Error: {e}")
                logger.error(
                    f"split: {split}, split_out: {split_out}, run_name: {run_name}"
                )
                raise ValueError(f"Error: {e}")

    assert len(dict_out) > 0, (
        "You did not pick anything from this model to analyze! "
        "Glitch in if_pick_the_split(run_name, split)?"
    )

    return dict_out

get_ensembled_anomaly_masks

get_ensembled_anomaly_masks(
    artifacts: dict,
) -> dict[str, dict[str, ndarray]]

Create ensembled anomaly masks from individual detector masks.

PARAMETER DESCRIPTION
artifacts

Dictionary mapping splits to 3D arrays of individual detector masks.

TYPE: dict

RETURNS DESCRIPTION
dict

Dictionary with ensembled masks for each split.

Source code in src/data_io/define_sources_for_flow.py
def get_ensembled_anomaly_masks(artifacts: dict) -> dict[str, dict[str, np.ndarray]]:
    """Create ensembled anomaly masks from individual detector masks.

    Parameters
    ----------
    artifacts : dict
        Dictionary mapping splits to 3D arrays of individual detector masks.

    Returns
    -------
    dict
        Dictionary with ensembled masks for each split.
    """
    dict_out = {}
    for split, array_3D in artifacts.items():
        dict_out[split] = {"mask": ensemble_masks(array_3D)}
    return dict_out

get_source_data

get_source_data(
    mlflow_runs: DataFrame, cfg: DictConfig, task: str
) -> tuple[dict, dict]

Load source data arrays from MLflow artifacts for each run.

PARAMETER DESCRIPTION
mlflow_runs

DataFrame of MLflow runs to process.

TYPE: DataFrame

cfg

Configuration dictionary.

TYPE: DictConfig

task

Task name ("outlier_detection" or "imputation").

TYPE: str

RETURNS DESCRIPTION
tuple

Tuple containing (dicts_out, mlflow_dict) where dicts_out contains the data arrays and mlflow_dict maps source names to MLflow run info.

RAISES DESCRIPTION
ValueError

If the task is unknown.

Source code in src/data_io/define_sources_for_flow.py
def get_source_data(
    mlflow_runs: pd.DataFrame, cfg: DictConfig, task: str
) -> tuple[dict, dict]:
    """Load source data arrays from MLflow artifacts for each run.

    Parameters
    ----------
    mlflow_runs : pd.DataFrame
        DataFrame of MLflow runs to process.
    cfg : DictConfig
        Configuration dictionary.
    task : str
        Task name ("outlier_detection" or "imputation").

    Returns
    -------
    tuple
        Tuple containing (dicts_out, mlflow_dict) where dicts_out contains
        the data arrays and mlflow_dict maps source names to MLflow run info.

    Raises
    ------
    ValueError
        If the task is unknown.
    """
    dicts_out = {}
    mlflow_dict = {}
    for mlflow_row in (
        pbar := tqdm(
            mlflow_runs.iterrows(),
            desc="Getting source data",
            total=mlflow_runs.shape[0],
        )
    ):
        mlflow_run = mlflow_row[1]
        run_name = mlflow_run["tags.mlflow.runName"]
        model_name, model_key = get_model_name_from_run_name(run_name, task)
        logger.info(f"Model name: {model_name}, run_name: {run_name}")
        artifacts = outlier_detection_artifacts_dict(mlflow_run, model_name, task)
        if task == "outlier_detection":
            if "ensemble" in model_name:
                dicts_out[model_key] = get_ensembled_anomaly_masks(artifacts)
            else:
                dicts_out[model_key] = get_arrays_for_splits_from_outlier_artifacts(
                    outlier_artifacts=artifacts, run_name=run_name
                )
                check_arrays(splits_dicts=dicts_out[model_key], task=task)
            mlflow_dict[model_key] = mlflow_run
            del artifacts
            pbar.set_description(
                f"Import Sources | RAM use: {psutil.virtual_memory().percent} %: {model_name}"
            )
        elif task == "imputation":
            # Use the a bit shorter unique combo name, the exact names and run_id will still be
            # returned in the mlflow_dict
            dict_key = mlflow_row[1]["unique_combo"]
            dicts_out[dict_key] = get_arrays_for_splits_from_imputer_artifacts(
                artifacts, run_name
            )
            check_arrays(splits_dicts=dicts_out[dict_key], task=task)
            mlflow_dict[dict_key] = mlflow_run
        else:
            logger.error(f"Unknown task: {task}")
            raise ValueError(f"Unknown task: {task}")

    logger.info(f"Found {len(list(dicts_out.keys()))} sources")

    return dicts_out, mlflow_dict

add_array_to_dict

add_array_to_dict(
    array_to_add: ndarray,
    key: str,
    astype: Optional[str] = None,
) -> ndarray

Validate and optionally cast a numpy array before adding to a dictionary.

PARAMETER DESCRIPTION
array_to_add

Array to validate and add.

TYPE: ndarray

key

Dictionary key name for error messages.

TYPE: str

astype

Data type to cast the array to, by default None.

TYPE: str DEFAULT: None

RETURNS DESCRIPTION
ndarray

Validated (and optionally cast) array.

RAISES DESCRIPTION
ValueError

If the input is not a numpy array.

Source code in src/data_io/define_sources_for_flow.py
def add_array_to_dict(
    array_to_add: np.ndarray, key: str, astype: Optional[str] = None
) -> np.ndarray:
    """Validate and optionally cast a numpy array before adding to a dictionary.

    Parameters
    ----------
    array_to_add : np.ndarray
        Array to validate and add.
    key : str
        Dictionary key name for error messages.
    astype : str, optional
        Data type to cast the array to, by default None.

    Returns
    -------
    np.ndarray
        Validated (and optionally cast) array.

    Raises
    ------
    ValueError
        If the input is not a numpy array.
    """
    if isinstance(array_to_add, np.ndarray):
        if astype is not None:
            return array_to_add.astype(astype)
        else:
            return array_to_add
    else:
        logger.error(
            "You are trying to add (key={}) a non-Numpy array, type = {}".format(
                key, type(array_to_add)
            )
        )
        if isinstance(array_to_add, dict):
            logger.error(f"dict keys = {list(array_to_add.keys())}")
        raise ValueError(
            "You are trying to add a non-Numpy array, type = {}".format(
                type(array_to_add)
            )
        )

get_dict_per_col_name

get_dict_per_col_name(
    col_names: list[str],
    data_dict: dict,
    col_name: str,
    mask_col: str,
) -> dict

Create data dictionary structure for a specific column name.

PARAMETER DESCRIPTION
col_names

List of column names to process (typically pupil signal columns).

TYPE: list

data_dict

Source data dictionary with df and preprocess keys.

TYPE: dict

col_name

Target column name for the data.

TYPE: str

mask_col

Column name for the mask data.

TYPE: str

RETURNS DESCRIPTION
dict

Structured data dictionary with X, X_GT, and mask arrays.

Source code in src/data_io/define_sources_for_flow.py
def get_dict_per_col_name(
    col_names: list[str], data_dict: dict, col_name: str, mask_col: str
) -> dict:
    """Create data dictionary structure for a specific column name.

    Parameters
    ----------
    col_names : list
        List of column names to process (typically pupil signal columns).
    data_dict : dict
        Source data dictionary with df and preprocess keys.
    col_name : str
        Target column name for the data.
    mask_col : str
        Column name for the mask data.

    Returns
    -------
    dict
        Structured data dictionary with X, X_GT, and mask arrays.
    """
    data_dicts = {}
    for gt_source in col_names:  # see if this for loop is really necessary
        data_dicts = {}
        data_dicts["df"] = {}
        data_dicts["preprocess"] = data_dict["preprocess"]
        # TODO! if you add a new key upstream, this will lose it

        for split in data_dict["df"].keys():
            data_dicts["df"][split] = {}
            for key, split_dict in data_dict["df"][split].items():
                if key != "data":
                    # copy as it is, the other dicts
                    assert isinstance(split_dict, dict), (
                        "This should be a dict, not {}".format(type(split_dict))
                    )
                    data_dicts["df"][split][key] = split_dict
                if key == "labels":
                    if "data" not in data_dicts["df"][split].keys():
                        data_dicts["df"][split]["data"] = {}
                    # Human annotated ground truth for missing values, you are evaluating
                    # on how well the model(s) can impute these
                    data_dicts["df"][split]["data"]["mask"] = add_array_to_dict(
                        array_to_add=data_dict["df"][split][key][mask_col],
                        astype="int",
                        key=key,
                    )
                if key == "data":
                    if "data" not in data_dicts["df"][split].keys():
                        data_dicts["df"][split]["data"] = {}
                    # Set the "col_name" to the column name of the pupil signal, e.g.
                    # "pupil_gt" or "pupil_raw", "pupil_orig that you wish to train the model with
                    data_dicts["df"][split]["data"]["X"] = add_array_to_dict(
                        array_to_add=data_dict["df"][split][key][gt_source], key=key
                    )

                    # This is the denoised ground truth, if you trained on the "pupil_gt" data,
                    # this will be exactly the same then
                    data_dicts["df"][split]["data"]["X_GT"] = add_array_to_dict(
                        array_to_add=data_dict["df"][split][key][gt_source], key=key
                    )

    return data_dicts

print_mask_stats

print_mask_stats(data_dict: dict, mask_col: str) -> None

Log statistics about the mask coverage for each split.

PARAMETER DESCRIPTION
data_dict

Data dictionary containing df with mask arrays.

TYPE: dict

mask_col

Name of the mask column for logging.

TYPE: str

Source code in src/data_io/define_sources_for_flow.py
def print_mask_stats(data_dict: dict, mask_col: str) -> None:
    """Log statistics about the mask coverage for each split.

    Parameters
    ----------
    data_dict : dict
        Data dictionary containing df with mask arrays.
    mask_col : str
        Name of the mask column for logging.
    """
    for split, dict_data in data_dict["df"].items():
        mask = dict_data["data"]["mask"]
        mask_sum = np.sum(mask)
        mask_percentage = 100 * (mask_sum / mask.size)
        logger.info(
            f"Split = {split}: {mask_percentage:.2f}% of mask is True ({mask_col})"
        )

import_data_for_flow

import_data_for_flow(
    cfg: DictConfig, task: str
) -> tuple[dict, str, dict]

Import data from DuckDB and prepare it for the processing flow.

PARAMETER DESCRIPTION
cfg

Configuration dictionary.

TYPE: DictConfig

task

Task name ("outlier_detection" or "imputation").

TYPE: str

RETURNS DESCRIPTION
tuple

Tuple containing (data_dicts, input_signal, data_dict) where data_dicts is the structured source data, input_signal is the column name for input data, and data_dict is the raw imported data.

RAISES DESCRIPTION
ValueError

If the task is unknown.

Source code in src/data_io/define_sources_for_flow.py
def import_data_for_flow(cfg: DictConfig, task: str) -> tuple[dict, str, dict]:
    """Import data from DuckDB and prepare it for the processing flow.

    Parameters
    ----------
    cfg : DictConfig
        Configuration dictionary.
    task : str
        Task name ("outlier_detection" or "imputation").

    Returns
    -------
    tuple
        Tuple containing (data_dicts, input_signal, data_dict) where
        data_dicts is the structured source data, input_signal is the
        column name for input data, and data_dict is the raw imported data.

    Raises
    ------
    ValueError
        If the task is unknown.
    """
    data_df = flow_import_data(cfg=cfg)
    data_dict = convert_df_to_dict(data_df=data_df, cfg=cfg)

    if task == "outlier_detection":
        # Note! outlier_detection refers to the previous task before imputation training, as in where is the data
        # that you are going to use for training the imputation models
        gt_signal = cfg["IMPUTATION_TRAINING"]["gt_col_name"]
        # for non-reconstructing methods, you need some pupil signal to mask with the pred_mask then
        input_signal = "pupil_orig_imputed"
        gt_signal_out = gt_signal
        # Remember: this task above now refers for the MLflow task, i.e.
        # for imputation training, you are getting the data from the previous outlier detection task
        mask_col = "imputation_mask"
    elif task == "imputation":
        # You could use some other key for the featurization, but atm just using the same as used
        # before in the imputation training
        gt_signal = cfg["IMPUTATION_TRAINING"]["gt_col_name"]
        gt_signal_out = gt_signal.replace("_", "-")
        # we have now two sources/models encoded in the name, i.e. imputation_model__outlier_source
        gt_signal_out = f"{gt_signal_out}__{gt_signal_out}"
        input_signal = "pupil_raw_imputed"
        # Now the imputation mask might or might not be needed when featurising PLR or getting the embeddings
        mask_col = "imputation_mask"
    else:
        logger.error(f"Unknown task: {task}")
        raise ValueError(f"Unknown task: {task}")

    data_dicts = {}
    data_dicts[gt_signal_out] = get_dict_per_col_name(
        col_names=cfg["IMPUTATION_TRAINING"]["col_name"],
        data_dict=data_dict,
        col_name=gt_signal_out,
        mask_col=mask_col,
    )

    print_mask_stats(data_dict=data_dicts[gt_signal_out], mask_col=mask_col)

    if task == "imputation":
        # rename key to match the "2-field encoding" of the sources (if you start parsing these or something)
        data_dicts[gt_signal_out] = data_dicts.pop(gt_signal_out)

    return data_dicts, input_signal, data_dict

check_combination

check_combination(
    source_data: dict, source_name: str, split: str
) -> None

Validate that X, mask, and time arrays have consistent dimensions.

PARAMETER DESCRIPTION
source_data

Dictionary of source data.

TYPE: dict

source_name

Name of the source to check.

TYPE: str

split

Split name to check.

TYPE: str

RAISES DESCRIPTION
AssertionError

If array dimensions are inconsistent.

Source code in src/data_io/define_sources_for_flow.py
def check_combination(source_data: dict, source_name: str, split: str) -> None:
    """Validate that X, mask, and time arrays have consistent dimensions.

    Parameters
    ----------
    source_data : dict
        Dictionary of source data.
    source_name : str
        Name of the source to check.
    split : str
        Split name to check.

    Raises
    ------
    AssertionError
        If array dimensions are inconsistent.
    """
    assert (
        source_data[source_name]["df"][split]["data"]["X"].shape[0]
        == source_data[source_name]["df"][split]["data"]["mask"].shape[0]
    ), "X and mask have different number of samples"
    assert (
        source_data[source_name]["df"][split]["data"]["X"].shape[0]
        == source_data[source_name]["df"][split]["time"]["time"].shape[0]
    ), (
        "X and time have different number of time points, "
        "outlier detecion data had {} samples, and {} time points"
    ).format(
        source_data[source_name]["df"][split]["data"]["X"].shape[0],
        source_data[source_name]["df"][split]["time"]["time"].shape[0],
    )

check_gt_and_X

check_gt_and_X(
    source_data: dict, source_name: str, split: str
) -> None

Check if X and X_GT arrays are identical and warn if so.

PARAMETER DESCRIPTION
source_data

Dictionary of source data.

TYPE: dict

source_name

Name of the source to check.

TYPE: str

split

Split name for logging.

TYPE: str

Source code in src/data_io/define_sources_for_flow.py
def check_gt_and_X(source_data: dict, source_name: str, split: str) -> None:
    """Check if X and X_GT arrays are identical and warn if so.

    Parameters
    ----------
    source_data : dict
        Dictionary of source data.
    source_name : str
        Name of the source to check.
    split : str
        Split name for logging.
    """
    model_data = source_data[source_name]
    for split, split_dict in model_data["df"].items():
        data_dict = split_dict["data"]
        X = data_dict[
            "X"
        ]  # could be either "pupil_gt" (not recommended), or coming from the previous outlier detection
        # i.e. reconstructed "pupil_orig" possibly with a lot of glitch still around
        X_GT = data_dict["X_GT"]  # from "pupil_gt"
        assert isinstance(X, np.ndarray), "X should be a numpy array"
        assert isinstance(X_GT, np.ndarray), "X_GT should be a numpy array"
        if np.all(X == X_GT):
            logger.warning(
                "Your X and Ground truth seem to be same, is this how you want things to be?"
            )
            logger.warning(f"source_name={source_name}, split={split}")

add_CI_to_data_dicts

add_CI_to_data_dicts(data_dicts: dict) -> dict

Add placeholder confidence interval arrays to data dictionaries.

The featurization script assumes CI arrays exist, so this adds NaN-filled arrays where they are missing.

PARAMETER DESCRIPTION
data_dicts

Data dictionaries to add CI arrays to.

TYPE: dict

RETURNS DESCRIPTION
dict

Updated data dictionaries with CI_pos and CI_neg arrays.

Source code in src/data_io/define_sources_for_flow.py
def add_CI_to_data_dicts(data_dicts: dict) -> dict:
    """Add placeholder confidence interval arrays to data dictionaries.

    The featurization script assumes CI arrays exist, so this adds NaN-filled
    arrays where they are missing.

    Parameters
    ----------
    data_dicts : dict
        Data dictionaries to add CI arrays to.

    Returns
    -------
    dict
        Updated data dictionaries with CI_pos and CI_neg arrays.
    """
    # Featurization script assumes that you have something here
    logger.info("Adding CI to data dicts")
    for pupil_col in data_dicts.keys():
        for split in data_dicts[pupil_col]["df"].keys():
            array_tmp = np.ones_like(data_dicts[pupil_col]["df"][split]["data"]["X"])
            array_tmp[:] = np.nan
            if "CI_pos" not in data_dicts[pupil_col]["df"][split]["data"].keys():
                data_dicts[pupil_col]["df"][split]["data"]["CI_pos"] = array_tmp
            if "CI_neg" not in data_dicts[pupil_col]["df"][split]["data"].keys():
                data_dicts[pupil_col]["df"][split]["data"]["CI_neg"] = array_tmp

    return data_dicts

add_mlflow_dict_to_sources

add_mlflow_dict_to_sources(
    sources: dict, mlflow_dict: Optional[dict]
) -> dict

Add MLflow run information to each source dictionary.

PARAMETER DESCRIPTION
sources

Dictionary of source data.

TYPE: dict

mlflow_dict

Dictionary mapping source names to MLflow run info.

TYPE: dict or None

RETURNS DESCRIPTION
dict

Updated sources dictionary with mlflow key added to each source.

Source code in src/data_io/define_sources_for_flow.py
def add_mlflow_dict_to_sources(sources: dict, mlflow_dict: Optional[dict]) -> dict:
    """Add MLflow run information to each source dictionary.

    Parameters
    ----------
    sources : dict
        Dictionary of source data.
    mlflow_dict : dict or None
        Dictionary mapping source names to MLflow run info.

    Returns
    -------
    dict
        Updated sources dictionary with mlflow key added to each source.
    """
    for source_name in sources.keys():
        if mlflow_dict is not None:
            if source_name in mlflow_dict.keys():
                sources[source_name]["mlflow"] = mlflow_dict[source_name]
            else:
                # e.g. the "pupil_gt" do not come from the MLflow, but from the DuckDB
                sources[source_name]["mlflow"] = None
        else:
            sources[source_name]["mlflow"] = None
    return sources

check_sources

check_sources(sources: dict) -> None

Quality check all source data for NaN values.

PARAMETER DESCRIPTION
sources

Dictionary of source data to check.

TYPE: dict

RAISES DESCRIPTION
ValueError

If any source contains NaN values in its data arrays.

Source code in src/data_io/define_sources_for_flow.py
def check_sources(sources: dict) -> None:
    """Quality check all source data for NaN values.

    Parameters
    ----------
    sources : dict
        Dictionary of source data to check.

    Raises
    ------
    ValueError
        If any source contains NaN values in its data arrays.
    """
    logger.info("Checking quality (QA) of the sources")
    for source in sources.keys():
        for split in sources[source]["df"].keys():
            for name, data_array in sources[source]["df"][split]["data"].items():
                no_nan = np.isnan(data_array).sum()
                if no_nan > 0:
                    logger.error(
                        f"source_name={source}, split={split}, var_name={name}, no_nan={no_nan}"
                    )
                    raise ValueError(
                        f"source_name={source}, split={split}, var_name={name}, no_nan={no_nan}"
                    )

combine_source_with_data_dicts

combine_source_with_data_dicts(
    source_data: Optional[dict],
    data_dicts_for_source: dict,
    mlflow_dict: Optional[dict],
    cfg: DictConfig,
    task: str,
    input_signal: str,
    data_dict: dict,
) -> dict

Combine MLflow source data with the base data dictionary template.

Merges reconstruction/mask arrays from MLflow runs with the full data structure (time, metadata, etc.) from the DuckDB import.

PARAMETER DESCRIPTION
source_data

Dictionary of source data from MLflow runs.

TYPE: dict or None

data_dicts_for_source

Base data dictionary template from DuckDB import.

TYPE: dict

mlflow_dict

Dictionary mapping source names to MLflow run info.

TYPE: dict

cfg

Configuration dictionary.

TYPE: DictConfig

task

Task name ("outlier_detection" or "imputation").

TYPE: str

input_signal

Column name for input data when no reconstruction available.

TYPE: str

data_dict

Raw imported data dictionary for fallback values.

TYPE: dict

RETURNS DESCRIPTION
dict

Combined sources dictionary with full data structure for each source.

Notes

Expected data_dict_template structure: df: dict train: dict time: dict data: dict labels: dict light: dict metadata: dict test: dict same as train preprocess: dict standardization: dict

Source code in src/data_io/define_sources_for_flow.py
def combine_source_with_data_dicts(
    source_data: Optional[dict],
    data_dicts_for_source: dict,
    mlflow_dict: Optional[dict],
    cfg: DictConfig,
    task: str,
    input_signal: str,
    data_dict: dict,
) -> dict:
    """Combine MLflow source data with the base data dictionary template.

    Merges reconstruction/mask arrays from MLflow runs with the full data
    structure (time, metadata, etc.) from the DuckDB import.

    Parameters
    ----------
    source_data : dict or None
        Dictionary of source data from MLflow runs.
    data_dicts_for_source : dict
        Base data dictionary template from DuckDB import.
    mlflow_dict : dict
        Dictionary mapping source names to MLflow run info.
    cfg : DictConfig
        Configuration dictionary.
    task : str
        Task name ("outlier_detection" or "imputation").
    input_signal : str
        Column name for input data when no reconstruction available.
    data_dict : dict
        Raw imported data dictionary for fallback values.

    Returns
    -------
    dict
        Combined sources dictionary with full data structure for each source.

    Notes
    -----
    Expected data_dict_template structure:
        df: dict
            train: dict
                time: dict
                data: dict
                labels: dict
                light: dict
                metadata: dict
            test: dict
                same as train
        preprocess: dict
            standardization: dict
    """
    # e.g. "pupil_gt" or "pupil_gt" and "pupil_raw"
    no_of_data_dicts_for_source = len(data_dicts_for_source)
    assert no_of_data_dicts_for_source == 1, (
        "Implement more input data sources if you feel like"
    )
    pupil_col = list(data_dicts_for_source.keys())[0]
    data_dict_template = data_dicts_for_source[pupil_col]

    # Change the 'data' of the template based on each source (outlier detection model)
    if source_data is not None:
        for i, (source_name, source_dict) in enumerate(
            tqdm(source_data.items(), desc="Getting data sources")
        ):
            logger.debug(
                "Picking data for source = {} (#{}/{})".format(
                    source_name, i + 1, len(source_data)
                )
            )
            dict_tmp = deepcopy(data_dict_template)
            for split, split_dict in source_dict.items():
                if "X" in split_dict:
                    # with non-ensemble methods, you typically have the input here as well
                    if np.all(np.isnan(split_dict["X"])):
                        # for sure you don't have any reconstruction done and you need to pick "orig_data
                        # This is true from non-reconstructing methods like LOF, OneClassSVM, Prophet, etc.
                        # We don't have any prediction so we just use the original data, that you can then mask
                        # with the pred_mask returned by these "simple methods"
                        dict_tmp["df"][split]["data"]["X"] = data_dict["df"][split][
                            "data"
                        ][input_signal]
                        if split == list(source_dict.keys())[0]:
                            logger.debug(
                                "No reconstruction from {}, "
                                'using "{}" pupil column as the original data'.format(
                                    source_name, input_signal
                                )
                            )
                        # import matplotlib.pyplot as plt
                        # plt.plot(dict_tmp["df"][split]["data"]["X"][0,:])
                        # plt.show()
                    else:
                        if not cfg["IMPUTATION_TRAINING"]["use_orig_data"]:
                            # This would come from the outlier detection model, e.g. MOMENT/TimesNet reconstructs
                            # the data (reconstructs from the "outlier_test/train" which then are the "pupil_raw")
                            # So setting "use_orig_data = False" you are having realistic evaluation of processing steps
                            dict_tmp["df"][split]["data"]["X"] = split_dict["X"]
                        else:
                            # Otherwise, use the data vector from the original data (DuckDB)
                            # i.e. "pupil_gt" that you mask with the outlier detection results
                            # You could use this in isolation for testing, but the "pupil_gt" is human-annotated ground truth
                            # that you won't be having access in "real-world" where you want an automatic algorithm and don't
                            # necessarily have the time for proofreading
                            if not np.all(np.isnan(split_dict["X"])):
                                logger.warning(
                                    "You are using the original non-reconstructed data "
                                    "even though this is not coming from a simple method (i.e. all values NaN)"
                                )

                    dict_tmp["df"][split]["data"]["mask"] = split_dict["mask"].astype(
                        int
                    )
                    if task == "imputation":
                        dict_tmp["df"][split]["data"]["CI_pos"] = split_dict["CI_pos"]
                        dict_tmp["df"][split]["data"]["CI_neg"] = split_dict["CI_neg"]
                else:
                    # only update the mask with the ensemble methods
                    dict_tmp["df"][split]["data"]["mask"] = split_dict["mask"].astype(
                        int
                    )

            source_data[source_name] = dict_tmp
            check_combination(source_data, source_name, split)
            check_gt_and_X(source_data, source_name, split)

    else:
        logger.debug("No source data, using only the original data as the source")

    if task == "imputation":
        data_dicts_for_source = add_CI_to_data_dicts(data_dicts=data_dicts_for_source)

    if source_data is not None:
        # Sort the sources
        source_data = dict(sorted(source_data.items()))
        sources = {**data_dicts_for_source, **source_data}
    else:
        sources = data_dicts_for_source

    # add the MLflow dict to the sources, if you need to trace later where the source was from
    sources = add_mlflow_dict_to_sources(sources, mlflow_dict)

    # quality checking
    # check_sources(sources)

    return sources

define_sources_for_flow

define_sources_for_flow(
    prev_experiment_name: str,
    cfg: DictConfig,
    task: str = "outlier_detection",
) -> dict

Define all data sources for a processing flow from previous MLflow experiments.

Main entry point for loading source data. Combines: 1. Best runs from the previous MLflow experiment 2. Original ground truth data from DuckDB

PARAMETER DESCRIPTION
prev_experiment_name

Name of the previous MLflow experiment to get runs from.

TYPE: str

cfg

Configuration dictionary.

TYPE: DictConfig

task

Task name for determining data source type, by default "outlier_detection".

TYPE: str DEFAULT: 'outlier_detection'

RETURNS DESCRIPTION
dict

Dictionary of all source data, including both MLflow and ground truth sources.

Source code in src/data_io/define_sources_for_flow.py
def define_sources_for_flow(
    prev_experiment_name: str, cfg: DictConfig, task: str = "outlier_detection"
) -> dict:
    """Define all data sources for a processing flow from previous MLflow experiments.

    Main entry point for loading source data. Combines:
    1. Best runs from the previous MLflow experiment
    2. Original ground truth data from DuckDB

    Parameters
    ----------
    prev_experiment_name : str
        Name of the previous MLflow experiment to get runs from.
    cfg : DictConfig
        Configuration dictionary.
    task : str, optional
        Task name for determining data source type, by default "outlier_detection".

    Returns
    -------
    dict
        Dictionary of all source data, including both MLflow and ground truth sources.
    """
    logger.debug("Defining sources for the flow = {}".format(task))

    # Get the best runs from the previous experiment
    mlflow_runs = get_previous_best_mlflow_runs(
        experiment_name=prev_experiment_name, cfg=cfg, task=task
    )

    # Get the data from the mlflow runs to be used in the flow (as in the imputation)
    if mlflow_runs is not None:
        source_data, mlflow_dict = get_source_data(mlflow_runs, cfg, task=task)
    else:
        source_data, mlflow_dict = None, None

    # Get the original ("ground truth") data annotated by the human
    # That you will use as the BASELINE against the various outlier detection
    # and imputation methods
    data_dicts_for_source, input_signal, data_dict = import_data_for_flow(cfg, task)

    # Now the MLflow source data (outputs from the outlier detection does not have the extra keys from
    # the imported data, so we copy those to source daata so that all the dictionaries have similar contents
    logger.info("Combine sources with data dicts")
    sources = combine_source_with_data_dicts(
        source_data,
        data_dicts_for_source,
        mlflow_dict,
        cfg,
        task,
        input_signal,
        data_dict,
    )

    logger.info("Total of {} sources for processing".format(len(sources)))

    return sources

PyTorch Data

torch_data

trim_data

trim_data(x)

Trim PLR data to remove edge artifacts.

Removes the first 3 and last 2 timepoints from PLR recordings to get a clean 1976-sample signal.

PARAMETER DESCRIPTION
x

Input array with shape (n_subjects, n_timepoints).

TYPE: ndarray

RETURNS DESCRIPTION
ndarray

Trimmed array with shape (n_subjects, 1976).

Source code in src/data_io/torch_data.py
def trim_data(x):
    """Trim PLR data to remove edge artifacts.

    Removes the first 3 and last 2 timepoints from PLR recordings
    to get a clean 1976-sample signal.

    Parameters
    ----------
    x : np.ndarray
        Input array with shape (n_subjects, n_timepoints).

    Returns
    -------
    np.ndarray
        Trimmed array with shape (n_subjects, 1976).
    """
    return x[:, 3:1979]  # 1976

nan_padding

nan_padding(x, n: int = 1981)

Pad trimmed data back to original length with NaN values.

Inverse operation of trim_data, fills edge positions with NaN.

PARAMETER DESCRIPTION
x

Trimmed input array with shape (n_subjects, 1976).

TYPE: ndarray

n

Target output length, by default 1981.

TYPE: int DEFAULT: 1981

RETURNS DESCRIPTION
ndarray

Padded array with shape (n_subjects, n) with NaN at edges.

Source code in src/data_io/torch_data.py
def nan_padding(x, n: int = 1981):
    """Pad trimmed data back to original length with NaN values.

    Inverse operation of trim_data, fills edge positions with NaN.

    Parameters
    ----------
    x : np.ndarray
        Trimmed input array with shape (n_subjects, 1976).
    n : int, optional
        Target output length, by default 1981.

    Returns
    -------
    np.ndarray
        Padded array with shape (n_subjects, n) with NaN at edges.
    """
    x_out = np.zeros((x.shape[0], n))
    x_out[:, :] = np.nan
    x_out[:, 3:1979] = x
    return x_out

get_outlier_data

get_outlier_data(data_dict, split)

Extract outlier detection data and masks from data dictionary.

Retrieves the imputed original pupil data and corresponding outlier mask for evaluating outlier detection algorithms.

PARAMETER DESCRIPTION
data_dict

Hierarchical data dictionary with split keys.

TYPE: dict

split

Data split to extract ("train" or "test").

TYPE: str

RETURNS DESCRIPTION
tuple of np.ndarray

X: pupil data array (n_subjects, n_timepoints) mask: outlier mask array where 1=outlier, 0=normal

RAISES DESCRIPTION
AssertionError

If no outliers are labeled in the mask.

Source code in src/data_io/torch_data.py
def get_outlier_data(data_dict, split):
    """Extract outlier detection data and masks from data dictionary.

    Retrieves the imputed original pupil data and corresponding outlier
    mask for evaluating outlier detection algorithms.

    Parameters
    ----------
    data_dict : dict
        Hierarchical data dictionary with split keys.
    split : str
        Data split to extract ("train" or "test").

    Returns
    -------
    tuple of np.ndarray
        X: pupil data array (n_subjects, n_timepoints)
        mask: outlier mask array where 1=outlier, 0=normal

    Raises
    ------
    AssertionError
        If no outliers are labeled in the mask.
    """
    # I.e. you use this to evaluate whether your network can remove outliers
    X = data_dict[split]["data"]["pupil_orig_imputed"]
    # Mask gives 1 for outliers, 0 for normal values
    # as in the labels or "mask" in the Moment code
    mask = data_dict[split]["labels"]["outlier_mask"]
    outlier_ratio = mask.sum() / mask.size
    logger.debug(f"outlier_ratio: {outlier_ratio}")
    assert mask.sum() > 0, "No outliers labeled in the mask"

    return X, mask

pick_pupil_data_col

pick_pupil_data_col(train_on, data_dict, split)

Select appropriate pupil data column based on training configuration.

Retrieves the correct data column (ground truth, raw imputed, or original imputed) and corresponding mask based on the train_on parameter.

PARAMETER DESCRIPTION
train_on

Data column to use: "pupil_gt", "pupil_raw_imputed", or "pupil_orig_imputed".

TYPE: str

data_dict

Hierarchical data dictionary with split keys.

TYPE: dict

split

Data split to extract ("train" or "test").

TYPE: str

RETURNS DESCRIPTION
tuple of np.ndarray

X: pupil data array (n_subjects, n_timepoints) mask: corresponding mask array (zeros for gt/raw, outlier_mask for orig)

RAISES DESCRIPTION
ValueError

If train_on parameter is not recognized.

Source code in src/data_io/torch_data.py
def pick_pupil_data_col(train_on, data_dict, split):
    """Select appropriate pupil data column based on training configuration.

    Retrieves the correct data column (ground truth, raw imputed, or original
    imputed) and corresponding mask based on the train_on parameter.

    Parameters
    ----------
    train_on : str
        Data column to use: "pupil_gt", "pupil_raw_imputed", or "pupil_orig_imputed".
    data_dict : dict
        Hierarchical data dictionary with split keys.
    split : str
        Data split to extract ("train" or "test").

    Returns
    -------
    tuple of np.ndarray
        X: pupil data array (n_subjects, n_timepoints)
        mask: corresponding mask array (zeros for gt/raw, outlier_mask for orig)

    Raises
    ------
    ValueError
        If train_on parameter is not recognized.
    """
    if train_on == "pupil_gt":
        # This is the denoised (clean signal)
        X = data_dict[split]["data"]["pupil_gt"]
        mask: np.ndarray = np.zeros_like(X)
        logger.info('Picking "pupil_gt" data')
    elif train_on == "pupil_raw_imputed":
        # The raw data (with no outliers)
        X = data_dict[split]["data"]["pupil_raw_imputed"]
        mask: np.ndarray = np.zeros_like(X)
        logger.info('Picking "pupil_raw_data" data')
    elif train_on == "pupil_orig_imputed":
        # The original raw data that most of the other methods used to train
        X = data_dict[split]["data"]["pupil_orig_imputed"]
        mask = data_dict[split]["labels"]["outlier_mask"]
        logger.info('Picking "pupil_orig_data" data')
    else:
        logger.error("Unknown train_on = {}".format(train_on))
        raise ValueError("Unknown train_on = {}".format(train_on))
        # mask comes from the outlier dataloaders, we just reconstruct the clean signal

    return X, mask

dataset_outlier_detection_selector

dataset_outlier_detection_selector(
    detection_type: str,
    train_on: str,
    split: str,
    split_data: str,
    data_dict: dict,
)

Select data for outlier detection based on detection type and split.

Routes data selection based on whether fine-tuning or zero-shot detection is used, and whether outlier-specific splits are requested.

PARAMETER DESCRIPTION
detection_type

Detection approach: "fine-tune" or "zero-shot".

TYPE: str

train_on

Data column to use for training.

TYPE: str

split

Data split ("train" or "test").

TYPE: str

split_data

Specific split type (e.g., "outlier_train", "outlier_test").

TYPE: str

data_dict

Hierarchical data dictionary.

TYPE: dict

RETURNS DESCRIPTION
tuple of np.ndarray

X: data array and mask: outlier mask array.

RAISES DESCRIPTION
ValueError

If detection_type is not recognized.

AssertionError

If outlier split requested but mask has no outliers.

Source code in src/data_io/torch_data.py
def dataset_outlier_detection_selector(
    detection_type: str, train_on: str, split: str, split_data: str, data_dict: dict
):
    """Select data for outlier detection based on detection type and split.

    Routes data selection based on whether fine-tuning or zero-shot detection
    is used, and whether outlier-specific splits are requested.

    Parameters
    ----------
    detection_type : str
        Detection approach: "fine-tune" or "zero-shot".
    train_on : str
        Data column to use for training.
    split : str
        Data split ("train" or "test").
    split_data : str
        Specific split type (e.g., "outlier_train", "outlier_test").
    data_dict : dict
        Hierarchical data dictionary.

    Returns
    -------
    tuple of np.ndarray
        X: data array and mask: outlier mask array.

    Raises
    ------
    ValueError
        If detection_type is not recognized.
    AssertionError
        If outlier split requested but mask has no outliers.
    """
    if detection_type == "fine-tune":
        if "outlier" in split_data:
            # You only need the outlier split for outlier detection. When doing imputation training,
            # You train for the reconstruction, and mask out the missing time points, and check out
            # how well the "vanilla test" is reconstructed on missing mask points
            X, mask = get_outlier_data(data_dict, split)
            assert mask.sum() > 0, "No outliers labeled in the mask"
        else:
            X, mask = pick_pupil_data_col(train_on, data_dict, split)

    elif detection_type == "zero-shot":
        if "outlier" in split_data:
            X, mask = get_outlier_data(data_dict, split)
        else:
            X, mask = pick_pupil_data_col(train_on, data_dict, split)

    else:
        logger.error("Unknown detection type = {}".format(detection_type))
        raise ValueError("Unknown detection type = {}".format(detection_type))

    return X, mask

dataset_ts_cls_selector

dataset_ts_cls_selector(
    detection_type: str,
    train_on: str,
    split: str,
    data_dict: dict,
)

Select data for time series classification task.

Extracts features and class labels for binary classification, encoding string labels to integers.

PARAMETER DESCRIPTION
detection_type

Detection approach: "fine-tune" or "full-finetune".

TYPE: str

train_on

Data column to use (not used directly, for interface consistency).

TYPE: str

split

Data split ("train" or "test").

TYPE: str

data_dict

Hierarchical data dictionary.

TYPE: dict

RETURNS DESCRIPTION
tuple

X: feature array (n_subjects, n_features) labels: integer class labels (n_subjects,)

RAISES DESCRIPTION
ValueError

If detection_type is not recognized.

AssertionError

If number of unique classes is not 2, or if label count mismatches X.

Source code in src/data_io/torch_data.py
def dataset_ts_cls_selector(
    detection_type: str, train_on: str, split: str, data_dict: dict
):
    """Select data for time series classification task.

    Extracts features and class labels for binary classification,
    encoding string labels to integers.

    Parameters
    ----------
    detection_type : str
        Detection approach: "fine-tune" or "full-finetune".
    train_on : str
        Data column to use (not used directly, for interface consistency).
    split : str
        Data split ("train" or "test").
    data_dict : dict
        Hierarchical data dictionary.

    Returns
    -------
    tuple
        X: feature array (n_subjects, n_features)
        labels: integer class labels (n_subjects,)

    Raises
    ------
    ValueError
        If detection_type is not recognized.
    AssertionError
        If number of unique classes is not 2, or if label count mismatches X.
    """
    if detection_type == "fine-tune" or detection_type == "full-finetune":
        X = data_dict[split]["data"]["X"]
        labels = data_dict[split]["labels"]["class_label"][:, 0]
        assert len(np.unique(labels)) == 2, (
            "You have != 2 classes, unique classes = {}".format(np.unique(labels))
        )
        labels = encode_labels_to_integers(labels)
        assert len(labels) == X.shape[0], (
            "Labels and X must have the same number of samples"
        )
    else:
        logger.error("Unknown detection type = {}".format(detection_type))
        raise ValueError("Unknown detection type = {}".format(detection_type))

    return X, labels

dataset_imputation_selector

dataset_imputation_selector(
    detection_type: str,
    train_on: str,
    split: str,
    data_dict: dict,
)

Select data for imputation task.

Extracts data and missingness mask for training imputation models to reconstruct missing values.

PARAMETER DESCRIPTION
detection_type

Detection approach: "fine-tune" or "zero-shot".

TYPE: str

train_on

Data column: "pupil_gt" or "pupil_raw_imputed".

TYPE: str

split

Data split ("train" or "test").

TYPE: str

data_dict

Hierarchical data dictionary.

TYPE: dict

RETURNS DESCRIPTION
tuple of np.ndarray

X: data array and mask: missingness mask.

RAISES DESCRIPTION
ValueError

If detection_type or train_on is not recognized.

NotImplementedError

If train_on is "pupil_raw_imputed" (not yet implemented).

AssertionError

If mask has no missing points.

Source code in src/data_io/torch_data.py
def dataset_imputation_selector(
    detection_type: str, train_on: str, split: str, data_dict: dict
):
    """Select data for imputation task.

    Extracts data and missingness mask for training imputation models
    to reconstruct missing values.

    Parameters
    ----------
    detection_type : str
        Detection approach: "fine-tune" or "zero-shot".
    train_on : str
        Data column: "pupil_gt" or "pupil_raw_imputed".
    split : str
        Data split ("train" or "test").
    data_dict : dict
        Hierarchical data dictionary.

    Returns
    -------
    tuple of np.ndarray
        X: data array and mask: missingness mask.

    Raises
    ------
    ValueError
        If detection_type or train_on is not recognized.
    NotImplementedError
        If train_on is "pupil_raw_imputed" (not yet implemented).
    AssertionError
        If mask has no missing points.
    """
    if detection_type == "fine-tune":
        if train_on == "pupil_gt":
            # This is the denoised (clean signal)
            X = data_dict[split]["data"]["X"]
        elif train_on == "pupil_raw_imputed":
            # X and X_GT are the same if you decided to train on pupil_gt, and different with something
            # else, and with "pupil_raw_imputed", the X would already be this? Debug more if you actually start
            # using these
            logger.error("Not yet implemented")
            raise NotImplementedError("Not yet implemented")
        else:
            # Implement here if you use pupil_raw or pupil_orig, or should not need implementation?
            # just the correct column options?
            logger.error("Unknown train_on = {}".format(train_on))
            raise ValueError("Unknown train_on = {}".format(train_on))
        mask = data_dict[split]["data"]["mask"]

    elif detection_type == "zero-shot":
        X = data_dict[split]["data"]["X"]
        mask = data_dict[split]["data"]["mask"]

    else:
        logger.error("Unknown detection type = {}".format(detection_type))
        raise ValueError("Unknown detection type = {}".format(detection_type))

    # In the imputation task, we don't have at the moment any outlier_x split so the mask should contain
    # some missingness points, otherwise we can't compute any metrics for the imputation performance
    assert mask.sum() > 0, "No outliers labeled in the mask"

    return X, mask

dataset_data_array_selector

dataset_data_array_selector(
    split_data,
    task,
    data_dict,
    detection_type: str = "zero-shot",
    train_on: str = "gt",
)

Main dispatcher for selecting data arrays based on task and split.

Routes data extraction to appropriate task-specific selector based on the task type (outlier detection, imputation, or classification).

PARAMETER DESCRIPTION
split_data

Split specification: "train", "test", "outlier_train", "outlier_test".

TYPE: str

task

Task type: "outlier_detection", "imputation", or "ts_cls".

TYPE: str

data_dict

Hierarchical data dictionary.

TYPE: dict

detection_type

Detection approach, by default "zero-shot".

TYPE: str DEFAULT: 'zero-shot'

train_on

Data column to use, by default "gt".

TYPE: str DEFAULT: 'gt'

RETURNS DESCRIPTION
tuple of np.ndarray

X: data array and mask/label array depending on task.

RAISES DESCRIPTION
ValueError

If split_data or task is not recognized.

AssertionError

If X contains NaN values or shape mismatch with mask.

Source code in src/data_io/torch_data.py
def dataset_data_array_selector(
    split_data,
    task,
    data_dict,
    detection_type: str = "zero-shot",
    train_on: str = "gt",
):
    """Main dispatcher for selecting data arrays based on task and split.

    Routes data extraction to appropriate task-specific selector based on
    the task type (outlier detection, imputation, or classification).

    Parameters
    ----------
    split_data : str
        Split specification: "train", "test", "outlier_train", "outlier_test".
    task : str
        Task type: "outlier_detection", "imputation", or "ts_cls".
    data_dict : dict
        Hierarchical data dictionary.
    detection_type : str, optional
        Detection approach, by default "zero-shot".
    train_on : str, optional
        Data column to use, by default "gt".

    Returns
    -------
    tuple of np.ndarray
        X: data array and mask/label array depending on task.

    Raises
    ------
    ValueError
        If split_data or task is not recognized.
    AssertionError
        If X contains NaN values or shape mismatch with mask.
    """
    if split_data == "outlier_test":
        # "outlier" is a "virtual split" that takes data from different processing level,
        # but we want to use the "test" or "val" split now as it is not used for training directly
        # (well now it has been used to select the best model, so there is no real extensive validation atm)
        split = "test"
    elif split_data == "outlier_train":
        split = "train"  # same as with the outlier_test
    else:
        # otherwise these are the same
        split = split_data

    valid_splits = ["train", "test", "outlier_test", "outlier_train"]
    if split in valid_splits:
        if task == "outlier_detection":
            X, mask = dataset_outlier_detection_selector(
                detection_type=detection_type,
                train_on=train_on,
                split=split,
                split_data=split_data,
                data_dict=data_dict,
            )
        elif task == "imputation":
            X, mask = dataset_imputation_selector(
                detection_type=detection_type,
                train_on=train_on,
                split=split,
                data_dict=data_dict,
            )
        elif task == "ts_cls":
            # well this is X, and label
            X, label = dataset_ts_cls_selector(
                detection_type=detection_type,
                train_on=train_on,
                split=split,
                data_dict=data_dict,
            )
            assert np.isnan(X).sum() == 0, "Missing values in the data"
            return X, label
        else:
            logger.error("Unkown task = {}".format(task))
            raise ValueError("Unkown task = {}".format(task))

        # Depending on the algorithm, you might want to use the "pupil_orig" or "pupil_raw" data as
        # well without the imputation, if NaNs or missing problems in general (irregular sampling) is okay
        assert isinstance(X, np.ndarray), "X must be a Numpy array, not {}".format(
            type(X)
        )
        assert np.isnan(X).sum() == 0, "Missing values in the data"
    else:
        logger.error("Unrecognized split = {}".format(split))
        raise ValueError("Unrecognized split = {}".format(split))

    assert X.shape[0] == mask.shape[0], (
        "X and mask must have the same number of samples"
    )

    return X, mask

pick_splits_from_data_dict_to_ts

pick_splits_from_data_dict_to_ts(
    data_dict_df, model_cfg, train_on
)

Extract train/test splits from data dictionary for time series export.

Organizes data into split-specific dictionaries with X (data), y (outlier mask), and time arrays for downstream processing.

PARAMETER DESCRIPTION
data_dict_df

Hierarchical data dictionary with split keys containing data arrays.

TYPE: dict

model_cfg

Model configuration (not currently used but kept for interface).

TYPE: DictConfig

train_on

Data column to extract (e.g., "pupil_orig_imputed").

TYPE: str

RETURNS DESCRIPTION
dict

Dictionary with "train" and "test" keys, each containing: - X: data array - y: outlier mask - time: time vector

References
  • https://github.com/eBay/RANSynCoders/blob/main/example.ipynb
Source code in src/data_io/torch_data.py
def pick_splits_from_data_dict_to_ts(data_dict_df, model_cfg, train_on):
    """Extract train/test splits from data dictionary for time series export.

    Organizes data into split-specific dictionaries with X (data), y (outlier mask),
    and time arrays for downstream processing.

    Parameters
    ----------
    data_dict_df : dict
        Hierarchical data dictionary with split keys containing data arrays.
    model_cfg : DictConfig
        Model configuration (not currently used but kept for interface).
    train_on : str
        Data column to extract (e.g., "pupil_orig_imputed").

    Returns
    -------
    dict
        Dictionary with "train" and "test" keys, each containing:
        - X: data array
        - y: outlier mask
        - time: time vector

    References
    ----------
    - https://github.com/eBay/RANSynCoders/blob/main/example.ipynb
    """
    data_splits = {}
    # train_on = "pupil_orig_imputed"  # model_cfg["MODEL"]["train_on"]
    for split in data_dict_df.keys():
        data_splits[split] = {}
        data_splits[split]["X"] = data_dict_df[split]["data"][train_on]
        data_splits[split]["y"] = data_dict_df[split]["labels"]["outlier_mask"]
        data_splits[split]["time"] = data_dict_df[split]["time"]["time"]

    return data_splits

create_dataset_from_numpy

create_dataset_from_numpy(
    data_dict_df: dict,
    dataset_cfg: DictConfig,
    model_cfg: DictConfig,
    split: str,
    task: str = "imputation",
    model_name: str = None,
)

Create a PyTorch TensorDataset from numpy arrays.

Converts data dictionary arrays to PyTorch tensors and optionally applies trimming for foundation model compatibility.

PARAMETER DESCRIPTION
data_dict_df

Hierarchical data dictionary containing numpy arrays.

TYPE: dict

dataset_cfg

Dataset configuration with trim_to_size and other settings.

TYPE: DictConfig

model_cfg

Model configuration with MODEL settings (detection_type, train_on).

TYPE: DictConfig

split

Data split to use ("train", "test", "outlier_train", "outlier_test").

TYPE: str

task

Task type for data selection, by default "imputation".

TYPE: str DEFAULT: 'imputation'

model_name

Model name for trim configuration, by default None.

TYPE: str DEFAULT: None

RETURNS DESCRIPTION
TensorDataset

PyTorch dataset with (X, mask, input_mask) tensors.

Source code in src/data_io/torch_data.py
def create_dataset_from_numpy(
    data_dict_df: dict,
    dataset_cfg: DictConfig,
    model_cfg: DictConfig,
    split: str,
    task: str = "imputation",
    model_name: str = None,
):
    """Create a PyTorch TensorDataset from numpy arrays.

    Converts data dictionary arrays to PyTorch tensors and optionally
    applies trimming for foundation model compatibility.

    Parameters
    ----------
    data_dict_df : dict
        Hierarchical data dictionary containing numpy arrays.
    dataset_cfg : DictConfig
        Dataset configuration with trim_to_size and other settings.
    model_cfg : DictConfig
        Model configuration with MODEL settings (detection_type, train_on).
    split : str
        Data split to use ("train", "test", "outlier_train", "outlier_test").
    task : str, optional
        Task type for data selection, by default "imputation".
    model_name : str, optional
        Model name for trim configuration, by default None.

    Returns
    -------
    TensorDataset
        PyTorch dataset with (X, mask, input_mask) tensors.
    """
    # Pick the needed arrays from the model artifacts dictionary
    X, mask = dataset_data_array_selector(
        split_data=split,
        task=task,
        data_dict=data_dict_df,
        detection_type=model_cfg["MODEL"]["detection_type"],
        train_on=model_cfg["MODEL"]["train_on"],
    )

    if dataset_cfg["trim_to_size"] is not None:
        # This takes fixed windows, the Class-based sliding window sampler might be better in the future
        # if MOMENT shows some promise
        X, mask, input_mask = transform_data_for_momentfm(
            X, mask, dataset_cfg, model_name
        )
    else:
        input_mask = np.zeros((X.shape[0], X.shape[1]))
        input_mask[~np.isnan(X)] = 1

    tensor_X = torch.Tensor(X)  # transform to torch tensor
    tensor_mask = torch.Tensor(mask)
    tensor_input_mask = torch.Tensor(input_mask)
    dataset = TensorDataset(tensor_X, tensor_mask, tensor_input_mask)

    return dataset