Skip to content

Commit 883c8f1

Browse files
authored
Merge pull request #33 from adaa-polsl/feat-user-defined-measures
Feat user defined measures
2 parents 6088f0d + cad00e7 commit 883c8f1

File tree

5 files changed

+110
-13
lines changed

5 files changed

+110
-13
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -293,4 +293,5 @@ notebooks/
293293
classification_tabular_datasets/
294294
.coverage
295295
junit.xml
296-
a.ipynb
296+
a.ipynb
297+
java.logs

rulekit/_helpers.py

+32-8
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import io
44
import json
55
from typing import Any
6+
from typing import Callable
67
from typing import Optional
78
from typing import Union
89

@@ -45,6 +46,11 @@ class RuleGeneratorConfigurator:
4546
"""Class for configuring rule induction parameters
4647
"""
4748

49+
_MEASURES_PARAMETERS: list[str] = [
50+
'induction_measure', 'pruning_measure', 'voting_measure',
51+
]
52+
_USER_DEFINED_MEASURE_VALUE: str = 'UserDefined'
53+
4854
def __init__(self, rule_generator):
4955
self.rule_generator = rule_generator
5056
self.LogRank = None # pylint: disable=invalid-name
@@ -96,16 +102,25 @@ def _configure_measure_parameter(self, param_name: str, param_value: Union[str,
96102
if isinstance(param_value, Measures):
97103
self.rule_generator.setParameter(
98104
param_name, param_value.value)
99-
if isinstance(param_value, str):
100-
self.rule_generator.setParameter(param_name, 'UserDefined')
101-
self.rule_generator.setParameter(param_name, param_value)
105+
if isinstance(param_value, Callable):
106+
self._configure_user_defined_measure_parameter(
107+
param_name, param_value)
108+
109+
def _configure_user_defined_measure_parameter(self, param_name: str, param_value: Any):
110+
from rulekit.params import _user_defined_measure_factory
111+
user_defined_measure = _user_defined_measure_factory(param_value)
112+
{
113+
'induction_measure': self.rule_generator.setUserMeasureInductionObject,
114+
'pruning_measure': self.rule_generator.setUserMeasurePurningObject,
115+
'voting_measure': self.rule_generator.setUserMeasureVotingObject,
116+
}[param_name](user_defined_measure)
117+
self.rule_generator.setParameter(
118+
param_name, self._USER_DEFINED_MEASURE_VALUE)
102119

103120
def _configure_rule_generator(self, **kwargs: dict[str, Any]):
104-
if kwargs.get('induction_measure') == Measures.LogRank or \
105-
kwargs.get('pruning_measure') == Measures.LogRank or \
106-
kwargs.get('voting_measure') == Measures.LogRank:
121+
if any([kwargs.get(param_name) == Measures.LogRank for param_name in self._MEASURES_PARAMETERS]):
107122
self.LogRank = JClass('adaa.analytics.rules.logic.quality.LogRank')
108-
for measure_param_name in ['induction_measure', 'pruning_measure', 'voting_measure']:
123+
for measure_param_name in self._MEASURES_PARAMETERS:
109124
measure_param_value: Measures = kwargs.pop(
110125
measure_param_name, None)
111126
self._configure_measure_parameter(
@@ -125,6 +140,15 @@ def _validate_rule_generator_parameters(self, **python_parameters: dict[str, Any
125140
ValueError: If failed to retrieve RuleGenerator parameters JSON
126141
RuleKitMisconfigurationException: If Java and Python parameters do not match
127142
"""
143+
def are_params_equal(java_params: dict[str, Any], python_params: dict[str, Any]):
144+
if java_params.keys() != python_params.keys():
145+
return False
146+
for key in java_params.keys():
147+
skip_check: bool = isinstance(python_params[key], Callable)
148+
if java_params[key] != python_params[key] and not skip_check:
149+
return False
150+
return True
151+
128152
python_parameters = dict(python_parameters)
129153
for param_name, param_value in python_parameters.items():
130154
# convert measures to strings values for comparison
@@ -152,7 +176,7 @@ def _validate_rule_generator_parameters(self, **python_parameters: dict[str, Any
152176
param_name: str(java_params[param_name])
153177
for param_name in python_parameters.keys()
154178
}
155-
if java_params != python_parameters:
179+
if not are_params_equal(java_params, python_parameters):
156180
raise RuleKitMisconfigurationException(
157181
java_parameters=java_params,
158182
python_parameters=python_parameters

rulekit/exceptions.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Module containing classes for handling exceptions."""
22
from typing import Any
3+
from typing import Callable
34

45
from jpype import JException
56

@@ -71,7 +72,9 @@ def _prepare_message(
7172
java_value = java_parameters.get(key)
7273
python_value = python_parameters.get(key)
7374
line: str = f' {key}: ({java_value}, {python_value}),'
74-
if java_value != python_value:
75+
# skip check for user defined measures
76+
skip_check: bool = isinstance(python_value, Callable)
77+
if java_value != python_value and not skip_check:
7578
line = f'{line} <-- **DIFFERENT**'
7679
params_lines.append(line)
7780
message: str = (

rulekit/params.py

+28-3
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,36 @@
11
"""Contains constants and classes for specyfing models parameters
22
"""
33
from enum import Enum
4+
from typing import Callable
45
from typing import Optional
56
from typing import Tuple
7+
from typing import Union
68

9+
from jpype import JImplements
10+
from jpype import JOverride
11+
from jpype.types import JDouble
712
from pydantic import BaseModel # pylint: disable=no-name-in-module
813

914
MAX_INT: int = 2147483647 # max integer value in Java
1015

16+
_UserDefinedMeasure = Callable[[float, float, float, float], float]
17+
18+
19+
def _user_defined_measure_factory(measure_function: _UserDefinedMeasure):
20+
from adaa.analytics.rules.logic.quality import \
21+
IUserMeasure # pylint: disable=import-outside-toplevel,import-error
22+
23+
@JImplements(IUserMeasure)
24+
class _UserMeasure: # pylint: disable=invalid-name,missing-function-docstring
25+
26+
@JOverride
27+
def getResult(self, p: JDouble, n: JDouble, P: JDouble, N: JDouble) -> float:
28+
return measure_function(
29+
float(p), float(n), float(P), float(N)
30+
)
31+
32+
return _UserMeasure()
33+
1134

1235
class Measures(Enum):
1336
# pylint: disable=invalid-name
@@ -92,14 +115,16 @@ class Measures(Enum):
92115
'penalty_saturation': 0.2,
93116
}
94117

118+
_QualityMeasure = Union[Measures, _UserDefinedMeasure]
119+
95120

96121
class ModelsParams(BaseModel):
97122
"""Model for validating models hyperparameters
98123
"""
99124
minsupp_new: Optional[float] = DEFAULT_PARAMS_VALUE['minsupp_new']
100-
induction_measure: Optional[Measures] = DEFAULT_PARAMS_VALUE['induction_measure']
101-
pruning_measure: Optional[Measures] = DEFAULT_PARAMS_VALUE['pruning_measure']
102-
voting_measure: Optional[Measures] = DEFAULT_PARAMS_VALUE['voting_measure']
125+
induction_measure: Optional[_QualityMeasure] = DEFAULT_PARAMS_VALUE['induction_measure']
126+
pruning_measure: Optional[_QualityMeasure] = DEFAULT_PARAMS_VALUE['pruning_measure']
127+
voting_measure: Optional[_QualityMeasure] = DEFAULT_PARAMS_VALUE['voting_measure']
103128
max_growing: Optional[float] = DEFAULT_PARAMS_VALUE['max_growing']
104129
enable_pruning: Optional[bool] = DEFAULT_PARAMS_VALUE['enable_pruning']
105130
ignore_missing: Optional[bool] = DEFAULT_PARAMS_VALUE['ignore_missing']

tests/test_classifier.py

+44
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from rulekit import classification
1313
from rulekit.events import RuleInductionProgressListener
14+
from rulekit.params import Measures
1415
from rulekit.rules import Rule
1516
from tests.utils import assert_accuracy_is_greater
1617
from tests.utils import assert_rules_are_equals
@@ -282,6 +283,49 @@ def test_left_open_intervals_in_expert_induction(self):
282283
expert_forbidden_conditions=expert_forbidden_conditions
283284
)
284285

286+
def test_user_defined_measures(self):
287+
def full_coverage(p: float, n: float, P: float, N: float) -> float:
288+
return (p + n) / (P + N)
289+
290+
python_clf = classification.RuleClassifier(
291+
induction_measure=full_coverage,
292+
pruning_measure=full_coverage,
293+
voting_measure=full_coverage,
294+
)
295+
java_clf = classification.RuleClassifier(
296+
induction_measure=Measures.FullCoverage,
297+
pruning_measure=Measures.FullCoverage,
298+
voting_measure=Measures.FullCoverage,
299+
)
300+
x, y = load_iris(return_X_y=True)
301+
302+
python_clf.fit(x, y)
303+
java_clf.fit(x, y)
304+
305+
self.assertEqual(
306+
[r.weight for r in python_clf.model.rules],
307+
[r.weight for r in java_clf.model.rules],
308+
'Weights should be equal'
309+
)
310+
self.assertEqual(
311+
[str(r) for r in python_clf.model.rules],
312+
[str(r) for r in java_clf.model.rules],
313+
'Rules should be equal'
314+
)
315+
316+
def zero_measure(p: float, n: float, P: float, N: float) -> float:
317+
return 0.0
318+
319+
python_clf2 = classification.RuleClassifier(
320+
induction_measure=Measures.FullCoverage,
321+
pruning_measure=Measures.FullCoverage,
322+
voting_measure=zero_measure,
323+
)
324+
python_clf2.fit(x, y)
325+
self.assertTrue(all([
326+
r.weight == 0.0 for r in python_clf2.model.rules
327+
]))
328+
285329

286330
if __name__ == '__main__':
287331
unittest.main()

0 commit comments

Comments
 (0)