Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 716777419
Change-Id: I69c13e0fdf4f81c62afa1da4acd784a471112680
  • Loading branch information
Brax Team authored and btaba committed Jan 17, 2025
1 parent 69637a3 commit d48b0b3
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 35 deletions.
1 change: 1 addition & 0 deletions brax/training/acting.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def run_evaluation(
for name, value in eval_metrics.episode_metrics.items()
})
metrics['eval/avg_episode_length'] = np.mean(eval_metrics.episode_steps)
metrics['eval/std_episode_length'] = np.std(eval_metrics.episode_steps)
metrics['eval/epoch_eval_time'] = epoch_eval_time
metrics['eval/sps'] = self._steps_per_unroll / epoch_eval_time
self._eval_walltime = self._eval_walltime + epoch_eval_time
Expand Down
6 changes: 3 additions & 3 deletions brax/training/agents/ars/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class TrainingState:

normalizer_params: running_statistics.RunningStatisticsState
policy_params: Params
num_env_steps: jax.Array
num_env_steps: int


# TODO: Pass the network as argument.
Expand Down Expand Up @@ -289,7 +289,7 @@ def training_epoch(
TrainingState( # type: ignore # jnp-type
normalizer_params=normalizer_params,
policy_params=policy_params,
num_env_steps=jnp.array(num_env_steps, dtype=jnp.int64),
num_env_steps=num_env_steps,
),
metrics,
)
Expand Down Expand Up @@ -323,7 +323,7 @@ def training_epoch_with_timing(
training_state = TrainingState(
normalizer_params=normalizer_params,
policy_params=policy_params,
num_env_steps=jnp.array(0, dtype=jnp.int64),
num_env_steps=0,
)

if not eval_env:
Expand Down
8 changes: 4 additions & 4 deletions brax/training/agents/es/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class TrainingState:
normalizer_params: running_statistics.RunningStatisticsState
optimizer_state: optax.OptState
policy_params: Params
num_env_steps: jax.Array
num_env_steps: int


# Centered rank from: https://arxiv.org/pdf/1703.03864.pdf
Expand Down Expand Up @@ -336,7 +336,7 @@ def training_epoch(

num_env_steps = (
training_state.num_env_steps
+ jnp.sum(obs_weights, dtype=jnp.int64) * action_repeat
+ jnp.sum(obs_weights, dtype=jnp.int32) * action_repeat
)

metrics = {
Expand All @@ -350,7 +350,7 @@ def training_epoch(
normalizer_params=normalizer_params,
optimizer_state=optimizer_state,
policy_params=policy_params,
num_env_steps=jnp.array(num_env_steps, dtype=jnp.int64),
num_env_steps=num_env_steps,
),
metrics,
)
Expand Down Expand Up @@ -386,7 +386,7 @@ def training_epoch_with_timing(
normalizer_params=normalizer_params,
optimizer_state=optimizer_state,
policy_params=policy_params,
num_env_steps=jnp.array(0, dtype=jnp.int64),
num_env_steps=0,
)

if not eval_env:
Expand Down
6 changes: 2 additions & 4 deletions brax/training/agents/ppo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,9 +463,7 @@ def f(carry, unused_t):
optimizer_state=optimizer_state,
params=params,
normalizer_params=normalizer_params,
env_steps=jnp.array(
training_state.env_steps + env_step_per_training_step,
dtype=jnp.int64),
env_steps=training_state.env_steps + env_step_per_training_step,
)
return (new_training_state, state, new_key), metrics

Expand Down Expand Up @@ -525,7 +523,7 @@ def training_epoch_with_timing(
normalizer_params=running_statistics.init_state(
_remove_pixels(obs_shape)
),
env_steps=jnp.array(0, dtype=jnp.int64),
env_steps=0,
)

if (
Expand Down
13 changes: 4 additions & 9 deletions brax/training/agents/sac/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _init_training_state(
q_params=q_params,
target_q_params=q_params,
gradient_steps=jnp.zeros(()),
env_steps=jnp.zeros((), dtype=jnp.int64),
env_steps=jnp.zeros(()),
alpha_optimizer_state=alpha_optimizer_state,
alpha_params=log_alpha,
normalizer_params=normalizer_params,
Expand Down Expand Up @@ -314,7 +314,7 @@ def sgd_step(
q_params=q_params,
target_q_params=new_target_q_params,
gradient_steps=training_state.gradient_steps + 1,
env_steps=jnp.array(training_state.env_steps, dtype=jnp.int64),
env_steps=training_state.env_steps,
alpha_optimizer_state=alpha_optimizer_state,
alpha_params=alpha_params,
normalizer_params=training_state.normalizer_params,
Expand Down Expand Up @@ -367,9 +367,7 @@ def training_step(
)
training_state = training_state.replace(
normalizer_params=normalizer_params,
env_steps=jnp.array(
training_state.env_steps + env_steps_per_actor_step, dtype=jnp.int64
),
env_steps=training_state.env_steps + env_steps_per_actor_step,
)

buffer_state, transitions = replay_buffer.sample(buffer_state)
Expand Down Expand Up @@ -406,10 +404,7 @@ def f(carry, unused):
)
new_training_state = training_state.replace(
normalizer_params=new_normalizer_params,
env_steps=jnp.array(
training_state.env_steps + env_steps_per_actor_step,
dtype=jnp.int64,
),
env_steps=training_state.env_steps + env_steps_per_actor_step,
)
return (new_training_state, env_state, buffer_state, new_key), ()

Expand Down
59 changes: 49 additions & 10 deletions brax/training/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,15 @@
_PPO_POLICY_HIDDEN_LAYER_SIZES = flags.DEFINE_string(
'ppo_policy_hidden_layer_sizes', None, 'PPO policy hidden layer sizes.'
)
_PPO_VALUE_HIDDEN_LAYER_SIZES = flags.DEFINE_string(
'ppo_value_hidden_layer_sizes', None, 'PPO value hidden layer sizes.'
)
_PPO_POLICY_OBS_KEY = flags.DEFINE_string(
'ppo_policy_obs_key', None, 'PPO policy obs key.'
)
_PPO_VALUE_OBS_KEY = flags.DEFINE_string(
'ppo_value_obs_key', None, 'PPO value obs key.'
)
# ARS hps.
_NUMBER_OF_DIRECTIONS = flags.DEFINE_integer(
'number_of_directions',
Expand All @@ -178,22 +187,23 @@
)


def get_env_factory():
def get_env_factory(env_name: str):
"""Returns a function that creates an environment."""
wrap_fn = None
randomizer_fn = None
if _CUSTOM_WRAP_ENV.value:
pass
else:
wrap_env_fn = None
get_environment = functools.partial(
envs.get_environment, backend=_BACKEND.value
)
return get_environment, wrap_env_fn
return get_environment, wrap_fn, randomizer_fn


def main(unused_argv):
logdir = _LOGDIR.value

get_environment, wrap_env_fn = get_env_factory()
get_environment, wrap_fn, randomizer_fn = get_env_factory(_ENV.value)
with metrics.Writer(logdir) as writer:
writer.write_hparams({
'num_evals': _NUM_EVALS.value,
Expand All @@ -208,7 +218,9 @@ def main(unused_argv):
)
make_policy, params, _ = sac.train(
environment=get_environment(_ENV.value),
wrap_env_fn=wrap_env_fn,
eval_env=get_environment(_ENV.value),
wrap_env_fn=wrap_fn,
randomization_fn=randomizer_fn,
num_envs=_NUM_ENVS.value,
action_repeat=_ACTION_REPEAT.value,
normalize_observations=_NORMALIZE_OBSERVATIONS.value,
Expand All @@ -230,7 +242,9 @@ def main(unused_argv):
elif _LEARNER.value == 'es':
make_policy, params, _ = es.train(
environment=get_environment(_ENV.value),
wrap_env_fn=wrap_env_fn,
eval_env=get_environment(_ENV.value),
wrap_env_fn=wrap_fn,
randomization_fn=randomizer_fn,
num_timesteps=_TOTAL_ENV_STEPS.value,
fitness_shaping=es.FitnessShaping[_FITNESS_SHAPING.value.upper()],
population_size=_POPULATION_SIZE.value,
Expand All @@ -253,12 +267,32 @@ def main(unused_argv):
int(x) for x in _PPO_POLICY_HIDDEN_LAYER_SIZES.value.split(',')
]
network_factory = functools.partial(
ppo_networks.make_ppo_networks,
network_factory,
policy_hidden_layer_sizes=policy_hidden_layer_sizes,
)
if _PPO_VALUE_HIDDEN_LAYER_SIZES.value is not None:
value_hidden_layer_sizes = [
int(x) for x in _PPO_VALUE_HIDDEN_LAYER_SIZES.value.split(',')
]
network_factory = functools.partial(
network_factory,
value_hidden_layer_sizes=value_hidden_layer_sizes,
)
if _PPO_POLICY_OBS_KEY.value is not None:
network_factory = functools.partial(
network_factory,
policy_obs_key=_PPO_POLICY_OBS_KEY.value,
)
if _PPO_VALUE_OBS_KEY.value is not None:
network_factory = functools.partial(
network_factory,
value_obs_key=_PPO_VALUE_OBS_KEY.value,
)
make_policy, params, _ = ppo.train(
environment=get_environment(_ENV.value),
wrap_env_fn=wrap_env_fn,
eval_env=get_environment(_ENV.value),
wrap_env_fn=wrap_fn,
randomization_fn=randomizer_fn,
num_timesteps=_TOTAL_ENV_STEPS.value,
episode_length=_EPISODE_LENGTH.value,
network_factory=network_factory,
Expand All @@ -284,7 +318,9 @@ def main(unused_argv):
elif _LEARNER.value == 'apg':
make_policy, params, _ = apg.train(
environment=get_environment(_ENV.value),
wrap_env_fn=wrap_env_fn,
eval_env=get_environment(_ENV.value),
wrap_env_fn=wrap_fn,
randomization_fn=randomizer_fn,
policy_updates=_POLICY_UPDATES.value,
num_envs=_NUM_ENVS.value,
action_repeat=_ACTION_REPEAT.value,
Expand All @@ -300,7 +336,9 @@ def main(unused_argv):
elif _LEARNER.value == 'ars':
make_policy, params, _ = ars.train(
environment=get_environment(_ENV.value),
wrap_env_fn=wrap_env_fn,
eval_env=get_environment(_ENV.value),
wrap_env_fn=wrap_fn,
randomization_fn=randomizer_fn,
number_of_directions=_NUMBER_OF_DIRECTIONS.value,
max_devices_per_host=_MAX_DEVICES_PER_HOST.value,
action_repeat=_ACTION_REPEAT.value,
Expand All @@ -324,6 +362,7 @@ def main(unused_argv):
model.save_params(path, params)

# Output an episode trajectory.
get_environment, *_ = get_env_factory(_ENV.value)
env = get_environment(_ENV.value)

@jax.jit
Expand Down
4 changes: 2 additions & 2 deletions brax/v1/experimental/braxlines/common/dist_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ def clipped_onehot_categorical(logits: jnp.ndarray, clip_range: float = 0):
if clip_range:
assert clip_range > 0.0, clip_range
logits -= jnp.max(logits, axis=-1, keepdims=True)
logits = jnp.clip(logits, a_min=-clip_range)
logits = jnp.clip(logits, min=-clip_range)
return tfd.OneHotCategorical(logits=logits)


def clipped_bernoulli(logits: jnp.ndarray, clip_range: float = 0):
if clip_range:
assert clip_range > 0.0, clip_range
logits = jnp.clip(logits, a_min=-clip_range, a_max=clip_range)
logits = jnp.clip(logits, min=-clip_range, max=clip_range)
return tfd.Bernoulli(logits=logits)
2 changes: 1 addition & 1 deletion brax/v1/experimental/braxlines/irl_smm/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def visualize_disc(
distgrid = disc.dist(
datagrid[..., :len(disc.obs_indices)], params=params['extra'])
probsgrid = jax.nn.sigmoid(distgrid.logits)
colors = jnp.clip(jnp.array([[-2, 0, 2]]) * (probsgrid - 0.5), a_min=0)
colors = jnp.clip(jnp.array([[-2, 0, 2]]) * (probsgrid - 0.5), min=0)
if fig is None or axs is None:
fig, ax = plt.subplots(ncols=1, figsize=figsize)
axs = [ax]
Expand Down
4 changes: 2 additions & 2 deletions brax/v1/experimental/composer/reward_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def direction_reward(action: jnp.ndarray,
agent_sign, lambda x: lax.cond(x, lambda y: jnp.sum(vel0 * y, axis=-1),
lambda y: jnp.zeros_like(x), vel1),
jnp.zeros_like, opp_sign)
return jnp.clip(inner_product, a_min=0.0), jnp.zeros_like(inner_product)
return jnp.clip(inner_product, min=0.0), jnp.zeros_like(inner_product)


def norm_reward(action: jnp.ndarray, obs_dict: Dict[str, jnp.ndarray],
Expand Down Expand Up @@ -166,7 +166,7 @@ def distance_reward(action: jnp.ndarray,
delta = obs1 - obs2
dist = jnp.linalg.norm(delta, axis=-1, **norm_kwargs)
# instead of clipping, terminate
# dist = jnp.clip(dist, a_min=min_dist, a_max=max_dist)
# dist = jnp.clip(dist, min=min_dist, max=max_dist)
done = jnp.zeros_like(dist)
done = jnp.where(dist < min_dist, jnp.ones_like(done), done)
done = jnp.where(dist > max_dist, jnp.ones_like(done), done)
Expand Down

0 comments on commit d48b0b3

Please sign in to comment.