@@ -318,7 +318,7 @@ def _make_td(state: torch.Tensor, action: torch.Tensor) -> TensorDict:
318
318
def _make_forest (self ) -> MCTSForest :
319
319
r0 , r1 , r2 , r3 , r4 = self .dummy_rollouts ()
320
320
assert r0 .shape
321
- forest = MCTSForest (consolidated = True )
321
+ forest = MCTSForest ()
322
322
forest .extend (r0 )
323
323
forest .extend (r1 )
324
324
forest .extend (r2 )
@@ -363,10 +363,24 @@ def _make_forest_intersect(self) -> MCTSForest:
363
363
forest .extend (rollout5 )
364
364
return forest
365
365
366
+ @staticmethod
367
+ def make_labels (tree ):
368
+ if tree .rollout is not None :
369
+ s = torch .cat (
370
+ [
371
+ tree .rollout ["observation" ][:1 ],
372
+ tree .rollout ["next" , "observation" ],
373
+ ]
374
+ )
375
+ s = s .tolist ()
376
+ return f"{ tree .node_id } : { s } "
377
+ return f"{ tree .node_id } "
378
+
366
379
def test_forest_build (self ):
367
380
r0 , * _ = self .dummy_rollouts ()
368
381
forest = self ._make_forest ()
369
382
tree = forest .get_tree (r0 [0 ])
383
+ # tree.plot(make_labels=self.make_labels)
370
384
371
385
def test_forest_vertices (self ):
372
386
r0 , * _ = self .dummy_rollouts ()
@@ -436,18 +450,6 @@ def test_forest_intersect(self):
436
450
tree = forest .get_tree (state0 )
437
451
subtree = forest .get_tree (TensorDict (observation = 19 ))
438
452
439
- def make_labels (tree ):
440
- if tree .rollout is not None :
441
- s = torch .cat (
442
- [
443
- tree .rollout ["observation" ][:1 ],
444
- tree .rollout ["next" , "observation" ],
445
- ]
446
- )
447
- s = s .tolist ()
448
- return f"{ tree .node_id } : { s } "
449
- return f"{ tree .node_id } "
450
-
451
453
# subtree.plot(make_labels=make_labels)
452
454
# tree.plot(make_labels=make_labels)
453
455
assert tree .get_vertex_by_id (2 ).num_children == 2
0 commit comments