Source code for mlcompare.models.models

from __future__ import annotations as _annotations

import inspect
import logging
import pickle
from abc import ABC, abstractmethod
from importlib import import_module
from pathlib import Path
from typing import Any, Generator, Literal, TypeAlias

import pandas as pd
from pydantic import BaseModel

from ..params_reader import ParamsInput, ParamsReader

logger = logging.getLogger(__name__)

LibraryNames: TypeAlias = Literal[
    "sklearn",
    "scikit-learn",
    "skl",
    "xgboost",
    "xgb",
    "pytorch",
    "torch",
    "tensorflow",
    "tf",
]


[docs] class LibraryModel(ABC, BaseModel): name: str module: str | None = None params: dict | None = None _library: Literal["sklearn", "xgboost", "torch", "tensorflow"] _ml_model: Any = None """ A base class for models from different machine learning libraries. Attributes: ----------- name (str): Class name of the model. Ex: RandomForestRegressor. module (str | None): Module containing the model class if it's not imported at the library level. params (dict | None): Parameters to pass to the model class constructor if any. _ml_model (Any): Model object instantiated from the library, accessed by the `train` and `predict` methods. """ @abstractmethod def model_post_init(self, Any) -> None: ...
[docs] @abstractmethod def train(self, X_train: pd.DataFrame, y_train: pd.DataFrame | pd.Series) -> None: ...
[docs] @abstractmethod def predict(self, X_test: pd.DataFrame): ...
[docs] def save(self, save_directory: Path): save_file = save_directory / f"{self.name}.pkl" with open(save_file, "wb") as file: pickle.dump(self._ml_model, file)
[docs] def resolve_model_submodule(self) -> Any | None: imported_library = import_module(self._library) library_modules = imported_library.__all__ for module_name in library_modules: try: module = import_module(f"{imported_library.__name__}.{module_name}") for class_name, obj in inspect.getmembers(module, inspect.isclass): if class_name == self.name: logger.info(f"{self.name} found in module: {module_name}.") return obj except Exception: pass return None
[docs] def instantiate_model(self) -> None: if self.module: full_import = f"{self._library}.{self.module}" else: full_import = self._library try: model_module = import_module(full_import) except ImportError: logger.error( f"Could not import module {full_import}. Check that you have {self._library} " "installed or that the module name is spelled correctly." ) raise # Get the model class from the module. If it fails and no module was given, try to find the class within submodules. try: model_class = getattr(model_module, self.name) except AttributeError: if self.module: logger.error(f"Could not find class: {self.name} in module: {self.module}.") raise else: logger.info(f"Searching {self._library} submodules for {self.name}.") model_class = self.resolve_model_submodule() if not model_class: raise ImportError( f"Could not find class {self.name} in any {self._library} submodules. Please provide a " "module for the model within the config i.e. 'module': 'ensemble'." ) # Initialize the model with the given parameters try: if self.params: ml_model = model_class(**self.params) else: ml_model = model_class() except Exception: logger.error(f"Could not initialize model {self.name} with params {self.params}") raise self._ml_model = ml_model
[docs] class SklearnModel(LibraryModel): """ A class used to instantiate and manage a Scikit-learn model. Attributes: ----------- name (str): Class name of the model. Ex: RandomForestRegressor. module (str | None): Module containing the model class if it's not imported at the library level. params (dict | None): Parameters to pass to the model class constructor if any. """ _library = "sklearn" def model_post_init(self, Any) -> None: self.instantiate_model()
[docs] def train(self, X_train, y_train) -> None: self._ml_model.fit(X_train, y_train)
[docs] def predict(self, X_test): return self._ml_model.predict(X_test)
[docs] class XGBoostModel(LibraryModel): """ A class used to instantiate and manage an XGBoost model. Attributes: ----------- name (str): Class name of the model. Ex: XGBRegressor. module (str | None): Module containing the model class if it's not imported at the library level. params (dict | None): Parameters to pass to the model class constructor if any. """ _library = "xgboost" def model_post_init(self, Any) -> None: self.instantiate_model()
[docs] def train(self, X_train, y_train) -> None: self._ml_model.fit(X_train, y_train)
[docs] def predict(self, X_test): return self._ml_model.predict(X_test)
class PyTorchModel(LibraryModel): """ A class used to instantiate and manage a PyTorch model. Attributes: ----------- name (str): Class name of the model. Ex: LSTM. module (str | None): Module containing the model class if it's not imported at the library level. params (dict | None): Parameters to pass to the model class constructor if any. """ _library = "torch" device: Literal["cuda", "mps", "cpu"] | None = None activation: str loss: str optimizer: str = "Adam" batch_size: int = 32 epochs: int = 100 def model_post_init(self, Any) -> None: import torch self.instantiate_model() if self.device is None: self.set_device() else: assert isinstance(self.device, torch.device) def train(self, X_train, y_train) -> None: self._ml_model.fit(X_train, y_train) def predict(self, X_test): return self._ml_model.predict(X_test) def set_device(self) -> None: import torch if self.device is not None: match self.device: case "cuda": self._torch_device = torch.device("cuda") case "mps": self._torch_device = torch.device("mps") case "cpu": self._torch_device = torch.device("cpu") case _: raise ValueError("Device must be one of 'cuda', 'mps', or 'cpu'.") else: if torch.cuda.is_available(): self._torch_device = torch.device("cuda") elif torch.backends.mps.is_available(): self._torch_device = torch.device("mps") else: self._torch_device = torch.device("cpu") logger.info(f"Pytorch device type set to: {self._torch_device.type}") class TensorflowModel(LibraryModel): """ A class used to instantiate and manage an TensorFlow model. Attributes: ----------- name (str): Class name of the model. Ex: XGBRegressor. module (str | None): Module containing the model class if it's not imported at the library level. params (dict | None): Parameters to pass to the model class constructor if any. _ml_model (Any): Instantiated machine learning model, accessed by the `train` and `predict` methods. """ _library = "tensorflow" activation: str loss: str optimizer: str = "Adam" epochs: int = 100 def model_post_init(self, Any): self.instantiate_model() def train(self, X_train, y_train): self._ml_model.fit(X_train, y_train) def predict(self, X_test): return self._ml_model.predict(X_test) MLModelType: TypeAlias = SklearnModel | XGBoostModel | PyTorchModel
[docs] class ModelFactory: """ Takes in a list of dictionaries and constructs model classes based on the `library` keyword provided for each. The class is designed to be iterated over. Attributes: ----------- params_list (list[dict[str, Any]] | str | Path): List of dictionaries containing dataset parameters or a path to a .json file with one. For a list of keys required in each dictionary, see below: Required keys: - `library` (Literal["sklearn", "xgboost", "pytorch", "tensorflow", "custom"]): The library to use. - `module` (str): Module containing the model class. - `name` (str): Name of the model class. Optional keys: - `params` (dict | None): The parameters to pass to the model class constructor Raises: ------- AssertionError: If `dataset_params` is not a list of dictionaries or a path to a .json file containing one. """ def __init__(self, params_list: ParamsInput) -> None: self.params_list = ParamsReader.read(params_list) def __iter__(self) -> Generator[MLModelType, None, None]: """ Makes the class iterable, yielding dataset instances one by one. Yields: ------- MLModelType: Instance of a LibraryModel child class. """ for params in self.params_list: yield ModelFactory.create(**params)
[docs] @staticmethod def create(library: LibraryNames, **kwargs) -> MLModelType: """ Factory method to create a dataset instance based on the dataset type. Args: ----- library (LibraryNames): Type of dataset to create. **kwargs: Arbitrary keyword arguments to be passed to the dataset class constructor. Returns: -------- BaseDataset: An instance of a dataset class (KaggleDataset or LocalDataset). Raises: ------- ValueError: If an unknown dataset type is provided. """ assert isinstance(library, str), "Library must be a string." library = library.lower() # type: ignore match library: case "sklearn" | "scikit-learn" | "skl": return SklearnModel(**kwargs) case "xgboost" | "xgb": return XGBoostModel(**kwargs) case "pytorch" | "torch": return PyTorchModel(**kwargs) case _: raise ValueError( f"Library: {library} is not supported. Valid library names " "are: 'sklearn', 'xgboost', 'pytorch', or 'tensorflow'. If your model is not " "in one of these libraries use 'custom' and provide a value for 'custom_function' " "that takes in train-test split data and returns an nd.array or pd.Series of " "predictions. See the documentation for more details." )