Skip to content

[Feature,Refactor] Chess improvements: fen, pgn, pixels, san #2702

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jan 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 163 additions & 22 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@
from torchrl.envs.transforms.transforms import (
AutoResetEnv,
AutoResetTransform,
Tokenizer,
Transform,
)
from torchrl.envs.utils import (
Expand Down Expand Up @@ -3441,35 +3442,148 @@ def test_partial_rest(self, batched):

# fen strings for board positions generated with:
# https://lichess.org/editor
@pytest.mark.parametrize("stateful", [False, True])
@pytest.mark.skipif(not _has_chess, reason="chess not found")
class TestChessEnv:
def test_env(self, stateful):
env = ChessEnv(stateful=stateful)
check_env_specs(env)
@pytest.mark.parametrize("include_pgn", [False, True])
@pytest.mark.parametrize("include_fen", [False, True])
@pytest.mark.parametrize("stateful", [False, True])
@pytest.mark.parametrize("include_hash", [False, True])
@pytest.mark.parametrize("include_san", [False, True])
def test_env(self, stateful, include_pgn, include_fen, include_hash, include_san):
with pytest.raises(
RuntimeError, match="At least one state representation"
) if not stateful and not include_pgn and not include_fen else contextlib.nullcontext():
env = ChessEnv(
stateful=stateful,
include_pgn=include_pgn,
include_fen=include_fen,
include_hash=include_hash,
include_san=include_san,
)
# Because we always use mask_actions=True
assert isinstance(env, TransformedEnv)
check_env_specs(env)
if include_hash:
if include_fen:
assert "fen_hash" in env.observation_spec.keys()
if include_pgn:
assert "pgn_hash" in env.observation_spec.keys()
if include_san:
assert "san_hash" in env.observation_spec.keys()

def test_pgn_bijectivity(self):
np.random.seed(0)
pgn = ChessEnv._PGN_RESTART
board = ChessEnv._pgn_to_board(pgn)
pgn_prev = pgn
for _ in range(10):
moves = list(board.legal_moves)
move = np.random.choice(moves)
board.push(move)
pgn_move = ChessEnv._board_to_pgn(board)
assert pgn_move != pgn_prev
assert pgn_move == ChessEnv._board_to_pgn(ChessEnv._pgn_to_board(pgn_move))
assert pgn_move == ChessEnv._add_move_to_pgn(pgn_prev, move)
pgn_prev = pgn_move

def test_consistency(self):
env0_stateful = ChessEnv(stateful=True, include_pgn=True, include_fen=True)
env1_stateful = ChessEnv(stateful=True, include_pgn=False, include_fen=True)
env2_stateful = ChessEnv(stateful=True, include_pgn=True, include_fen=False)
env0_stateless = ChessEnv(stateful=False, include_pgn=True, include_fen=True)
env1_stateless = ChessEnv(stateful=False, include_pgn=False, include_fen=True)
env2_stateless = ChessEnv(stateful=False, include_pgn=True, include_fen=False)
torch.manual_seed(0)
r1_stateless = env1_stateless.rollout(50, break_when_any_done=False)
torch.manual_seed(0)
r1_stateful = env1_stateful.rollout(50, break_when_any_done=False)
torch.manual_seed(0)
r2_stateless = env2_stateless.rollout(50, break_when_any_done=False)
torch.manual_seed(0)
r2_stateful = env2_stateful.rollout(50, break_when_any_done=False)
torch.manual_seed(0)
r0_stateless = env0_stateless.rollout(50, break_when_any_done=False)
torch.manual_seed(0)
r0_stateful = env0_stateful.rollout(50, break_when_any_done=False)
assert (r0_stateless["action"] == r1_stateless["action"]).all()
assert (r0_stateless["action"] == r2_stateless["action"]).all()
assert (r0_stateless["action"] == r0_stateful["action"]).all()
assert (r1_stateless["action"] == r1_stateful["action"]).all()
assert (r2_stateless["action"] == r2_stateful["action"]).all()

@pytest.mark.parametrize(
"include_fen,include_pgn", [[True, False], [False, True], [True, True]]
)
@pytest.mark.parametrize("stateful", [False, True])
def test_san(self, stateful, include_fen, include_pgn):
torch.manual_seed(0)
env = ChessEnv(
stateful=stateful,
include_pgn=include_pgn,
include_fen=include_fen,
include_san=True,
)
r = env.rollout(100, break_when_any_done=False)
sans = r["next", "san"]
actions = [env.san_moves.index(san) for san in sans]
i = 0

def policy(td):
nonlocal i
td["action"] = actions[i]
i += 1
return td

def test_rollout(self, stateful):
env = ChessEnv(stateful=stateful)
env.rollout(5000)
r2 = env.rollout(100, policy=policy, break_when_any_done=False)
assert_allclose_td(r, r2)

def test_reset_white_to_move(self, stateful):
env = ChessEnv(stateful=stateful)
@pytest.mark.parametrize(
"include_fen,include_pgn", [[True, False], [False, True], [True, True]]
)
@pytest.mark.parametrize("stateful", [False, True])
def test_rollout(self, stateful, include_pgn, include_fen):
torch.manual_seed(0)
env = ChessEnv(
stateful=stateful, include_pgn=include_pgn, include_fen=include_fen
)
r = env.rollout(500, break_when_any_done=False)
assert r.shape == (500,)

@pytest.mark.parametrize(
"include_fen,include_pgn", [[True, False], [False, True], [True, True]]
)
@pytest.mark.parametrize("stateful", [False, True])
def test_reset_white_to_move(self, stateful, include_pgn, include_fen):
env = ChessEnv(
stateful=stateful, include_pgn=include_pgn, include_fen=include_fen
)
fen = "5k2/4r3/8/8/8/1Q6/2K5/8 w - - 0 1"
td = env.reset(TensorDict({"fen": fen}))
assert td["fen"] == fen
if include_fen:
assert td["fen"] == fen
assert env.board.fen() == fen
assert td["turn"] == env.lib.WHITE
assert not td["done"]

def test_reset_black_to_move(self, stateful):
env = ChessEnv(stateful=stateful)
@pytest.mark.parametrize("include_fen,include_pgn", [[True, False], [True, True]])
@pytest.mark.parametrize("stateful", [False, True])
def test_reset_black_to_move(self, stateful, include_pgn, include_fen):
env = ChessEnv(
stateful=stateful, include_pgn=include_pgn, include_fen=include_fen
)
fen = "5k2/4r3/8/8/8/1Q6/2K5/8 b - - 0 1"
td = env.reset(TensorDict({"fen": fen}))
assert td["fen"] == fen
assert env.board.fen() == fen
assert td["turn"] == env.lib.BLACK
assert not td["done"]

def test_reset_done_error(self, stateful):
env = ChessEnv(stateful=stateful)
@pytest.mark.parametrize("include_fen,include_pgn", [[True, False], [True, True]])
@pytest.mark.parametrize("stateful", [False, True])
def test_reset_done_error(self, stateful, include_pgn, include_fen):
env = ChessEnv(
stateful=stateful, include_pgn=include_pgn, include_fen=include_fen
)
fen = "1R3k2/2R5/8/8/8/8/2K5/8 b - - 0 1"
with pytest.raises(ValueError) as e_info:
env.reset(TensorDict({"fen": fen}))
Expand All @@ -3480,12 +3594,19 @@ def test_reset_done_error(self, stateful):
@pytest.mark.parametrize(
"endstate", ["white win", "black win", "stalemate", "50 move", "insufficient"]
)
def test_reward(self, stateful, reset_without_fen, endstate):
@pytest.mark.parametrize("include_pgn", [False, True])
@pytest.mark.parametrize("include_fen", [True])
@pytest.mark.parametrize("stateful", [False, True])
def test_reward(
self, stateful, reset_without_fen, endstate, include_pgn, include_fen
):
if stateful and reset_without_fen:
# reset_without_fen is only used for stateless env
return

env = ChessEnv(stateful=stateful)
env = ChessEnv(
stateful=stateful, include_pgn=include_pgn, include_fen=include_fen
)

if endstate == "white win":
fen = "5k2/2R5/8/8/8/1R6/2K5/8 w - - 0 1"
Expand All @@ -3498,28 +3619,28 @@ def test_reward(self, stateful, reset_without_fen, endstate):
fen = "5k2/6r1/8/8/8/8/7r/1K6 b - - 0 1"
expected_turn = env.lib.BLACK
move = "Rg1#"
expected_reward = -1
expected_reward = 1
expected_done = True

elif endstate == "stalemate":
fen = "5k2/6r1/8/8/8/8/7r/K7 b - - 0 1"
expected_turn = env.lib.BLACK
move = "Rb7"
expected_reward = 0
expected_reward = 0.5
expected_done = True

elif endstate == "insufficient":
fen = "5k2/8/8/8/3r4/2K5/8/8 w - - 0 1"
expected_turn = env.lib.WHITE
move = "Kxd4"
expected_reward = 0
expected_reward = 0.5
expected_done = True

elif endstate == "50 move":
fen = "5k2/8/1R6/8/6r1/2K5/8/8 b - - 99 123"
expected_turn = env.lib.BLACK
move = "Kf7"
expected_reward = 0
expected_reward = 0.5
expected_done = True

elif endstate == "not_done":
Expand All @@ -3538,13 +3659,33 @@ def test_reward(self, stateful, reset_without_fen, endstate):
td = env.reset(TensorDict({"fen": fen}))
assert td["turn"] == expected_turn

moves = env.get_legal_moves(None if stateful else td)
td["action"] = moves.index(move)
td["action"] = env._san_moves.index(move)
td = env.step(td)["next"]
assert td["done"] == expected_done
assert td["reward"] == expected_reward
assert td["turn"] == (not expected_turn)

def test_chess_tokenized(self):
env = ChessEnv(include_fen=True, stateful=True, include_san=True)
assert isinstance(env.observation_spec["fen"], NonTensor)
env = env.append_transform(
Tokenizer(in_keys=["fen"], out_keys=["fen_tokenized"])
)
assert isinstance(env.observation_spec["fen"], NonTensor)
env.transform.transform_output_spec(env.base_env.output_spec)
env.transform.transform_input_spec(env.base_env.input_spec)
r = env.rollout(10, return_contiguous=False)
assert "fen_tokenized" in r
assert "fen" in r
assert "fen_tokenized" in r["next"]
assert "fen" in r["next"]
ftd = env.fake_tensordict()
assert "fen_tokenized" in ftd
assert "fen" in ftd
assert "fen_tokenized" in ftd["next"]
assert "fen" in ftd["next"]
env.check_env_specs()


class TestCustomEnvs:
def test_tictactoe_env(self):
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5042,7 +5042,7 @@ def zero(self, shape: torch.Size = None) -> TensorDictBase:

def __eq__(self, other):
return (
type(self) is type(other)
type(self) == type(other)
and self.shape == other.shape
and self._device == other._device
and set(self._specs.keys()) == set(other._specs.keys())
Expand Down
23 changes: 12 additions & 11 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,17 +718,13 @@ def _create_td(self) -> None:
env_output_keys = set()
env_obs_keys = set()
for meta_data in self.meta_data:
env_obs_keys = env_obs_keys.union(
key
for key in meta_data.specs["output_spec"][
"full_observation_spec"
].keys(True, True)
)
env_output_keys = env_output_keys.union(
meta_data.specs["output_spec"]["full_observation_spec"].keys(
True, True
)
keys = meta_data.specs["output_spec"]["full_observation_spec"].keys(
True, True
)
keys = list(keys)
env_obs_keys = env_obs_keys.union(keys)

env_output_keys = env_output_keys.union(keys)
env_output_keys = env_output_keys.union(self.reward_keys + self.done_keys)
self._env_obs_keys = sorted(env_obs_keys, key=_sort_keys)
self._env_input_keys = sorted(env_input_keys, key=_sort_keys)
Expand Down Expand Up @@ -1003,7 +999,12 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
for i, _env in enumerate(self._envs):
if not needs_resetting[i]:
if out_tds is not None and tensordict is not None:
out_tds[i] = tensordict[i].exclude(*self._envs[i].reset_keys)
ftd = _env.observation_spec.zero()
if self.device is None:
ftd.clear_device_()
else:
ftd = ftd.to(self.device)
out_tds[i] = ftd
continue
if tensordict is not None:
tensordict_ = tensordict[i]
Expand Down
21 changes: 18 additions & 3 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2505,11 +2505,26 @@ def reset(
Returns:
a tensordict (or the input tensordict, if any), modified in place with the resulting observations.

.. note:: `reset` should not be overwritten by :class:`~torchrl.envs.EnvBase` subclasses. The method to
modify is :meth:`~torchrl.envs.EnvBase._reset`.

"""
if tensordict is not None:
self._assert_tensordict_shape(tensordict)

tensordict_reset = self._reset(tensordict, **kwargs)
select_reset_only = kwargs.pop("select_reset_only", False)
if select_reset_only and tensordict is not None:
# When making rollouts with step_and_maybe_reset, it can happen that a tensordict has
# keys that are used by reset to optionally set the reset state (eg, the fen in chess). If that's the
# case and we don't throw them away here, reset will just be a no-op (put the env in the state reached
# during the previous step).
# Therefore, maybe_reset tells reset to temporarily hide the non-reset keys.
# To make step_and_maybe_reset handle custom reset states, some version of TensorDictPrimer should be used.
tensordict_reset = self._reset(
tensordict.select(*self.reset_keys, strict=False), **kwargs
)
else:
tensordict_reset = self._reset(tensordict, **kwargs)
# We assume that this is done properly
# if reset.device != self.device:
# reset = reset.to(self.device, non_blocking=True)
Expand Down Expand Up @@ -3293,7 +3308,7 @@ def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase:
else:
any_done = False
if any_done:
tensordict._set_str(
tensordict = tensordict._set_str(
"_reset",
done.clone(),
validated=True,
Expand All @@ -3307,7 +3322,7 @@ def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase:
key="_reset",
)
if any_done:
tensordict = self.reset(tensordict)
return self.reset(tensordict, select_reset_only=True)
return tensordict

def empty_cache(self):
Expand Down
Loading
Loading