Skip to content

Commit

Permalink
pkg_resources hotfix for composite nets
Browse files Browse the repository at this point in the history
  • Loading branch information
jrzkaminski committed Aug 30, 2023
1 parent dfd874a commit 7cb9a95
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions bamt/utils/composite_utils/MLUtils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from random import choice
import pkg_resources

from catboost import CatBoostClassifier, CatBoostRegressor
from sklearn.cluster import KMeans
Expand Down Expand Up @@ -35,6 +36,11 @@
"Install lightgbm (e.g. pip install lightgbm) to enable LGBMRegressor and LGBMClassifier"
)

lgbm_params = "lgbm_params.json"
models_repo = "models_repo.json"
lgbm_params_path = pkg_resources.resource_filename(__name__, lgbm_params)
models_repo_path = pkg_resources.resource_filename(__name__, models_repo)


class MlModels:
def __init__(self):
Expand Down Expand Up @@ -96,7 +102,7 @@ def __init__(self):
self.operations_by_types["lgbm"] = "LGBMClassifier"

if LGBMClassifier and LGBMRegressor is not None:
with open("bamt/utils/composite_utils/lgbm_params.json") as file:
with open(lgbm_params_path) as file:
self.lgbm_dict = json.load(file)

def get_model_by_children_type(self, node: CompositeNode):
Expand All @@ -106,7 +112,7 @@ def get_model_by_children_type(self, node: CompositeNode):
else:
type_model = "class"
forbidden_tags = ["non-default", "expensive"]
with open("bamt/utils/composite_utils/models_repo.json", "r") as f:
with open(models_repo_path, "r") as f:
models_json = json.load(f)
models = models_json["operations"]
if LGBMClassifier and LGBMRegressor is not None:
Expand Down

0 comments on commit 7cb9a95

Please sign in to comment.