Closed
Description
Hi,
I'm getting error RuntimeError: Cannot call @triton.jit'd outside of the scope of a kernel
when running the code block
def sum_spec(x: Float32[Tensor, "4 200"]) -> Float32[Tensor, "4"]:
return x.sum(1)
@triton.jit
def sum_kernel(x_ptr, z_ptr, N0, N1, T, B0: tl.constexpr, B1: tl.constexpr):
pid_0 = tl.program_id(0)
offsets_x = tl.arange(0, B0) + pid_0 * B0
mask_x = offsets_x < N0
#### don't error out
# z = tl.load(z_ptr + offsets_x, mask=mask_x)
# z = 0
#### error out
z = tl.zeros((B0,), dtype=tl.float32)
for i in range(0, T, B1):
offsets_y = tl.arange(0, B1) + i
mask = (offsets_x[:, None] < N0) & (offsets_y[None, :] < T)
x = tl.load(x_ptr + offsets_x[:, None] * T + offsets_y[None, :], mask=mask, other=0.0)
z += tl.sum(x, axis=1)
tl.store(z_ptr + offsets_x, z, mask=mask_x)
return
test(sum_kernel, sum_spec, B={"B0": 1, "B1": 32}, nelem={"N0": 4, "N1": 32, "T": 200})
and it errors out at the line when creating tl.zeros