Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 8, 2024
2 parents 1224f40 + b8a3823 commit bdc54d2
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions test/test_storage_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def _make_td(state: torch.Tensor, action: torch.Tensor) -> TensorDict:
def _make_forest(self) -> MCTSForest:
r0, r1, r2, r3, r4 = self.dummy_rollouts()
assert r0.shape
forest = MCTSForest(consolidated=True)
forest = MCTSForest()
forest.extend(r0)
forest.extend(r1)
forest.extend(r2)
Expand Down Expand Up @@ -363,10 +363,24 @@ def _make_forest_intersect(self) -> MCTSForest:
forest.extend(rollout5)
return forest

@staticmethod
def make_labels(tree):
if tree.rollout is not None:
s = torch.cat(
[
tree.rollout["observation"][:1],
tree.rollout["next", "observation"],
]
)
s = s.tolist()
return f"{tree.node_id}: {s}"
return f"{tree.node_id}"

def test_forest_build(self):
r0, *_ = self.dummy_rollouts()
forest = self._make_forest()
tree = forest.get_tree(r0[0])
# tree.plot(make_labels=self.make_labels)

def test_forest_vertices(self):
r0, *_ = self.dummy_rollouts()
Expand Down Expand Up @@ -436,18 +450,6 @@ def test_forest_intersect(self):
tree = forest.get_tree(state0)
subtree = forest.get_tree(TensorDict(observation=19))

def make_labels(tree):
if tree.rollout is not None:
s = torch.cat(
[
tree.rollout["observation"][:1],
tree.rollout["next", "observation"],
]
)
s = s.tolist()
return f"{tree.node_id}: {s}"
return f"{tree.node_id}"

# subtree.plot(make_labels=make_labels)
# tree.plot(make_labels=make_labels)
assert tree.get_vertex_by_id(2).num_children == 2
Expand Down

0 comments on commit bdc54d2

Please sign in to comment.