Source code for mlcompare.data.split_data

from __future__ import annotations as _annotations

import logging
import pickle
from pathlib import Path
from typing import TypeAlias

import pandas as pd
from pydantic import BaseModel, ConfigDict

logger = logging.getLogger(__name__)

SplitDataTuple: TypeAlias = tuple[
    pd.DataFrame,
    pd.DataFrame,
    pd.DataFrame | pd.Series,
    pd.DataFrame | pd.Series,
]
"""A train-test split also split by features and target variable. Primarily used to both train and evaluate models."""


[docs] class SplitData(BaseModel): """ Validates and stores train-test and feature-target split data. """ X_train: pd.DataFrame X_test: pd.DataFrame y_train: pd.DataFrame | pd.Series y_test: pd.DataFrame | pd.Series model_config = ConfigDict(arbitrary_types_allowed=True)
def load_split_data(load_path: str | Path) -> SplitDataTuple: """ Loads a SplitData object from a pickle file and returns the data it was holding. Args: ----- load_path (str | Path): Path to a pickle file contain a SplitData object. Returns: -------- SplitDataTuple: pd.DataFrame: Training split features. pd.DataFrame: Testing split features. pd.DataFrame | pd.Series: Training split target values. pd.DataFrame | pd.Series: Testing split target values. """ if not isinstance(load_path, (Path)): if not isinstance(load_path, str): raise ValueError("`load_path` must be a string or Path object.") else: load_path = Path(load_path) with open(load_path, "rb") as file: split_data = pickle.load(file) if not isinstance(split_data, SplitData): raise TypeError("Loaded data must be of type SplitData.") return split_data.X_train, split_data.X_test, split_data.y_train, split_data.y_test