From a1b627945567e7c235d733b0f584cbc617406601 Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Wed, 19 Feb 2025 13:36:27 -0800 Subject: [PATCH 1/7] Update [ghstack-poisoned] --- examples/trees/mcts.py | 220 +++++++++++++++++++++++++++++++++++++++ torchrl/data/map/tree.py | 7 ++ 2 files changed, 227 insertions(+) create mode 100644 examples/trees/mcts.py diff --git a/examples/trees/mcts.py b/examples/trees/mcts.py new file mode 100644 index 00000000000..52949ffe28a --- /dev/null +++ b/examples/trees/mcts.py @@ -0,0 +1,220 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torchrl +from tensordict import TensorDict + +pgn_or_fen = "fen" + +env = torchrl.envs.ChessEnv( + include_pgn=False, + include_fen=True, + include_hash=True, + include_hash_inv=True, + include_san=True, + stateful=True, + mask_actions=True, +) + + +def transform_reward(td): + if "reward" not in td: + return td + reward = td["reward"] + if reward == 0.5: + td["reward"] = 0 + elif reward == 1 and td["turn"]: + td["reward"] = -td["reward"] + return td + + +# ChessEnv sets the reward to 0.5 for a draw and 1 for a win for either player. +# Need to transform the reward to be: +# white win = 1 +# draw = 0 +# black win = -1 +env.append_transform(transform_reward) + +forest = torchrl.data.MCTSForest() +forest.reward_keys = env.reward_keys + ["_visits", "_reward_sum"] +forest.done_keys = env.done_keys +forest.action_keys = env.action_keys +forest.observation_keys = [f"{pgn_or_fen}_hash", "turn", "action_mask"] + +C = 2.0**0.5 + + +def traversal_priority_UCB1(tree): + if tree.rollout[-1]["next", "_visits"] == 0: + res = float("inf") + else: + if tree.parent.rollout is None: + parent_visits = 0 + for child in tree.parent.subtree: + parent_visits += child.rollout[-1]["next", "_visits"] + else: + parent_visits = tree.parent.rollout[-1]["next", "_visits"] + assert parent_visits > 0 + + value_avg = ( + tree.rollout[-1]["next", "_reward_sum"] + / tree.rollout[-1]["next", "_visits"] + ) + + # If it's black's turn, flip the reward, since black wants to optimize + # for the lowest reward. + if not tree.rollout[0]["turn"]: + value_avg = -value_avg + + res = ( + value_avg + + C + * torch.sqrt(torch.log(parent_visits) / tree.rollout[-1]["next", "_visits"]) + ).item() + + return res + + +def _traverse_MCTS_one_step(forest, tree, env, max_rollout_steps): + done = False + trees_visited = [] + + while not done: + if tree.subtree is None: + td_tree = tree.rollout[-1]["next"] + + if (td_tree["_visits"] > 0 or tree.parent is None) and not td_tree["done"]: + actions = env.all_actions(td_tree) + subtrees = [] + + for action in actions: + td = env.step(env.reset(td_tree.clone()).update(action)).update( + TensorDict( + { + ("next", "_visits"): 0, + ("next", "_reward_sum"): env.reward_spec.zeros(), + } + ) + ) + + new_node = torchrl.data.Tree( + rollout=td.unsqueeze(0), + node_data=td["next"].select(*forest.node_map.in_keys), + ) + subtrees.append(new_node) + + tree.subtree = TensorDict.lazy_stack(subtrees) + chosen_idx = torch.randint(0, len(subtrees), ()).item() + rollout_state = subtrees[chosen_idx].rollout[-1]["next"] + + else: + rollout_state = td_tree + + if rollout_state["done"]: + rollout_reward = rollout_state["reward"] + else: + rollout = env.rollout( + max_steps=max_rollout_steps, + tensordict=rollout_state, + ) + rollout_reward = rollout[-1]["next", "reward"] + done = True + + else: + priorities = torch.tensor( + [traversal_priority_UCB1(subtree) for subtree in tree.subtree] + ) + chosen_idx = torch.argmax(priorities).item() + tree = tree.subtree[chosen_idx] + trees_visited.append(tree) + + for tree in trees_visited: + td = tree.rollout[-1]["next"] + td["_visits"] += 1 + td["_reward_sum"] += rollout_reward + + +def traverse_MCTS(forest, root, env, num_steps, max_rollout_steps): + """Performs Monte-Carlo tree search in an environment. + + Args: + forest (MCTSForest): Forest of the tree to update. If the tree does not + exist yet, it is added. + root (TensorDict): The root step of the tree to update. + env (EnvBase): Environment to performs actions in. + num_steps (int): Number of iterations to traverse. + max_rollout_steps (int): Maximum number of steps for each rollout. + """ + if root not in forest: + for action in env.all_actions(root.clone()): + td = env.step(env.reset(root.clone()).update(action)).update( + TensorDict( + { + ("next", "_visits"): 0, + ("next", "_reward_sum"): env.reward_spec.zeros(), + } + ) + ) + forest.extend(td.unsqueeze(0)) + + tree = forest.get_tree(root) + + for _ in range(num_steps): + _traverse_MCTS_one_step(forest, tree, env, max_rollout_steps) + + return tree + + +def tree_format_fn(tree): + td = tree.rollout[-1]["next"] + return [ + td["san"], + td[pgn_or_fen].split("\n")[-1], + td["_reward_sum"].item(), + td["_visits"].item(), + ] + + +def get_best_move(fen, mcts_steps, rollout_steps): + root = env.reset(TensorDict({"fen": fen})) + tree = traverse_MCTS(forest, root, env, mcts_steps, rollout_steps) + + # print('------------------------------') + # print(tree.to_string(tree_format_fn)) + # print('------------------------------') + + moves = [] + + for subtree in tree.subtree: + san = subtree.rollout[0]["next", "san"] + reward_sum = subtree.rollout[-1]["next", "_reward_sum"] + visits = subtree.rollout[-1]["next", "_visits"] + value_avg = (reward_sum / visits).item() + if not subtree.rollout[0]["turn"]: + value_avg = -value_avg + moves.append((value_avg, san)) + + moves = sorted(moves, key=lambda x: -x[0]) + + print("------------------") + for value_avg, san in moves: + print(f" {value_avg:0.02f} {san}") + print("------------------") + + return moves[0][1] + + +# White has M1, best move Rd8#. Any other moves lose to M2 or M1. +fen0 = "7k/6pp/7p/7K/8/8/6q1/3R4 w - - 0 1" +assert get_best_move(fen0, 100, 10) == "Rd8#" + +# Black has M1, best move Qg6#. Other moves give rough equality or worse. +fen1 = "6qk/2R4p/7K/8/8/8/8/4R3 b - - 1 1" +assert get_best_move(fen1, 100, 10) == "Qg6#" + +# White has M2, best move Rxg8+. Any other move loses. +fen2 = "2R3qk/5p1p/7K/8/8/8/5r2/2R5 w - - 0 1" +assert get_best_move(fen2, 1000, 10) == "Rxg8+" diff --git a/torchrl/data/map/tree.py b/torchrl/data/map/tree.py index 283bd99bd52..02c47b39ca1 100644 --- a/torchrl/data/map/tree.py +++ b/torchrl/data/map/tree.py @@ -330,6 +330,8 @@ def maybe_flatten_list(maybe_nested_list): return TensorDict.lazy_stack( [self._from_tensordict(r) for r in parent_result] ) + if parent_result is None: + return None return self._from_tensordict(parent_result) @property @@ -1227,6 +1229,11 @@ def valid_paths(cls, tree: Tree): def __len__(self): return len(self.data_map) + def __contains__(self, root: TensorDictBase): + if self.node_map is None: + return False + return root.select(*self.node_map.in_keys) in self.node_map + def to_string(self, td_root, node_format_fn): """Generates a string representation of a tree in the forest. From 165b055e654965786f39ea0c80d2f65b448cc88c Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Wed, 19 Feb 2025 13:48:37 -0800 Subject: [PATCH 2/7] Update [ghstack-poisoned] --- torchrl/data/map/tree.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchrl/data/map/tree.py b/torchrl/data/map/tree.py index 02c47b39ca1..b1d0934f9ea 100644 --- a/torchrl/data/map/tree.py +++ b/torchrl/data/map/tree.py @@ -330,8 +330,6 @@ def maybe_flatten_list(maybe_nested_list): return TensorDict.lazy_stack( [self._from_tensordict(r) for r in parent_result] ) - if parent_result is None: - return None return self._from_tensordict(parent_result) @property From 201f71b6eb47b3f837b3da1c0d3a8e5f676b79bd Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Wed, 19 Feb 2025 15:44:25 -0800 Subject: [PATCH 3/7] Update [ghstack-poisoned] --- examples/trees/mcts.py | 73 +++++++++++++++++++----------------------- 1 file changed, 33 insertions(+), 40 deletions(-) diff --git a/examples/trees/mcts.py b/examples/trees/mcts.py index 52949ffe28a..aa2f08106fe 100644 --- a/examples/trees/mcts.py +++ b/examples/trees/mcts.py @@ -47,51 +47,41 @@ def transform_reward(td): C = 2.0**0.5 -def traversal_priority_UCB1(tree): - if tree.rollout[-1]["next", "_visits"] == 0: - res = float("inf") +def traversal_priority_UCB1(tree, root_visits): + subtree = tree.subtree + td_subtree = subtree.rollout[:, -1]["next"] + visits = td_subtree["_visits"] + reward_sum = td_subtree["_reward_sum"].clone() + + # If it's black's turn, flip the reward, since black wants to + # optimize for the lowest reward, not highest. + if not subtree.rollout[0, 0]["turn"]: + reward_sum = -reward_sum + + if tree.rollout is None: + parent_visits = root_visits else: - if tree.parent.rollout is None: - parent_visits = 0 - for child in tree.parent.subtree: - parent_visits += child.rollout[-1]["next", "_visits"] - else: - parent_visits = tree.parent.rollout[-1]["next", "_visits"] - assert parent_visits > 0 - - value_avg = ( - tree.rollout[-1]["next", "_reward_sum"] - / tree.rollout[-1]["next", "_visits"] - ) - - # If it's black's turn, flip the reward, since black wants to optimize - # for the lowest reward. - if not tree.rollout[0]["turn"]: - value_avg = -value_avg + parent_visits = tree.rollout[-1]["next", "_visits"] + reward_sum = reward_sum.squeeze(-1) + priority = (reward_sum + C * torch.sqrt(torch.log(parent_visits))) / visits + priority[visits == 0] = float("inf") + return priority - res = ( - value_avg - + C - * torch.sqrt(torch.log(parent_visits) / tree.rollout[-1]["next", "_visits"]) - ).item() - return res - - -def _traverse_MCTS_one_step(forest, tree, env, max_rollout_steps): +def _traverse_MCTS_one_step(forest, tree, env, max_rollout_steps, root_visits): done = False - trees_visited = [] + td_trees_visited = [] while not done: if tree.subtree is None: - td_tree = tree.rollout[-1]["next"] + td_tree = tree.rollout[-1]["next"].clone() if (td_tree["_visits"] > 0 or tree.parent is None) and not td_tree["done"]: actions = env.all_actions(td_tree) subtrees = [] for action in actions: - td = env.step(env.reset(td_tree.clone()).update(action)).update( + td = env.step(env.reset(td_tree).update(action)).update( TensorDict( { ("next", "_visits"): 0, @@ -106,6 +96,8 @@ def _traverse_MCTS_one_step(forest, tree, env, max_rollout_steps): ) subtrees.append(new_node) + # NOTE: This whole script runs about 2x faster with lazy stack + # versus eager stack. tree.subtree = TensorDict.lazy_stack(subtrees) chosen_idx = torch.randint(0, len(subtrees), ()).item() rollout_state = subtrees[chosen_idx].rollout[-1]["next"] @@ -124,15 +116,12 @@ def _traverse_MCTS_one_step(forest, tree, env, max_rollout_steps): done = True else: - priorities = torch.tensor( - [traversal_priority_UCB1(subtree) for subtree in tree.subtree] - ) + priorities = traversal_priority_UCB1(tree, root_visits) chosen_idx = torch.argmax(priorities).item() tree = tree.subtree[chosen_idx] - trees_visited.append(tree) + td_trees_visited.append(tree.rollout[-1]["next"]) - for tree in trees_visited: - td = tree.rollout[-1]["next"] + for td in td_trees_visited: td["_visits"] += 1 td["_reward_sum"] += rollout_reward @@ -149,7 +138,7 @@ def traverse_MCTS(forest, root, env, num_steps, max_rollout_steps): max_rollout_steps (int): Maximum number of steps for each rollout. """ if root not in forest: - for action in env.all_actions(root.clone()): + for action in env.all_actions(root): td = env.step(env.reset(root.clone()).update(action)).update( TensorDict( { @@ -162,8 +151,12 @@ def traverse_MCTS(forest, root, env, num_steps, max_rollout_steps): tree = forest.get_tree(root) + # TODO: Add this to the root node + root_visits = torch.tensor([0]) + for _ in range(num_steps): - _traverse_MCTS_one_step(forest, tree, env, max_rollout_steps) + _traverse_MCTS_one_step(forest, tree, env, max_rollout_steps, root_visits) + root_visits += 1 return tree From 6cc8eed033ef0fcb1011ad773989810025e0e5ef Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Mon, 12 May 2025 19:22:40 -0700 Subject: [PATCH 4/7] Update [ghstack-poisoned] --- examples/trees/mcts.py | 65 +++++++++++++++++++++--------------- torchrl/modules/mcts/mcts.py | 15 +++++---- 2 files changed, 46 insertions(+), 34 deletions(-) diff --git a/examples/trees/mcts.py b/examples/trees/mcts.py index 343ab4552f5..cd670001672 100644 --- a/examples/trees/mcts.py +++ b/examples/trees/mcts.py @@ -5,13 +5,12 @@ import time +import torch import torchrl import torchrl.envs import torchrl.modules.mcts from tensordict import TensorDict -start_time = time.time() - pgn_or_fen = "fen" mask_actions = True @@ -30,33 +29,34 @@ class TransformReward: def __init__(self): self.first_turn = None - def __call__(self, td): - if self.first_turn is None and "turn" in td: - self.first_turn = td["turn"] - print(f"first turn: {self.first_turn}") + def reset(self, *args): + self.first_turn = None + def __call__(self, td): if "reward" not in td: return td + reward = td["reward"] + + if self.first_turn is None: + self.first_turn = td["turn"] + if reward == 0.5: reward = 0 - # elif reward == 1 and td["turn"] == env.lib.WHITE: - elif reward == 1 and td["turn"] == self.first_turn: + elif reward == 1 and td["turn"]: reward = -reward td["reward"] = reward return td - def reset(self, td): - self.first_turn = None - # ChessEnv sets the reward to 0.5 for a draw and 1 for a win for either player. # Need to transform the reward to be: # white win = 1 # draw = 0 # black win = -1 -env = env.append_transform(TransformReward()) +transform_reward = TransformReward() +env = env.append_transform(transform_reward) forest = torchrl.data.MCTSForest() forest.reward_keys = env.reward_keys @@ -80,6 +80,7 @@ def tree_format_fn(tree): def get_best_move(fen, mcts_steps, rollout_steps): + transform_reward.reset() root = env.reset(TensorDict({"fen": fen})) tree = torchrl.modules.mcts.MCTS(forest, root, env, mcts_steps, rollout_steps) moves = [] @@ -89,10 +90,8 @@ def get_best_move(fen, mcts_steps, rollout_steps): reward_sum = subtree.wins visits = subtree.visits value_avg = (reward_sum / visits).item() - - if not subtree.rollout[0]["turn"]: + if not root["turn"]: value_avg = -value_avg - moves.append((value_avg, san)) moves = sorted(moves, key=lambda x: -x[0]) @@ -107,19 +106,31 @@ def get_best_move(fen, mcts_steps, rollout_steps): return moves[0][1] -# White has M1, best move Rd8#. Any other moves lose to M2 or M1. -fen0 = "7k/6pp/7p/7K/8/8/6q1/3R4 w - - 0 1" -assert get_best_move(fen0, 100, 10) == "Rd8#" +for idx in range(3): + print("==========") + print(idx) + print("==========") + torch.manual_seed(idx) + + start_time = time.time() + + # White has M1, best move Rd8#. Any other moves lose to M2 or M1. + fen0 = "7k/6pp/7p/7K/8/8/6q1/3R4 w - - 0 1" + assert get_best_move(fen0, 40, 10) == "Rd8#" + + # Black has M1, best move Qg6#. Other moves give rough equality or worse. + fen1 = "6qk/2R4p/7K/8/8/8/8/4R3 b - - 1 1" + assert get_best_move(fen1, 40, 10) == "Qg6#" -# Black has M1, best move Qg6#. Other moves give rough equality or worse. -fen1 = "6qk/2R4p/7K/8/8/8/8/4R3 b - - 1 1" -assert get_best_move(fen1, 100, 10) == "Qg6#" + # White has M2, best move Rxg8+. Any other move loses. + fen2 = "2R3qk/5p1p/7K/8/8/8/5r2/2R5 w - - 0 1" + assert get_best_move(fen2, 600, 10) == "Rxg8+" -# White has M2, best move Rxg8+. Any other move loses. -fen2 = "2R3qk/5p1p/7K/8/8/8/5r2/2R5 w - - 0 1" -assert get_best_move(fen2, 1000, 10) == "Rxg8+" + # Black has M2, best move Rxg1+. Any other move loses. + fen3 = "2r5/5R2/8/8/8/7k/5P1P/2r3QK b - - 0 1" + assert get_best_move(fen3, 600, 10) == "Rxg1+" -end_time = time.time() -total_time = end_time - start_time + end_time = time.time() + total_time = end_time - start_time -print(f"Took {total_time} s") + print(f"Took {total_time} s") diff --git a/torchrl/modules/mcts/mcts.py b/torchrl/modules/mcts/mcts.py index 3d3521e2e9e..99ccc4eed96 100644 --- a/torchrl/modules/mcts/mcts.py +++ b/torchrl/modules/mcts/mcts.py @@ -19,9 +19,10 @@ def _traversal_priority_UCB1(tree): visits = subtree.visits reward_sum = subtree.wins - # TODO: Remove this in favor of a reward transform in the example - # If it's black's turn, flip the reward, since black wants to - # optimize for the lowest reward, not highest. + # If it's black's turn, flip the reward, since black wants to optimize for + # the lowest reward, not highest. + # TODO: Need a more generic way to do this, since not all use cases of MCTS + # will be two player turn based games. if not subtree.rollout[0, 0]["turn"]: reward_sum = -reward_sum @@ -101,12 +102,12 @@ def MCTS( num_steps (int): Number of iterations to traverse. max_rollout_steps (int): Maximum number of steps for each rollout. """ - if root not in forest: - for action in env.all_actions(root): - td = env.step(env.reset(root.clone()).update(action)) - forest.extend(td.unsqueeze(0)) + for action in env.all_actions(root): + td = env.step(env.reset(root.clone()).update(action)) + forest.extend(td.unsqueeze(0)) tree = forest.get_tree(root) + tree.wins = torch.zeros_like(td["next", env.reward_key]) for subtree in tree.subtree: subtree.wins = torch.zeros_like(td["next", env.reward_key]) From e7dc7d34ae3720a81968164761208522d0f3d02d Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Mon, 12 May 2025 19:28:26 -0700 Subject: [PATCH 5/7] Update [ghstack-poisoned] --- examples/trees/mcts.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/examples/trees/mcts.py b/examples/trees/mcts.py index cd670001672..b2b0cd78b50 100644 --- a/examples/trees/mcts.py +++ b/examples/trees/mcts.py @@ -26,21 +26,12 @@ class TransformReward: - def __init__(self): - self.first_turn = None - - def reset(self, *args): - self.first_turn = None - def __call__(self, td): if "reward" not in td: return td reward = td["reward"] - if self.first_turn is None: - self.first_turn = td["turn"] - if reward == 0.5: reward = 0 elif reward == 1 and td["turn"]: @@ -80,7 +71,6 @@ def tree_format_fn(tree): def get_best_move(fen, mcts_steps, rollout_steps): - transform_reward.reset() root = env.reset(TensorDict({"fen": fen})) tree = torchrl.modules.mcts.MCTS(forest, root, env, mcts_steps, rollout_steps) moves = [] From 29f9c23b3e57aa5d00114f63c7e2d59ae10598db Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Wed, 14 May 2025 12:51:39 -0700 Subject: [PATCH 6/7] Update [ghstack-poisoned] --- examples/trees/mcts.py | 5 +- torchrl/modules/mcts/mcts.py | 212 +++++++++++++++++++---------------- 2 files changed, 117 insertions(+), 100 deletions(-) diff --git a/examples/trees/mcts.py b/examples/trees/mcts.py index b2b0cd78b50..f56e0bb642a 100644 --- a/examples/trees/mcts.py +++ b/examples/trees/mcts.py @@ -72,7 +72,8 @@ def tree_format_fn(tree): def get_best_move(fen, mcts_steps, rollout_steps): root = env.reset(TensorDict({"fen": fen})) - tree = torchrl.modules.mcts.MCTS(forest, root, env, mcts_steps, rollout_steps) + mcts = torchrl.modules.mcts.MCTS(mcts_steps, rollout_steps) + tree = mcts(forest, root, env) moves = [] for subtree in tree.subtree: @@ -96,7 +97,7 @@ def get_best_move(fen, mcts_steps, rollout_steps): return moves[0][1] -for idx in range(3): +for idx in range(30): print("==========") print(idx) print("==========") diff --git a/torchrl/modules/mcts/mcts.py b/torchrl/modules/mcts/mcts.py index 99ccc4eed96..107be61fca5 100644 --- a/torchrl/modules/mcts/mcts.py +++ b/torchrl/modules/mcts/mcts.py @@ -6,6 +6,7 @@ import torch import torchrl from tensordict import TensorDict, TensorDictBase +from tensordict.nn import TensorDictModuleBase from torchrl.data.map import MCTSForest, Tree from torchrl.envs import EnvBase @@ -13,106 +14,121 @@ C = 2.0**0.5 -# TODO: Allow user to specify different priority functions with PR #2358 -def _traversal_priority_UCB1(tree): - subtree = tree.subtree - visits = subtree.visits - reward_sum = subtree.wins - - # If it's black's turn, flip the reward, since black wants to optimize for - # the lowest reward, not highest. - # TODO: Need a more generic way to do this, since not all use cases of MCTS - # will be two player turn based games. - if not subtree.rollout[0, 0]["turn"]: - reward_sum = -reward_sum - - parent_visits = tree.visits - reward_sum = reward_sum.squeeze(-1) - priority = (reward_sum + C * torch.sqrt(torch.log(parent_visits))) / visits - priority[visits == 0] = float("inf") - return priority - - -def _traverse_MCTS_one_step(forest, tree, env, max_rollout_steps): - done = False - trees_visited = [tree] - - while not done: - if tree.subtree is None: - td_tree = tree.rollout[-1]["next"].clone() - - if (tree.visits > 0 or tree.parent is None) and not td_tree["done"]: - actions = env.all_actions(td_tree) - subtrees = [] - - for action in actions: - td = env.step(env.reset(td_tree).update(action)) - new_node = torchrl.data.Tree( - rollout=td.unsqueeze(0), - node_data=td["next"].select(*forest.node_map.in_keys), - count=torch.tensor(0), - wins=torch.zeros_like(td["next", env.reward_key]), - ) - subtrees.append(new_node) - - # NOTE: This whole script runs about 2x faster with lazy stack - # versus eager stack. - tree.subtree = TensorDict.lazy_stack(subtrees) - chosen_idx = torch.randint(0, len(subtrees), ()).item() - rollout_state = subtrees[chosen_idx].rollout[-1]["next"] +class MCTS(TensorDictModuleBase): + """Monte-Carlo tree search. - else: - rollout_state = td_tree + Attributes: + num_traversals (int): Number of times to traverse the tree. + rollout_max_steps (int): Maximum number of steps for each rollout. - if rollout_state["done"]: - rollout_reward = rollout_state[env.reward_key] - else: - rollout = env.rollout( - max_steps=max_rollout_steps, - tensordict=rollout_state, - ) - rollout_reward = rollout[-1]["next", env.reward_key] - done = True - - else: - priorities = _traversal_priority_UCB1(tree) - chosen_idx = torch.argmax(priorities).item() - tree = tree.subtree[chosen_idx] - trees_visited.append(tree) - - for tree in trees_visited: - tree.visits += 1 - tree.wins += rollout_reward - - -def MCTS( - forest: MCTSForest, - root: TensorDictBase, - env: EnvBase, - num_steps: int, - max_rollout_steps: int | None = None, -) -> Tree: - """Performs Monte-Carlo tree search in an environment. - - Args: - forest (MCTSForest): Forest of the tree to update. If the tree does not - exist yet, it is added. - root (TensorDict): The root step of the tree to update. - env (EnvBase): Environment to performs actions in. - num_steps (int): Number of iterations to traverse. - max_rollout_steps (int): Maximum number of steps for each rollout. + Methods: + forward: Runs the tree search. """ - for action in env.all_actions(root): - td = env.step(env.reset(root.clone()).update(action)) - forest.extend(td.unsqueeze(0)) - tree = forest.get_tree(root) - - tree.wins = torch.zeros_like(td["next", env.reward_key]) - for subtree in tree.subtree: - subtree.wins = torch.zeros_like(td["next", env.reward_key]) - - for _ in range(num_steps): - _traverse_MCTS_one_step(forest, tree, env, max_rollout_steps) + def __init__( + self, + num_traversals: int, + rollout_max_steps: int | None = None, + ): + super().__init__() + self.num_traversals = num_traversals + self.rollout_max_steps = rollout_max_steps + + def forward( + self, + forest: MCTSForest, + root: TensorDictBase, + env: EnvBase, + ) -> Tree: + """Performs Monte-Carlo tree search in an environment. + + Args: + forest (MCTSForest): Forest of the tree to update. If the tree does not + exist yet, it is added. + root (TensorDict): The root step of the tree to update. + env (EnvBase): Environment to performs actions in. + """ + for action in env.all_actions(root): + td = env.step(env.reset(root.clone()).update(action)) + forest.extend(td.unsqueeze(0)) + + tree = forest.get_tree(root) + + tree.wins = torch.zeros_like(td["next", env.reward_key]) + for subtree in tree.subtree: + subtree.wins = torch.zeros_like(td["next", env.reward_key]) + + for _ in range(self.num_traversals): + self._traverse_MCTS_one_step(forest, tree, env, self.rollout_max_steps) + + return tree + + def _traverse_MCTS_one_step(self, forest, tree, env, rollout_max_steps): + done = False + trees_visited = [tree] + + while not done: + if tree.subtree is None: + td_tree = tree.rollout[-1]["next"].clone() + + if (tree.visits > 0 or tree.parent is None) and not td_tree["done"]: + actions = env.all_actions(td_tree) + subtrees = [] + + for action in actions: + td = env.step(env.reset(td_tree).update(action)) + new_node = torchrl.data.Tree( + rollout=td.unsqueeze(0), + node_data=td["next"].select(*forest.node_map.in_keys), + count=torch.tensor(0), + wins=torch.zeros_like(td["next", env.reward_key]), + ) + subtrees.append(new_node) + + # NOTE: This whole script runs about 2x faster with lazy stack + # versus eager stack. + tree.subtree = TensorDict.lazy_stack(subtrees) + chosen_idx = torch.randint(0, len(subtrees), ()).item() + rollout_state = subtrees[chosen_idx].rollout[-1]["next"] + + else: + rollout_state = td_tree + + if rollout_state["done"]: + rollout_reward = rollout_state[env.reward_key] + else: + rollout = env.rollout( + max_steps=rollout_max_steps, + tensordict=rollout_state, + ) + rollout_reward = rollout[-1]["next", env.reward_key] + done = True - return tree + else: + priorities = self._traversal_priority_UCB1(tree) + chosen_idx = torch.argmax(priorities).item() + tree = tree.subtree[chosen_idx] + trees_visited.append(tree) + + for tree in trees_visited: + tree.visits += 1 + tree.wins += rollout_reward + + # TODO: Allow user to specify different priority functions with PR #2358 + def _traversal_priority_UCB1(self, tree): + subtree = tree.subtree + visits = subtree.visits + reward_sum = subtree.wins + + # If it's black's turn, flip the reward, since black wants to optimize for + # the lowest reward, not highest. + # TODO: Need a more generic way to do this, since not all use cases of MCTS + # will be two player turn based games. + if not subtree.rollout[0, 0]["turn"]: + reward_sum = -reward_sum + + parent_visits = tree.visits + reward_sum = reward_sum.squeeze(-1) + priority = (reward_sum + C * torch.sqrt(torch.log(parent_visits))) / visits + priority[visits == 0] = float("inf") + return priority From 73d8e5678c10ddbaff8eac5643a208e7e4771ed1 Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Wed, 14 May 2025 12:54:43 -0700 Subject: [PATCH 7/7] Update [ghstack-poisoned] --- torchrl/modules/mcts/mcts.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/torchrl/modules/mcts/mcts.py b/torchrl/modules/mcts/mcts.py index 107be61fca5..8e65542b41a 100644 --- a/torchrl/modules/mcts/mcts.py +++ b/torchrl/modules/mcts/mcts.py @@ -6,15 +6,13 @@ import torch import torchrl from tensordict import TensorDict, TensorDictBase -from tensordict.nn import TensorDictModuleBase +from torch import nn from torchrl.data.map import MCTSForest, Tree from torchrl.envs import EnvBase -C = 2.0**0.5 - -class MCTS(TensorDictModuleBase): +class MCTS(nn.Module): """Monte-Carlo tree search. Attributes: @@ -129,6 +127,7 @@ def _traversal_priority_UCB1(self, tree): parent_visits = tree.visits reward_sum = reward_sum.squeeze(-1) + C = 2.0**0.5 priority = (reward_sum + C * torch.sqrt(torch.log(parent_visits))) / visits priority[visits == 0] = float("inf") return priority