Skip to content

Commit

Permalink
[JAX] Allow enabling partial custom calls through the environment var…
Browse files Browse the repository at this point in the history
…iable (NVIDIA#1007)

* Add enabled() to BasePrimitive

* Add layernorm/rmsnorm fallback

* Add cast_fp8 fallback

* Add transpose/cast_transpose XLA fall back

* Act_lu fallback

* Add transpose fallback

* Add softmax fallback

* Unify the use of _cast_fp8

* Add tests for NVTE_JAX_CUSTOM_CALLS_RE

---------

Signed-off-by: Reese Wang <[email protected]>
Co-authored-by: Phuong Nguyen <[email protected]>
  • Loading branch information
zlsh80826 and phu0ngng authored Jul 17, 2024
1 parent 210e57d commit 6c57926
Show file tree
Hide file tree
Showing 12 changed files with 369 additions and 47 deletions.
3 changes: 3 additions & 0 deletions qa/L0_jax_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ pip install pytest==8.2.1

pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed'

# Test without custom calls
NVTE_CUSTOM_CALLS_RE="" pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py

pip install -r $TE_PATH/examples/jax/mnist/requirements.txt
pip install -r $TE_PATH/examples/jax/encoder/requirements.txt

Expand Down
33 changes: 4 additions & 29 deletions tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper, is_fp8_available
from transformer_engine.jax.layernorm import layernorm, layernorm_fp8_dot
from transformer_engine.jax.layernorm_mlp import activation_lu, fused_layernorm_fp8_mlp
from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu
from transformer_engine.jax import cpp_extensions as tex


GEMM_CASES = [
(256, 256, 512),
(32, 32, 32),
Expand All @@ -34,21 +36,6 @@
is_fp8_supported, reason = is_fp8_available()


def _convert_to_activation_function(fn_or_string):
"""Convert a string to an activation function."""
if fn_or_string == "linear":
return lambda x: x
if fn_or_string == "quick_gelu":
return lambda x: nn.gelu(x, approximate=True)
if fn_or_string == "squared_relu":
return lambda x: functools.reduce(operator.mul, [nn.relu(x), nn.relu(x)])
if isinstance(fn_or_string, str):
return getattr(nn, fn_or_string)
if callable(fn_or_string):
return fn_or_string
raise ValueError(f"don't know how to convert {fn_or_string} to an activation function")


class TestFP8Dot:

@staticmethod
Expand Down Expand Up @@ -293,14 +280,7 @@ def layernorm_fp8_mlp_ref(
bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape
linear_1_out += jnp.reshape(bias_1, bias_1_shape)

x = jnp.split(linear_1_out, len(activation_type), axis=-2)
acts = []
for idx, act_fn in enumerate(activation_type):
x_i = _convert_to_activation_function(act_fn)(x[idx])
acts.append(x_i)
x = functools.reduce(operator.mul, acts)

x = jnp.asarray(jnp.squeeze(x, axis=-2), jnp.bfloat16)
x = _jax_act_lu(linear_1_out, activation_type)

fp8_meta_pkg_2 = FP8MetaPackage(
amax_list_2[0],
Expand Down Expand Up @@ -443,12 +423,7 @@ class TestActivationLu:
def ref_func(self, x, activation_type):

def ref_act_lu(inputs):
x = jnp.split(inputs, len(activation_type), axis=-2)
acts = []
for idx, act_fn in enumerate(activation_type):
x_i = _convert_to_activation_function(act_fn)(x[idx])
acts.append(x_i)
x = functools.reduce(operator.mul, acts)
x = _jax_act_lu(inputs, activation_type)
return jnp.mean(x)

ref_act_func = jit(value_and_grad(ref_act_lu, (0,)))
Expand Down
6 changes: 2 additions & 4 deletions tests/jax/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,12 @@ def grad_func(func, *args, **kwargs):

# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
jitted_primitive = jit(
value_and_grad(
lambda logits, *args: grad_func(softmax, self.logits, *args, **kwargs), (0,)
)
value_and_grad(lambda logits, *args: grad_func(softmax, logits, *args, **kwargs), (0,))
)
jitted_reference = jit(
value_and_grad(
lambda logits, *args: grad_func(
__class__.reference_softmax, self.logits, *args, **kwargs
__class__.reference_softmax, logits, *args, **kwargs
),
(0,),
)
Expand Down
46 changes: 45 additions & 1 deletion transformer_engine/jax/cpp_extensions/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
"""JAX/TE custom ops for activation"""
from typing import Tuple, Sequence, Union, Callable
import operator
from functools import reduce
from functools import reduce, partial

import jax
import jax.numpy as jnp
from jax import core, dtypes
from jax.interpreters.mlir import ir
Expand All @@ -22,6 +23,7 @@
jax_dtype_to_ir_dtype,
get_padded_spec,
)
from .quantization import _jax_cast_fp8
from ..sharding import all_reduce_max_along_all_axes_except_PP


Expand All @@ -42,6 +44,35 @@
}


def _convert_to_activation_function(fn_or_string):
"""Convert a string to an activation function."""
if fn_or_string == "linear":
return lambda x: x
if fn_or_string == "quick_gelu":
return lambda x: jax.nn.sigmoid(1.702 * x) * x
if fn_or_string == "squared_relu":
return lambda x: reduce(operator.mul, [jax.nn.relu(x), jax.nn.relu(x)])
if isinstance(fn_or_string, str):
return getattr(jax.nn, fn_or_string)
if callable(fn_or_string):
return fn_or_string
raise ValueError(f"Unsupported {fn_or_string} to an activation function")


def _jax_act_lu(inputs, activation_type):
"""
JAX native activation implementation
"""
x = jnp.split(inputs, len(activation_type), axis=-2)
acts = []
for idx, act_fn in enumerate(activation_type):
x_i = _convert_to_activation_function(act_fn)(x[idx])
acts.append(x_i)
x = reduce(operator.mul, acts)
x = jnp.squeeze(x, axis=-2)
return x


class ActLuPrimitive(BasePrimitive):
"""
Activation Forward Primitive
Expand Down Expand Up @@ -155,6 +186,9 @@ def act_lu(inputs: jnp.ndarray, activation_type: Sequence[Union[str, Callable]])
Input shape: (N, 1, H) for non-gated activations
(N, 2, H) for gated activations
"""
if not ActLuPrimitive.enabled():
return _jax_act_lu(inputs, activation_type)

act_type_id = ActivationEnum[activation_type]
return ActLuPrimitive.outer_primitive.bind(inputs, act_enum=act_type_id)

Expand Down Expand Up @@ -286,6 +320,11 @@ def dact_lu(
dact_lu fusion wrapper
Return dgated_act_lu(inputs)
"""

if not DActLuPrimitive.enabled():
_, vjp_func = jax.vjp(partial(_jax_act_lu, activation_type=activation_type), act_lu_inputs)
return vjp_func(inputs)[0]

act_type_id = ActivationEnum[activation_type]
return DActLuPrimitive.outer_primitive.bind(inputs, act_lu_inputs, act_enum=act_type_id)

Expand Down Expand Up @@ -443,6 +482,11 @@ def act_lu_fp8(
Input shape: (N, 1, H) for non-gated activations
(N, 2, H) for gated activations
"""
if not ActLuFp8Primitive.enabled():
act_lu_output = _jax_act_lu(x, activation_type)
casted_output, updated_amax = _jax_cast_fp8(act_lu_output, scale, amax, out_dtype)
return casted_output, updated_amax

act_type_id = ActivationEnum[activation_type]
return ActLuFp8Primitive.outer_primitive.bind(
x, amax, scale, scale_inv, out_dtype=out_dtype, act_enum=act_type_id
Expand Down
17 changes: 17 additions & 0 deletions transformer_engine/jax/cpp_extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# See LICENSE for license information.
"""JAX/TE base custom ops"""
import os
import re
from abc import ABCMeta, abstractmethod
from functools import partial

Expand All @@ -17,6 +19,21 @@ class BasePrimitive(metaclass=ABCMeta):
jax primitive
"""

name = None

@classmethod
def enabled(cls):
"""
A custom call is marked as disabled if the `cls.name` does not fully match the
`NVTE_JAX_CUSTOM_CALLS_RE` pattern.
By default, `NVTE_JAX_CUSTOM_CALLS_RE` is set to `.*`, which matches and enables all names.
For example, set `NVTE_JAX_CUSTOM_CALLS_RE='^(?!te_act_lu$).+$'` to disable `te_act_lu`.
"""
pattern = os.getenv("NVTE_JAX_CUSTOM_CALLS_RE", r".*")
pattern = re.compile(pattern)
is_enabled = pattern.fullmatch(cls.name) is not None
return is_enabled

@staticmethod
@abstractmethod
def abstract():
Expand Down
102 changes: 102 additions & 0 deletions transformer_engine/jax/cpp_extensions/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import warnings

import jax
import jax.numpy as jnp
from jax import core, dtypes
from jax.interpreters import mlir
Expand All @@ -25,6 +26,7 @@
jax_dtype_to_ir_dtype,
te_dtype_to_jax_dtype,
)
from .quantization import _jax_cast_fp8
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp


Expand Down Expand Up @@ -239,12 +241,77 @@ def partition(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos):
register_primitive(LayerNormFwdPrimitive)


def _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps):
"""
JAX native layernorm implementation
"""
x_ = jnp.asarray(x, jnp.float32)
mean = jnp.mean(x_, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps)
if zero_centered_gamma:
gamma += 1.0
return jnp.asarray(normed_input * gamma + beta).astype(x.dtype)


def _jax_rmsnorm(x, gamma, zero_centered_gamma, eps):
"""
JAX native rmsnorm implementation
"""
x_ = jnp.asarray(x, jnp.float32)
var = jnp.mean(jnp.square(x_), axis=-1, keepdims=True)
normed_input = x_ * jax.lax.rsqrt(var + eps)
if zero_centered_gamma:
gamma += 1.0
return jnp.asarray(normed_input * gamma).astype(x.dtype)


def _jax_layernorm_fp8(x, gamma, beta, scale, amax, out_dtype, zero_centered_gamma, eps):
"""
JAX native layernorm fp8 implementation
"""
x_ = jnp.asarray(x, jnp.float32)
mean = jnp.mean(x_, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
rsigma = jax.lax.rsqrt(var + eps)
normed_input = (x_ - mean) * rsigma
if zero_centered_gamma:
gamma += 1.0
output = normed_input * gamma + beta
casted_output, updated_amax = _jax_cast_fp8(output, scale, amax, out_dtype=out_dtype)
return casted_output, jnp.squeeze(mean, axis=-1), jnp.squeeze(rsigma, axis=-1), updated_amax


def _jax_rmsnorm_fp8(x, gamma, scale, amax, out_dtype, zero_centered_gamma, eps):
"""
JAX native rmsnorm fp8 implementation
"""
x_ = jnp.asarray(x, jnp.float32)
var = jnp.mean(jnp.square(x_), axis=-1, keepdims=True)
rsigma = jax.lax.rsqrt(var + eps)
normed_input = x_ * rsigma
if zero_centered_gamma:
gamma += 1.0
output = normed_input * gamma
casted_output, updated_amax = _jax_cast_fp8(output, scale, amax, out_dtype=out_dtype)
return casted_output, jnp.squeeze(rsigma, axis=-1), updated_amax


def layernorm_fwd(
x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, zero_centered_gamma: bool, epsilon: float
):
"""
Wrapper for TE layernorm fwd
"""
if not LayerNormFwdPrimitive.enabled():
x_ = jnp.asarray(x, jnp.float32)
mu = jnp.mean(x_, axis=-1, keepdims=True)
rsigma = jax.lax.rsqrt(jnp.mean(jnp.square(x_ - mu), axis=-1, keepdims=True) + epsilon)
return (
_jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon),
jnp.squeeze(mu, axis=-1),
jnp.squeeze(rsigma, axis=-1),
)
return LayerNormFwdPrimitive.outer_primitive.bind(
x, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon
)
Expand Down Expand Up @@ -468,12 +535,21 @@ def layernorm_bwd(
mu: jnp.ndarray,
rsigma: jnp.ndarray,
gamma: jnp.ndarray,
beta: jnp.ndarray,
zero_centered_gamma: bool,
epsilon: float,
):
"""
Wrapper for TE layernorm bwd
"""
if not LayerNormBwdPrimitive.enabled():
_, vjp_func = jax.vjp(
partial(_jax_layernorm, zero_centered_gamma=zero_centered_gamma, eps=epsilon),
x,
gamma,
beta,
)
return vjp_func(dz)
return LayerNormBwdPrimitive.outer_primitive.bind(
dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon
)
Expand Down Expand Up @@ -655,6 +731,12 @@ def rmsnorm_fwd(x: jnp.ndarray, gamma: jnp.ndarray, epsilon: float):
"""
Wrapper for TE rmsnorm fwd
"""
if not RmsNormFwdPrimitive.enabled():
x_ = jnp.asarray(x, jnp.float32)
rsigma = jax.lax.rsqrt(jnp.mean(jnp.square(x_), axis=-1, keepdims=True) + epsilon)
return _jax_rmsnorm(x, gamma, zero_centered_gamma=False, eps=epsilon), jnp.squeeze(
rsigma, axis=-1
)
return RmsNormFwdPrimitive.outer_primitive.bind(x, gamma, epsilon=epsilon)


Expand Down Expand Up @@ -852,6 +934,11 @@ def rmsnorm_bwd(
"""
Wrapper for TE layernorm bwd
"""
if not RmsNormBwdPrimitive.enabled():
_, vjp_func = jax.vjp(
partial(_jax_rmsnorm, zero_centered_gamma=False, eps=epsilon), x, gamma
)
return vjp_func(dz)
return RmsNormBwdPrimitive.outer_primitive.bind(dz, x, rsigma, gamma, epsilon=epsilon)


Expand Down Expand Up @@ -1148,6 +1235,17 @@ def layernorm_fwd_fp8(
"""
Wrapper for TE layernorm fwd (fp8 out)
"""
if not LayerNormFwdFp8Primitive.enabled():
return _jax_layernorm_fp8(
x,
gamma,
beta,
scale,
amax,
out_dtype=out_dtype,
zero_centered_gamma=zero_centered_gamma,
eps=epsilon,
)
return LayerNormFwdFp8Primitive.outer_primitive.bind(
x,
gamma,
Expand Down Expand Up @@ -1387,6 +1485,10 @@ def rmsnorm_fwd_fp8(
"""
Wrapper for TE rmsnorm fwd (fp8 out)
"""
if not RmsNormFwdFp8Primitive.enabled():
return _jax_rmsnorm_fp8(
x, gamma, scale, amax, out_dtype=out_dtype, zero_centered_gamma=False, eps=epsilon
)
return RmsNormFwdFp8Primitive.outer_primitive.bind(
x, gamma, amax, scale, scale_inv, out_dtype=out_dtype, epsilon=epsilon
)
Loading

0 comments on commit 6c57926

Please sign in to comment.