diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 4c46eafb4c..2190c6df6c 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -8,8 +8,8 @@ import operator from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union -import jax.numpy as jnp import numpy as np +import jax.numpy as jnp from flax import linen as nn from flax.linen import partitioning as nn_partitioning from jax import lax @@ -57,14 +57,18 @@ def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_ga def _create_layernorm_parameters( - layernorm_type, shape, scale_init, scale_axes, bias_init, bias_axes, dtype + layernorm_type, shape, scale_init, scale_axes, bias_init, bias_axes, dtype, weight_dtype ): - scale = nn_partitioning.param_with_axes("scale", scale_init, shape, dtype, axes=scale_axes) + scale = nn_partitioning.param_with_axes( + "scale", scale_init, shape, weight_dtype, axes=scale_axes + ) scale = scale.astype(dtype) layernorm_type = canonicalize_layernorm_type(layernorm_type) if layernorm_type == "layernorm": - bias = nn_partitioning.param_with_axes("ln_bias", bias_init, shape, dtype, axes=bias_axes) + bias = nn_partitioning.param_with_axes( + "ln_bias", bias_init, shape, weight_dtype, axes=bias_axes + ) bias = bias.astype(dtype) else: assert layernorm_type == "rmsnorm" @@ -256,8 +260,10 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods Optimization parameters ----------------------- - dtype : jax.numpy.dtype, default = jax.numpy.float32 - the data type used to allocate the initial parameters. + dtype: jax.numpy.dtype, default = jax.numpy.float32 + The data type used for computation. + weight_dtype: jax.numpy.dtype, default = jax.numpy.float32 + The data type of the module parameters. transpose_batch_sequence : bool, default = False Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors @@ -272,6 +278,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods bias_init: Initializer = nn.initializers.zeros bias_axes: Tuple[str, ...] = ("embed",) dtype: DType = jnp.float32 + weight_dtype: DType = jnp.float32 transpose_batch_sequence: bool = False def __post_init__(self): @@ -307,6 +314,7 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: self.bias_init, self.bias_axes, self.dtype, + self.weight_dtype, ) return layernorm( x, @@ -399,8 +407,10 @@ class DenseGeneral(TransformerEngineBase): Optimization parameters ----------------------- - dtype : jax.numpy.dtype, default = jax.numpy.float32 - The data type used to allocate the initial parameters. + dtype: jax.numpy.dtype, default = jax.numpy.float32 + The data type used for computation. + weight_dtype: jax.numpy.dtype, default = jax.numpy.float32 + The data type of the module parameters. transpose_batch_sequence : bool, default = True Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors @@ -418,12 +428,13 @@ class DenseGeneral(TransformerEngineBase): low_rank_adaptation_alpha: float = None axis: Union[Iterable[int], int] = -1 dtype: DType = jnp.float32 + weight_dtype: DType = jnp.float32 transpose_batch_sequence: bool = False def __post_init__(self): if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling( - 1.0, "fan_in", "truncated_normal", dtype=self.dtype + 1.0, "fan_in", "truncated_normal", dtype=self.weight_dtype ) super().__post_init__() @@ -452,13 +463,13 @@ def __call__(self, inputs: Array) -> Array: kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features kernel = nn_partitioning.param_with_axes( - "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes + "kernel", self.kernel_init, kernel_shape, self.weight_dtype, axes=self.kernel_axes ) kernel = kernel.astype(self.dtype) if self.use_bias: bias = nn_partitioning.param_with_axes( - "bias", self.bias_init, features, self.dtype, axes=self.bias_axes + "bias", self.bias_init, features, self.weight_dtype, axes=self.bias_axes ) bias = bias.astype(self.dtype) else: @@ -489,7 +500,7 @@ def __call__(self, inputs: Array) -> Array: "lora_a_kernel", self.kernel_init, lora_a_kernel_init_shape, - self.dtype, + self.weight_dtype, axes=lora_a_kernel_axes, ) lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape) @@ -501,7 +512,7 @@ def __call__(self, inputs: Array) -> Array: "lora_b_kernel", nn.initializers.zeros, lora_b_kernel_shape, - self.dtype, + self.weight_dtype, axes=lora_b_kernel_axes, ) lora_b_kernel = lora_b_kernel.astype(self.dtype) @@ -594,8 +605,10 @@ class LayerNormDenseGeneral(TransformerEngineBase): Optimization parameters ----------------------- - dtype : jax.numpy.dtype, default = jax.numpy.float32 - The data type used to allocate the initial parameters. + dtype: jax.numpy.dtype, default = jax.numpy.float32 + The data type used for computation. + weight_dtype: jax.numpy.dtype, default = jax.numpy.float32 + The data type of the module parameters. transpose_batch_sequence : bool, default = True Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors @@ -625,6 +638,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): low_rank_adaptation_alpha: float = None axis: Union[Iterable[int], int] = -1 dtype: DType = jnp.float32 + weight_dtype: DType = jnp.float32 transpose_batch_sequence: bool = True layernorm_input_axes: Tuple[str, ...] = None dot_input_axes: Tuple[str, ...] = None @@ -633,7 +647,10 @@ class LayerNormDenseGeneral(TransformerEngineBase): def __post_init__(self): if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling( - 1.0, "fan_in", "truncated_normal", dtype=self.dtype + 1.0, + "fan_in", + "truncated_normal", + dtype=self.weight_dtype, ) self.scale_init = _obtain_default_layernorm_scale_init_if_need( self.scale_init, @@ -683,6 +700,7 @@ def __call__(self, inputs: Array) -> Array: self.ln_bias_init, self.ln_bias_axes, self.dtype, + self.weight_dtype, ) if not fuse_layernorm: @@ -712,7 +730,7 @@ def __call__(self, inputs: Array) -> Array: kernel_shape = tuple(y.shape[ax] for ax in axis) + features kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features kernel = nn_partitioning.param_with_axes( - "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes + "kernel", self.kernel_init, kernel_shape, self.weight_dtype, axes=self.kernel_axes ) kernel = kernel.astype(self.dtype) @@ -757,7 +775,7 @@ def __call__(self, inputs: Array) -> Array: "lora_a_kernel", self.kernel_init, lora_a_kernel_init_shape, - self.dtype, + self.weight_dtype, axes=lora_a_kernel_axes, ) lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape) @@ -769,7 +787,7 @@ def __call__(self, inputs: Array) -> Array: "lora_b_kernel", nn.initializers.zeros, lora_b_kernel_shape, - self.dtype, + self.weight_dtype, axes=lora_b_kernel_axes, ) lora_b_kernel = lora_b_kernel.astype(self.dtype) @@ -781,7 +799,7 @@ def __call__(self, inputs: Array) -> Array: bias = None if self.use_bias: bias = nn_partitioning.param_with_axes( - "bias", self.bias_init, features, self.dtype, axes=self.bias_axes + "bias", self.bias_init, features, self.weight_dtype, axes=self.bias_axes ) bias = bias.astype(self.dtype) @@ -896,8 +914,10 @@ class LayerNormMLP(TransformerEngineBase): Optimization parameters ----------------------- - dtype : jax.numpy.dtype, default = jax.numpy.float32 - The data type used to allocate the initial parameters. + dtype: jax.numpy.dtype, default = jax.numpy.float32 + The data type used for computation. + weight_dtype: jax.numpy.dtype, default = jax.numpy.float32 + The data type of the module parameters. transpose_batch_sequence : bool, default = True Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors @@ -930,6 +950,7 @@ class LayerNormMLP(TransformerEngineBase): low_rank_adaptation_alpha: float = None axis: Union[Iterable[int], int] = -1 dtype: DType = jnp.float32 + weight_dtype: DType = jnp.float32 transpose_batch_sequence: bool = True layernorm_input_axes: Tuple[str, ...] = None dot_1_input_axes: Tuple[str, ...] = None @@ -938,7 +959,7 @@ class LayerNormMLP(TransformerEngineBase): def __post_init__(self): if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling( - 1.0, "fan_in", "truncated_normal", dtype=self.dtype + 1.0, "fan_in", "truncated_normal", dtype=self.weight_dtype ) self.scale_init = _obtain_default_layernorm_scale_init_if_need( self.scale_init, @@ -1015,6 +1036,7 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: self.ln_bias_init, self.ln_bias_axes, self.dtype, + self.weight_dtype, ) if not fuse_layernorm: @@ -1061,7 +1083,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): num_activations, -2, kernel_1_each_shape, - self.dtype, + self.weight_dtype, axes=self.kernel_axes_1, ) kernel_1 = jnp.reshape(kernel_1, kernel_1_shape) @@ -1074,7 +1096,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): "wo_kernel", self.kernel_init, kernel_2_param_shape, - self.dtype, + self.weight_dtype, axes=self.kernel_axes_2, ) kernel_2 = jnp.reshape(kernel_2, kernel_2_shape) @@ -1090,13 +1112,21 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): if self.use_bias: bias_1_shape = intermediate_dim bias_1 = nn_partitioning.param_with_axes( - "wi_bias", self.bias_init, bias_1_shape, self.dtype, axes=self.bias_axes_1 + "wi_bias", + self.bias_init, + bias_1_shape, + self.weight_dtype, + axes=self.bias_axes_1, ) bias_1 = bias_1.astype(self.dtype) bias_2_shape = (hidden_size,) bias_2 = nn_partitioning.param_with_axes( - "wo_bias", self.bias_init, bias_2_shape, self.dtype, axes=self.bias_axes_2 + "wo_bias", + self.bias_init, + bias_2_shape, + self.weight_dtype, + axes=self.bias_axes_2, ) bias_2 = bias_2.astype(self.dtype) else: @@ -1165,7 +1195,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): num_activations, -2, wi_lora_a_kernel_init_each_shape, - self.dtype, + self.weight_dtype, axes=wi_lora_a_kernel_axes, ) wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape) @@ -1181,7 +1211,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): "wi_lora_b_kernel", nn.initializers.zeros, wi_lora_b_kernel_shape, - self.dtype, + self.weight_dtype, axes=wi_lora_b_kernel_axes, ) wi_lora_b_kernel = wi_lora_b_kernel.astype(self.dtype) @@ -1198,7 +1228,11 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): bias_1 = None if self.use_bias: bias_1 = nn_partitioning.param_with_axes( - "wi_bias", self.bias_init, intermediate_dim, self.dtype, axes=self.bias_axes_1 + "wi_bias", + self.bias_init, + intermediate_dim, + self.weight_dtype, + axes=self.bias_axes_1, ) bias_1_shape = (1,) * (x.ndim - bias_1.ndim) + bias_1.shape bias_1 = bias_1.astype(self.dtype) @@ -1240,7 +1274,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): "wo_lora_a_kernel", self.kernel_init, wo_lora_a_kernel_shape, - self.dtype, + self.weight_dtype, axes=wo_lora_a_kernel_axes, ) wo_lora_a_kernel = wo_lora_a_kernel.astype(self.dtype) @@ -1251,7 +1285,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): "wo_lora_b_kernel", nn.initializers.zeros, wo_lora_b_kernel_shape, - self.dtype, + self.weight_dtype, axes=wo_lora_b_kernel_axes, ) wo_lora_b_kernel = wo_lora_b_kernel.astype(self.dtype) @@ -1268,7 +1302,11 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): bias_2 = None if self.use_bias: bias_2 = nn_partitioning.param_with_axes( - "wo_bias", self.bias_init, (hidden_size,), self.dtype, axes=self.bias_axes_2 + "wo_bias", + self.bias_init, + (hidden_size,), + self.weight_dtype, + axes=self.bias_axes_2, ) bias_2 = bias_2.astype(self.dtype) out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,)) diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 89278f720b..6c96e7ba1a 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -115,6 +115,7 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public- attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK attn_bias_type: Optional[AttnBiasType] = None dtype: DType = jnp.float32 + weight_dtype: DType = jnp.float32 float32_logits: bool = False scale_factor: Optional[float] = None transpose_batch_sequence: bool = True @@ -261,6 +262,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me attn_mask_type: AttnMaskType = AttnMaskType.CAUSAL_MASK attn_bias_type: Optional[AttnBiasType] = None dtype: DType = jnp.float32 + weight_dtype: DType = jnp.float32 qkv_layout: QKVLayout = QKVLayout.BSHD_BSHD_BSHD scale_factor: Optional[float] = None transpose_batch_sequence: bool = False @@ -480,8 +482,10 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods Optimization parameters ----------------------- - dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type used to allocate the initial parameters. + dtype: jax.numpy.dtype, default = jax.numpy.float32 + The data type used for computation. + weight_dtype: jax.numpy.dtype, default = jax.numpy.float32 + The data type of the module parameters. """ head_dim: int @@ -491,6 +495,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods attn_mask_type: AttnMaskType = "causal" attn_bias_type: AttnBiasType = None dtype: DType = jnp.float32 + weight_dtype: DType = jnp.float32 dropout_rng_name: str = "dropout" float32_logits: bool = False qkv_layout: str = "bshd_bshd_bshd" @@ -615,6 +620,7 @@ def __call__( attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type, dtype=self.dtype, + weight_dtype=self.weight_dtype, float32_logits=self.float32_logits, scale_factor=scale_factor, transpose_batch_sequence=self.transpose_batch_sequence, @@ -626,6 +632,7 @@ def __call__( attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type, dtype=self.dtype, + weight_dtype=self.weight_dtype, scale_factor=scale_factor, transpose_batch_sequence=self.transpose_batch_sequence, qkv_layout=qkv_layout, @@ -880,8 +887,10 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods Optimization parameters ----------------------- - dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type used to allocate the initial parameters. + dtype: jax.numpy.dtype, default = jax.numpy.float32 + The data type used for computation. + weight_dtype: jax.numpy.dtype, default = jax.numpy.float32 + The data type of the module parameters. fuse_qkv_params: bool, default = True If set to True, this module exposes a single fused parameter for query-key-value for self-attention and key-value for @@ -927,6 +936,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods low_rank_adaptation_dim: int = 32 low_rank_adaptation_alpha: float = None dtype: DType = jnp.float32 + weight_dtype: DType = jnp.float32 fuse_qkv_params: bool = True transpose_batch_sequence: bool = True enable_sequence_parallel: bool = False @@ -977,7 +987,7 @@ def __post_init__(self): if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling( - 1.0, "fan_in", "normal", dtype=self.dtype + 1.0, "fan_in", "normal", self.weight_dtype ) if self.num_gqa_groups is None: self.num_gqa_groups = self.num_attention_heads @@ -1105,6 +1115,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): dot_input_axes=inputs_logical_axes_no_sp, name="qkv", dtype=self.dtype, + weight_dtype=self.weight_dtype, )(inputs_q) qkv_proj = checkpoint_name(qkv_proj, "combined_qkv_proj") qkv_layout = QKVLayout.BS3HD @@ -1128,6 +1139,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): low_rank_adaptation_dim=self.low_rank_adaptation_dim, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, dtype=self.dtype, + weight_dtype=self.weight_dtype, kernel_init=query_init, layernorm_input_axes=inputs_logical_axes_maybe_sp, dot_input_axes=inputs_logical_axes_no_sp, @@ -1152,6 +1164,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, name="kv", dtype=self.dtype, + weight_dtype=self.weight_dtype, )(inputs_kv) kv_proj = checkpoint_name(kv_proj, "combined_kv_proj") qkv_layout = QKVLayout.BSHD_BS2HD @@ -1169,6 +1182,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): low_rank_adaptation_dim=self.low_rank_adaptation_dim, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, dtype=self.dtype, + weight_dtype=self.weight_dtype, ) query, ln_out = LayerNormDenseGeneral( enable_layernorm=self.input_layernorm, @@ -1189,6 +1203,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): low_rank_adaptation_dim=self.low_rank_adaptation_dim, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, dtype=self.dtype, + weight_dtype=self.weight_dtype, kernel_init=query_init, layernorm_input_axes=inputs_logical_axes_maybe_sp, dot_input_axes=inputs_logical_axes_no_sp, @@ -1326,6 +1341,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): attn_bias_type=self.attn_bias_type, attention_dropout=self.attention_dropout, dtype=self.dtype, + weight_dtype=self.weight_dtype, dropout_rng_name=self.dropout_rng_name, float32_logits=self.float32_logits, qkv_layout=qkv_layout.name, @@ -1351,6 +1367,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): low_rank_adaptation_dim=self.low_rank_adaptation_dim, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, dtype=self.dtype, + weight_dtype=self.weight_dtype, name="out", )(x) out = checkpoint_name(out, "out_proj") @@ -1379,7 +1396,9 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho Optimization parameters ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type used to allocate the initial parameters. + The data type used for computation. + weight_dtype: jax.numpy.dtype, default = jax.numpy.float32 + The data type of the module parameters. """ num_buckets: int @@ -1388,6 +1407,7 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho embedding_init: Callable[..., Array] = nn.linear.default_embed_init embedding_axes: Tuple[str, ...] = ("heads", "relpos_buckets") dtype: DType = jnp.float32 + weight_dtype: DType = jnp.float32 @nn.compact def __call__(self, q_seqlen, k_seqlen, bidirectional=True): @@ -1440,7 +1460,7 @@ def __call__(self, q_seqlen, k_seqlen, bidirectional=True): "rel_embedding", self.embedding_init, (self.num_attention_heads, self.num_buckets), - self.dtype, + self.weight_dtype, axes=self.embedding_axes, ) @@ -1613,7 +1633,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods Optimization parameters ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type used to allocate the initial parameters. + The data type used for computation. + weight_dtype: jax.numpy.dtype, default = jax.numpy.float32 + The data type of the module parameters. drop_path: float, default = 0.0 When > 0.0, applies stochastic depth per sample in the main path of the residual block. @@ -1666,6 +1688,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods low_rank_adaptation_dim: int = 32 low_rank_adaptation_alpha: float = None dtype: DType = jnp.float32 + weight_dtype: DType = jnp.float32 drop_path: float = 0.0 fuse_qkv_params: bool = True transpose_batch_sequence: bool = False @@ -1677,11 +1700,11 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods def __post_init__(self): if self.mha_kernel_init is None: self.mha_kernel_init = nn.initializers.variance_scaling( - 1.0, "fan_in", "normal", dtype=self.dtype + 1.0, "fan_in", "normal", dtype=self.weight_dtype ) if self.mlp_kernel_init is None: self.mlp_kernel_init = nn.initializers.variance_scaling( - 1.0, "fan_in", "truncated_normal", dtype=self.dtype + 1.0, "fan_in", "truncated_normal", dtype=self.weight_dtype ) if self.num_gqa_groups is None: self.num_gqa_groups = self.num_attention_heads @@ -1771,6 +1794,7 @@ def generate_batch_seqlen_logical_axes(is_shared_seq=None): max_distance=128, num_attention_heads=self.num_attention_heads, dtype=self.dtype, + weight_dtype=self.weight_dtype, embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"), name="relpos_bias", ) @@ -1804,6 +1828,7 @@ def generate_batch_seqlen_logical_axes(is_shared_seq=None): x, ln_out = MultiHeadAttention( num_attention_heads=self.num_attention_heads, dtype=self.dtype, + weight_dtype=self.weight_dtype, head_dim=head_dim, num_gqa_groups=self.num_gqa_groups, transpose_batch_sequence=self.transpose_batch_sequence, @@ -1882,6 +1907,7 @@ def hidden_dropout(x, deterministic): y, ln_out = MultiHeadAttention( num_attention_heads=self.num_attention_heads, dtype=self.dtype, + weight_dtype=self.weight_dtype, head_dim=head_dim, num_gqa_groups=self.num_gqa_groups, transpose_batch_sequence=self.transpose_batch_sequence, @@ -1947,6 +1973,7 @@ def hidden_dropout(x, deterministic): intermediate_dropout_rate=self.intermediate_dropout, intermediate_hidden_dropout_dims=self.intermediate_dropout_dims, dtype=self.dtype, + weight_dtype=self.weight_dtype, scale_axes=(W_NO_SHARD_AXES,), ln_bias_axes=(W_NO_SHARD_AXES,), kernel_init=self.mlp_kernel_init, @@ -1996,6 +2023,7 @@ def hidden_dropout(x, deterministic): bias_axes=(W_NO_SHARD_AXES,), transpose_batch_sequence=self.transpose_batch_sequence, dtype=self.dtype, + weight_dtype=self.weight_dtype, name="output_layernorm", )(z)