Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spirv.ConstantOp does not support lowering from arith.ConstantOp with bf16 #657

Open
Stonepia opened this issue Jul 11, 2023 · 0 comments
Open

Comments

@Stonepia
Copy link

When lowering a Arith.ConstantOp with BF16 type, spirv will complain that:

%22 = "spirv.Constant"() <{value = 1.000000e+00 : bf16}> : () -> bf16

loc("-":4:12): error: 'spirv.Constant' op result #0 must be void or bool or 8/16/32/64-bit integer or 16/32/64-bit float or vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float values of length 2/3/4/8/16 or any SPIR-V pointer type or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V joint matrix type or any SPIR-V matrix type or any SPIR-V sampled image type, but got 'bf16'

A workaround is when lowering bf16, convert it to fp32, and then do the lowering, and convert it back.

Triton has done the workaround like below, but on spirv side, lowering the arith.ConstantOp is also needed.

https://github.com/intel/intel-xpu-backend-for-triton/blob/cbde09a65422e166cc69549a1e941b0207a43c49/lib/Conversion/TritonGPUToSPIRV/ViewOpToSPIRV.cpp#L73-L101

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant