Skip to content

Commit bdc54d2

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
2 parents 1224f40 + b8a3823 commit bdc54d2

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

test/test_storage_map.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def _make_td(state: torch.Tensor, action: torch.Tensor) -> TensorDict:
318318
def _make_forest(self) -> MCTSForest:
319319
r0, r1, r2, r3, r4 = self.dummy_rollouts()
320320
assert r0.shape
321-
forest = MCTSForest(consolidated=True)
321+
forest = MCTSForest()
322322
forest.extend(r0)
323323
forest.extend(r1)
324324
forest.extend(r2)
@@ -363,10 +363,24 @@ def _make_forest_intersect(self) -> MCTSForest:
363363
forest.extend(rollout5)
364364
return forest
365365

366+
@staticmethod
367+
def make_labels(tree):
368+
if tree.rollout is not None:
369+
s = torch.cat(
370+
[
371+
tree.rollout["observation"][:1],
372+
tree.rollout["next", "observation"],
373+
]
374+
)
375+
s = s.tolist()
376+
return f"{tree.node_id}: {s}"
377+
return f"{tree.node_id}"
378+
366379
def test_forest_build(self):
367380
r0, *_ = self.dummy_rollouts()
368381
forest = self._make_forest()
369382
tree = forest.get_tree(r0[0])
383+
# tree.plot(make_labels=self.make_labels)
370384

371385
def test_forest_vertices(self):
372386
r0, *_ = self.dummy_rollouts()
@@ -436,18 +450,6 @@ def test_forest_intersect(self):
436450
tree = forest.get_tree(state0)
437451
subtree = forest.get_tree(TensorDict(observation=19))
438452

439-
def make_labels(tree):
440-
if tree.rollout is not None:
441-
s = torch.cat(
442-
[
443-
tree.rollout["observation"][:1],
444-
tree.rollout["next", "observation"],
445-
]
446-
)
447-
s = s.tolist()
448-
return f"{tree.node_id}: {s}"
449-
return f"{tree.node_id}"
450-
451453
# subtree.plot(make_labels=make_labels)
452454
# tree.plot(make_labels=make_labels)
453455
assert tree.get_vertex_by_id(2).num_children == 2

0 commit comments

Comments
 (0)