Skip to content

Spirv.ConstantOp does not support lowering from arith.ConstantOp with bf16 #657

Open
@Stonepia

Description

@Stonepia

When lowering a Arith.ConstantOp with BF16 type, spirv will complain that:

%22 = "spirv.Constant"() <{value = 1.000000e+00 : bf16}> : () -> bf16

loc("-":4:12): error: 'spirv.Constant' op result #0 must be void or bool or 8/16/32/64-bit integer or 16/32/64-bit float or vector of bool or 8/16/32/64-bit integer or 16/32/64-bit float values of length 2/3/4/8/16 or any SPIR-V pointer type or any SPIR-V array type or any SPIR-V runtime array type or any SPIR-V struct type or any SPIR-V cooperative matrix type or any SPIR-V joint matrix type or any SPIR-V matrix type or any SPIR-V sampled image type, but got 'bf16'

A workaround is when lowering bf16, convert it to fp32, and then do the lowering, and convert it back.

Triton has done the workaround like below, but on spirv side, lowering the arith.ConstantOp is also needed.

https://github.com/intel/intel-xpu-backend-for-triton/blob/cbde09a65422e166cc69549a1e941b0207a43c49/lib/Conversion/TritonGPUToSPIRV/ViewOpToSPIRV.cpp#L73-L101

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions