Skip to content

Commit

Permalink
[SparrowMahjong] Transpose observation (15 x 11 to 11 x 15) (#1010)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk authored Aug 21, 2023
1 parent 29e6ab0 commit 75abefe
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 68 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ Use `pgx.available_envs() -> Tuple[EnvId]` to see the list of currently availabl
|<a href="https://github.com/kenjyoung/MinAtar">MinAtar/SpaceInvaders</a><br>`"minatar-space_invaders"` |<img src="https://raw.githubusercontent.com/sotetsuk/pgx/main/docs/assets/minatar-space_invaders.gif" width="50px">| `v0` | *Alien shooter game, dodge bullets.* |
|<a href="https://en.wikipedia.org/wiki/Reversi">Othello</a><br>`"othello"` |<img src="https://raw.githubusercontent.com/sotetsuk/pgx/main/docs/assets/othello_dark.gif" width="60px"><img src="https://raw.githubusercontent.com/sotetsuk/pgx/main/docs/assets/othello_light.gif" width="60px">| `v0` | *Flip and conquer opponent's pieces.* |
|<a href="https://en.wikipedia.org/wiki/Shogi">Shogi</a><br>`"shogi"` |<img src="https://raw.githubusercontent.com/sotetsuk/pgx/main/docs/assets/shogi_dark.gif" width="60px"><img src="https://raw.githubusercontent.com/sotetsuk/pgx/main/docs/assets/shogi_light.gif" width="60px"> | `v0` | *Japanese chess with captured pieces.* |
|<a href="https://sugorokuya.jp/p/suzume-jong">Sparrow Mahjong</a><br>`"sparrow_mahjong"` |<img src="https://raw.githubusercontent.com/sotetsuk/pgx/main/docs/assets/sparrow_mahjong_dark.svg" width="60px"><img src="https://raw.githubusercontent.com/sotetsuk/pgx/main/docs/assets/sparrow_mahjong_light.svg" width="60px">| `v0` | *A simplified, children-friendly Mahjong.* |
|<a href="https://sugorokuya.jp/p/suzume-jong">Sparrow Mahjong</a><br>`"sparrow_mahjong"` |<img src="https://raw.githubusercontent.com/sotetsuk/pgx/main/docs/assets/sparrow_mahjong_dark.svg" width="60px"><img src="https://raw.githubusercontent.com/sotetsuk/pgx/main/docs/assets/sparrow_mahjong_light.svg" width="60px">| `v1` | *A simplified, children-friendly Mahjong.* |
|<a href="https://en.wikipedia.org/wiki/Tic-tac-toe">Tic-tac-toe</a><br>`"tic_tac_toe"` |<img src="https://raw.githubusercontent.com/sotetsuk/pgx/main/docs/assets/tic_tac_toe_dark.gif" width="60px"><img src="https://raw.githubusercontent.com/sotetsuk/pgx/main/docs/assets/tic_tac_toe_light.gif" width="60px">| `v0` | *Three in a row wins.* |

- <a href="https://en.wikipedia.org/wiki/Japanese_mahjong">Mahjong</a> environments are under development 🚧 If you have any requests for new environments, please let us know by [opening an issue](https://github.com/sotetsuk/pgx/issues/new)
Expand Down
6 changes: 3 additions & 3 deletions docs/sparrow_mahjong.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ Pgx implementation is simplified as follows:

| Name | Value |
|:---|:----:|
| Version | `v0` |
| Version | `v1` |
| Number of players | `3` |
| Number of actions | `11` |
| Observation shape | `(15, 11)` |
| Observation shape | `(11, 15)` |
| Observation type | `bool` |
| Rewards | `[-1, 1]` |

Expand Down Expand Up @@ -97,4 +97,4 @@ Terminates when either player wins or the wall becomes empty.

## Version History

- `v0` : Initial release (v1.0.0)
- `v1` : Initial release (v1.0.0)
4 changes: 2 additions & 2 deletions pgx/sparrow_mahjong.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def id(self) -> v1.EnvId:

@property
def version(self) -> str:
return "v0"
return "v1"

@property
def num_players(self) -> int:
Expand Down Expand Up @@ -498,7 +498,7 @@ def _observe(state: State, player_id: jnp.ndarray):
),
lambda: obs,
)
return obs
return jnp.transpose(obs)


def _tile_type_to_str(tile_type) -> str:
Expand Down
124 changes: 62 additions & 62 deletions tests/test_sparrow_mahjong.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,17 +500,17 @@ def test_observe():
state = step(state, jnp.int32(1))
print(_to_str(state))
obs = observe(state, player_id=jnp.int8(2))
assert obs.shape[0] == 15
assert obs.shape[1] == 11
assert jnp.all(obs[0] == jnp.bool_([0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0]))
assert jnp.all(obs[1] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[2] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[3] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[4] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0]))
assert jnp.all(obs[5] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))
assert jnp.all(obs[6] == jnp.bool_([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[7] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[8] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert obs.shape[0] == 11
assert obs.shape[1] == 15
assert jnp.all(obs[:, 0] == jnp.bool_([0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0]))
assert jnp.all(obs[:, 1] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[:, 2] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[:, 3] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[:, 4] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0]))
assert jnp.all(obs[:, 5] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))
assert jnp.all(obs[:, 6] == jnp.bool_([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[:, 7] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[:, 8] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))

seed = 5
key = jax.random.PRNGKey(seed)
Expand All @@ -529,59 +529,59 @@ def test_observe():
[1] 1 2*3 4 5 : r*_ _ _ _ _ _ _ _ _
"""
obs = observe(state, player_id=jnp.int8(0))
assert obs.shape[0] == 15
assert obs.shape[1] == 11
assert jnp.all(obs[0] == jnp.bool_([1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0]))
assert jnp.all(obs[1] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[2] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[3] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[4] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[5] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[6] == jnp.bool_([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1]))
assert jnp.all(obs[7] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[8] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))
assert jnp.all(obs[9] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[10] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[11] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[12] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))
assert jnp.all(obs[13] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[14] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert obs.shape[0] == 11
assert obs.shape[1] == 15
assert jnp.all(obs[:, 0] == jnp.bool_([1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0]))
assert jnp.all(obs[:, 1] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[:, 2] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[:, 3] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[:, 4] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[:, 5] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[:, 6] == jnp.bool_([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1]))
assert jnp.all(obs[:, 7] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[:, 8] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))
assert jnp.all(obs[:, 9] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[:, 10] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[:, 11] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[:, 12] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))
assert jnp.all(obs[:, 13] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[:, 14] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
obs = observe(state, player_id=jnp.int8(1))
assert obs.shape[0] == 15
assert obs.shape[1] == 11
assert jnp.all(obs[0] == jnp.bool_([1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[1] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[2] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[3] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[4] == jnp.bool_([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[5] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[6] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))
assert jnp.all(obs[7] == jnp.bool_([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1]))
assert jnp.all(obs[8] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[9] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))
assert jnp.all(obs[10] == jnp.bool_([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]))
assert jnp.all(obs[11] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[12] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[13] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[14] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert obs.shape[0] == 11
assert obs.shape[1] == 15
assert jnp.all(obs[:, 0] == jnp.bool_([1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[:, 1] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[:, 2] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[:, 3] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[:, 4] == jnp.bool_([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[:, 5] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[:, 6] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))
assert jnp.all(obs[:, 7] == jnp.bool_([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1]))
assert jnp.all(obs[:, 8] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[:, 9] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))
assert jnp.all(obs[:, 10] == jnp.bool_([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]))
assert jnp.all(obs[:, 11] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[:, 12] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[:, 13] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[:, 14] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
obs = observe(state, player_id=jnp.int8(2))
assert obs.shape[0] == 15
assert obs.shape[1] == 11
assert jnp.all(obs[0] == jnp.bool_([0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0]))
assert jnp.all(obs[1] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]))
assert jnp.all(obs[2] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[3] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[4] == jnp.bool_([0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0]))
assert jnp.all(obs[5] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[6] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[7] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))
assert jnp.all(obs[8] == jnp.bool_([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1]))
assert jnp.all(obs[9] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))
assert jnp.all(obs[10] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[11] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[12] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))
assert jnp.all(obs[13] == jnp.bool_([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]))
assert jnp.all(obs[14] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert obs.shape[0] == 11
assert obs.shape[1] == 15
assert jnp.all(obs[:, 0] == jnp.bool_([0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0]))
assert jnp.all(obs[:, 1] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]))
assert jnp.all(obs[:, 2] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[:, 3] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[:, 4] == jnp.bool_([0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0]))
assert jnp.all(obs[:, 5] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[:, 6] == jnp.bool_([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[:, 7] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))
assert jnp.all(obs[:, 8] == jnp.bool_([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1]))
assert jnp.all(obs[:, 9] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))
assert jnp.all(obs[:, 10] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[:, 11] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
assert jnp.all(obs[:, 12] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]))
assert jnp.all(obs[:, 13] == jnp.bool_([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]))
assert jnp.all(obs[:, 14] == jnp.bool_([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))


def test_api():
Expand Down

0 comments on commit 75abefe

Please sign in to comment.