3
3
import io
4
4
import json
5
5
from typing import Any
6
+ from typing import Callable
6
7
from typing import Optional
7
8
from typing import Union
8
9
@@ -45,6 +46,11 @@ class RuleGeneratorConfigurator:
45
46
"""Class for configuring rule induction parameters
46
47
"""
47
48
49
+ _MEASURES_PARAMETERS : list [str ] = [
50
+ 'induction_measure' , 'pruning_measure' , 'voting_measure' ,
51
+ ]
52
+ _USER_DEFINED_MEASURE_VALUE : str = 'UserDefined'
53
+
48
54
def __init__ (self , rule_generator ):
49
55
self .rule_generator = rule_generator
50
56
self .LogRank = None # pylint: disable=invalid-name
@@ -96,16 +102,25 @@ def _configure_measure_parameter(self, param_name: str, param_value: Union[str,
96
102
if isinstance (param_value , Measures ):
97
103
self .rule_generator .setParameter (
98
104
param_name , param_value .value )
99
- if isinstance (param_value , str ):
100
- self .rule_generator .setParameter (param_name , 'UserDefined' )
101
- self .rule_generator .setParameter (param_name , param_value )
105
+ if isinstance (param_value , Callable ):
106
+ self ._configure_user_defined_measure_parameter (
107
+ param_name , param_value )
108
+
109
+ def _configure_user_defined_measure_parameter (self , param_name : str , param_value : Any ):
110
+ from rulekit .params import _user_defined_measure_factory
111
+ user_defined_measure = _user_defined_measure_factory (param_value )
112
+ {
113
+ 'induction_measure' : self .rule_generator .setUserMeasureInductionObject ,
114
+ 'pruning_measure' : self .rule_generator .setUserMeasurePurningObject ,
115
+ 'voting_measure' : self .rule_generator .setUserMeasureVotingObject ,
116
+ }[param_name ](user_defined_measure )
117
+ self .rule_generator .setParameter (
118
+ param_name , self ._USER_DEFINED_MEASURE_VALUE )
102
119
103
120
def _configure_rule_generator (self , ** kwargs : dict [str , Any ]):
104
- if kwargs .get ('induction_measure' ) == Measures .LogRank or \
105
- kwargs .get ('pruning_measure' ) == Measures .LogRank or \
106
- kwargs .get ('voting_measure' ) == Measures .LogRank :
121
+ if any ([kwargs .get (param_name ) == Measures .LogRank for param_name in self ._MEASURES_PARAMETERS ]):
107
122
self .LogRank = JClass ('adaa.analytics.rules.logic.quality.LogRank' )
108
- for measure_param_name in [ 'induction_measure' , 'pruning_measure' , 'voting_measure' ] :
123
+ for measure_param_name in self . _MEASURES_PARAMETERS :
109
124
measure_param_value : Measures = kwargs .pop (
110
125
measure_param_name , None )
111
126
self ._configure_measure_parameter (
@@ -125,6 +140,15 @@ def _validate_rule_generator_parameters(self, **python_parameters: dict[str, Any
125
140
ValueError: If failed to retrieve RuleGenerator parameters JSON
126
141
RuleKitMisconfigurationException: If Java and Python parameters do not match
127
142
"""
143
+ def are_params_equal (java_params : dict [str , Any ], python_params : dict [str , Any ]):
144
+ if java_params .keys () != python_params .keys ():
145
+ return False
146
+ for key in java_params .keys ():
147
+ skip_check : bool = isinstance (python_params [key ], Callable )
148
+ if java_params [key ] != python_params [key ] and not skip_check :
149
+ return False
150
+ return True
151
+
128
152
python_parameters = dict (python_parameters )
129
153
for param_name , param_value in python_parameters .items ():
130
154
# convert measures to strings values for comparison
@@ -152,7 +176,7 @@ def _validate_rule_generator_parameters(self, **python_parameters: dict[str, Any
152
176
param_name : str (java_params [param_name ])
153
177
for param_name in python_parameters .keys ()
154
178
}
155
- if java_params != python_parameters :
179
+ if not are_params_equal ( java_params , python_parameters ) :
156
180
raise RuleKitMisconfigurationException (
157
181
java_parameters = java_params ,
158
182
python_parameters = python_parameters
0 commit comments