Skip to content

Commit

Permalink
Fix seed updating logic when running in a multihost environment.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 676570757
  • Loading branch information
mjanusz authored and copybara-github committed Sep 19, 2024
1 parent 76254f5 commit e4e3a34
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 7 deletions.
8 changes: 5 additions & 3 deletions ffn/jax/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,21 +335,23 @@ def __next__(self):
'weight': batched_weights,
}

def update_seeds(self, batched_seeds: np.ndarray | jax.Array):
def update_seeds(self, batched_seeds: list[jax.Array]):
"""Propagates data from `batched_seeds` back to the example generators."""

def _update(
seeds: list[np.ndarray],
batched_seeds: np.ndarray | jax.Array,
batched_seeds: list[jax.Array],
current: list[int],
):
# Transfer data from device to host if using a JAX array.
# Transfer data from device to host..
batched_seeds = np.array(batched_seeds)
# Fold batch dimensions back to a single one.
batched_seeds = np.reshape(
batched_seeds, [-1] + list(batched_seeds.shape[-4:])
)

assert batched_seeds.shape[0] == len(seeds)

dx = self._info.input_seed_size[0] - self._info.pred_mask_size[0]
dy = self._info.input_seed_size[1] - self._info.pred_mask_size[1]
dz = self._info.input_seed_size[2] - self._info.pred_mask_size[2]
Expand Down
18 changes: 15 additions & 3 deletions ffn/jax/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ class TrainState(flax.struct.PyTreeNode): # pytype: disable=invalid-function-de


DataIterator = TypeVar(
'DataIterator', tf.data.Iterator #
'DataIterator',
tf.data.Iterator, #
)


Expand Down Expand Up @@ -511,7 +512,7 @@ def train_fn(state, batch, loss_scale):
shard_out = (
replicate_sharding, # state
replicate_sharding, # metrics
replicate_sharding, # logits
batch_sharding, # logits
replicate_sharding, # loss scale
)
p_train_step = jax.jit(train_fn, shard_in, shard_out)
Expand Down Expand Up @@ -610,7 +611,18 @@ def _reshape(x):
)

with training.MeasureTime(timings, 'update_seed'):
batch_iter.update_seeds(updated_seed) # pytype: disable=wrong-arg-types # jnp-type
host_local_seeds = [] # [b, z, y, x, 1] * num_devices
dev_to_slice = batch_sharding.addressable_devices_indices_map(
updated_seed.shape
)

# Ensure device order is the same as that used to build the
# global array in postprocess_batch().
assert list(dev_to_slice.keys()) == list(mesh.local_devices)
for slc in dev_to_slice.values():
host_local_seeds.append(updated_seed[slc])

batch_iter.update_seeds(host_local_seeds)

with training.MeasureTime(timings, 'admin'):
if checkpoint_manager.should_save(step) or is_last_step:
Expand Down
2 changes: 1 addition & 1 deletion ffn/training/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __next__(self):
return (batched_seeds, np.concatenate(patches), np.concatenate(labels),
batched_weights)

def update_seeds(self, batched_seeds: np.ndarray):
def update_seeds(self, batched_seeds: np.typing.ArrayLike):
"""Distributes updated predictions back to the generator buffers.
Args:
Expand Down

0 comments on commit e4e3a34

Please sign in to comment.