Skip to content

Commit

Permalink
made lightgbm optional
Browse files Browse the repository at this point in the history
  • Loading branch information
jrzkaminski committed Aug 8, 2023
1 parent bb2576a commit 2b6cae9
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 5 deletions.
31 changes: 27 additions & 4 deletions bamt/utils/composite_utils/MLUtils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from catboost import CatBoostClassifier, CatBoostRegressor
from bamt.log import logger_network

# from lightgbm.sklearn import LGBMClassifier, LGBMRegressor
from sklearn.cluster import KMeans
from sklearn.ensemble import (
AdaBoostRegressor,
Expand All @@ -24,6 +24,16 @@
from .CompositeModel import CompositeNode
from random import choice

# Try to import LGBMRegressor and LGBMClassifier from lightgbm, if not available set to None
try:
from lightgbm.sklearn import LGBMRegressor, LGBMClassifier
except ModuleNotFoundError:
LGBMRegressor = None
LGBMClassifier = None
logger_network.warning(
"Install lightgbm (e.g. pip install lightgbm) to use LGBMRegressor and LGBMClassifier"
)


class MlModels:
def __init__(self):
Expand All @@ -47,7 +57,6 @@ def __init__(self):
"dt": "DecisionTreeClassifier",
"rf": "RandomForestClassifier",
"mlp": "MLPClassifier",
# "lgbm": "LGBMClassifier",
"catboost": "CatBoostClassifier",
"kmeans": "KMeans",
}
Expand All @@ -63,7 +72,7 @@ def __init__(self):
"Ridge": Ridge,
"Lasso": Lasso,
"SGDRegressor": SGDRegressor,
# "LGBMRegressor": LGBMRegressor,
"LGBMRegressor": LGBMRegressor,
"CatBoostRegressor": CatBoostRegressor,
"XGBClassifier": XGBClassifier,
"LogisticRegression": LogisticRegression,
Expand All @@ -72,11 +81,23 @@ def __init__(self):
"DecisionTreeClassifier": DecisionTreeClassifier,
"RandomForestClassifier": RandomForestClassifier,
"MLPClassifier": MLPClassifier,
# "LGBMClassifier": LGBMClassifier,
"LGBMClassifier": LGBMClassifier,
"CatBoostClassifier": CatBoostClassifier,
"KMeans": KMeans,
}

# Include LGBMRegressor and LGBMClassifier if they were imported successfully
if LGBMRegressor is not None:
self.dict_models["LGBMRegressor"] = LGBMRegressor
self.operations_by_types["lgbmreg"] = "LGBMRegressor"
if LGBMClassifier is not None:
self.dict_models["LGBMClassifier"] = LGBMClassifier
self.operations_by_types["lgbm"] = "LGBMClassifier"

if LGBMClassifier and LGBMRegressor is not None:
with open("bamt/utils/composite_utils/lgbm_params.json") as file:
self.lgbm_dict = json.load(file)

def get_model_by_children_type(self, node: CompositeNode):
candidates = []
if node.content["type"] == "cont":
Expand All @@ -87,6 +108,8 @@ def get_model_by_children_type(self, node: CompositeNode):
with open("bamt/utils/composite_utils/models_repo.json", "r") as f:
models_json = json.load(f)
models = models_json["operations"]
if LGBMClassifier and LGBMRegressor is not None:
models = models | self.lgbm_dict
for model, value in models.items():
if (
model not in ["knnreg", "knn", "qda"]
Expand Down
11 changes: 11 additions & 0 deletions bamt/utils/composite_utils/lgbm_params.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"lgbm": {
"meta": "sklearn_class",
"tags": ["boosting", "tree", "non_linear", "mix"]
},
"lgbmreg": {
"meta": "sklearn_regr",
"presets": ["*tree"],
"tags": ["boosting", "tree", "non_multi", "non_linear", "mix"]
}
}
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ pgmpy = "0.1.20"
thegolem = ">=0.3.1"
xgboost = ">=1.7.6"
catboost = ">=1.0.6"
lightgbm = {version = ">=4.0.0", optional = true }

[tool.poetry.extras]
composite_extras = ["lightgbm"]


[tool.poetry.dev-dependencies]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,7 +1104,7 @@ def test_learning(self):

self.bn.add_nodes(info)

self.bn.add_edges(self.data)
self.bn.add_edges(self.data, verbose=False)

self.bn.fit_parameters(self.data)

Expand Down

0 comments on commit 2b6cae9

Please sign in to comment.