From 863892146e836d705f97e2b055c32115523371db Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 17 Jan 2025 13:29:50 +0000 Subject: [PATCH 1/4] Update [ghstack-poisoned] --- torchrl/envs/custom/chess.py | 238 ++++++++++++++++++++++++++++++----- 1 file changed, 210 insertions(+), 28 deletions(-) diff --git a/torchrl/envs/custom/chess.py b/torchrl/envs/custom/chess.py index 4dc5dbe5321..5dc8111411b 100644 --- a/torchrl/envs/custom/chess.py +++ b/torchrl/envs/custom/chess.py @@ -4,18 +4,43 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +import importlib.util +import io from typing import Dict, Optional import torch +from PIL import Image from tensordict import TensorDict, TensorDictBase from torchrl.data import Categorical, Composite, NonTensor, Unbounded from torchrl.envs import EnvBase +from torchrl.envs.common import _EnvPostInit from torchrl.envs.utils import _classproperty -class ChessEnv(EnvBase): +class _HashMeta(_EnvPostInit): + def __call__(cls, *args, **kwargs): + instance = super().__call__(*args, **kwargs) + if kwargs.get("include_hash"): + from torchrl.envs import Hash + + in_keys = [] + out_keys = [] + if instance.include_san: + in_keys.append("san") + out_keys.append("san_hash") + if instance.include_fen: + in_keys.append("fen") + out_keys.append("fen_hash") + if instance.include_pgn: + in_keys.append("pgn") + out_keys.append("pgn_hash") + return instance.append_transform(Hash(in_keys, out_keys)) + return instance + + +class ChessEnv(EnvBase, metaclass=_HashMeta): """A chess environment that follows the TorchRL API. Requires: the `chess` library. More info `here `__. @@ -23,7 +48,7 @@ class ChessEnv(EnvBase): Args: stateful (bool): Whether to keep track of the internal state of the board. If False, the state will be stored in the observation and passed back - to the environment on each call. Default: ``False``. + to the environment on each call. Default: ``True``. .. note:: the action spec is a :class:`~torchrl.data.Categorical` spec with a ``-1`` shape. Unless :meth:`~torchrl.data.Categorical.set_provisional_n` is called with the cardinality of the legal moves, @@ -90,28 +115,76 @@ class ChessEnv(EnvBase): """ _hash_table: Dict[int, str] = {} + _PNG_RESTART = """[Event "?"] +[Site "?"] +[Date "????.??.??"] +[Round "?"] +[White "?"] +[Black "?"] +[Result "*"] + +*""" @_classproperty def lib(cls): try: import chess + import chess.pgn except ImportError: raise ImportError( "The `chess` library could not be found. Make sure you installed it through `pip install chess`." ) return chess - def __init__(self, stateful: bool = False): + def __init__( + self, + *, + stateful: bool = True, + include_san: bool = False, + include_fen: bool = False, + include_pgn: bool = False, + include_hash: bool = False, + pixels: bool = False, + ): chess = self.lib super().__init__() self.full_observation_spec = Composite( - hashing=Unbounded(shape=(), dtype=torch.int64), - fen=NonTensor(shape=()), turn=Categorical(n=2, dtype=torch.bool, shape=()), ) + self.include_san = include_san + self.include_fen = include_fen + self.include_pgn = include_pgn + if include_san: + self.full_observation_spec["san"] = NonTensor(shape=(), example_data="Nc6") + if include_pgn: + self.full_observation_spec["pgn"] = NonTensor( + shape=(), example_data=self._PNG_RESTART + ) + if include_fen: + self.full_observation_spec["fen"] = NonTensor(shape=(), example_data="any") + if not stateful and not (include_pgn or include_fen): + raise RuntimeError( + "At least one state representation (pgn or fen) must be enabled when stateful " + f"is {stateful}." + ) + self.stateful = stateful + if not self.stateful: self.full_state_spec = self.full_observation_spec.clone() + + self.pixels = pixels + if pixels: + if importlib.util.find_spec("cairosvg") is None: + raise ImportError( + "Please install cairosvg to use this environment with pixel rendering." + ) + if importlib.util.find_spec("torchvision") is None: + raise ImportError( + "Please install torchvision to use this environment with pixel rendering." + ) + self.full_observation_spec["pixels"] = Unbounded(shape=()) + self.full_action_spec = Composite( action=Categorical(n=-1, shape=(), dtype=torch.int64) ) @@ -132,27 +205,88 @@ def _is_done(self, board): def _reset(self, tensordict=None): fen = None + pgn = None if tensordict is not None: - fen = self._get_fen(tensordict).data - dest = tensordict.empty() + if self.include_fen: + fen = self._get_fen(tensordict).data + dest = tensordict.empty() + if self.include_pgn: + fen = self._get_pgn(tensordict).data + dest = tensordict.empty() else: dest = TensorDict() - if fen is None: + if fen is None and pgn is None: self.board.reset() - fen = self.board.fen() + if self.include_fen and fen is None: + fen = self.board.fen() + if self.include_pgn and pgn is None: + pgn = self._PNG_RESTART else: - self.board.set_fen(fen) - if self._is_done(self.board): - raise ValueError( - "Cannot reset to a fen that is a gameover state." f" fen: {fen}" - ) - - hashing = hash(fen) + if fen is not None: + self.board.set_fen(fen) + if self._is_done(self.board): + raise ValueError( + "Cannot reset to a fen that is a gameover state." f" fen: {fen}" + ) + elif pgn is not None: + self.board = self._pgn_to_board(pgn) self._set_action_space() turn = self.board.turn - return dest.set("fen", fen).set("hashing", hashing).set("turn", turn) + if self.include_san: + dest.set("san", "[SAN][START]") + if self.include_fen: + if fen is None: + fen = self.board.fen() + dest.set("fen", fen) + if self.include_pgn: + if pgn is None: + pgn = self._board_to_pgn(self.board) + dest.set("pgn", pgn) + dest.set("turn", turn) + if self.pixels: + dest.set("pixels", self._get_tensor_image(board=self.board)) + return dest + + _cairosvg_lib = None + + @_classproperty + def _cairosvg(cls): + csvg = cls._cairosvg_lib + if csvg is None: + import cairosvg + + csvg = cls._cairosvg_lib = cairosvg + return csvg + + _torchvision_lib = None + + @_classproperty + def _torchvision(cls): + tv = cls._torchvision_lib + if tv is None: + import torchvision + + tv = cls._torchvision_lib = torchvision + return tv + + @classmethod + def _get_tensor_image(cls, board): + try: + svg = board._repr_svg_() + # Convert SVG to PNG using cairosvg + png_data = io.BytesIO() + cls._cairosvg.svg2png(bytestring=svg.encode("utf-8"), write_to=png_data) + png_data.seek(0) + # Open the PNG image using Pillow + img = Image.open(png_data) + img = cls._torchvision.transforms.functional.pil_to_tensor(img) + except ImportError: + raise ImportError( + "Chess rendering requires cairosvg and torchvision to be installed." + ) + return img def _set_action_space(self, tensordict: TensorDict | None = None): if not self.stateful and tensordict is not None: @@ -160,13 +294,37 @@ def _set_action_space(self, tensordict: TensorDict | None = None): self.board.set_fen(fen) self.action_spec.set_provisional_n(self.board.legal_moves.count()) + @classmethod + def _pgn_to_board( + cls, pgn_string: str, board: "chess.Board" | None = None + ) -> "chess.Board": + pgn_io = io.StringIO(pgn_string) + game = cls.lib.pgn.read_game(pgn_io) + if board is None: + board = cls.Board() + else: + board.reset() + for move in game.mainline_moves(): + board.push(move) + return board + + @classmethod + def _board_to_pgn(cls, board: "chess.Board") -> str: + # Create a new Game object + game = cls.lib.pgn.Game() + + # Add the moves to the game + node = game + for move in board.move_stack: + node = node.add_variation(move) + + # Generate the PGN string + pgn_string = str(game) + return pgn_string + @classmethod def _get_fen(cls, tensordict): fen = tensordict.get("fen", None) - if fen is None: - hashing = tensordict.get("hashing", None) - if hashing is not None: - fen = cls._hash_table.get(hashing.item()) return fen def get_legal_moves(self, tensordict=None, uci=False): @@ -205,19 +363,40 @@ def _step(self, tensordict): # action action = tensordict.get("action") board = self.board + if not self.stateful: - fen = self._get_fen(tensordict).data - board.set_fen(fen) + if self.include_fen: + fen = self._get_fen(tensordict).data + board.set_fen(fen) + elif self.include_pgn: + pgn = self._get_pgn(tensordict).data + self._pgn_to_board(pgn, board) + else: + raise RuntimeError( + "Not enough information to deduce the board. If stateful=False, include_pgn or include_fen must be True." + ) + action = list(board.legal_moves)[action] + san = None + if self.include_san: + san = board.san(action) board.push(action) + self._set_action_space() - # Collect data - fen = self.board.fen() dest = tensordict.empty() - hashing = hash(fen) - dest.set("fen", fen) - dest.set("hashing", hashing) + + # Collect data + if self.include_fen: + fen = board.fen() + dest.set("fen", fen) + + if self.include_pgn: + pgn = self._board_to_pgn(board) + dest.set("pgn", pgn) + + if san is not None: + dest.set("san", san) turn = torch.tensor(board.turn) if board.is_checkmate(): @@ -226,12 +405,15 @@ def _step(self, tensordict): reward_val = 1 if winner == self.lib.WHITE else -1 else: reward_val = 0 + reward = torch.tensor([reward_val], dtype=torch.int32) done = self._is_done(board) dest.set("reward", reward) dest.set("turn", turn) dest.set("done", [done]) dest.set("terminated", [done]) + if self.pixels: + dest.set("pixels", self._get_tensor_image(board=self.board)) return dest def _set_seed(self, *args, **kwargs): From 759ea2769fecd749a1084d3437c138cbc82e10bf Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 21 Jan 2025 10:19:29 +0000 Subject: [PATCH 2/4] Update [ghstack-poisoned] --- torchrl/envs/transforms/transforms.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 73d42e91d9f..f1c433c3014 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -9974,20 +9974,3 @@ def _apply_transform(self, reward: Tensor) -> TensorDictBase: ) return (self.weights * reward).sum(dim=-1) - -class ConditionalPolicySwitch(Transform): - def __init__(self, policy: Callable[[TensorDictBase], TensorDictBase], condition: Callable[[TensorDictBase], bool]): - super().__init__([], []) - self.__dict__["policy"] = policy - self.condition = condition - def _step( - self, tensordict: TensorDictBase, next_tensordict: TensorDictBase - ) -> TensorDictBase: - if self.condition(tensordict): - parent: TransformedEnv = self.parent - tensordict = parent.step(tensordict) - tensordict_ = parent.step_mdp(tensordict) - tensordict_ = self.policy(tensordict_) - return parent.step(tensordict_) - return tensordict - return \ No newline at end of file From 7ad50dd851b0755e83bb6e1c35f04d995ee1554d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 26 Jan 2025 15:43:48 -0800 Subject: [PATCH 3/4] Update [ghstack-poisoned] --- test/test_env.py | 30 ++++++-- torchrl/data/tensor_specs.py | 2 +- torchrl/envs/custom/chess.py | 104 ++++++++++++++------------ torchrl/envs/transforms/transforms.py | 58 +++++++++++++- 4 files changed, 139 insertions(+), 55 deletions(-) diff --git a/test/test_env.py b/test/test_env.py index 12bdc0bc9ad..d8bf36cdf98 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -131,6 +131,7 @@ from torchrl.envs.transforms.transforms import ( AutoResetEnv, AutoResetTransform, + Tokenizer, Transform, ) from torchrl.envs.utils import ( @@ -3346,10 +3347,6 @@ def test_batched_dynamic(self, break_when_any_done): ) del env_no_buffers gc.collect() - # print(dummy_rollouts) - # print(rollout_no_buffers_serial) - # # for a, b in zip(dummy_rollouts.exclude("action").unbind(0), rollout_no_buffers_serial.exclude("action").unbind(0)): - # assert_allclose_td(a, b) assert_allclose_td( dummy_rollouts.exclude("action"), rollout_no_buffers_serial.exclude("action"), @@ -3463,6 +3460,8 @@ def test_env(self, stateful, include_pgn, include_fen, include_hash, include_san include_hash=include_hash, include_san=include_san, ) + # Because we always use mask_actions=True + assert isinstance(env, TransformedEnv) check_env_specs(env) if include_hash: if include_fen: @@ -3560,8 +3559,8 @@ def test_reset_white_to_move(self, stateful, include_pgn, include_fen): ) fen = "5k2/4r3/8/8/8/1Q6/2K5/8 w - - 0 1" td = env.reset(TensorDict({"fen": fen})) - assert td["fen"] == fen if include_fen: + assert td["fen"] == fen assert env.board.fen() == fen assert td["turn"] == env.lib.WHITE assert not td["done"] @@ -3666,6 +3665,27 @@ def test_reward( assert td["reward"] == expected_reward assert td["turn"] == (not expected_turn) + def test_chess_tokenized(self): + env = ChessEnv(include_fen=True, stateful=True, include_san=True) + assert isinstance(env.observation_spec["fen"], NonTensor) + env = env.append_transform( + Tokenizer(in_keys=["fen"], out_keys=["fen_tokenized"]) + ) + assert isinstance(env.observation_spec["fen"], NonTensor) + env.transform.transform_output_spec(env.base_env.output_spec) + env.transform.transform_input_spec(env.base_env.input_spec) + r = env.rollout(10, return_contiguous=False) + assert "fen_tokenized" in r + assert "fen" in r + assert "fen_tokenized" in r["next"] + assert "fen" in r["next"] + ftd = env.fake_tensordict() + assert "fen_tokenized" in ftd + assert "fen" in ftd + assert "fen_tokenized" in ftd["next"] + assert "fen" in ftd["next"] + env.check_env_specs() + class TestCustomEnvs: def test_tictactoe_env(self): diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 3d4198ae234..95aaaebd936 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -5042,7 +5042,7 @@ def zero(self, shape: torch.Size = None) -> TensorDictBase: def __eq__(self, other): return ( - type(self) is type(other) + type(self) == type(other) and self.shape == other.shape and self._device == other._device and set(self._specs.keys()) == set(other._specs.keys()) diff --git a/torchrl/envs/custom/chess.py b/torchrl/envs/custom/chess.py index 45d5e765d3b..63265376f66 100644 --- a/torchrl/envs/custom/chess.py +++ b/torchrl/envs/custom/chess.py @@ -7,12 +7,12 @@ import importlib.util import io import pathlib -from typing import Dict, Optional +from typing import Dict import torch from PIL import Image from tensordict import TensorDict, TensorDictBase -from torchrl.data import Bounded, Categorical, Composite, NonTensor, Unbounded +from torchrl.data import Binary, Bounded, Categorical, Composite, NonTensor, Unbounded from torchrl.envs import EnvBase from torchrl.envs.common import _EnvPostInit @@ -20,7 +20,7 @@ from torchrl.envs.utils import _classproperty -class _HashMeta(_EnvPostInit): +class _ChessMeta(_EnvPostInit): def __call__(cls, *args, **kwargs): instance = super().__call__(*args, **kwargs) if kwargs.get("include_hash"): @@ -37,11 +37,15 @@ def __call__(cls, *args, **kwargs): if instance.include_pgn: in_keys.append("pgn") out_keys.append("pgn_hash") - return instance.append_transform(Hash(in_keys, out_keys)) + instance = instance.append_transform(Hash(in_keys, out_keys)) + if kwargs.get("mask_actions", True): + from torchrl.envs import ActionMask + + instance = instance.append_transform(ActionMask()) return instance -class ChessEnv(EnvBase, metaclass=_HashMeta): +class ChessEnv(EnvBase, metaclass=_ChessMeta): r"""A chess environment that follows the TorchRL API. This environment simulates a chess game using the `chess` library. It supports various state representations @@ -63,6 +67,8 @@ class ChessEnv(EnvBase, metaclass=_HashMeta): include_pgn (bool): Whether to include PGN (Portable Game Notation) in the observations. Default: ``False``. include_legal_moves (bool): Whether to include legal moves in the observations. Default: ``False``. include_hash (bool): Whether to include hash transformations in the environment. Default: ``False``. + mask_actions (bool): if ``True``, a :class:`~torchrl.envs.ActionMask` transform will be appended + to the env to make sure that the actions are properly masked. Default: ``True``. pixels (bool): Whether to include pixel-based observations of the board. Default: ``False``. .. note:: The action spec is a :class:`~torchrl.data.Categorical` with a number of actions equal to the number of possible SAN moves. @@ -202,16 +208,15 @@ def _legal_moves_to_index( ) -> torch.Tensor: if not self.stateful: if tensordict is None: - raise RuntimeError( - "rand_action requires a tensordict when stateful is False." - ) - if self.include_fen: - fen = self._get_fen(tensordict) + # trust the board + pass + elif self.include_fen: + fen = tensordict.get("fen", None) fen = fen.data self.board.set_fen(fen) board = self.board elif self.include_pgn: - pgn = self._get_pgn(tensordict) + pgn = tensordict.get("pgn") pgn = pgn.data board = self._pgn_to_board(pgn, self.board) @@ -224,15 +229,19 @@ def _legal_moves_to_index( ) if return_mask: - return torch.zeros(len(self.san_moves), dtype=torch.bool).index_fill_( - 0, indices, True - ) + return self._move_index_to_mask(indices) if pad: indices = torch.nn.functional.pad( indices, [0, 218 - indices.numel() + 1], value=len(self.san_moves) ) return indices + @classmethod + def _move_index_to_mask(cls, indices: torch.Tensor) -> torch.Tensor: + return torch.zeros(len(cls.san_moves), dtype=torch.bool).index_fill_( + 0, indices, True + ) + def __init__( self, *, @@ -242,6 +251,7 @@ def __init__( include_pgn: bool = False, include_legal_moves: bool = False, include_hash: bool = False, + mask_actions: bool = True, pixels: bool = False, ): chess = self.lib @@ -252,6 +262,7 @@ def __init__( self.include_san = include_san self.include_fen = include_fen self.include_pgn = include_pgn + self.mask_actions = mask_actions self.include_legal_moves = include_legal_moves if include_legal_moves: # 218 max possible legal moves per chess board position @@ -276,8 +287,10 @@ def __init__( self.stateful = stateful - if not self.stateful: - self.full_state_spec = self.full_observation_spec.clone() + # state_spec is loosely defined as such - it's not really an issue that extra keys + # can go missing but it allows us to reset the env using fen passed to the reset + # method. + self.full_state_spec = self.full_observation_spec.clone() self.pixels = pixels if pixels: @@ -297,16 +310,16 @@ def __init__( self.full_reward_spec = Composite( reward=Unbounded(shape=(1,), dtype=torch.float32) ) + if self.mask_actions: + self.full_observation_spec["action_mask"] = Binary( + n=len(self.san_moves), dtype=torch.bool + ) + # done spec generated automatically self.board = chess.Board() if self.stateful: self.action_spec.set_provisional_n(len(list(self.board.legal_moves))) - def rand_action(self, tensordict: Optional[TensorDictBase] = None): - mask = self._legal_moves_to_index(tensordict, return_mask=True) - self.action_spec.update_mask(mask) - return super().rand_action(tensordict) - def _is_done(self, board): return board.is_game_over() | board.is_fifty_moves() @@ -316,11 +329,11 @@ def _reset(self, tensordict=None): if tensordict is not None: dest = tensordict.empty() if self.include_fen: - fen = self._get_fen(tensordict) + fen = tensordict.get("fen", None) if fen is not None: fen = fen.data elif self.include_pgn: - pgn = self._get_pgn(tensordict) + pgn = tensordict.get("pgn", None) if pgn is not None: pgn = pgn.data else: @@ -360,13 +373,18 @@ def _reset(self, tensordict=None): if self.include_legal_moves: moves_idx = self._legal_moves_to_index(board=self.board, pad=True) dest.set("legal_moves", moves_idx) + if self.mask_actions: + dest.set("action_mask", self._move_index_to_mask(moves_idx)) + elif self.mask_actions: + dest.set( + "action_mask", + self._legal_moves_to_index( + board=self.board, pad=True, return_mask=True + ), + ) + if self.pixels: dest.set("pixels", self._get_tensor_image(board=self.board)) - - if self.stateful: - mask = self._legal_moves_to_index(dest, return_mask=True) - self.action_spec.update_mask(mask) - return dest _cairosvg_lib = None @@ -437,16 +455,6 @@ def _board_to_pgn(cls, board: "chess.Board") -> str: # noqa: F821 pgn_string = str(game) return pgn_string - @classmethod - def _get_fen(cls, tensordict): - fen = tensordict.get("fen", None) - return fen - - @classmethod - def _get_pgn(cls, tensordict): - pgn = tensordict.get("pgn", None) - return pgn - def get_legal_moves(self, tensordict=None, uci=False): """List the legal moves in a position. @@ -470,7 +478,7 @@ def get_legal_moves(self, tensordict=None, uci=False): raise ValueError( "tensordict must be given since this env is not stateful" ) - fen = self._get_fen(tensordict).data + fen = tensordict.get("fen").data board.set_fen(fen) moves = board.legal_moves @@ -488,10 +496,10 @@ def _step(self, tensordict): fen = None if not self.stateful: if self.include_fen: - fen = self._get_fen(tensordict).data + fen = tensordict.get("fen").data board.set_fen(fen) elif self.include_pgn: - pgn = self._get_pgn(tensordict).data + pgn = tensordict.get("pgn").data board = self._pgn_to_board(pgn, board) else: raise RuntimeError( @@ -521,6 +529,15 @@ def _step(self, tensordict): if self.include_legal_moves: moves_idx = self._legal_moves_to_index(board=board, pad=True) dest.set("legal_moves", moves_idx) + if self.mask_actions: + dest.set("action_mask", self._move_index_to_mask(moves_idx)) + elif self.mask_actions: + dest.set( + "action_mask", + self._legal_moves_to_index( + board=self.board, pad=True, return_mask=True + ), + ) turn = torch.tensor(board.turn) done = self._is_done(board) @@ -540,11 +557,6 @@ def _step(self, tensordict): dest.set("terminated", [done]) if self.pixels: dest.set("pixels", self._get_tensor_image(board=self.board)) - - if self.stateful: - mask = self._legal_moves_to_index(dest, return_mask=True) - self.action_spec.update_mask(mask) - return dest def _set_seed(self, *args, **kwargs): diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 3cba7d2bd1f..65eda4bc6ec 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -861,7 +861,9 @@ def rand_action(self, tensordict: Optional[TensorDictBase] = None) -> TensorDict f"The rand_action method from the base env {self.base_env.__class__.__name__} " "has been overwritten, but the transforms appended to the environment modify " "the action. To call the base env rand_action method, we should then invert the " - "action transform, which is (in general) not doable." + "action transform, which is (in general) not doable. " + f"The full action spec of the base env is: {self.base_env.full_action_spec}, \n" + f"the full action spec of the transformed env is {self.full_action_spec}." ) return self.base_env.rand_action(tensordict) return super().rand_action(tensordict) @@ -5070,23 +5072,73 @@ def transform_input_spec(self, input_spec: Composite) -> Composite: # We need to cap the spec to generate valid random strings for out_key in self.out_keys_inv: if out_key in input_spec["full_state_spec"].keys(True, True): + new_shape = input_spec["full_state_spec"][out_key].shape + if self.max_length is None: + # Then we can't tell what the shape will be + new_shape = new_shape[:-1] + torch.Size((-1,)) input_spec["full_state_spec"][out_key] = Bounded( 0, self.tokenizer.vocab_size, - shape=input_spec["full_state_spec"][out_key].shape, + shape=new_shape, device=input_spec["full_state_spec"][out_key].device, dtype=input_spec["full_state_spec"][out_key].dtype, ) elif out_key in input_spec["full_action_spec"].keys(True, True): + new_shape = input_spec["full_action_spec"][out_key].shape + if self.max_length is None: + # Then we can't tell what the shape will be + new_shape = new_shape[:-1] + torch.Size((-1,)) input_spec["full_action_spec"][out_key] = Bounded( 0, self.tokenizer.vocab_size, - shape=input_spec["full_action_spec"][out_key].shape, + shape=new_shape, device=input_spec["full_action_spec"][out_key].device, dtype=input_spec["full_action_spec"][out_key].dtype, ) return input_spec + def transform_output_spec(self, output_spec: Composite) -> Composite: + output_spec = super().transform_output_spec(output_spec) + # We need to cap the spec to generate valid random strings + for out_key in self.out_keys: + if out_key in output_spec["full_observation_spec"].keys(True, True): + new_shape = output_spec["full_observation_spec"][out_key].shape + if self.max_length is None: + # Then we can't tell what the shape will be + new_shape = new_shape[:-1] + torch.Size((-1,)) + output_spec["full_observation_spec"][out_key] = Bounded( + 0, + self.tokenizer.vocab_size, + shape=new_shape, + device=output_spec["full_observation_spec"][out_key].device, + dtype=output_spec["full_observation_spec"][out_key].dtype, + ) + elif out_key in output_spec["full_reward_spec"].keys(True, True): + new_shape = output_spec["full_reward_spec"][out_key].shape + if self.max_length is None: + # Then we can't tell what the shape will be + new_shape = new_shape[:-1] + torch.Size((-1,)) + output_spec["full_reward_spec"][out_key] = Bounded( + 0, + self.tokenizer.vocab_size, + shape=new_shape, + device=output_spec["full_reward_spec"][out_key].device, + dtype=output_spec["full_reward_spec"][out_key].dtype, + ) + elif out_key in output_spec["full_done_spec"].keys(True, True): + new_shape = output_spec["full_done_spec"][out_key].shape + if self.max_length is None: + # Then we can't tell what the shape will be + new_shape = new_shape[:-1] + torch.Size((-1,)) + output_spec["full_done_spec"][out_key] = Bounded( + 0, + self.tokenizer.vocab_size, + shape=new_shape, + device=output_spec["full_done_spec"][out_key].device, + dtype=output_spec["full_done_spec"][out_key].dtype, + ) + return output_spec + class Stack(Transform): """Stacks tensors and tensordicts. From 2f4b30227b828f7ff39dc7136a04baebfd2db05e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 26 Jan 2025 15:53:36 -0800 Subject: [PATCH 4/4] Update [ghstack-poisoned] --- torchrl/envs/custom/chess.py | 95 ++++++++++++++++++++++-------------- 1 file changed, 58 insertions(+), 37 deletions(-) diff --git a/torchrl/envs/custom/chess.py b/torchrl/envs/custom/chess.py index 63265376f66..ebd23e18452 100644 --- a/torchrl/envs/custom/chess.py +++ b/torchrl/envs/custom/chess.py @@ -76,19 +76,28 @@ class ChessEnv(EnvBase, metaclass=_ChessMeta): being a subset of this space. The environment uses a mask to ensure only legal moves are selected. Examples: + >>> import torch + >>> from torchrl.envs import ChessEnv + >>> _ = torch.manual_seed(0) >>> env = ChessEnv(include_fen=True, include_san=True, include_pgn=True, include_legal_moves=True) + >>> print(env) + TransformedEnv( + env=ChessEnv(), + transform=ActionMask(keys=['action', 'action_mask'])) >>> r = env.reset() - >>> env.rand_step(r) + >>> print(env.rand_step(r)) TensorDict( fields={ action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), + action_mask: Tensor(shape=torch.Size([29275]), device=cpu, dtype=torch.bool, is_shared=False), done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1, batch_size=torch.Size([]), device=None), legal_moves: Tensor(shape=torch.Size([219]), device=cpu, dtype=torch.int64, is_shared=False), next: TensorDict( fields={ + action_mask: Tensor(shape=torch.Size([29275]), device=cpu, dtype=torch.bool, is_shared=False), done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), - fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/1P6/P1PPPPPP/RNBQKBNR b KQkq - 0 1, batch_size=torch.Size([]), device=None), + fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/5P2/8/PPPPP1PP/RNBQKBNR b KQkq - 0 1, batch_size=torch.Size([]), device=None), legal_moves: Tensor(shape=torch.Size([219]), device=cpu, dtype=torch.int64, is_shared=False), pgn: NonTensorData(data=[Event "?"] [Site "?"] @@ -97,9 +106,10 @@ class ChessEnv(EnvBase, metaclass=_ChessMeta): [White "?"] [Black "?"] [Result "*"] - 1. b3 *, batch_size=torch.Size([]), device=None), + + 1. f4 *, batch_size=torch.Size([]), device=None), reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), - san: NonTensorData(data=b3, batch_size=torch.Size([]), device=None), + san: NonTensorData(data=f4, batch_size=torch.Size([]), device=None), terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([]), @@ -112,56 +122,59 @@ class ChessEnv(EnvBase, metaclass=_ChessMeta): [White "?"] [Black "?"] [Result "*"] + *, batch_size=torch.Size([]), device=None), - san: NonTensorData(data=[SAN][START], batch_size=torch.Size([]), device=None), + san: NonTensorData(data=, batch_size=torch.Size([]), device=None), terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) - >>> env.rollout(1000) + >>> print(env.rollout(1000)) TensorDict( fields={ - action: Tensor(shape=torch.Size([352]), device=cpu, dtype=torch.int64, is_shared=False), - done: Tensor(shape=torch.Size([352, 1]), device=cpu, dtype=torch.bool, is_shared=False), + action: Tensor(shape=torch.Size([96]), device=cpu, dtype=torch.int64, is_shared=False), + action_mask: Tensor(shape=torch.Size([96, 29275]), device=cpu, dtype=torch.bool, is_shared=False), + done: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.bool, is_shared=False), fen: NonTensorStack( ['rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQ..., - batch_size=torch.Size([352]), + batch_size=torch.Size([96]), device=None), - legal_moves: Tensor(shape=torch.Size([352, 219]), device=cpu, dtype=torch.int64, is_shared=False), + legal_moves: Tensor(shape=torch.Size([96, 219]), device=cpu, dtype=torch.int64, is_shared=False), next: TensorDict( fields={ - done: Tensor(shape=torch.Size([352, 1]), device=cpu, dtype=torch.bool, is_shared=False), + action_mask: Tensor(shape=torch.Size([96, 29275]), device=cpu, dtype=torch.bool, is_shared=False), + done: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.bool, is_shared=False), fen: NonTensorStack( - ['rnbqkbnr/pppppppp/8/8/8/N7/PPPPPPPP/R1BQKBNR b K..., - batch_size=torch.Size([352]), + ['rnbqkbnr/pppppppp/8/8/8/5N2/PPPPPPPP/RNBQKB1R b ..., + batch_size=torch.Size([96]), device=None), - legal_moves: Tensor(shape=torch.Size([352, 219]), device=cpu, dtype=torch.int64, is_shared=False), + legal_moves: Tensor(shape=torch.Size([96, 219]), device=cpu, dtype=torch.int64, is_shared=False), pgn: NonTensorStack( ['[Event "?"]\n[Site "?"]\n[Date "????.??.??"]\n[R..., - batch_size=torch.Size([352]), + batch_size=torch.Size([96]), device=None), - reward: Tensor(shape=torch.Size([352, 1]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.float32, is_shared=False), san: NonTensorStack( - ['Na3', 'a5', 'Nb1', 'Nc6', 'a3', 'g6', 'd4', 'd6'..., - batch_size=torch.Size([352]), + ['Nf3', 'Na6', 'c4', 'f6', 'h4', 'Rb8', 'Na3', 'Ra..., + batch_size=torch.Size([96]), device=None), - terminated: Tensor(shape=torch.Size([352, 1]), device=cpu, dtype=torch.bool, is_shared=False), - turn: Tensor(shape=torch.Size([352]), device=cpu, dtype=torch.bool, is_shared=False)}, - batch_size=torch.Size([352]), + terminated: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.bool, is_shared=False), + turn: Tensor(shape=torch.Size([96]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([96]), device=None, is_shared=False), pgn: NonTensorStack( ['[Event "?"]\n[Site "?"]\n[Date "????.??.??"]\n[R..., - batch_size=torch.Size([352]), + batch_size=torch.Size([96]), device=None), san: NonTensorStack( - ['[SAN][START]', 'Na3', 'a5', 'Nb1', 'Nc6', 'a3', ..., - batch_size=torch.Size([352]), + ['', 'Nf3', 'Na6', 'c4', 'f6', 'h4', 'Rb8',..., + batch_size=torch.Size([96]), device=None), - terminated: Tensor(shape=torch.Size([352, 1]), device=cpu, dtype=torch.bool, is_shared=False), - turn: Tensor(shape=torch.Size([352]), device=cpu, dtype=torch.bool, is_shared=False)}, - batch_size=torch.Size([352]), + terminated: Tensor(shape=torch.Size([96, 1]), device=cpu, dtype=torch.bool, is_shared=False), + turn: Tensor(shape=torch.Size([96]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([96]), device=None, is_shared=False) @@ -227,13 +240,15 @@ def _legal_moves_to_index( [self._san_moves.index(board.san(m)) for m in board.legal_moves], dtype=torch.int64, ) - + mask = None if return_mask: - return self._move_index_to_mask(indices) + mask = self._move_index_to_mask(indices) if pad: indices = torch.nn.functional.pad( indices, [0, 218 - indices.numel() + 1], value=len(self.san_moves) ) + if return_mask: + return indices, mask return indices @classmethod @@ -371,16 +386,19 @@ def _reset(self, tensordict=None): dest.set("pgn", pgn) dest.set("turn", turn) if self.include_legal_moves: - moves_idx = self._legal_moves_to_index(board=self.board, pad=True) - dest.set("legal_moves", moves_idx) + moves_idx = self._legal_moves_to_index( + board=self.board, pad=True, return_mask=self.mask_actions + ) if self.mask_actions: - dest.set("action_mask", self._move_index_to_mask(moves_idx)) + moves_idx, mask = moves_idx + dest.set("action_mask", mask) + dest.set("legal_moves", moves_idx) elif self.mask_actions: dest.set( "action_mask", self._legal_moves_to_index( board=self.board, pad=True, return_mask=True - ), + )[1], ) if self.pixels: @@ -527,16 +545,19 @@ def _step(self, tensordict): dest.set("san", san) if self.include_legal_moves: - moves_idx = self._legal_moves_to_index(board=board, pad=True) - dest.set("legal_moves", moves_idx) + moves_idx = self._legal_moves_to_index( + board=board, pad=True, return_mask=self.mask_actions + ) if self.mask_actions: - dest.set("action_mask", self._move_index_to_mask(moves_idx)) + moves_idx, mask = moves_idx + dest.set("action_mask", mask) + dest.set("legal_moves", moves_idx) elif self.mask_actions: dest.set( "action_mask", self._legal_moves_to_index( board=self.board, pad=True, return_mask=True - ), + )[1], ) turn = torch.tensor(board.turn)