forked from NTT123/a0-jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
60 lines (42 loc) · 1.46 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""Useful functions."""
import importlib
from functools import partial
from typing import Tuple
import chex
import jax
import jax.numpy as jnp
import pax
from env import Enviroment as E
@pax.pure
def batched_policy(agent, states):
"""Apply a policy to a batch of states.
Also return the updated agent.
"""
return agent, agent(states, batched=True)
def replicate(value: chex.ArrayTree, repeat: int) -> chex.ArrayTree:
"""Replicate along the first axis."""
return jax.tree_util.tree_map(lambda x: jnp.stack([x] * repeat), value)
@pax.pure
def reset_env(env: E) -> E:
"""Return a reset enviroment."""
env.reset()
return env
@jax.jit
def env_step(env: E, action: chex.Array) -> Tuple[E, chex.Array]:
"""Execute one step in the enviroment."""
env, reward = env.step(action)
return env, reward
def import_class(path: str) -> E:
"""Import a class from a python file.
For example:
>> Game = import_class("connect_two_game.Connect2Game")
Game is the Connect2Game class from `connection_two_game.py`.
"""
names = path.split(".")
mod_path, class_name = names[:-1], names[-1]
mod = importlib.import_module(".".join(mod_path))
return getattr(mod, class_name)
def select_tree(pred: jnp.ndarray, a, b):
"""Selects a pytree based on the given predicate."""
assert pred.ndim == 0 and pred.dtype == jnp.bool_, "expected boolean scalar"
return jax.tree_util.tree_map(partial(jax.lax.select, pred), a, b)