Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix seed updating logic when running in a multihost environment. #88

Merged
merged 1 commit into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions ffn/jax/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,21 +335,23 @@ def __next__(self):
'weight': batched_weights,
}

def update_seeds(self, batched_seeds: np.ndarray | jax.Array):
def update_seeds(self, batched_seeds: list[jax.Array]):
"""Propagates data from `batched_seeds` back to the example generators."""

def _update(
seeds: list[np.ndarray],
batched_seeds: np.ndarray | jax.Array,
batched_seeds: list[jax.Array],
current: list[int],
):
# Transfer data from device to host if using a JAX array.
# Transfer data from device to host.
batched_seeds = np.array(batched_seeds)
# Fold batch dimensions back to a single one.
batched_seeds = np.reshape(
batched_seeds, [-1] + list(batched_seeds.shape[-4:])
)

assert batched_seeds.shape[0] == len(seeds)

dx = self._info.input_seed_size[0] - self._info.pred_mask_size[0]
dy = self._info.input_seed_size[1] - self._info.pred_mask_size[1]
dz = self._info.input_seed_size[2] - self._info.pred_mask_size[2]
Expand Down
18 changes: 15 additions & 3 deletions ffn/jax/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ class TrainState(flax.struct.PyTreeNode): # pytype: disable=invalid-function-de


DataIterator = TypeVar(
'DataIterator', tf.data.Iterator #
'DataIterator',
tf.data.Iterator,
)


Expand Down Expand Up @@ -511,7 +512,7 @@ def train_fn(state, batch, loss_scale):
shard_out = (
replicate_sharding, # state
replicate_sharding, # metrics
replicate_sharding, # logits
batch_sharding, # logits
replicate_sharding, # loss scale
)
p_train_step = jax.jit(train_fn, shard_in, shard_out)
Expand Down Expand Up @@ -610,7 +611,18 @@ def _reshape(x):
)

with training.MeasureTime(timings, 'update_seed'):
batch_iter.update_seeds(updated_seed) # pytype: disable=wrong-arg-types # jnp-type
host_local_seeds = [] # [b, z, y, x, 1] * num_devices
dev_to_slice = batch_sharding.addressable_devices_indices_map(
updated_seed.shape
)

# Ensure device order is the same as that used to build the
# global array in postprocess_batch().
assert list(dev_to_slice.keys()) == list(mesh.local_devices)
for slc in dev_to_slice.values():
host_local_seeds.append(updated_seed[slc])

batch_iter.update_seeds(host_local_seeds)

with training.MeasureTime(timings, 'admin'):
if checkpoint_manager.should_save(step) or is_last_step:
Expand Down
2 changes: 1 addition & 1 deletion ffn/training/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __next__(self):
return (batched_seeds, np.concatenate(patches), np.concatenate(labels),
batched_weights)

def update_seeds(self, batched_seeds: np.ndarray):
def update_seeds(self, batched_seeds: np.typing.ArrayLike):
"""Distributes updated predictions back to the generator buffers.
Args:
Expand Down
Loading