Implemented automated broadcasting in weight rescale when number of model shards is fewer than number of experts #265
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There are use cases where one might want to shard the quantized model across fewer devices than the number of experts. However, doing so would result in a shape mismatch when attempting to re-scale the weights back to bfloat16 during inference. For example, when generating using the grok-1 weights across one model-parallel shard, one would observe a TypeError from
jax.numpy
:Below is a brief analysis of why this exception happened, and how this PR would address this issue.
In the quantized weights released with this repo, each tensor is represented as a QuantizedWeight8bit. For example, the
w
parameter ofdecoder_layer_0/moe/linear
consists of:weight
tensor of shape(8, 32768, 6144)
and dtypeint8
, andscales
tensor of shape(8, 8, 6144)
and dtypebfloat16
.The modelling code in the grok-1 repo leverages the
jax.experimental.shard_map.shard_map
decorator to ensure that re-scaling the weight matrix does not require cross-device communication. Specifically,moe_slow_matmul1
andmoe_slow_matmul2
are wrapped in theshard_map
decorator to handle parameters from one expert at a time. Note that as seen in the example above,scales
is not directly broadcastable toweight
when computingweight = weight * scale
. Rather,shard_map
would partition weight into eight(8, 4096, 6144)
blocks and scales into eight(8, 1, 6144)
blocks before supplying the partitioned tensors tomoe_slow_matmul1
. Each block of the scales tensor is then broadcasted alongaxis=1
of the corresponding block of the weight tensor.This approach works as expected as long as each model-parallel partition contains exactly one expert. However, when partitioning the pretrained model across fewer devices than experts, the input to
moe_slow_matmul1
would no longer be broadcastable. For example, when running a total of 4 devices for 2 experts per device, the tensors supplied tomoe_slow_matmul1
would be of shape(8, 8192, 6144)
for weights and(8, 2, 6144)
for scales. Promptly, jax.numpy would complain about how the two inputs cannot be broadcasted in the multiplicationweight = weight * shape
(source).This PR proposes a workaround that
reshape
the tensors prior to re-scaling. Since the proposed changes are wrapped entirely inside theshard_map
decorator, the proposed reshape logic will not require communication between devices. When the number of experts matches the number of model parallelism shards, the proposed behavior would be equivalent to that of the original reference implementation.