Skip to content

orchestration

Pipeline orchestration and hyperparameter management.

Overview

This module handles:

  • Prefect flow orchestration
  • Hyperparameter sweeps
  • Debug utilities

prefect_utils

pre_flow_prefect_checks

pre_flow_prefect_checks(prefect_cfg: DictConfig)

Perform pre-flight checks before running a Prefect flow.

Logs the Hydra output directory, optionally starts the Prefect server, and checks CUDA availability with appropriate warnings.

PARAMETER DESCRIPTION
prefect_cfg

Prefect configuration with SERVER.autostart setting.

TYPE: DictConfig

Source code in src/orchestration/prefect_utils.py
def pre_flow_prefect_checks(prefect_cfg: DictConfig):
    """Perform pre-flight checks before running a Prefect flow.

    Logs the Hydra output directory, optionally starts the Prefect server,
    and checks CUDA availability with appropriate warnings.

    Parameters
    ----------
    prefect_cfg : DictConfig
        Prefect configuration with SERVER.autostart setting.
    """
    logger.info(
        'Hydra output directory = "{}"'.format(
            hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
        )
    )

    if prefect_cfg["SERVER"]["autostart"]:
        pre_check_server()
    # Prefect results/artifacts?
    # https://docs.prefect.io/3.0/develop/results

    # Check that CUDA is available
    if torch.cuda.is_available():
        logger.info("CUDA is available")
    else:
        logger.warning("--")
        logger.warning("-----")
        logger.warning("----------")
        logger.warning("CUDA is not available! You will be training on  (Takes time!)")
        logger.warning("----------")
        logger.warning("-----")
        logger.warning("--")

pre_check_server

pre_check_server()

Check and start the Prefect server if not running.

Attempts to start the Prefect server and logs status. If port 4200 is already in use, assumes server is running.

Notes

Dashboard is available at http://127.0.0.1:4200/dashboard when running.

Source code in src/orchestration/prefect_utils.py
def pre_check_server():
    """Check and start the Prefect server if not running.

    Attempts to start the Prefect server and logs status. If port 4200
    is already in use, assumes server is running.

    Notes
    -----
    Dashboard is available at http://127.0.0.1:4200/dashboard when running.
    """
    # see https://orion-docs.prefect.io/latest/api-ref/prefect/cli/server/
    logger.debug("PREFECT SERVER AUTOSTART=True: Trying to autostart Prefect server")
    p = subprocess.Popen(
        ["nohup", "prefect", "server", "start"], stdout=subprocess.PIPE
    )
    out, err = p.communicate()  # TODO! Can jam here, why? add some timeout?
    # output = check_output(cmd, stderr=STDOUT, timeout=seconds)?
    # https://stackoverflow.com/a/12698328/6412152

    logger.debug(out)
    if err is not None:
        logger.error(err)

    if "Port 4200 is already in use" in str(out):
        logger.info("Prefect server running")
        # Dashboard by default is here http://127.0.0.1:4200/dashboard
        # TODO! display the URL to the user, with dynamic URL
        logger.info("Prefect dashboard is at http://127.0.0.1:4200/dashboard")
    else:
        logger.info("Prefect server was not running, and it was started!")

pre_check_workpool

pre_check_workpool()

Check and create a Prefect work pool if not exists.

Attempts to create a process-type work pool named 'my-work-pool'.

Notes

Work pools are used for distributed task execution in Prefect. See: https://docs.prefect.io/3.0/get-started/quickstart#create-a-work-pool

Source code in src/orchestration/prefect_utils.py
def pre_check_workpool():
    """Check and create a Prefect work pool if not exists.

    Attempts to create a process-type work pool named 'my-work-pool'.

    Notes
    -----
    Work pools are used for distributed task execution in Prefect.
    See: https://docs.prefect.io/3.0/get-started/quickstart#create-a-work-pool
    """
    # https://docs.prefect.io/3.0/get-started/quickstart#create-a-work-pool
    workpool_name = "my-work-pool"
    logger.debug(
        "PREFECT WORKPOOL AUTOSTART=True: Trying to autostart Prefect Work Pool"
    )
    # prefect work-pool create --type process my-work-pool
    p = subprocess.Popen(
        ["nohup", "prefect", "work-pool", "create", "--type", "process", workpool_name],
        stdout=subprocess.PIPE,
    )
    out, err = p.communicate()

    if "already exists" in str(out):
        logger.info('Prefect work pool "{}" already exists'.format(workpool_name))
    else:
        logger.info("TODO!")

post_flow_prefect_housekeeping

post_flow_prefect_housekeeping(prefect_cfg: DictConfig)

Perform cleanup tasks after a Prefect flow completes.

Placeholder for post-flow housekeeping such as stopping servers or cleaning up resources.

PARAMETER DESCRIPTION
prefect_cfg

Prefect configuration dictionary.

TYPE: DictConfig

Notes

Currently a stub - implementation pending.

Source code in src/orchestration/prefect_utils.py
def post_flow_prefect_housekeeping(prefect_cfg: DictConfig):
    """Perform cleanup tasks after a Prefect flow completes.

    Placeholder for post-flow housekeeping such as stopping servers
    or cleaning up resources.

    Parameters
    ----------
    prefect_cfg : DictConfig
        Prefect configuration dictionary.

    Notes
    -----
    Currently a stub - implementation pending.
    """
    logger.info("Prefect housekeeping Placeholder")

hyperparameter_sweep_utils

flatten_the_nested_dicts

flatten_the_nested_dicts(cfgs, delimiter='_')

Flatten nested configuration dictionaries to a single level.

Extracts configurations from the first model's nested structure.

PARAMETER DESCRIPTION
cfgs

Nested dictionary with model names as keys.

TYPE: dict

delimiter

Delimiter for flattened keys (currently unused). Default is '_'.

TYPE: str DEFAULT: '_'

RETURNS DESCRIPTION
dict

Flattened configuration dictionary.

Source code in src/orchestration/hyperparameter_sweep_utils.py
def flatten_the_nested_dicts(cfgs, delimiter="_"):
    """Flatten nested configuration dictionaries to a single level.

    Extracts configurations from the first model's nested structure.

    Parameters
    ----------
    cfgs : dict
        Nested dictionary with model names as keys.
    delimiter : str, optional
        Delimiter for flattened keys (currently unused). Default is '_'.

    Returns
    -------
    dict
        Flattened configuration dictionary.
    """
    cfgs_flat = cfgs[list(cfgs.keys())[0]].copy()
    logger.debug(
        "HYPERPARAMETER SEARCH | {} model architectures".format(len(cfgs.keys()))
    )
    logger.info(
        "HYPERPARAMETER SEARCH | {} | {} hyperparameter sets".format(
            list(cfgs.keys())[0], len(cfgs_flat.keys())
        )
    )

    return cfgs_flat

drop_other_models

drop_other_models(cfg_model, model, task: str)

Drop all models from the config except the one specified.

Creates a copy of the configuration with only the specified model, ensuring exactly one model per configuration for hyperparameter sweeps.

PARAMETER DESCRIPTION
cfg_model

Configuration containing multiple models.

TYPE: DictConfig

model

Name of the model to keep.

TYPE: str

task

Task type determining config key: 'outlier_detection', 'imputation', or 'classification'.

TYPE: str

RETURNS DESCRIPTION
DictConfig

Configuration with only the specified model.

RAISES DESCRIPTION
ValueError

If task is not recognized.

AssertionError

If resulting config doesn't have exactly one model.

Source code in src/orchestration/hyperparameter_sweep_utils.py
def drop_other_models(cfg_model, model, task: str):
    """Drop all models from the config except the one specified.

    Creates a copy of the configuration with only the specified model,
    ensuring exactly one model per configuration for hyperparameter sweeps.

    Parameters
    ----------
    cfg_model : DictConfig
        Configuration containing multiple models.
    model : str
        Name of the model to keep.
    task : str
        Task type determining config key: 'outlier_detection', 'imputation',
        or 'classification'.

    Returns
    -------
    DictConfig
        Configuration with only the specified model.

    Raises
    ------
    ValueError
        If task is not recognized.
    AssertionError
        If resulting config doesn't have exactly one model.
    """
    if task == "outlier_detection":
        model_cfg_key = "OUTLIER_MODELS"
    elif task == "imputation":
        model_cfg_key = "MODELS"
    elif task == "classification":
        model_cfg_key = "CLS_MODELS"
    else:
        logger.error(f"Task {task} not recognized")
        raise ValueError(f"Task {task} not recognized")

    cfg_out = cfg_model.copy()
    for model_name in cfg_model[model_cfg_key]:
        if model_name != model:
            logger.debug(f"dropping model {model_name} from the config")
            with open_dict(cfg_out):
                del cfg_out[model_cfg_key][model_name]

    # as you copy the config for each hyperparameter combo, only
    # one model per cfg is allowed
    assert len(cfg_out[model_cfg_key]) == 1, (
        "Only one model per cfg is allowed, you had {}".format(
            list(cfg_out[model_cfg_key].keys())
        )
    )

    return cfg_out

pick_cfg_key

pick_cfg_key(cfg: DictConfig, task: str)

Get the configuration key for models based on task type.

PARAMETER DESCRIPTION
cfg

Configuration dictionary (currently unused).

TYPE: DictConfig

task

Task type: 'outlier_detection', 'imputation', or 'classification'.

TYPE: str

RETURNS DESCRIPTION
str

Configuration key: 'OUTLIER_MODELS', 'MODELS', or 'CLS_MODELS'.

RAISES DESCRIPTION
ValueError

If task is not recognized.

Source code in src/orchestration/hyperparameter_sweep_utils.py
def pick_cfg_key(cfg: DictConfig, task: str):
    """Get the configuration key for models based on task type.

    Parameters
    ----------
    cfg : DictConfig
        Configuration dictionary (currently unused).
    task : str
        Task type: 'outlier_detection', 'imputation', or 'classification'.

    Returns
    -------
    str
        Configuration key: 'OUTLIER_MODELS', 'MODELS', or 'CLS_MODELS'.

    Raises
    ------
    ValueError
        If task is not recognized.
    """
    if task == "outlier_detection":
        cfg_key = "OUTLIER_MODELS"
    elif task == "imputation":
        cfg_key = "MODELS"
    elif task == "classification":
        cfg_key = "CLS_MODELS"
    else:
        raise ValueError(f"Task {task} not recognized")
    return cfg_key
define_hyperparameter_search(
    cfg: DictConfig, task: str, cfg_key: str
) -> dict

Define hyperparameter search configurations for all models.

Creates configuration variants for each model based on their defined search space (LIST or GRID method).

PARAMETER DESCRIPTION
cfg

Main configuration with model definitions.

TYPE: DictConfig

task

Task type for naming conventions.

TYPE: str

cfg_key

Key to access model configurations (e.g., 'MODELS').

TYPE: str

RETURNS DESCRIPTION
dict

Dictionary mapping configuration names to their configs.

RAISES DESCRIPTION
ValueError

If search method is unknown or parameters are missing.

Source code in src/orchestration/hyperparameter_sweep_utils.py
def define_hyperparameter_search(cfg: DictConfig, task: str, cfg_key: str) -> dict:
    """Define hyperparameter search configurations for all models.

    Creates configuration variants for each model based on their defined
    search space (LIST or GRID method).

    Parameters
    ----------
    cfg : DictConfig
        Main configuration with model definitions.
    task : str
        Task type for naming conventions.
    cfg_key : str
        Key to access model configurations (e.g., 'MODELS').

    Returns
    -------
    dict
        Dictionary mapping configuration names to their configs.

    Raises
    ------
    ValueError
        If search method is unknown or parameters are missing.
    """
    cfgs = {}
    no_models = len(cfg[cfg_key].keys())  # dict containing the Hydra DictConfigs
    logger.debug("HYPERPARAMETER SEARCH | {} model architectures".format(no_models))
    logger.debug(list(cfg[cfg_key].keys()))
    for model in cfg[cfg_key]:  # e.g. SAITS
        # Get the preferred search method for this particular model, "LIST" or "GRID",
        # hyperopt/optuna Bayesian optimization done "inside the cfg"
        cfgs[model] = cfg.copy()
        cfgs[model] = drop_other_models(cfg_model=cfgs[model], model=model, task=task)
        if "HYPERPARAMS" in cfg[cfg_key][model]:
            method = cfg[cfg_key][model]["HYPERPARAMS"]["method"]  # e.g. LIST
            if method in cfg[cfg_key][model]["SEARCH_SPACE"]:
                logger.info(
                    f"HYPERPARAMETER SEARCH | method {method} found for model {model}"
                )
                if method == "LIST":
                    cfgs[model] = define_list_hyperparam_combos(
                        cfg_model=cfgs[model],
                        param_dict=cfg[cfg_key][model]["SEARCH_SPACE"][method],
                        model=model,
                        task=task,
                        cfg_key=cfg_key,
                    )
                    # Harmonize maybe a bit this, as there is the extra "model_name" nesting
                    # cfgs = flatten_the_nested_dicts(cfgs)

                elif method == "GRID":
                    cfgs[model] = define_grid_hyperparam_combos(
                        cfg_model=cfgs[model],
                        param_dict=cfg[cfg_key][model]["SEARCH_SPACE"][method],
                        model=model,
                        task=task,
                        cfg_key=cfg_key,
                    )
                else:
                    logger.error(
                        'Method not recognized, must be "LIST" or "GRID", not {}'.format(
                            method
                        )
                    )
                    raise ValueError(
                        'Method not recognized, must be "LIST" or "GRID", not {}'.format(
                            method
                        )
                    )
            else:
                logger.error(
                    f"No params defined for search method {method} not for model {model}"
                )
                raise ValueError(
                    f"No params defined for search method {method} not for model {model}"
                )
        else:
            logger.info(
                f"No HYPERPARAMS key found for model {model}, training just with the default hyperparameters"
            )
            method = None

    if method == "LIST":
        cfgs_tmp = cfgs.copy()
        cfgs = {}
        # TODO! Add some checks if you don't have any hyperparams, you need to add extra key then, see TimesNet below
        for model in cfgs_tmp:
            # if model == 'TimesNet':  # quick fix
            #     logger.warning('Manual fix for TimesNet')
            #     cfgs['TimesNet'] = cfg
            for cfg_name, cfg in cfgs_tmp[model].items():
                cfgs[cfg_name] = cfg

    return cfgs

define_hyperparam_group

define_hyperparam_group(cfg: DictConfig, task: str) -> dict

Define a group of hyperparameter configurations for a task.

Main entry point for hyperparameter configuration generation. Either creates multiple configs for hyperparameter search or a single config for default parameters.

PARAMETER DESCRIPTION
cfg

Main configuration with EXPERIMENT.hyperparam_search flag.

TYPE: DictConfig

task

Task type: 'outlier_detection', 'imputation', or 'classification'.

TYPE: str

RETURNS DESCRIPTION
dict

Dictionary mapping run names to their configurations.

RAISES DESCRIPTION
ValueError

If multiple models defined without hyperparameter search enabled, or if task is not recognized.

NotImplementedError

If classification naming is requested (not yet implemented).

Source code in src/orchestration/hyperparameter_sweep_utils.py
def define_hyperparam_group(cfg: DictConfig, task: str) -> dict:
    """Define a group of hyperparameter configurations for a task.

    Main entry point for hyperparameter configuration generation.
    Either creates multiple configs for hyperparameter search or
    a single config for default parameters.

    Parameters
    ----------
    cfg : DictConfig
        Main configuration with EXPERIMENT.hyperparam_search flag.
    task : str
        Task type: 'outlier_detection', 'imputation', or 'classification'.

    Returns
    -------
    dict
        Dictionary mapping run names to their configurations.

    Raises
    ------
    ValueError
        If multiple models defined without hyperparameter search enabled,
        or if task is not recognized.
    NotImplementedError
        If classification naming is requested (not yet implemented).
    """
    cfg_key = pick_cfg_key(cfg=cfg, task=task)
    if cfg["EXPERIMENT"]["hyperparam_search"]:
        try:
            cfgs = define_hyperparameter_search(cfg, task=task, cfg_key=cfg_key)
            logger.info("HYPERPARAMETER SEARCH |total of {} configs".format(len(cfgs)))
        except Exception as e:
            logger.error(f"Error in defining the hyperparameter search experiment: {e}")
            raise e
    else:
        logger.info(
            "Skipping hyperparameter search, "
            "running just one config (single model+single set of hyperparameters)"
        )
        model_name = list(cfg[cfg_key].keys())
        if len(model_name) != 1:
            logger.error(
                "You have multiple models defined in your config, but hyperparameter search is disabled"
            )
            logger.error("Model names: {}".format(model_name))
            raise ValueError(
                "You have multiple models defined in your config, but hyperparameter search is disabled"
            )
        else:
            model_name = model_name[0]
            if task == "outlier_detection":
                cfg_name = update_outlier_detection_run_name(cfg)

            elif task == "imputation":
                cfg_name = create_name_from_model_params(
                    model_name=model_name, param_cfg=cfg[cfg_key][model_name]["MODEL"]
                )
            elif task == "classification":
                logger.error("Implement classification naming")
                raise NotImplementedError("Implement classification naming")
            else:
                logger.error(f"Task {task} not recognized")
                raise ValueError(f"Task {task} not recognized")
            logger.info(f"SINGLE RUN | Model name: {model_name}, cfg_name: {cfg_name}")

        cfgs = {cfg_name: cfg}

    # Flatten the nested dicts
    cfgs_out = {}
    for cfg_name, cfg in cfgs.items():
        # e.g. LOF, PROPHET, TimesNet
        nested_dict_first_key = list(cfg.keys())[0]
        if cfg_name in nested_dict_first_key:
            # If for example one of your models had hyperparams combinations
            # you would have all those nested inside the dictionary, and the first key
            # would contain the model name
            for hyperparam_key in cfg.keys():
                # e.g. LOF-1, LOF-2, LOF-3
                cfgs_out[hyperparam_key] = cfg[hyperparam_key]
        else:
            cfgs_out[cfg_name] = cfg

    return cfgs_out

hyperparamer_list_utils

clean_param

clean_param(param)

Convert snake_case parameter name to camelCase.

PARAMETER DESCRIPTION
param

Parameter name in snake_case (e.g., 'd_ffn').

TYPE: str

RETURNS DESCRIPTION
str

Parameter name in camelCase (e.g., 'dFfn').

Source code in src/orchestration/hyperparamer_list_utils.py
def clean_param(param):
    """Convert snake_case parameter name to camelCase.

    Parameters
    ----------
    param : str
        Parameter name in snake_case (e.g., 'd_ffn').

    Returns
    -------
    str
        Parameter name in camelCase (e.g., 'dFfn').
    """
    fields = param.split("_")
    if len(fields) > 1:
        # e.g. d_ffn -> dFfn
        for i, string in enumerate(fields):
            if i == 0:
                param = string
            else:
                param += string.title()
    return param

create_hyperparam_name

create_hyperparam_name(
    param: str,
    value_from_list,
    i: int = 0,
    j: int = 0,
    n_params: int = 1,
    value_key_delimiter: str = "",
    param_delimiter: str = "_",
)

Create a standardized name string for a hyperparameter value.

PARAMETER DESCRIPTION
param

Parameter name.

TYPE: str

value_from_list

Value of the parameter.

TYPE: Any

i

Index in the value list (unused). Default is 0.

TYPE: int DEFAULT: 0

j

Current parameter index. Default is 0.

TYPE: int DEFAULT: 0

n_params

Total number of parameters. Default is 1.

TYPE: int DEFAULT: 1

value_key_delimiter

Delimiter between param name and value. Default is ''.

TYPE: str DEFAULT: ''

param_delimiter

Delimiter between parameters. Default is '_'.

TYPE: str DEFAULT: '_'

RETURNS DESCRIPTION
str

Formatted parameter key string (e.g., 'dFfn128_').

RAISES DESCRIPTION
ValueError

If parameter cleaning fails.

Source code in src/orchestration/hyperparamer_list_utils.py
def create_hyperparam_name(
    param: str,
    value_from_list,
    i: int = 0,
    j: int = 0,
    n_params: int = 1,
    value_key_delimiter: str = "",
    param_delimiter: str = "_",
):
    """Create a standardized name string for a hyperparameter value.

    Parameters
    ----------
    param : str
        Parameter name.
    value_from_list : Any
        Value of the parameter.
    i : int, optional
        Index in the value list (unused). Default is 0.
    j : int, optional
        Current parameter index. Default is 0.
    n_params : int, optional
        Total number of parameters. Default is 1.
    value_key_delimiter : str, optional
        Delimiter between param name and value. Default is ''.
    param_delimiter : str, optional
        Delimiter between parameters. Default is '_'.

    Returns
    -------
    str
        Formatted parameter key string (e.g., 'dFfn128_').

    Raises
    ------
    ValueError
        If parameter cleaning fails.
    """
    try:
        param = clean_param(param)
    except Exception:
        logger.error(f"Error in cleaning the parameter {param}")
        raise ValueError(f"Error in cleaning the parameter {param}")
    param_key = param + value_key_delimiter + str(value_from_list)
    if j + 1 != n_params:
        param_key += param_delimiter
    return param_key

create_name_from_model_params

create_name_from_model_params(
    model_name: str, param_cfg: DictConfig
) -> str

Create a descriptive name from model name and parameters.

PARAMETER DESCRIPTION
model_name

Base model name (e.g., 'SAITS').

TYPE: str

param_cfg

Model parameter configuration.

TYPE: DictConfig

RETURNS DESCRIPTION
str

Combined name with model and parameters (e.g., 'SAITS_dFfn128_nLayers2').

Source code in src/orchestration/hyperparamer_list_utils.py
def create_name_from_model_params(model_name: str, param_cfg: DictConfig) -> str:
    """Create a descriptive name from model name and parameters.

    Parameters
    ----------
    model_name : str
        Base model name (e.g., 'SAITS').
    param_cfg : DictConfig
        Model parameter configuration.

    Returns
    -------
    str
        Combined name with model and parameters (e.g., 'SAITS_dFfn128_nLayers2').
    """
    cfg_name = f"{model_name}_"
    for j, (param, value) in enumerate(param_cfg.items()):
        string = f"{create_hyperparam_name(param, value, j=j, n_params=len(param_cfg.keys()))}"
        cfg_name += string
    return cfg_name

define_list_hyperparam_combos

define_list_hyperparam_combos(
    cfg_model: DictConfig,
    param_dict: DictConfig,
    model: str,
    task: str,
    cfg_key: str,
) -> dict

Generate configurations from parallel parameter lists.

Creates one configuration per index across all parameter lists, where all lists must have the same length.

PARAMETER DESCRIPTION
cfg_model

Base model configuration.

TYPE: DictConfig

param_dict

Dictionary mapping parameter names to value lists.

TYPE: DictConfig

model

Model name.

TYPE: str

task

Task type for naming.

TYPE: str

cfg_key

Configuration key for model access.

TYPE: str

RETURNS DESCRIPTION
dict

Dictionary mapping configuration names to configs.

RAISES DESCRIPTION
ValueError

If parameter not found in model config or lists have different lengths.

Source code in src/orchestration/hyperparamer_list_utils.py
def define_list_hyperparam_combos(
    cfg_model: DictConfig, param_dict: DictConfig, model: str, task: str, cfg_key: str
) -> dict:
    """Generate configurations from parallel parameter lists.

    Creates one configuration per index across all parameter lists,
    where all lists must have the same length.

    Parameters
    ----------
    cfg_model : DictConfig
        Base model configuration.
    param_dict : DictConfig
        Dictionary mapping parameter names to value lists.
    model : str
        Model name.
    task : str
        Task type for naming.
    cfg_key : str
        Configuration key for model access.

    Returns
    -------
    dict
        Dictionary mapping configuration names to configs.

    Raises
    ------
    ValueError
        If parameter not found in model config or lists have different lengths.
    """

    def check_no_of_values(param_dict):
        # all params must have the same number of values in the LIST method
        no_of_values_per_param = []
        for i, param in enumerate(param_dict.keys()):
            no_of_values_per_param.append(len(param_dict[param]))
            if i > 0:
                assert no_of_values_per_param[i] == no_of_values_per_param[i - 1], (
                    "All parameters must have the same number of values"
                )
        return no_of_values_per_param

    cfg_params = {}
    no_of_values_per_param = check_no_of_values(param_dict)[0]
    for i in range(
        no_of_values_per_param
    ):  # e.g, 3 values for each param 'len(list) = 3'
        cfg_tmp = cfg_model.copy()
        param_combo_key = ""
        for j, param in enumerate(
            param_dict.keys()
        ):  # how many hyperparams you wanted to vary
            if param in cfg_tmp[cfg_key][model]["MODEL"]:
                value_from_list = param_dict[param][i]
                # create a name for the hyperparameter, this will be used then in the run names (MLflows)
                param_combo_key += f"{create_hyperparam_name(param, value_from_list, i, j, n_params=len(param_dict.keys()))}"
                cfg_tmp[cfg_key][model]["MODEL"][param] = value_from_list
            else:
                logger.error(
                    f"Parameter {param} not found in the model {model} (typo in your search_space?"
                )
                logger.error(
                    f"Possible param keys for assignment = {cfg_tmp[cfg_key][model]['MODEL'].keys()}"
                )
                raise ValueError(
                    f"Parameter {param} not found in the model {model} (typo in your search_space?"
                )
        cfg_params[f"{model}_{param_combo_key}"] = cfg_tmp

    return cfg_params

define_grid_hyperparam_combos

define_grid_hyperparam_combos(
    cfg_model: DictConfig,
    param_dict: DictConfig,
    model: str,
    task: str,
    cfg_key: str,
) -> dict

Generate configurations for all combinations in a grid search.

Creates the Cartesian product of all parameter value lists.

PARAMETER DESCRIPTION
cfg_model

Base model configuration.

TYPE: DictConfig

param_dict

Dictionary mapping parameter names to value lists.

TYPE: DictConfig

model

Model name.

TYPE: str

task

Task type: 'outlier_detection' or 'imputation'.

TYPE: str

cfg_key

Configuration key for model access.

TYPE: str

RETURNS DESCRIPTION
dict

Dictionary mapping configuration names to configs.

RAISES DESCRIPTION
ValueError

If task type is not recognized for naming.

Source code in src/orchestration/hyperparamer_list_utils.py
def define_grid_hyperparam_combos(
    cfg_model: DictConfig, param_dict: DictConfig, model: str, task: str, cfg_key: str
) -> dict:
    """Generate configurations for all combinations in a grid search.

    Creates the Cartesian product of all parameter value lists.

    Parameters
    ----------
    cfg_model : DictConfig
        Base model configuration.
    param_dict : DictConfig
        Dictionary mapping parameter names to value lists.
    model : str
        Model name.
    task : str
        Task type: 'outlier_detection' or 'imputation'.
    cfg_key : str
        Configuration key for model access.

    Returns
    -------
    dict
        Dictionary mapping configuration names to configs.

    Raises
    ------
    ValueError
        If task type is not recognized for naming.
    """
    choice_keys = list(param_dict.keys())
    choices = list(itertools.product(*param_dict.values()))

    cfgs = {}
    for i, choice in enumerate(choices):
        cfg_tmp = cfg_model.copy()
        for j, key in enumerate(choice_keys):
            logger.debug("Setting {} to {}".format(key, choice[j]))
            cfg_tmp[cfg_key][model]["MODEL"][key] = choice[j]
        if task == "outlier_detection":
            cfg_name = update_outlier_detection_run_name(cfg_tmp)
        elif task == "imputation":
            cfg_name = update_imputation_run_name(cfg_tmp)
        else:
            logger.error("Define some name for the run, task = {}".format(task))
            raise ValueError("Define some name for the run, task = {}".format(task))

        cfgs[cfg_name] = cfg_tmp

    return cfgs

debug_utils

debug_classification_macro

debug_classification_macro(cfg: DictConfig) -> DictConfig

Reduce bootstrap iterations for faster debugging.

Modifies the configuration to use only 50 bootstrap iterations instead of the default (typically 1000).

PARAMETER DESCRIPTION
cfg

Configuration dictionary to modify.

TYPE: DictConfig

RETURNS DESCRIPTION
DictConfig

Modified configuration with reduced bootstrap iterations.

Source code in src/orchestration/debug_utils.py
def debug_classification_macro(cfg: DictConfig) -> DictConfig:
    """Reduce bootstrap iterations for faster debugging.

    Modifies the configuration to use only 50 bootstrap iterations
    instead of the default (typically 1000).

    Parameters
    ----------
    cfg : DictConfig
        Configuration dictionary to modify.

    Returns
    -------
    DictConfig
        Modified configuration with reduced bootstrap iterations.
    """
    with open_dict(cfg):
        cfg["CLS_EVALUATION"]["BOOTSTRAP"]["n_iterations"] = 50
        logger.info(
            "Setting number of bootstrap iterations to {} to speed up debugging".format(
                cfg["CLS_EVALUATION"]["BOOTSTRAP"]["n_iterations"]
            )
        )
    return cfg

pick_one_model

pick_one_model(
    cfg: DictConfig, model_name: str = "SAITS"
) -> DictConfig

Keep only one model in configuration for testing.

Reduces the model dictionary to contain only the specified model.

PARAMETER DESCRIPTION
cfg

Configuration with MODELS dictionary.

TYPE: DictConfig

model_name

Name of the model to keep. Default is 'SAITS'.

TYPE: str DEFAULT: 'SAITS'

RETURNS DESCRIPTION
DictConfig

Modified configuration with single model.

Source code in src/orchestration/debug_utils.py
def pick_one_model(cfg: DictConfig, model_name: str = "SAITS") -> DictConfig:
    """Keep only one model in configuration for testing.

    Reduces the model dictionary to contain only the specified model.

    Parameters
    ----------
    cfg : DictConfig
        Configuration with MODELS dictionary.
    model_name : str, optional
        Name of the model to keep. Default is 'SAITS'.

    Returns
    -------
    DictConfig
        Modified configuration with single model.
    """
    logger.warning("Picking just one model for testing purposes: {}".format(model_name))
    cfg["MODELS"] = {model_name: cfg["MODELS"][model_name]}
    return cfg

debug_train_only_for_one_epoch

debug_train_only_for_one_epoch(
    cfg: DictConfig,
) -> DictConfig

Reduce all epoch counts to 1 for quick debugging.

Recursively searches configuration for epoch-related keys and sets them to 1 for fast iteration during development.

PARAMETER DESCRIPTION
cfg

Configuration dictionary to modify.

TYPE: DictConfig

RETURNS DESCRIPTION
DictConfig

Modified configuration with all epoch counts set to 1.

Notes

Modifies keys: 'epochs', 'max_epoch', 'train_epochs'.

Source code in src/orchestration/debug_utils.py
def debug_train_only_for_one_epoch(cfg: DictConfig) -> DictConfig:
    """Reduce all epoch counts to 1 for quick debugging.

    Recursively searches configuration for epoch-related keys and
    sets them to 1 for fast iteration during development.

    Parameters
    ----------
    cfg : DictConfig
        Configuration dictionary to modify.

    Returns
    -------
    DictConfig
        Modified configuration with all epoch counts set to 1.

    Notes
    -----
    Modifies keys: 'epochs', 'max_epoch', 'train_epochs'.
    """

    def replace_item(obj: DictConfig, key: str, replace_value: Any) -> DictConfig:
        for k, v in obj.items():
            if isinstance(v, DictConfig):
                obj[k] = replace_item(v, key, replace_value)
        if key in obj:
            old_value = obj[key]
            if isinstance(obj[key], int):
                obj[key] = replace_value
                logger.warning(
                    "Replacing old value '{}={}' with '{}={}'".format(
                        key, old_value, key, obj[key]
                    )
                )
        return obj.copy()

    # Would it be robust enough to just search for the "epoch" substring?
    cfg_modified = replace_item(cfg.copy(), "epochs", replace_value=1)
    cfg_modified = replace_item(cfg_modified.copy(), "max_epoch", replace_value=1)
    cfg_modified = replace_item(cfg_modified.copy(), "train_epochs", replace_value=1)

    return cfg_modified

fix_tree_learners_for_debug

fix_tree_learners_for_debug(
    cfg: DictConfig, model_name: str
) -> DictConfig

Reduce tree-based model iterations for debugging.

Specifically handles MissForest by reducing max_iter to 2.

PARAMETER DESCRIPTION
cfg

Configuration dictionary.

TYPE: DictConfig

model_name

Name of the model to check.

TYPE: str

RETURNS DESCRIPTION
DictConfig

Modified configuration if MissForest, otherwise unchanged.

Source code in src/orchestration/debug_utils.py
def fix_tree_learners_for_debug(cfg: DictConfig, model_name: str) -> DictConfig:
    """Reduce tree-based model iterations for debugging.

    Specifically handles MissForest by reducing max_iter to 2.

    Parameters
    ----------
    cfg : DictConfig
        Configuration dictionary.
    model_name : str
        Name of the model to check.

    Returns
    -------
    DictConfig
        Modified configuration if MissForest, otherwise unchanged.
    """
    if "MISSFOREST" in model_name:
        logger.warning("DEBUG | Setting the MissForest max_iter to 2")
        with open_dict(cfg):
            cfg["MODELS"][model_name]["MODEL"]["max_iter"] = 2

    return cfg

tabm_hyperparams

get_metric_of_runs

get_metric_of_runs(metrics, eval_metric, split='test')

Extract a specific evaluation metric from multiple runs.

PARAMETER DESCRIPTION
metrics

Dictionary mapping run indices to their metric results.

TYPE: dict

eval_metric

Name of the evaluation metric to extract (case-insensitive).

TYPE: str

split

Data split to get metrics from. Default is 'test'.

TYPE: str DEFAULT: 'test'

RETURNS DESCRIPTION
ndarray

Array of metric values, one per run.

Source code in src/orchestration/tabm_hyperparams.py
def get_metric_of_runs(metrics, eval_metric, split="test"):
    """Extract a specific evaluation metric from multiple runs.

    Parameters
    ----------
    metrics : dict
        Dictionary mapping run indices to their metric results.
    eval_metric : str
        Name of the evaluation metric to extract (case-insensitive).
    split : str, optional
        Data split to get metrics from. Default is 'test'.

    Returns
    -------
    np.ndarray
        Array of metric values, one per run.
    """
    metric_list = []
    for _, metric in metrics.items():
        scalars = metric["metrics_stats"][split]["metrics"]["scalars"]
        metric_list.append(scalars[eval_metric.upper()]["mean"])
    return np.array(metric_list)

pick_the_best_hyperparam_metrics

pick_the_best_hyperparam_metrics(
    metrics, hparam_cfg, model_cfgs, cfg
)

Select the best hyperparameter configuration based on evaluation metric.

PARAMETER DESCRIPTION
metrics

Dictionary of metrics from each hyperparameter configuration run.

TYPE: dict

hparam_cfg

Hyperparameter search configuration with SEARCH_SPACE.GRID.

TYPE: dict

model_cfgs

List of model configurations corresponding to each run.

TYPE: list

cfg

Main configuration dictionary.

TYPE: DictConfig

RETURNS DESCRIPTION
tuple

A tuple containing: - best_metrics : dict Metrics from the best-performing configuration. - best_choice : dict Dictionary with 'choice', 'choice_keys', and 'model_cfg' for the selected configuration.

Source code in src/orchestration/tabm_hyperparams.py
def pick_the_best_hyperparam_metrics(metrics, hparam_cfg, model_cfgs, cfg):
    """Select the best hyperparameter configuration based on evaluation metric.

    Parameters
    ----------
    metrics : dict
        Dictionary of metrics from each hyperparameter configuration run.
    hparam_cfg : dict
        Hyperparameter search configuration with SEARCH_SPACE.GRID.
    model_cfgs : list
        List of model configurations corresponding to each run.
    cfg : DictConfig
        Main configuration dictionary.

    Returns
    -------
    tuple
        A tuple containing:
        - best_metrics : dict
            Metrics from the best-performing configuration.
        - best_choice : dict
            Dictionary with 'choice', 'choice_keys', and 'model_cfg'
            for the selected configuration.
    """
    eval_metric = get_eval_metric_name("TabM", cfg)
    metric_of_runs = get_metric_of_runs(metrics, eval_metric)
    choice_keys, choices = get_grid_choices(hparam_cfg)
    assert len(choices) == len(metric_of_runs), (
        f"Number of choices {len(choices)} "
        f"does not match number of metrics {len(metric_of_runs)}"
    )

    best_idx = np.argmax(metric_of_runs)
    best_metrics = metrics[best_idx]
    best_choice = {
        "choice": choices[best_idx],
        "choice_keys": choice_keys,
        "model_cfg": model_cfgs[best_idx],
    }
    return best_metrics, best_choice

get_grid_choices

get_grid_choices(hparam_cfg)

Generate all combinations from a grid search space.

PARAMETER DESCRIPTION
hparam_cfg

Configuration with SEARCH_SPACE.GRID containing parameter lists.

TYPE: dict

RETURNS DESCRIPTION
tuple

A tuple containing: - choice_keys : list List of hyperparameter names. - choices : list List of tuples, each representing one parameter combination.

Source code in src/orchestration/tabm_hyperparams.py
def get_grid_choices(hparam_cfg):
    """Generate all combinations from a grid search space.

    Parameters
    ----------
    hparam_cfg : dict
        Configuration with SEARCH_SPACE.GRID containing parameter lists.

    Returns
    -------
    tuple
        A tuple containing:
        - choice_keys : list
            List of hyperparameter names.
        - choices : list
            List of tuples, each representing one parameter combination.
    """
    grid_search_space = hparam_cfg["SEARCH_SPACE"]["GRID"]
    choice_keys = list(grid_search_space.keys())
    choices = list(itertools.product(*grid_search_space.values()))
    return choice_keys, choices

create_tabm_grid_experiment

create_tabm_grid_experiment(hparam_cfg, cls_model_cfg)

Create configurations for TabM grid search experiment.

Generates a list of model configurations, one for each combination of hyperparameters in the grid search space.

PARAMETER DESCRIPTION
hparam_cfg

Hyperparameter configuration with SEARCH_SPACE.GRID.

TYPE: dict

cls_model_cfg

Base classifier model configuration to modify.

TYPE: dict

RETURNS DESCRIPTION
list

List of configuration dictionaries, one per grid combination.

Source code in src/orchestration/tabm_hyperparams.py
def create_tabm_grid_experiment(hparam_cfg, cls_model_cfg):
    """Create configurations for TabM grid search experiment.

    Generates a list of model configurations, one for each combination
    of hyperparameters in the grid search space.

    Parameters
    ----------
    hparam_cfg : dict
        Hyperparameter configuration with SEARCH_SPACE.GRID.
    cls_model_cfg : dict
        Base classifier model configuration to modify.

    Returns
    -------
    list
        List of configuration dictionaries, one per grid combination.
    """
    # e.g. with 50 bootstrap iterations
    # src.classification.bootstrap_evaluation:bootstrap_evaluator:245 - Bootstrap evaluation in 32.43 seconds
    # Bootstrap iterations:   2%|▏ | 1/50 [00:00<00:27,  1.78it/s]-
    # Best epoch: 0, val score = 0.4107, test score: 0.4250, train score: 0.4000
    choice_keys, choices = get_grid_choices(hparam_cfg)
    cfgs = []
    for ii, choice in enumerate(choices):
        cfg_tmp = cls_model_cfg.copy()
        for i, key in enumerate(choice_keys):
            cfg_tmp[key] = choice[i]
        cfgs.append(cfg_tmp)
    logger.info(f"Created {len(cfgs)} hyperparameter configurations")
    return cfgs

create_tabm_hyperparam_experiment

create_tabm_hyperparam_experiment(
    run_name, hparam_cfg, cls_model_cfg
)

Create hyperparameter experiment configurations for TabM.

Determines whether to run grid search based on configuration and run name, and generates appropriate configurations.

PARAMETER DESCRIPTION
run_name

Name of the current run, used to determine if ground truth.

TYPE: str

hparam_cfg

Hyperparameter configuration with search settings.

TYPE: dict

cls_model_cfg

Base classifier configuration.

TYPE: dict

RETURNS DESCRIPTION
list

List of configurations to run. Returns single-element list with original config if grid search is disabled or not applicable.

RAISES DESCRIPTION
NotImplementedError

If LIST search space is specified (not yet implemented).

ValueError

If search space type is unknown.

Source code in src/orchestration/tabm_hyperparams.py
def create_tabm_hyperparam_experiment(run_name, hparam_cfg, cls_model_cfg):
    """Create hyperparameter experiment configurations for TabM.

    Determines whether to run grid search based on configuration and
    run name, and generates appropriate configurations.

    Parameters
    ----------
    run_name : str
        Name of the current run, used to determine if ground truth.
    hparam_cfg : dict
        Hyperparameter configuration with search settings.
    cls_model_cfg : dict
        Base classifier configuration.

    Returns
    -------
    list
        List of configurations to run. Returns single-element list with
        original config if grid search is disabled or not applicable.

    Raises
    ------
    NotImplementedError
        If LIST search space is specified (not yet implemented).
    ValueError
        If search space type is unknown.
    """
    if hparam_cfg["HYPERPARAMS"]["run_grid_hyperparam_search"]:
        if hparam_cfg["HYPERPARAMS"]["run_only_on_ground_truth"]:
            if "pupil-gt__pupil-gt" in run_name:
                if "LIST" in hparam_cfg["SEARCH_SPACE"].keys():
                    logger.error("List not implemented yet")
                    raise NotImplementedError("List not implemented yet")
                elif "GRID" in hparam_cfg["SEARCH_SPACE"].keys():
                    cfgs = create_tabm_grid_experiment(hparam_cfg, cls_model_cfg)
                    return cfgs
                else:
                    logger.error(f"Unknown search space, {hparam_cfg['SEARCH_SPACE']}")
                    raise ValueError(
                        f"Unknown search space, {hparam_cfg['SEARCH_SPACE']}"
                    )
            else:
                return [cls_model_cfg]
        else:
            return [cls_model_cfg]

    else:
        return [cls_model_cfg]