Source code for mlcompare.params_reader

from __future__ import annotations as _annotations

import json
import logging
from pathlib import Path
from typing import Any, TypeAlias

logger = logging.getLogger(__name__)

ParamsInput: TypeAlias = list[dict[str, Any]] | str | Path
"""User input for pipelines, containing information to load and process datasets or to create ml models."""


[docs] class ParamsReader: """ Reads and validates a list of parameters by calling `.read()`. Attributes: ----------- params_list (list[dict[str, Any]] | str | Path]): List of dictionaries or a path to a JSON file. """
[docs] @staticmethod def read(params_list: ParamsInput) -> list[dict[str, Any]]: """ Reads and validates a list of parameters. Args: ----- params_list (list[dict[str, Any]] | str | Path]): List of dictionaries or a path to a JSON file. Returns: -------- list[dict[str, Any]]: List of dictionaries containing parameters. Raises: ------- TypeError: If the params_list is not a list of dictionaries or a path to a JSON file. FileNotFoundError: If the specified file does not exist. json.JSONDecodeError: If there is an error decoding the JSON file. """ if isinstance(params_list, str): params_list = Path(params_list) if isinstance(params_list, Path): params_list = ParamsReader._load_params_from_file(params_list) ParamsReader._validate_params_list(params_list) return params_list
@staticmethod def _load_params_from_file(file_path: Path) -> list[dict[str, Any]]: """ Loads parameters from a JSON file. Args: ----- file_path (Path): Path to the JSON file. Returns: -------- List[dict[str, Any]]: List of dictionaries containing parameters. Raises: ------- FileNotFoundError: If the specified file does not exist. json.JSONDecodeError: If there is an error decoding the JSON file. """ try: with open(file_path, "r") as file: return json.load(file) except FileNotFoundError: logger.error(f"Could not find file: {file_path}") raise except json.JSONDecodeError: logger.error(f"Error decoding JSON from file: {file_path}") raise except Exception: logger.error(f"Unexpected error loading params from file: {file_path}") raise @staticmethod def _validate_params_list(params_list: list[dict[str, Any]]) -> None: assert isinstance( params_list, list ), "`params_list` must be a list of dictionaries or a path to .json file containing one." assert all( isinstance(params, dict) for params in params_list ), "Each list element of `params_list` must be a dictionary." for params in params_list: one_hot_encode = set(params.get("oneHotEncode", [])) target_encode = set(params.get("targetEncode", [])) ordinal_encode = set(params.get("ordinalEncode", [])) overlap_onehot_target = one_hot_encode.intersection(target_encode) overlap_onehot_ordinal = one_hot_encode.intersection(ordinal_encode) overlap_target_ordinal = target_encode.intersection(ordinal_encode) if overlap_onehot_target: raise ValueError( f"Columns: {overlap_onehot_target} are listed in both 'targetEncode' and 'oneHotEncode' for " "one of the datasets." ) if overlap_onehot_ordinal: raise ValueError( f"Columns: {overlap_onehot_ordinal} are listed in both 'ordinalEncode' and 'oneHotEncode' for " "one of the datasets." ) if overlap_target_ordinal: raise ValueError( f"Columns: {overlap_target_ordinal} are listed in both 'targetEncode' and 'ordinalEncode' for " "one of the datasets." )