8
8
9
9
from river import base , metrics , stats , tree
10
10
from river .drift import ADWIN
11
- from river .tree ._nodes import RandomLearningNodeAdaptive # noqa
12
- from river .tree ._nodes import RandomLearningNodeMC # noqa
13
- from river .tree ._nodes import RandomLearningNodeMean # noqa
14
- from river .tree ._nodes import RandomLearningNodeModel # noqa
15
- from river .tree ._nodes import RandomLearningNodeNB # noqa
16
- from river .tree ._nodes import RandomLearningNodeNBA # noqa
11
+ from river .tree .nodes .arf_htc_nodes import (
12
+ RandomLeafMajorityClass ,
13
+ RandomLeafNaiveBayes ,
14
+ RandomLeafNaiveBayesAdaptive ,
15
+ )
16
+ from river .tree .nodes .arf_htr_nodes import (
17
+ RandomLeafAdaptive ,
18
+ RandomLeafMean ,
19
+ RandomLeafModel ,
20
+ )
17
21
from river .tree .splitter import Splitter
18
22
from river .utils .skmultiflow_utils import check_random_state
19
23
@@ -156,8 +160,13 @@ def __init__(
156
160
nb_threshold : int = 0 ,
157
161
nominal_attributes : list = None ,
158
162
splitter : Splitter = None ,
163
+ binary_split : bool = False ,
164
+ max_size : int = 100 ,
165
+ memory_estimate_period : int = 1000000 ,
166
+ stop_mem_management : bool = False ,
167
+ remove_poor_attrs : bool = False ,
168
+ merit_preprune : bool = True ,
159
169
seed = None ,
160
- ** kwargs ,
161
170
):
162
171
super ().__init__ (
163
172
grace_period = grace_period ,
@@ -169,14 +178,19 @@ def __init__(
169
178
nb_threshold = nb_threshold ,
170
179
nominal_attributes = nominal_attributes ,
171
180
splitter = splitter ,
172
- ** kwargs ,
181
+ binary_split = binary_split ,
182
+ max_size = max_size ,
183
+ memory_estimate_period = memory_estimate_period ,
184
+ stop_mem_management = stop_mem_management ,
185
+ remove_poor_attrs = remove_poor_attrs ,
186
+ merit_preprune = merit_preprune ,
173
187
)
174
188
175
189
self .max_features = max_features
176
190
self .seed = seed
177
191
self ._rng = check_random_state (self .seed )
178
192
179
- def _new_learning_node (self , initial_stats = None , parent = None ):
193
+ def _new_leaf (self , initial_stats = None , parent = None ):
180
194
if initial_stats is None :
181
195
initial_stats = {}
182
196
@@ -189,15 +203,15 @@ def _new_learning_node(self, initial_stats=None, parent=None):
189
203
seed = self ._rng .randint (0 , 4294967295 , dtype = "u8" )
190
204
191
205
if self ._leaf_prediction == self ._MAJORITY_CLASS :
192
- return RandomLearningNodeMC (
206
+ return RandomLeafMajorityClass (
193
207
initial_stats , depth , self .splitter , self .max_features , seed ,
194
208
)
195
209
elif self ._leaf_prediction == self ._NAIVE_BAYES :
196
- return RandomLearningNodeNB (
210
+ return RandomLeafNaiveBayes (
197
211
initial_stats , depth , self .splitter , self .max_features , seed ,
198
212
)
199
213
else : # NAIVE BAYES ADAPTIVE (default)
200
- return RandomLearningNodeNBA (
214
+ return RandomLeafNaiveBayesAdaptive (
201
215
initial_stats , depth , self .splitter , self .max_features , seed ,
202
216
)
203
217
@@ -231,8 +245,13 @@ def __init__(
231
245
nominal_attributes : list = None ,
232
246
splitter : Splitter = None ,
233
247
min_samples_split : int = 5 ,
248
+ binary_split : bool = False ,
249
+ max_size : int = 100 ,
250
+ memory_estimate_period : int = 1000000 ,
251
+ stop_mem_management : bool = False ,
252
+ remove_poor_attrs : bool = False ,
253
+ merit_preprune : bool = True ,
234
254
seed = None ,
235
- ** kwargs ,
236
255
):
237
256
super ().__init__ (
238
257
grace_period = grace_period ,
@@ -245,14 +264,19 @@ def __init__(
245
264
nominal_attributes = nominal_attributes ,
246
265
splitter = splitter ,
247
266
min_samples_split = min_samples_split ,
248
- ** kwargs ,
267
+ binary_split = binary_split ,
268
+ max_size = max_size ,
269
+ memory_estimate_period = memory_estimate_period ,
270
+ stop_mem_management = stop_mem_management ,
271
+ remove_poor_attrs = remove_poor_attrs ,
272
+ merit_preprune = merit_preprune ,
249
273
)
250
274
251
275
self .max_features = max_features
252
276
self .seed = seed
253
277
self ._rng = check_random_state (self .seed )
254
278
255
- def _new_learning_node (self , initial_stats = None , parent = None ): # noqa
279
+ def _new_leaf (self , initial_stats = None , parent = None ): # noqa
256
280
"""Create a new learning node.
257
281
258
282
The type of learning node depends on the tree configuration.
@@ -274,11 +298,11 @@ def _new_learning_node(self, initial_stats=None, parent=None): # noqa
274
298
leaf_model = copy .deepcopy (parent ._leaf_model ) # noqa
275
299
276
300
if self .leaf_prediction == self ._TARGET_MEAN :
277
- return RandomLearningNodeMean (
301
+ return RandomLeafMean (
278
302
initial_stats , depth , self .splitter , self .max_features , seed ,
279
303
)
280
304
elif self .leaf_prediction == self ._MODEL :
281
- return RandomLearningNodeModel (
305
+ return RandomLeafModel (
282
306
initial_stats ,
283
307
depth ,
284
308
self .splitter ,
@@ -287,7 +311,7 @@ def _new_learning_node(self, initial_stats=None, parent=None): # noqa
287
311
leaf_model = leaf_model ,
288
312
)
289
313
else : # adaptive learning node
290
- new_adaptive = RandomLearningNodeAdaptive (
314
+ new_adaptive = RandomLeafAdaptive (
291
315
initial_stats ,
292
316
depth ,
293
317
self .splitter ,
@@ -383,18 +407,23 @@ class AdaptiveRandomForestClassifier(BaseForest, base.Classifier):
383
407
property `is_target_class`. This is an advanced option. Special care must be taken when
384
408
choosing different splitters. By default, `tree.splitter.GaussianSplitter` is used
385
409
if `splitter` is `None`.
410
+ binary_split
411
+ [*Tree parameter*] If True, only allow binary splits.
386
412
max_size
387
413
[*Tree parameter*] Maximum memory (MB) consumed by the tree.
388
414
memory_estimate_period
389
415
[*Tree parameter*] Number of instances between memory consumption checks.
416
+ stop_mem_management
417
+ [*Tree parameter*] If True, stop growing as soon as memory limit is hit.
418
+ remove_poor_attrs
419
+ [*Tree parameter*] If True, disable poor attributes to reduce memory usage.
420
+ merit_preprune
421
+ [*Tree parameter*] If True, enable merit-based tree pre-pruning.
390
422
seed
391
423
If `int`, `seed` is used to seed the random number generator;
392
424
If `RandomState`, `seed` is the random number generator;
393
425
If `None`, the random number generator is the `RandomState` instance
394
426
used by `np.random`.
395
- kwargs
396
- Other parameters passed to `tree.HoeffdingTree`. Check the `tree` module documentation
397
- for more information.
398
427
399
428
Examples
400
429
--------
@@ -444,10 +473,13 @@ def __init__(
444
473
nb_threshold : int = 0 ,
445
474
nominal_attributes : list = None ,
446
475
splitter : Splitter = None ,
476
+ binary_split : bool = False ,
447
477
max_size : int = 32 ,
448
- memory_estimate_period : int = 2000000 ,
478
+ memory_estimate_period : int = 2_000_000 ,
479
+ stop_mem_management : bool = False ,
480
+ remove_poor_attrs : bool = False ,
481
+ merit_preprune : bool = True ,
449
482
seed : int = None ,
450
- ** kwargs ,
451
483
):
452
484
super ().__init__ (
453
485
n_models = n_models ,
@@ -473,9 +505,12 @@ def __init__(
473
505
self .nb_threshold = nb_threshold
474
506
self .nominal_attributes = nominal_attributes
475
507
self .splitter = splitter
508
+ self .binary_split = binary_split
476
509
self .max_size = max_size
477
510
self .memory_estimate_period = memory_estimate_period
478
- self .kwargs = kwargs
511
+ self .stop_mem_management = stop_mem_management
512
+ self .remove_poor_attrs = remove_poor_attrs
513
+ self .merit_preprune = merit_preprune
479
514
480
515
@classmethod
481
516
def _unit_test_params (cls ):
@@ -521,10 +556,13 @@ def _new_base_model(self, seed: int):
521
556
nominal_attributes = self .nominal_attributes ,
522
557
splitter = self .splitter ,
523
558
max_depth = self .max_depth ,
524
- memory_estimate_period = self .memory_estimate_period ,
559
+ binary_split = self .binary_split ,
525
560
max_size = self .max_size ,
561
+ memory_estimate_period = self .memory_estimate_period ,
562
+ stop_mem_management = self .stop_mem_management ,
563
+ remove_poor_attrs = self .remove_poor_attrs ,
564
+ merit_preprune = self .merit_preprune ,
526
565
seed = seed ,
527
- ** self .kwargs ,
528
566
)
529
567
530
568
@@ -622,18 +660,23 @@ class AdaptiveRandomForestRegressor(BaseForest, base.Regressor):
622
660
min_samples_split
623
661
[*Tree parameter*] The minimum number of samples every branch resulting from a split
624
662
candidate must have to be considered valid.
663
+ binary_split
664
+ [*Tree parameter*] If True, only allow binary splits.
625
665
max_size
626
666
[*Tree parameter*] Maximum memory (MB) consumed by the tree.
627
667
memory_estimate_period
628
668
[*Tree parameter*] Number of instances between memory consumption checks.
669
+ stop_mem_management
670
+ [*Tree parameter*] If True, stop growing as soon as memory limit is hit.
671
+ remove_poor_attrs
672
+ [*Tree parameter*] If True, disable poor attributes to reduce memory usage.
673
+ merit_preprune
674
+ [*Tree parameter*] If True, enable merit-based tree pre-pruning.
629
675
seed
630
676
If `int`, `seed` is used to seed the random number generator;
631
677
If `RandomState`, `seed` is the random number generator;
632
678
If `None`, the random number generator is the `RandomState` instance
633
679
used by `np.random`.
634
- kwargs
635
- Other parameters passed to `tree.HoeffdingTree`. Check the `tree` module documentation
636
- for more information.
637
680
638
681
References
639
682
----------
@@ -693,8 +736,12 @@ def __init__(
693
736
nominal_attributes : list = None ,
694
737
splitter : Splitter = None ,
695
738
min_samples_split : int = 5 ,
696
- max_size : int = 100 ,
697
- memory_estimate_period : int = 2000000 ,
739
+ binary_split : bool = False ,
740
+ max_size : int = 500 ,
741
+ memory_estimate_period : int = 2_000_000 ,
742
+ stop_mem_management : bool = False ,
743
+ remove_poor_attrs : bool = False ,
744
+ merit_preprune : bool = True ,
698
745
seed : int = None ,
699
746
** kwargs ,
700
747
):
@@ -723,9 +770,12 @@ def __init__(
723
770
self .nominal_attributes = nominal_attributes
724
771
self .splitter = splitter
725
772
self .min_samples_split = min_samples_split
773
+ self .binary_split = binary_split
726
774
self .max_size = max_size
727
775
self .memory_estimate_period = memory_estimate_period
728
- self .kwargs = kwargs
776
+ self .stop_mem_management = stop_mem_management
777
+ self .remove_poor_attrs = remove_poor_attrs
778
+ self .merit_preprune = merit_preprune
729
779
730
780
if aggregation_method in self ._VALID_AGGREGATION_METHOD :
731
781
self .aggregation_method = aggregation_method
@@ -787,10 +837,13 @@ def _new_base_model(self, seed: int):
787
837
model_selector_decay = self .model_selector_decay ,
788
838
nominal_attributes = self .nominal_attributes ,
789
839
splitter = self .splitter ,
840
+ binary_split = self .binary_split ,
790
841
max_size = self .max_size ,
791
842
memory_estimate_period = self .memory_estimate_period ,
843
+ stop_mem_management = self .stop_mem_management ,
844
+ remove_poor_attrs = self .remove_poor_attrs ,
845
+ merit_preprune = self .merit_preprune ,
792
846
seed = seed ,
793
- ** self .kwargs ,
794
847
)
795
848
796
849
@property
0 commit comments