Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dot to use preferred_element_type #39

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 21 additions & 15 deletions jax_triton/pallas/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from jax._src.state import primitives as sp
from jax._src.state import discharge
from jax._src.state import ShapedArrayRef
from jax_triton.triton_call import get_triton_python_ir
import jax.numpy as jnp
import triton
import triton.language as tl
Expand All @@ -42,6 +41,8 @@
import triton._C.libtriton.triton as _triton

import jax_triton as jt
from jax_triton.triton_call import get_triton_python_ir
from jax_triton.triton_call import get_triton_element_type
from jax_triton.pallas import primitives

map, unsafe_map = util.safe_map, map
Expand Down Expand Up @@ -247,13 +248,8 @@ def _convert_element_type_lowering_rule(ctx: TritonLoweringRuleContext, a, *,
new_dtype, weak_type):
if new_dtype == ctx.avals_in[0].dtype:
return a
if new_dtype == jnp.float32:
new_dtype = tl.float32
elif new_dtype == jnp.float16:
new_dtype = tl.float16
elif new_dtype == jnp.bfloat16:
new_dtype = tl.bfloat16
return tl.semantic.cast(a, new_dtype, ctx.builder)
triton_eltype = get_triton_element_type(ctx.avals_in[0].dtype)
return tl.semantic.cast(a, triton_eltype, ctx.builder)
triton_lowering_rules[jax.lax.convert_element_type_p] = _convert_element_type_lowering_rule

def max_lowering_rule(ctx: TritonLoweringRuleContext, a, b):
Expand Down Expand Up @@ -312,6 +308,7 @@ def _offset_ptr(ptr, idx: primitives.NDIndexer, shape, builder, is_scalar):
other_shape = indexer_shape[len(idx.int_indexer_shape):]
bcast_indices = []
other_shape_idx = 0
dest_shape = map(tl.constexpr, idx.get_indexer_shape())
for stride, index, dim_size, is_sc in zip(strides, indices, shape, is_scalar):
if isinstance(index, primitives.Slice):
index_size = index.size
Expand Down Expand Up @@ -344,12 +341,10 @@ def _offset_ptr(ptr, idx: primitives.NDIndexer, shape, builder, is_scalar):
index = tl.broadcast_to(index, desired_shape, _builder=builder)
else:
index = tl.reshape(index, desired_shape, _builder=builder)
if dest_shape != index.shape:
index = tl.broadcast_to(index, dest_shape, _builder=builder)
stride_size = tl.core._to_tensor(int(stride), builder)
bcast_indices.append(index.__mul__(stride_size, _builder=builder))
dest_shape = map(tl.constexpr, idx.get_indexer_shape())
bcast_indices = [
tl.broadcast_to(index, dest_shape, _builder=builder) if dest_shape != index.shape
else index for index in bcast_indices]
for bcast_idx in bcast_indices:
ptr = ptr.__add__(bcast_idx, _builder=builder)
return ptr
Expand Down Expand Up @@ -466,15 +461,26 @@ def _addupdate_lowering_rule(ctx: TritonLoweringRuleContext, ptr, value,

def _dot_general_lowering(ctx: TritonLoweringRuleContext, a, b, *,
dimension_numbers, precision, preferred_element_type):
if preferred_element_type is None:
preferred_element_type = ctx.avals_out[0].dtype
contract_dims, batch_dims = dimension_numbers
assert batch_dims == ((), ())
if batch_dims != ((), ()):
raise NotImplementedError("`batch_dims` currently unsupported.")
if len(contract_dims[0]) != 1 or len(contract_dims[1]) != 1:
raise NotImplementedError("Multiple contraction dimensions currently unsupported.")
a_contract_dim, = contract_dims[0]
b_contract_dim, = contract_dims[1]
trans_a = a_contract_dim == 0
trans_b = b_contract_dim == 1
allow_tf32 = precision == lax.Precision.HIGH or precision == lax.Precision.DEFAULT
return tl.dot(a, b, _builder=ctx.builder, trans_a=trans_a, trans_b=trans_b,
allow_tf32=allow_tf32)
out = tl.dot(a, b, _builder=ctx.builder, trans_a=trans_a, trans_b=trans_b,
allow_tf32=allow_tf32)
out_eltype = get_triton_element_type(preferred_element_type)
if out_eltype != out.dtype:
# `tl.dot` by default outputs f32 accumulation. We cast it to the dtype JAX
# wants.
out = tl.semantic.cast(out, out_eltype, ctx.builder)
return out
triton_lowering_rules[jax.lax.dot_general_p] = _dot_general_lowering

def _reduce_lowering(triton_op, ctx: TritonLoweringRuleContext, a, *, axes):
Expand Down
6 changes: 5 additions & 1 deletion jax_triton/pallas/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,11 @@ def store(x_ref, idx, val, *, mask=None, eviction_policy="") -> None:
def dot(a, b, trans_a=False, trans_b=False, allow_tf32=True):
rhs_contract_dim = int(trans_b)
lhs_contract_dim = int(not trans_a)
# `pl.dot`, like `tl.dot` does accumulation in f32.
preferred_element_type = None
if jnp.issubdtype(a.dtype, jnp.floating):
preferred_element_type = jnp.dtype("float32")
return jax.lax.dot_general(
a, b, dimension_numbers=(((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())),
precision=lax.Precision.HIGH if allow_tf32 else lax.Precision.HIGHEST,
preferred_element_type=None)
preferred_element_type=preferred_element_type)
55 changes: 33 additions & 22 deletions jax_triton/triton_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,30 +57,41 @@

triton_type_mappings = {}

def get_triton_type(obj: Any) -> str:
type_map = {
jnp.dtype("bfloat16"): "bf16",
jnp.dtype("float64"): "fp64",
jnp.dtype("float32"): "fp32",
jnp.dtype("float16"): "fp16",
# Triton has 'fp8' as well which Jax doesn't support yet.

jnp.dtype("int64"): "i64",
jnp.dtype("int32"): "i32",
jnp.dtype("int16"): "i16",
jnp.dtype("int8"): "i8",

jnp.dtype("uint64"): "u64",
jnp.dtype("uint32"): "u32",
jnp.dtype("uint16"): "u16",
jnp.dtype("uint8"): "u8",

# Triton defines a 'B' type, which is an alias for both i1 and bool.
jnp.dtype("bool"): "B",
}
_element_type_map = {
jnp.dtype("bfloat16"): (tl.bfloat16, "bf16"),
jnp.dtype("float64"): (tl.float64, "fp64"),
jnp.dtype("float32"): (tl.float32, "fp32"),
jnp.dtype("float16"): (tl.float16, "fp16"),
# Triton has 'fp8' as well which Jax doesn't support yet.

jnp.dtype("int64"): (tl.int64, "i64"),
jnp.dtype("int32"): (tl.int32, "i32"),
jnp.dtype("int16"): (tl.int16, "i16"),
jnp.dtype("int8"): (tl.int8, "i8"),

jnp.dtype("uint64"): (tl.uint64, "u64"),
jnp.dtype("uint32"): (tl.uint32, "u32"),
jnp.dtype("uint16"): (tl.uint16, "u16"),
jnp.dtype("uint8"): (tl.uint8, "u8"),

# Triton defines a 'B' type, which is an alias for both i1 and bool.
jnp.dtype("bool"): (tl.int32, "B"),
}

def get_triton_element_type(dtype: jnp.dtype) -> tl.dtype:
if dtype not in _element_type_map:
raise NotImplementedError(f"Unknown dtype: {dtype}")
return _element_type_map[dtype][0]

def get_triton_element_type_as_str(dtype: jnp.dtype) -> str:
if dtype not in _element_type_map:
raise NotImplementedError(f"Unknown dtype: {dtype}")
return _element_type_map[dtype][1]

def get_triton_type(obj: Any) -> str:
if isinstance(obj, (jax.core.ShapedArray, state.ShapedArrayRef)):
return f"*{type_map[obj.dtype]}"
eltype = get_triton_element_type_as_str(obj.dtype)
return f"*{eltype}"
if isinstance(obj, tl.constexpr):
obj = obj.value
if isinstance(obj, int):
Expand Down
8 changes: 4 additions & 4 deletions tests/pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def body(i, acc_ref):
jax.lax.broadcast_in_dim(idx_k, (bk, bn), (0,)),
jax.lax.broadcast_in_dim(idx_n, (bk, bn), (1,)))
x_block, y_block = x_ref[x_idx], y_ref[y_idx]
out = jnp.dot(x_block, y_block)
out = pl.dot(x_block, y_block)
acc_ref[:, :] += out
acc = for_loop(k // bk, body, acc).astype(o_ref.dtype)
o_idx = (
Expand All @@ -157,7 +157,7 @@ def body(i, acc_ref):
x = random.normal(k1, (m, k), dtype=dtype)
y = random.normal(k2, (k, n), dtype=dtype)
out, expected = matmul(x, y), jnp.matmul(x, y)
np.testing.assert_allclose(out, expected, atol=0.05, rtol=0.05)
np.testing.assert_allclose(out, expected, atol=0.03, rtol=0.03)

@parameterized.named_parameters(*(
dict(testcase_name=f"{size}_{dtype}", size=size, dtype=dtype)
Expand All @@ -177,13 +177,13 @@ def test_dot(self, size, dtype):
def dot(x_ref, y_ref, o_ref):
x = x_ref[:, :]
y = y_ref[:, :]
o_ref[:, :] = pl.dot(x, y)
o_ref[:, :] = pl.dot(x, y).astype(o_ref.dtype)

k1, k2 = random.split(random.PRNGKey(0))
x = random.normal(k1, (size, size), dtype=dtype)
y = random.normal(k2, (size, size), dtype=dtype)
out, expected = dot(x, y), jnp.dot(x, y)
np.testing.assert_allclose(out, expected, atol=0.05, rtol=0.05)
np.testing.assert_allclose(out, expected, atol=0.02, rtol=0.02)

@parameterized.named_parameters(*(
dict(testcase_name=f"{batch_size}_{size}_{block_size}_{dtype}",
Expand Down