diff --git a/test/test_storage_map.py b/test/test_storage_map.py index b2b1a3ed8cb..5fb9e71cbf2 100644 --- a/test/test_storage_map.py +++ b/test/test_storage_map.py @@ -318,7 +318,7 @@ def _make_td(state: torch.Tensor, action: torch.Tensor) -> TensorDict: def _make_forest(self) -> MCTSForest: r0, r1, r2, r3, r4 = self.dummy_rollouts() assert r0.shape - forest = MCTSForest(consolidated=True) + forest = MCTSForest() forest.extend(r0) forest.extend(r1) forest.extend(r2) @@ -363,10 +363,24 @@ def _make_forest_intersect(self) -> MCTSForest: forest.extend(rollout5) return forest + @staticmethod + def make_labels(tree): + if tree.rollout is not None: + s = torch.cat( + [ + tree.rollout["observation"][:1], + tree.rollout["next", "observation"], + ] + ) + s = s.tolist() + return f"{tree.node_id}: {s}" + return f"{tree.node_id}" + def test_forest_build(self): r0, *_ = self.dummy_rollouts() forest = self._make_forest() tree = forest.get_tree(r0[0]) + # tree.plot(make_labels=self.make_labels) def test_forest_vertices(self): r0, *_ = self.dummy_rollouts() @@ -436,18 +450,6 @@ def test_forest_intersect(self): tree = forest.get_tree(state0) subtree = forest.get_tree(TensorDict(observation=19)) - def make_labels(tree): - if tree.rollout is not None: - s = torch.cat( - [ - tree.rollout["observation"][:1], - tree.rollout["next", "observation"], - ] - ) - s = s.tolist() - return f"{tree.node_id}: {s}" - return f"{tree.node_id}" - # subtree.plot(make_labels=make_labels) # tree.plot(make_labels=make_labels) assert tree.get_vertex_by_id(2).num_children == 2