You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I think it's not too bad to implement pipeline parallelism directly in Stacked. The basic idea is that we map the Layers axis of a Stacked to a (new) physical axis (called stage here and in the link), then we reshape our batch into microbatches and push through the pipeline.
The biggest thing that's not clear to me is partitioning of the (macro) batch itself. Easiest thing to do is replicate it across the stage axis, but i think that's not ideal. should take a look at an impl of pipeline parallelism
The text was updated successfully, but these errors were encountered:
I think it's not too bad to implement pipeline parallelism directly in Stacked. The basic idea is that we map the Layers axis of a Stacked to a (new) physical axis (called
stage
here and in the link), then we reshape our batch into microbatches and push through the pipeline.Example implementation https://github.com/tensorflow/lingvo/blob/master/lingvo/jax/layers/pipeline.py (which looks a lot like accumulate_gradients_sharded)
The biggest thing that's not clear to me is partitioning of the (macro) batch itself. Easiest thing to do is replicate it across the stage axis, but i think that's not ideal. should take a look at an impl of pipeline parallelism
The text was updated successfully, but these errors were encountered: