From b14955afc5541207dbb6a931a952138646e3b19e Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Thu, 5 Dec 2019 17:50:47 -0500 Subject: [PATCH] fixed deserialization logic for mimic explainer with older versions of lightgbm (#124) --- python/interpret_community/common/constants.py | 1 + .../interpret_community/mimic/models/lightgbm_model.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/python/interpret_community/common/constants.py b/python/interpret_community/common/constants.py index d52923c6..dcf6e01b 100644 --- a/python/interpret_community/common/constants.py +++ b/python/interpret_community/common/constants.py @@ -197,6 +197,7 @@ class LightGBMSerializationConstants(object): LOGGER = '_logger' MODEL_STR = 'model_str' MULTICLASS = 'multiclass' + REGRESSION = 'regression' TREE_EXPLAINER = '_tree_explainer' OBJECTIVE = 'objective' diff --git a/python/interpret_community/mimic/models/lightgbm_model.py b/python/interpret_community/mimic/models/lightgbm_model.py index 573e5eeb..50fda0ae 100644 --- a/python/interpret_community/mimic/models/lightgbm_model.py +++ b/python/interpret_community/mimic/models/lightgbm_model.py @@ -271,13 +271,19 @@ def _load(properties): # https://github.com/Microsoft/LightGBM/issues/1942 # https://github.com/Microsoft/LightGBM/issues/1217 booster_args = {LightGBMSerializationConstants.MODEL_STR: value} + is_multiclass = json.loads(properties[LightGBMSerializationConstants.MULTICLASS]) + if is_multiclass: + objective = LightGBMSerializationConstants.MULTICLASS + else: + objective = LightGBMSerializationConstants.REGRESSION if LightGBMSerializationConstants.MODEL_STR in inspect.getargspec(Booster).args: - extras = {LightGBMSerializationConstants.OBJECTIVE: LightGBMSerializationConstants.MULTICLASS} + extras = {LightGBMSerializationConstants.OBJECTIVE: objective} lgbm_booster = Booster(**booster_args, params=extras) else: # For backwards compatibility with older versions of lightgbm + booster_args[LightGBMSerializationConstants.OBJECTIVE] = objective lgbm_booster = Booster(params=booster_args) - if json.loads(properties[LightGBMSerializationConstants.MULTICLASS]): + if is_multiclass: new_lgbm = LGBMClassifier() new_lgbm._Booster = lgbm_booster new_lgbm._n_classes = _n_classes