diff --git a/ffn/jax/input_pipeline.py b/ffn/jax/input_pipeline.py index 4ed99d2..f858eac 100644 --- a/ffn/jax/input_pipeline.py +++ b/ffn/jax/input_pipeline.py @@ -335,21 +335,23 @@ def __next__(self): 'weight': batched_weights, } - def update_seeds(self, batched_seeds: np.ndarray | jax.Array): + def update_seeds(self, batched_seeds: list[jax.Array]): """Propagates data from `batched_seeds` back to the example generators.""" def _update( seeds: list[np.ndarray], - batched_seeds: np.ndarray | jax.Array, + batched_seeds: list[jax.Array], current: list[int], ): - # Transfer data from device to host if using a JAX array. + # Transfer data from device to host. batched_seeds = np.array(batched_seeds) # Fold batch dimensions back to a single one. batched_seeds = np.reshape( batched_seeds, [-1] + list(batched_seeds.shape[-4:]) ) + assert batched_seeds.shape[0] == len(seeds) + dx = self._info.input_seed_size[0] - self._info.pred_mask_size[0] dy = self._info.input_seed_size[1] - self._info.pred_mask_size[1] dz = self._info.input_seed_size[2] - self._info.pred_mask_size[2] diff --git a/ffn/jax/train.py b/ffn/jax/train.py index 29ee016..698e9af 100644 --- a/ffn/jax/train.py +++ b/ffn/jax/train.py @@ -62,7 +62,8 @@ class TrainState(flax.struct.PyTreeNode): # pytype: disable=invalid-function-de DataIterator = TypeVar( - 'DataIterator', tf.data.Iterator # + 'DataIterator', + tf.data.Iterator, ) @@ -511,7 +512,7 @@ def train_fn(state, batch, loss_scale): shard_out = ( replicate_sharding, # state replicate_sharding, # metrics - replicate_sharding, # logits + batch_sharding, # logits replicate_sharding, # loss scale ) p_train_step = jax.jit(train_fn, shard_in, shard_out) @@ -610,7 +611,18 @@ def _reshape(x): ) with training.MeasureTime(timings, 'update_seed'): - batch_iter.update_seeds(updated_seed) # pytype: disable=wrong-arg-types # jnp-type + host_local_seeds = [] # [b, z, y, x, 1] * num_devices + dev_to_slice = batch_sharding.addressable_devices_indices_map( + updated_seed.shape + ) + + # Ensure device order is the same as that used to build the + # global array in postprocess_batch(). + assert list(dev_to_slice.keys()) == list(mesh.local_devices) + for slc in dev_to_slice.values(): + host_local_seeds.append(updated_seed[slc]) + + batch_iter.update_seeds(host_local_seeds) with training.MeasureTime(timings, 'admin'): if checkpoint_manager.should_save(step) or is_last_step: diff --git a/ffn/training/examples.py b/ffn/training/examples.py index a729e94..f10e615 100644 --- a/ffn/training/examples.py +++ b/ffn/training/examples.py @@ -126,7 +126,7 @@ def __next__(self): return (batched_seeds, np.concatenate(patches), np.concatenate(labels), batched_weights) - def update_seeds(self, batched_seeds: np.ndarray): + def update_seeds(self, batched_seeds: np.typing.ArrayLike): """Distributes updated predictions back to the generator buffers. Args: