From 2b6cae949ccda5c97cd65bc274daff66fb0c9747 Mon Sep 17 00:00:00 2001 From: jrzkaminski <86363785+jrzkaminski@users.noreply.github.com> Date: Tue, 8 Aug 2023 16:02:53 +0300 Subject: [PATCH] made lightgbm optional --- bamt/utils/composite_utils/MLUtils.py | 31 ++++++++++++++++++--- bamt/utils/composite_utils/lgbm_params.json | 11 ++++++++ pyproject.toml | 4 +++ tests/test_networks.py | 2 +- 4 files changed, 43 insertions(+), 5 deletions(-) create mode 100644 bamt/utils/composite_utils/lgbm_params.json diff --git a/bamt/utils/composite_utils/MLUtils.py b/bamt/utils/composite_utils/MLUtils.py index 4c032d8..4653a85 100644 --- a/bamt/utils/composite_utils/MLUtils.py +++ b/bamt/utils/composite_utils/MLUtils.py @@ -1,6 +1,6 @@ from catboost import CatBoostClassifier, CatBoostRegressor +from bamt.log import logger_network -# from lightgbm.sklearn import LGBMClassifier, LGBMRegressor from sklearn.cluster import KMeans from sklearn.ensemble import ( AdaBoostRegressor, @@ -24,6 +24,16 @@ from .CompositeModel import CompositeNode from random import choice +# Try to import LGBMRegressor and LGBMClassifier from lightgbm, if not available set to None +try: + from lightgbm.sklearn import LGBMRegressor, LGBMClassifier +except ModuleNotFoundError: + LGBMRegressor = None + LGBMClassifier = None + logger_network.warning( + "Install lightgbm (e.g. pip install lightgbm) to use LGBMRegressor and LGBMClassifier" + ) + class MlModels: def __init__(self): @@ -47,7 +57,6 @@ def __init__(self): "dt": "DecisionTreeClassifier", "rf": "RandomForestClassifier", "mlp": "MLPClassifier", - # "lgbm": "LGBMClassifier", "catboost": "CatBoostClassifier", "kmeans": "KMeans", } @@ -63,7 +72,7 @@ def __init__(self): "Ridge": Ridge, "Lasso": Lasso, "SGDRegressor": SGDRegressor, - # "LGBMRegressor": LGBMRegressor, + "LGBMRegressor": LGBMRegressor, "CatBoostRegressor": CatBoostRegressor, "XGBClassifier": XGBClassifier, "LogisticRegression": LogisticRegression, @@ -72,11 +81,23 @@ def __init__(self): "DecisionTreeClassifier": DecisionTreeClassifier, "RandomForestClassifier": RandomForestClassifier, "MLPClassifier": MLPClassifier, - # "LGBMClassifier": LGBMClassifier, + "LGBMClassifier": LGBMClassifier, "CatBoostClassifier": CatBoostClassifier, "KMeans": KMeans, } + # Include LGBMRegressor and LGBMClassifier if they were imported successfully + if LGBMRegressor is not None: + self.dict_models["LGBMRegressor"] = LGBMRegressor + self.operations_by_types["lgbmreg"] = "LGBMRegressor" + if LGBMClassifier is not None: + self.dict_models["LGBMClassifier"] = LGBMClassifier + self.operations_by_types["lgbm"] = "LGBMClassifier" + + if LGBMClassifier and LGBMRegressor is not None: + with open("bamt/utils/composite_utils/lgbm_params.json") as file: + self.lgbm_dict = json.load(file) + def get_model_by_children_type(self, node: CompositeNode): candidates = [] if node.content["type"] == "cont": @@ -87,6 +108,8 @@ def get_model_by_children_type(self, node: CompositeNode): with open("bamt/utils/composite_utils/models_repo.json", "r") as f: models_json = json.load(f) models = models_json["operations"] + if LGBMClassifier and LGBMRegressor is not None: + models = models | self.lgbm_dict for model, value in models.items(): if ( model not in ["knnreg", "knn", "qda"] diff --git a/bamt/utils/composite_utils/lgbm_params.json b/bamt/utils/composite_utils/lgbm_params.json new file mode 100644 index 0000000..5c5fa9d --- /dev/null +++ b/bamt/utils/composite_utils/lgbm_params.json @@ -0,0 +1,11 @@ +{ + "lgbm": { + "meta": "sklearn_class", + "tags": ["boosting", "tree", "non_linear", "mix"] + }, + "lgbmreg": { + "meta": "sklearn_regr", + "presets": ["*tree"], + "tags": ["boosting", "tree", "non_multi", "non_linear", "mix"] + } +} \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 12cc4c3..0fad6a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,10 @@ pgmpy = "0.1.20" thegolem = ">=0.3.1" xgboost = ">=1.7.6" catboost = ">=1.0.6" +lightgbm = {version = ">=4.0.0", optional = true } + +[tool.poetry.extras] +composite_extras = ["lightgbm"] [tool.poetry.dev-dependencies] diff --git a/tests/test_networks.py b/tests/test_networks.py index 84f6204..8adc81e 100644 --- a/tests/test_networks.py +++ b/tests/test_networks.py @@ -1104,7 +1104,7 @@ def test_learning(self): self.bn.add_nodes(info) - self.bn.add_edges(self.data) + self.bn.add_edges(self.data, verbose=False) self.bn.fit_parameters(self.data)