Skip to content

Commit

Permalink
[MinAtar] Add baseline models (#1023)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotetsuk authored Aug 31, 2023
1 parent 328b406 commit 716e98a
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 15 deletions.
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ install-dev:
ipython \
jax[cpu] \
dm-haiku \
pytest-cov
pytest-cov \
pgx-minatar

install-fmt:
python3 -m pip install \
Expand Down
99 changes: 86 additions & 13 deletions pgx/_src/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,38 @@

def make_baseline_model(
model_id: BaselineModelId, download_dir: str = "baselines"
):
if model_id in (
"animal_shogi_v0",
"gardner_chess_v0",
"go_9x9_v0",
"hex_v0",
"othello_v0",
):
return _make_az_baseline_model(model_id, download_dir)
elif model_id in (
"minatar-asterix_v0",
"minatar-breakout_v0",
"minatar-freeway_v0",
"minatar-seaquest_v0",
"minatar-space_invaders_v0",
):
return _make_minatar_baseline_model(model_id, download_dir)
else:
assert False


def _make_az_baseline_model(
model_id: BaselineModelId, download_dir: str = "baselines"
):
import haiku as hk

create_model_fn = _make_create_model_fn(model_id)
model_args, model_params, model_state = _load_baseline_model(
model_id, download_dir
)

def forward_fn(x, is_eval=False):
net = create_model_fn(**model_args)
net = _create_az_model_v0(**model_args)
policy_out, value_out = net(
x, is_training=not is_eval, test_local_stats=False
)
Expand All @@ -44,17 +66,63 @@ def apply(obs):
return apply


def _make_create_model_fn(baseline_model: BaselineModelId):
if baseline_model in (
"animal_shogi_v0",
"gardner_chess_v0",
"go_9x9_v0",
"hex_v0",
"othello_v0",
):
return _create_az_model_v0
else:
assert False
def _make_minatar_baseline_model(
model_id: BaselineModelId, download_dir: str = "baselines"
):
import haiku as hk

model_args, model_params, model_state = _load_baseline_model(
model_id, download_dir
)
del model_state

class ActorCritic(hk.Module):
def __init__(self, num_actions, activation="tanh"):
super().__init__()
self.num_actions = num_actions
self.activation = activation
assert activation in ["relu", "tanh"]

def __call__(self, x):
x = x.astype(jnp.float32)
if self.activation == "relu":
activation = jax.nn.relu
else:
activation = jax.nn.tanh
x = hk.Conv2D(32, kernel_shape=2)(x)
x = jax.nn.relu(x)
x = hk.avg_pool(
x, window_shape=(2, 2), strides=(2, 2), padding="VALID"
)
x = x.reshape((x.shape[0], -1)) # flatten
x = hk.Linear(64)(x)
x = jax.nn.relu(x)
actor_mean = hk.Linear(64)(x)
actor_mean = activation(actor_mean)
actor_mean = hk.Linear(64)(actor_mean)
actor_mean = activation(actor_mean)
actor_mean = hk.Linear(self.num_actions)(actor_mean)

critic = hk.Linear(64)(x)
critic = activation(critic)
critic = hk.Linear(64)(critic)
critic = activation(critic)
critic = hk.Linear(1)(critic)

return actor_mean, jnp.squeeze(critic, axis=-1)

def forward_fn(x):
net = ActorCritic(**model_args)
logits, value = net(x)
return logits, value

forward = hk.without_apply_rng(hk.transform(forward_fn))

def apply(obs):
logits, value = forward.apply(model_params, obs)
return logits, value

return apply


def _load_baseline_model(
Expand All @@ -81,6 +149,11 @@ def _get_download_url(baseline_model: BaselineModelId) -> str:
"go_9x9_v0": "https://drive.google.com/uc?id=1hXMicBALW3WU43NquDoX4zthY4-KjiVu",
"hex_v0": "https://drive.google.com/uc?id=11qpLAT4_0NgPrKRcJCPE7RdN92VP8Ws3",
"othello_v0": "https://drive.google.com/uc?id=1mY40mWoPuYCOrlfMQk_6DPGEFaQcvNAM",
"minatar-asterix_v0": "https://drive.google.com/uc?id=1ohUxhZTYQCwyH-WJRH_Ma9BV3M1WoY0N",
"minatar-breakout_v0": "https://drive.google.com/uc?id=1ED1-p3Gmi4PZEH3hF-9NZzPyNkiPCnvT",
"minatar-freeway_v0": "https://drive.google.com/uc?id=1rbnJGxlzAWkt5DtF7tiYwoNkSqk0l2kD",
"minatar-seaquest_v0": "https://drive.google.com/uc?id=1740nIi00Z8fQWRbA-52GiSGkW7rcqM8o",
"minatar-space_invaders_v0": "https://drive.google.com/uc?id=1I7kJ8GEhY9K3rAFnbnYtlI5KQusFReq9",
}
assert baseline_model in urls
return urls[baseline_model]
Expand Down
7 changes: 6 additions & 1 deletion tests/test_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@ def test_az_basline():
("gardner_chess", "gardner_chess_v0"),
("go_9x9", "go_9x9_v0"),
("hex", "hex_v0"),
("othello", "othello_v0")
("othello", "othello_v0"),
("minatar-asterix", "minatar-asterix_v0"),
("minatar-breakout", "minatar-breakout_v0"),
("minatar-freeway", "minatar-freeway_v0"),
("minatar-seaquest", "minatar-seaquest_v0"),
("minatar-space_invaders", "minatar-space_invaders_v0")
)

for env_id, model_id in test_cases:
Expand Down

0 comments on commit 716e98a

Please sign in to comment.