Skip to content

Commit a24a76b

Browse files
author
Saulo Martiello Mastelini
authored
Tree refactor (online-ml#568)
- Standardize branching and sorting operations - Create the BranchFactory that replaces SplitSuggestion - Improve documentation - Update the user-guide
1 parent aaa3d61 commit a24a76b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+2326
-2370
lines changed

docs/user-guide/on-hoeffding-trees.ipynb

+240-72
Large diffs are not rendered by default.

river/anomaly/hst.py

+3
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ def next(self, x):
4343
return left
4444
return right
4545

46+
def most_common_path(self):
47+
raise NotImplementedError
48+
4649
@property
4750
def repr_split(self):
4851
return f"{self.feature} < {self.threshold:.5f}"

river/ensemble/adaptive_random_forest.py

+86-33
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,16 @@
88

99
from river import base, metrics, stats, tree
1010
from river.drift import ADWIN
11-
from river.tree._nodes import RandomLearningNodeAdaptive # noqa
12-
from river.tree._nodes import RandomLearningNodeMC # noqa
13-
from river.tree._nodes import RandomLearningNodeMean # noqa
14-
from river.tree._nodes import RandomLearningNodeModel # noqa
15-
from river.tree._nodes import RandomLearningNodeNB # noqa
16-
from river.tree._nodes import RandomLearningNodeNBA # noqa
11+
from river.tree.nodes.arf_htc_nodes import (
12+
RandomLeafMajorityClass,
13+
RandomLeafNaiveBayes,
14+
RandomLeafNaiveBayesAdaptive,
15+
)
16+
from river.tree.nodes.arf_htr_nodes import (
17+
RandomLeafAdaptive,
18+
RandomLeafMean,
19+
RandomLeafModel,
20+
)
1721
from river.tree.splitter import Splitter
1822
from river.utils.skmultiflow_utils import check_random_state
1923

@@ -156,8 +160,13 @@ def __init__(
156160
nb_threshold: int = 0,
157161
nominal_attributes: list = None,
158162
splitter: Splitter = None,
163+
binary_split: bool = False,
164+
max_size: int = 100,
165+
memory_estimate_period: int = 1000000,
166+
stop_mem_management: bool = False,
167+
remove_poor_attrs: bool = False,
168+
merit_preprune: bool = True,
159169
seed=None,
160-
**kwargs,
161170
):
162171
super().__init__(
163172
grace_period=grace_period,
@@ -169,14 +178,19 @@ def __init__(
169178
nb_threshold=nb_threshold,
170179
nominal_attributes=nominal_attributes,
171180
splitter=splitter,
172-
**kwargs,
181+
binary_split=binary_split,
182+
max_size=max_size,
183+
memory_estimate_period=memory_estimate_period,
184+
stop_mem_management=stop_mem_management,
185+
remove_poor_attrs=remove_poor_attrs,
186+
merit_preprune=merit_preprune,
173187
)
174188

175189
self.max_features = max_features
176190
self.seed = seed
177191
self._rng = check_random_state(self.seed)
178192

179-
def _new_learning_node(self, initial_stats=None, parent=None):
193+
def _new_leaf(self, initial_stats=None, parent=None):
180194
if initial_stats is None:
181195
initial_stats = {}
182196

@@ -189,15 +203,15 @@ def _new_learning_node(self, initial_stats=None, parent=None):
189203
seed = self._rng.randint(0, 4294967295, dtype="u8")
190204

191205
if self._leaf_prediction == self._MAJORITY_CLASS:
192-
return RandomLearningNodeMC(
206+
return RandomLeafMajorityClass(
193207
initial_stats, depth, self.splitter, self.max_features, seed,
194208
)
195209
elif self._leaf_prediction == self._NAIVE_BAYES:
196-
return RandomLearningNodeNB(
210+
return RandomLeafNaiveBayes(
197211
initial_stats, depth, self.splitter, self.max_features, seed,
198212
)
199213
else: # NAIVE BAYES ADAPTIVE (default)
200-
return RandomLearningNodeNBA(
214+
return RandomLeafNaiveBayesAdaptive(
201215
initial_stats, depth, self.splitter, self.max_features, seed,
202216
)
203217

@@ -231,8 +245,13 @@ def __init__(
231245
nominal_attributes: list = None,
232246
splitter: Splitter = None,
233247
min_samples_split: int = 5,
248+
binary_split: bool = False,
249+
max_size: int = 100,
250+
memory_estimate_period: int = 1000000,
251+
stop_mem_management: bool = False,
252+
remove_poor_attrs: bool = False,
253+
merit_preprune: bool = True,
234254
seed=None,
235-
**kwargs,
236255
):
237256
super().__init__(
238257
grace_period=grace_period,
@@ -245,14 +264,19 @@ def __init__(
245264
nominal_attributes=nominal_attributes,
246265
splitter=splitter,
247266
min_samples_split=min_samples_split,
248-
**kwargs,
267+
binary_split=binary_split,
268+
max_size=max_size,
269+
memory_estimate_period=memory_estimate_period,
270+
stop_mem_management=stop_mem_management,
271+
remove_poor_attrs=remove_poor_attrs,
272+
merit_preprune=merit_preprune,
249273
)
250274

251275
self.max_features = max_features
252276
self.seed = seed
253277
self._rng = check_random_state(self.seed)
254278

255-
def _new_learning_node(self, initial_stats=None, parent=None): # noqa
279+
def _new_leaf(self, initial_stats=None, parent=None): # noqa
256280
"""Create a new learning node.
257281
258282
The type of learning node depends on the tree configuration.
@@ -274,11 +298,11 @@ def _new_learning_node(self, initial_stats=None, parent=None): # noqa
274298
leaf_model = copy.deepcopy(parent._leaf_model) # noqa
275299

276300
if self.leaf_prediction == self._TARGET_MEAN:
277-
return RandomLearningNodeMean(
301+
return RandomLeafMean(
278302
initial_stats, depth, self.splitter, self.max_features, seed,
279303
)
280304
elif self.leaf_prediction == self._MODEL:
281-
return RandomLearningNodeModel(
305+
return RandomLeafModel(
282306
initial_stats,
283307
depth,
284308
self.splitter,
@@ -287,7 +311,7 @@ def _new_learning_node(self, initial_stats=None, parent=None): # noqa
287311
leaf_model=leaf_model,
288312
)
289313
else: # adaptive learning node
290-
new_adaptive = RandomLearningNodeAdaptive(
314+
new_adaptive = RandomLeafAdaptive(
291315
initial_stats,
292316
depth,
293317
self.splitter,
@@ -383,18 +407,23 @@ class AdaptiveRandomForestClassifier(BaseForest, base.Classifier):
383407
property `is_target_class`. This is an advanced option. Special care must be taken when
384408
choosing different splitters. By default, `tree.splitter.GaussianSplitter` is used
385409
if `splitter` is `None`.
410+
binary_split
411+
[*Tree parameter*] If True, only allow binary splits.
386412
max_size
387413
[*Tree parameter*] Maximum memory (MB) consumed by the tree.
388414
memory_estimate_period
389415
[*Tree parameter*] Number of instances between memory consumption checks.
416+
stop_mem_management
417+
[*Tree parameter*] If True, stop growing as soon as memory limit is hit.
418+
remove_poor_attrs
419+
[*Tree parameter*] If True, disable poor attributes to reduce memory usage.
420+
merit_preprune
421+
[*Tree parameter*] If True, enable merit-based tree pre-pruning.
390422
seed
391423
If `int`, `seed` is used to seed the random number generator;
392424
If `RandomState`, `seed` is the random number generator;
393425
If `None`, the random number generator is the `RandomState` instance
394426
used by `np.random`.
395-
kwargs
396-
Other parameters passed to `tree.HoeffdingTree`. Check the `tree` module documentation
397-
for more information.
398427
399428
Examples
400429
--------
@@ -444,10 +473,13 @@ def __init__(
444473
nb_threshold: int = 0,
445474
nominal_attributes: list = None,
446475
splitter: Splitter = None,
476+
binary_split: bool = False,
447477
max_size: int = 32,
448-
memory_estimate_period: int = 2000000,
478+
memory_estimate_period: int = 2_000_000,
479+
stop_mem_management: bool = False,
480+
remove_poor_attrs: bool = False,
481+
merit_preprune: bool = True,
449482
seed: int = None,
450-
**kwargs,
451483
):
452484
super().__init__(
453485
n_models=n_models,
@@ -473,9 +505,12 @@ def __init__(
473505
self.nb_threshold = nb_threshold
474506
self.nominal_attributes = nominal_attributes
475507
self.splitter = splitter
508+
self.binary_split = binary_split
476509
self.max_size = max_size
477510
self.memory_estimate_period = memory_estimate_period
478-
self.kwargs = kwargs
511+
self.stop_mem_management = stop_mem_management
512+
self.remove_poor_attrs = remove_poor_attrs
513+
self.merit_preprune = merit_preprune
479514

480515
@classmethod
481516
def _unit_test_params(cls):
@@ -521,10 +556,13 @@ def _new_base_model(self, seed: int):
521556
nominal_attributes=self.nominal_attributes,
522557
splitter=self.splitter,
523558
max_depth=self.max_depth,
524-
memory_estimate_period=self.memory_estimate_period,
559+
binary_split=self.binary_split,
525560
max_size=self.max_size,
561+
memory_estimate_period=self.memory_estimate_period,
562+
stop_mem_management=self.stop_mem_management,
563+
remove_poor_attrs=self.remove_poor_attrs,
564+
merit_preprune=self.merit_preprune,
526565
seed=seed,
527-
**self.kwargs,
528566
)
529567

530568

@@ -622,18 +660,23 @@ class AdaptiveRandomForestRegressor(BaseForest, base.Regressor):
622660
min_samples_split
623661
[*Tree parameter*] The minimum number of samples every branch resulting from a split
624662
candidate must have to be considered valid.
663+
binary_split
664+
[*Tree parameter*] If True, only allow binary splits.
625665
max_size
626666
[*Tree parameter*] Maximum memory (MB) consumed by the tree.
627667
memory_estimate_period
628668
[*Tree parameter*] Number of instances between memory consumption checks.
669+
stop_mem_management
670+
[*Tree parameter*] If True, stop growing as soon as memory limit is hit.
671+
remove_poor_attrs
672+
[*Tree parameter*] If True, disable poor attributes to reduce memory usage.
673+
merit_preprune
674+
[*Tree parameter*] If True, enable merit-based tree pre-pruning.
629675
seed
630676
If `int`, `seed` is used to seed the random number generator;
631677
If `RandomState`, `seed` is the random number generator;
632678
If `None`, the random number generator is the `RandomState` instance
633679
used by `np.random`.
634-
kwargs
635-
Other parameters passed to `tree.HoeffdingTree`. Check the `tree` module documentation
636-
for more information.
637680
638681
References
639682
----------
@@ -693,8 +736,12 @@ def __init__(
693736
nominal_attributes: list = None,
694737
splitter: Splitter = None,
695738
min_samples_split: int = 5,
696-
max_size: int = 100,
697-
memory_estimate_period: int = 2000000,
739+
binary_split: bool = False,
740+
max_size: int = 500,
741+
memory_estimate_period: int = 2_000_000,
742+
stop_mem_management: bool = False,
743+
remove_poor_attrs: bool = False,
744+
merit_preprune: bool = True,
698745
seed: int = None,
699746
**kwargs,
700747
):
@@ -723,9 +770,12 @@ def __init__(
723770
self.nominal_attributes = nominal_attributes
724771
self.splitter = splitter
725772
self.min_samples_split = min_samples_split
773+
self.binary_split = binary_split
726774
self.max_size = max_size
727775
self.memory_estimate_period = memory_estimate_period
728-
self.kwargs = kwargs
776+
self.stop_mem_management = stop_mem_management
777+
self.remove_poor_attrs = remove_poor_attrs
778+
self.merit_preprune = merit_preprune
729779

730780
if aggregation_method in self._VALID_AGGREGATION_METHOD:
731781
self.aggregation_method = aggregation_method
@@ -787,10 +837,13 @@ def _new_base_model(self, seed: int):
787837
model_selector_decay=self.model_selector_decay,
788838
nominal_attributes=self.nominal_attributes,
789839
splitter=self.splitter,
840+
binary_split=self.binary_split,
790841
max_size=self.max_size,
791842
memory_estimate_period=self.memory_estimate_period,
843+
stop_mem_management=self.stop_mem_management,
844+
remove_poor_attrs=self.remove_poor_attrs,
845+
merit_preprune=self.merit_preprune,
792846
seed=seed,
793-
**self.kwargs,
794847
)
795848

796849
@property

river/tree/__init__.py

+4-21
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,10 @@
66
Each family of iDT will be presented in a dedicated section.
77
88
At any moment, iDT might face situations where an input feature previously used to make
9-
a split decision is missing in an incoming sample. In this case, the river's trees follow the
10-
conventions:
11-
12-
- *Learning:* choose the subtree branch most traversed so far to pass the instance on.</br>
13-
* In case of nominal features, a new branch is created to accommodate the new
14-
category.</br>
15-
- *Predicting:* Use the last "reachable" decision node to provide responses.
9+
a split decision is missing in an incoming sample. In this case, the most traversed path is
10+
selected to pass down the instance. Moreover, in the case of nominal features, if a new category
11+
arises and the feature is used in a decision node, a new branch is created to accommodate the new
12+
value.
1613
1714
**1. Hoeffding Trees**
1815
@@ -42,20 +39,6 @@
4239
* Define properties to access leaf prediction strategies, split criteria, and other
4340
relevant characteristics.
4441
45-
All HTs have the following parameters, in addition to their own, that can be selected
46-
using `**kwargs`. The following default values are used, unless otherwise explicitly stated
47-
in the tree documentation.
48-
49-
| Parameter | Description | Default |
50-
| :- | :- | -: |
51-
|`max_depth` | The maximum depth a tree can reach. If `None`, the tree will grow indefinitely. | `None` |
52-
| `binary_split` | If True, only allow binary splits. | `False` |
53-
| `max_size` | The maximum size the tree can reach, in Megabytes (MB). | `100` |
54-
| `memory_estimate_period` | Interval (number of processed instances) between memory consumption checks. | `1_000_000` |
55-
| `stop_mem_management` | If True, stop growing as soon as memory limit is hit. | `False` |
56-
| `remove_poor_attrs` | If True, disable poorly descriptive attributes to reduce memory usage. | `False` |
57-
| `merit_preprune` | If True, enable merit-based tree pre-pruning. | `True` |
58-
5942
"""
6043

6144
from . import splitter

river/tree/_attribute_test/__init__.py

-15
This file was deleted.

0 commit comments

Comments
 (0)