Skip to content

Commit

Permalink
move pmap out
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Jun 1, 2023
1 parent 55def9b commit 273a22b
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions mace_jax/tools/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import itertools
import logging
import time
from functools import partial
from typing import Any, Callable, Dict, Optional, Tuple

import jax
Expand Down Expand Up @@ -31,24 +32,24 @@ def train(

logging.info("Started training")

@partial(jax.pmap, in_axes=(None, 0), out_axes=0)
def grad_fn(params, graph: jraph.GraphsTuple):
# graph is assumed to be padded by jraph.pad_with_graphs
mask = jraph.get_graph_padding_mask(graph) # [n_graphs,]
loss, grad = jax.value_and_grad(
lambda params: jnp.sum(jnp.where(mask, total_loss_fn(params, graph), 0.0))
)(params)
return jnp.sum(mask), loss, grad

# jit-of-pmap is not recommended but so far it seems faster
@jax.jit
def update_fn(
params, optimizer_state, ema_params, num_updates: int, graph: jraph.GraphsTuple
) -> Tuple[float, Any, Any]:
if graph.n_node.ndim == 1:
graph = jax.tree_map(lambda x: x[None, ...], graph)

def grad_fn(graph):
# graph is assumed to be padded by jraph.pad_with_graphs
mask = jraph.get_graph_padding_mask(graph) # [n_graphs,]
loss, grad = jax.value_and_grad(
lambda params: jnp.sum(
jnp.where(mask, total_loss_fn(params, graph), 0.0)
)
)(params)
return jnp.sum(mask), loss, grad

n, loss, grad = jax.pmap(grad_fn, axis_name="batch")(graph)
n, loss, grad = grad_fn(params, graph)
loss = jnp.sum(loss) / jnp.sum(n)
grad = jax.tree_map(lambda x: jnp.sum(x, axis=0) / jnp.sum(n), grad)

Expand Down

0 comments on commit 273a22b

Please sign in to comment.