diff --git a/model.py b/model.py index 87d700d..28e4eb0 100644 --- a/model.py +++ b/model.py @@ -44,6 +44,47 @@ def shape(self): return self.weight.shape +def rescale_quantized_weight(weight: jax.Array, scales: jax.Array) -> jax.Array: + """ + Automatically handle broadcasting when total + number of model shards is less than 8. + + Params: + weight: quantized weight array + scales: coefficients for restoring weights. + + Returns: + Array with same shape as weight and same dtype as scales. + """ + + shape_w = weight.shape + shape_s = scales.shape + + # Insert new axis at each mismatched axis. + shape_w_expanded = [] + shape_s_expanded = [] + + # Insert length_w if matched. + # Otherwise, insert (length_s, length_w // length_s) to emulate sharding + for length_w, length_s in zip(shape_w, shape_s): + if (length_w != length_s) and (length_s > 1): + assert length_w % length_s == 0, (length_w, length_s) + shape_w_expanded.extend((length_s, length_w // length_s)) + shape_s_expanded.extend((length_s, 1)) + else: + shape_w_expanded.extend((length_w,)) + shape_s_expanded.extend((length_s,)) + + # Reshape weight along each mismatched axis. + w_expanded = weight.reshape(shape_w_expanded) + s_expanded = scales.reshape(shape_s_expanded) + + output_expanded = w_expanded.astype(s_expanded.dtype) * s_expanded + output = output_expanded.reshape(shape_w) + return output + + + tree_util.register_pytree_node( QuantizedWeight8bit, lambda qw: ([qw.weight, qw.scales], ()), @@ -330,7 +371,7 @@ def _inference_call(self, inputs: jax.Array, padding_mask: Optional[jax.Array] = check_rep=False, ) def moe_slow_matmul1(input, weight, scales, index, prob): - weight = weight * scales + weight = rescale_quantized_weight(weight, scales) one_hot_indices = jax.nn.one_hot(index.reshape(-1), 8, axis=0) all_expert_output = jnp.einsum("mk,bkn->bmn", input, weight) output = jnp.einsum("bm,bmn->mn", one_hot_indices, all_expert_output) @@ -350,7 +391,7 @@ def moe_slow_matmul1(input, weight, scales, index, prob): check_rep=False, ) def moe_slow_matmul2(input, weight, scales, index, prob): - weight = weight * scales + weight = rescale_quantized_weight(weight, scales) one_hot_indices = jax.nn.one_hot(index.reshape(-1), 8, axis=0) all_expert_output = jnp.einsum("mk,bkn->bmn", input, weight) output = jnp.einsum("bm,bmn->mn", one_hot_indices, all_expert_output) @@ -570,7 +611,7 @@ def __call__( check_rep=False, ) def mul(w, s): - return w.astype(s.dtype) * s + return rescale_quantized_weight(w, s) w = mul(w.weight, w.scales) out = jnp.dot(inputs, w.astype(fprop_dtype)) diff --git a/test_modelling.py b/test_modelling.py new file mode 100644 index 0000000..893bc9c --- /dev/null +++ b/test_modelling.py @@ -0,0 +1,23 @@ +import numpy as np + +from model import rescale_quantized_weight + + +def test_rescale(): + weight = np.arange(42).reshape((6, 7)).astype(np.float16) + + # Each row of scales is applied to + # three consecutive rows of weight. + scales = np.arange(2 * 7).reshape((2, 7)).astype(np.int32) + + rescaled_array = rescale_quantized_weight(weight, scales) + assert rescaled_array.shape == weight.shape + assert rescaled_array[:, 0].flatten().tolist() == [ + 0 * 0, + 0 * 7, + 0 * 14, + 7 * 21, + 7 * 28, + 7 * 35, + ] + assert rescaled_array.dtype == np.int32