Skip to content

Commit

Permalink
Replace jax.tree_map to jax.tree.map.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 621135413
Change-Id: Iec8bfe6e50e6a4561c4079548fab55d55cbcb352
  • Loading branch information
cdoersch committed Apr 2, 2024
1 parent 30d3e77 commit f6143e1
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 20 deletions.
2 changes: 1 addition & 1 deletion experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def _update_func(
if updates is None:
updates = task_updates
else:
updates = jax.tree.map(jnp.add, updates, task_updates)
updates = jax.tree_map(jnp.add, updates, task_updates)
scalars['gradient_norm'] = optax.global_norm(grads)

# Grab the learning rate to log before performing the step.
Expand Down
2 changes: 1 addition & 1 deletion optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def update_fn(updates, state, params):

u_in, u_ex = hk.data_structures.partition(include, updates)
p_in, _ = hk.data_structures.partition(include, params)
u_in = jax.tree.map(lambda g, p: g + weight_decay * p, u_in, p_in)
u_in = jax.tree_map(lambda g, p: g + weight_decay * p, u_in, p_in)
updates = hk.data_structures.merge(u_ex, u_in)
return updates, state

Expand Down
6 changes: 3 additions & 3 deletions supervised_point_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,14 +930,14 @@ def _eval_epoch(
logging.info('eval batch: %d', batch_id)

# Accumulate the sum of scalars for each step.
scalars = jax.tree.map(lambda x: jnp.sum(x, axis=0), scalars)
scalars = jax.tree_map(lambda x: jnp.sum(x, axis=0), scalars)
if summed_scalars is None:
summed_scalars = scalars
else:
summed_scalars = jax.tree.map(jnp.add, summed_scalars, scalars)
summed_scalars = jax.tree_map(jnp.add, summed_scalars, scalars)

if 'eval_jhmdb' not in mode:
mean_scalars = jax.tree.map(lambda x: x / num_samples, summed_scalars)
mean_scalars = jax.tree_map(lambda x: x / num_samples, summed_scalars)
logging.info(mean_scalars)
logging.info(evaluation_datasets.latex_table(mean_scalars))

Expand Down
14 changes: 7 additions & 7 deletions tapir_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def update(
)

updates, new_opt_state = optimiser.update(gradients, state.opt_state)
updates = jax.tree.map(lambda x: x * lr_mul, updates)
updates = jax.tree_map(lambda x: x * lr_mul, updates)
new_params = optax.apply_updates(state.params, updates)

new_state = TrainingState(
Expand Down Expand Up @@ -567,14 +567,14 @@ def compute_clusters(
separation_tracks = separation_tracks[enough_visible]
separation_visibility = separation_visibility[enough_visible]
if query_features is not None:
query_features = jax.tree.map(
query_features = jax.tree_map(
lambda x: x[:, enough_visible] if len(x.shape) > 1 else x,
query_features,
)
separation_tracks_dict = jax.tree.map(
separation_tracks_dict = jax.tree_map(
lambda x: x[enough_visible], separation_tracks_dict
)
separation_visibility_dict = jax.tree.map(
separation_visibility_dict = jax.tree_map(
lambda x: x[enough_visible], separation_visibility_dict
)

Expand Down Expand Up @@ -846,7 +846,7 @@ def construct_fake_causal_state(
fake_ret = {k: np.zeros(v) for k, v in value_shapes.items()}
fake_ret = [fake_ret] * num_resolutions * 4
if convert_to_jax:
fake_ret = jax.tree.map(jnp.array, fake_ret)
fake_ret = jax.tree_map(jnp.array, fake_ret)
return fake_ret


Expand Down Expand Up @@ -1132,7 +1132,7 @@ def merge_struct(query_features, tmp_query_points):
causal_context=causal_state,
)

prediction = jax.tree.map(np.array, prediction)
prediction = jax.tree_map(np.array, prediction)

res = predictions_to_tracks_visibility(prediction)
separation_tracks.append(res[0])
Expand Down Expand Up @@ -1170,7 +1170,7 @@ def merge_struct(query_features, tmp_query_points):
'video_shape': {
x: separation_video_shapes[i] for i, x in enumerate(demo_episode_ids)
},
'query_features': jax.tree.map(np.array, out_query_features),
'query_features': jax.tree_map(np.array, out_query_features),
'demo_episode_ids': demo_episode_ids,
'query_points': out_query_points,
}
14 changes: 7 additions & 7 deletions tapir_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def refine_pips(
)
x = einshape('bnfc->(bn)fc', mlp_input)
if causal_context is not None:
causal_context = jax.tree.map(
causal_context = jax.tree_map(
lambda x: einshape('bn...->(bn)...', x), causal_context
)
res, new_causal_context = self.pips_mixer(
Expand All @@ -598,7 +598,7 @@ def refine_pips(

res = einshape('(bn)fc->bnfc', res, b=mlp_input.shape[0])
if get_causal_context:
new_causal_context = jax.tree.map(
new_causal_context = jax.tree_map(
lambda x: einshape('(bn)...->bn...', x, b=mlp_input.shape[0]),
new_causal_context,
)
Expand Down Expand Up @@ -938,7 +938,7 @@ def train2orig(x):
perm_chunk = perm[ch : ch + query_chunk_size]
chunk = query_features.lowres[0][:, perm_chunk] + barrier
if causal_context is not None:
cc_chunk = jax.tree.map(lambda x: x[:, perm_chunk], causal_context) # pylint: disable=cell-var-from-loop
cc_chunk = jax.tree_map(lambda x: x[:, perm_chunk], causal_context) # pylint: disable=cell-var-from-loop
if query_points_in_video is not None:
infer_query_points = query_points_in_video[
:, perm[ch : ch + query_chunk_size]
Expand Down Expand Up @@ -1033,7 +1033,7 @@ def train2orig(x):
expd.append(jnp.concatenate(expd_iters[i], axis=1)[:, inv_perm])

for i in range(len(new_causal_context)):
new_causal_context[i] = jax.tree.map(
new_causal_context[i] = jax.tree_map(
lambda *x: jnp.concatenate(x, axis=1)[:, inv_perm],
*new_causal_context[i],
)
Expand Down Expand Up @@ -1163,10 +1163,10 @@ def upd(s1, s2):
return s1.at[:, idx_to_update].set(s2)

query_features = QueryFeatures(
lowres=jax.tree.map(
lowres=jax.tree_map(
upd, query_features.lowres, new_query_features.lowres
),
hires=jax.tree.map(upd, query_features.hires, new_query_features.hires),
hires=jax.tree_map(upd, query_features.hires, new_query_features.hires),
resolutions=query_features.resolutions,
)

Expand All @@ -1175,7 +1175,7 @@ def upd(s1, s2):
len(idx_to_update), len(query_features.resolutions) - 1
)

causal_state = jax.tree.map(upd, causal_state, init_causal_state)
causal_state = jax.tree_map(upd, causal_state, init_causal_state)

return query_features, causal_state

Expand Down
2 changes: 1 addition & 1 deletion utils/experiment_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def save(self, ckpt_series: str) -> None:
if name == 'global_step':
raise ValueError(
'global_step attribute would overwrite jaxline global step')
np_params = jax.tree.map(f_np, getattr(exp_mod, attr))
np_params = jax.tree_map(f_np, getattr(exp_mod, attr))
to_save[name] = np_params
to_save['global_step'] = global_step

Expand Down

0 comments on commit f6143e1

Please sign in to comment.