diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index b8c4cc00b85b..d353fe6c1935 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -4,6 +4,10 @@ using namespace mlir; using namespace mlir::triton; using ::mlir::triton::gpu::getTotalElemsPerThread; +typedef std::function(Location, ConversionPatternRewriter &, + const SmallVector &)> + ConverterT; + /* ----- FP8E5M2 ------ */ // This data-type is the standard FP8E5M2 format #ifdef USE_ROCM @@ -50,8 +54,93 @@ static const std::string Fp16_to_Fp8E5M2(bool hasNativeFP) { } #endif +// ROCM utility functions for data type conversion +#ifdef USE_ROCM +static Value cvtFp16ToFp32(Location loc, + ConversionPatternRewriter &rewriter, + const Value &v) { + GCNBuilder builder; + auto &cvt = *builder.create("v_cvt_f32_f16"); + auto res = builder.newOperand("=v"); + auto operand = builder.newOperand(v, "v"); + cvt(res, operand); + return builder.launch(rewriter, loc, f32_ty, false); +} + +static Value cvtFp32ToFp16(Location loc, + ConversionPatternRewriter &rewriter, + const Value &v) { + GCNBuilder builder; + auto &cvt = *builder.create("v_cvt_f16_f32"); + auto res = builder.newOperand("=v"); + auto operand = builder.newOperand(v, "v"); + cvt(res, operand); + return builder.launch(rewriter, loc, f16_ty, false); +} + +static SmallVector convert_val_Fp16_to_Fp8( + Location loc, ConversionPatternRewriter &rewriter, + Value v0, Value v1, const std::string& fp8_format) { + assert(fp8_format == "fp8" or fp8_format == "bf8"); + std::string ins_str = "v_cvt_pk_" + fp8_format + "_f32"; + + auto f32_0 = cvtFp16ToFp32(loc, rewriter, v0); + auto f32_1 = cvtFp16ToFp32(loc, rewriter, v1); + + GCNBuilder builder; + auto &cvt = *builder.create(ins_str); + auto res = builder.newOperand("=v"); + auto operand0 = builder.newOperand(f32_0, "v"); + auto operand1 = builder.newOperand(f32_1, "v"); + cvt(res, operand0, operand1); + auto fp8x4Vec = builder.launch(rewriter, loc, i32_ty, false); + + auto fp8x4VecTy = vec_ty(i8_ty, 4); + auto a1 = bitcast(fp8x4Vec, fp8x4VecTy); + + SmallVector ret(2); + ret[0] = extract_element(i8_ty, a1, i32_val(0)); + ret[1] = extract_element(i8_ty, a1, i32_val(1)); + + return ret; +} + +static SmallVector convert_val_Fp8_to_Fp16( + Location loc, ConversionPatternRewriter &rewriter, + Value v0, Value v1, const std::string& fp8_format) { + assert(fp8_format == "fp8" or fp8_format == "bf8"); + std::string ins_str = "v_cvt_pk_f32_" + fp8_format; + + auto fp8x4VecTy = vec_ty(i8_ty, 4); + Value fp8x4Vec = undef(fp8x4VecTy); + fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v0, i32_val(0)); + fp8x4Vec = insert_element(fp8x4VecTy, fp8x4Vec, v1, i32_val(1)); + auto i32v = bitcast(fp8x4Vec, i32_ty); + + GCNBuilder builder1; + auto &cvt = *builder1.create(ins_str); + auto res = builder1.newOperand("=v"); + auto operand = builder1.newOperand(i32v, "v"); + cvt(res, operand); + auto i64v = builder1.launch(rewriter, loc, i64_ty, false); + auto fp32x2VecTy = vec_ty(f32_ty, 2); + auto fp32x2Vec = bitcast(i64v, fp32x2VecTy); + + auto f32_0 = extract_element(f32_ty, fp32x2Vec, i32_val(0)); + auto f32_1 = extract_element(f32_ty, fp32x2Vec, i32_val(1)); + + SmallVector ret(2); + ret[0] = cvtFp32ToFp16(loc, rewriter, f32_0); + ret[1] = cvtFp32ToFp16(loc, rewriter, f32_1); + + return ret; +} +#endif + #ifdef USE_ROCM -static Value convert_val_Fp16_to_Fp8E5M2FNUZ( +// Depend on whether we focus more on performance, we may skip +// the processing of submornal values +static Value Fp16_to_Fp8E5M2FNUZ_oneValue( Location loc, ConversionPatternRewriter &rewriter, Value v) { auto vi16 = bitcast(v, i16_ty); auto e = and_(i16_ty, vi16, int_val(16, 0x7C00)); @@ -79,28 +168,26 @@ static Value convert_val_Fp16_to_Fp8E5M2FNUZ( } static SmallVector -Fp16_to_Fp8E5M2FNUZ(Location loc, ConversionPatternRewriter &rewriter, +Fp16_to_Fp8E5M2FNUZ_SW(Location loc, ConversionPatternRewriter &rewriter, const SmallVector &v) { - SmallVector result(4); - result[0] = convert_val_Fp16_to_Fp8E5M2FNUZ(loc, rewriter, v[0]); - result[1] = convert_val_Fp16_to_Fp8E5M2FNUZ(loc, rewriter, v[1]); - result[2] = convert_val_Fp16_to_Fp8E5M2FNUZ(loc, rewriter, v[2]); - result[3] = convert_val_Fp16_to_Fp8E5M2FNUZ(loc, rewriter, v[3]); - + SmallVector result(2); + result[0] = Fp16_to_Fp8E5M2FNUZ_oneValue(loc, rewriter, v[0]); + result[1] = Fp16_to_Fp8E5M2FNUZ_oneValue(loc, rewriter, v[1]); return result; } -#else -const std::string Fp16_to_Fp8E5M2FNUZ = - "{ \n" - ".reg .b32 a<2>; \n" - "and.b32 a0, $1, 0xfffefffe; \n" // a0 &= 0xfffefffe - "and.b32 a1, $2, 0xfffefffe; \n" // (strip lowest bit) - "add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080 - "add.u32 a1, a1, 0x00800080; \n" // (round to nearest) - "prmt.b32 $0, a0, a1, 0x7531; \n\t" // output = a1a0 - "}"; + +static SmallVector Fp16_to_Fp8E5M2FNUZ_HW( + Location loc, ConversionPatternRewriter &rewriter, + const SmallVector& v) { + return convert_val_Fp16_to_Fp8(loc, rewriter, v[0], v[1], "bf8"); +} + +ConverterT Fp16_to_Fp8E5M2FNUZ(int computeCapability) { + return computeCapability >= 300 ? Fp16_to_Fp8E5M2FNUZ_HW : Fp16_to_Fp8E5M2FNUZ_SW; +} #endif + #ifdef USE_ROCM static SmallVector Fp8E5M2_to_Fp16(Location loc, ConversionPatternRewriter &rewriter, @@ -145,8 +232,7 @@ static const std::string Fp8E5M2_to_Fp16(bool hasNativeFP) { #endif #ifdef USE_ROCM - -static Value convert_val_Fp8E5M2FNUZ_to_Fp16( +static Value Fp8E5M2FNUZ_to_Fp16_oneValue( Location loc, ConversionPatternRewriter &rewriter, Value v) { auto fp8x2VecTy = vec_ty(i8_ty, 2); Value a = undef(fp8x2VecTy); @@ -181,16 +267,22 @@ static Value convert_val_Fp8E5M2FNUZ_to_Fp16( } static SmallVector -Fp8E5M2FNUZ_to_Fp16(Location loc, ConversionPatternRewriter &rewriter, +Fp8E5M2FNUZ_to_Fp16_SW(Location loc, ConversionPatternRewriter &rewriter, const SmallVector &v) { + SmallVector result(2); + result[0] = Fp8E5M2FNUZ_to_Fp16_oneValue(loc, rewriter, v[0]); + result[1] = Fp8E5M2FNUZ_to_Fp16_oneValue(loc, rewriter, v[1]); + return result; +} - SmallVector result(4); - result[0] = convert_val_Fp8E5M2FNUZ_to_Fp16(loc, rewriter, v[0]); - result[1] = convert_val_Fp8E5M2FNUZ_to_Fp16(loc, rewriter, v[1]); - result[2] = convert_val_Fp8E5M2FNUZ_to_Fp16(loc, rewriter, v[2]); - result[3] = convert_val_Fp8E5M2FNUZ_to_Fp16(loc, rewriter, v[3]); +static SmallVector +Fp8E5M2FNUZ_to_Fp16_HW(Location loc, ConversionPatternRewriter &rewriter, + const SmallVector &v) { + return convert_val_Fp8_to_Fp16(loc, rewriter, v[0], v[1], "bf8"); +} - return result; +ConverterT Fp8E5M2FNUZ_to_Fp16(int computeCapability) { + return (computeCapability >= 300) ? Fp8E5M2FNUZ_to_Fp16_HW : Fp8E5M2FNUZ_to_Fp16_SW; } #endif @@ -655,7 +747,7 @@ static const std::string Fp16_to_Fp8E4M3B15x4 = // more than a single NaN values. #ifdef USE_ROCM -static Value convert_val_Fp8E4M3FNUZ_to_Fp16( +static Value Fp8E4M3FNUZ_to_Fp16_oneValue( Location loc, ConversionPatternRewriter &rewriter, Value v) { auto fp8x2VecTy = vec_ty(i8_ty, 2); Value a = undef(fp8x2VecTy); @@ -686,37 +778,30 @@ static Value convert_val_Fp8E4M3FNUZ_to_Fp16( return bitcast(io, f16_ty); } -// Fp8E4M3FNUZ -> Fp16 (packed) static SmallVector -Fp8E4M3FNUZ_to_Fp16(Location loc, ConversionPatternRewriter &rewriter, +Fp8E4M3FNUZ_to_Fp16_SW(Location loc, ConversionPatternRewriter &rewriter, const SmallVector &v) { SmallVector result(2); - result[0] = convert_val_Fp8E4M3FNUZ_to_Fp16(loc, rewriter, v[0]); - result[1] = convert_val_Fp8E4M3FNUZ_to_Fp16(loc, rewriter, v[1]); - + result[0] = Fp8E4M3FNUZ_to_Fp16_oneValue(loc, rewriter, v[0]); + result[1] = Fp8E4M3FNUZ_to_Fp16_oneValue(loc, rewriter, v[1]); return result; } -#else -const std::string Fp8E4M3FNUZ_to_Fp16 = - "{ \n" - ".reg .b32 a<2>, b<2>; \n" // if input = 0xf1f2f3f4 - "prmt.b32 a0, 0, $2, 0x5040; \n" // a0 = 0xf300f400 - "prmt.b32 a1, 0, $2, 0x7060; \n" // a1 = 0xf100f200 - "lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n" // b0 = a0 & 0x7fff7fff - "lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign) - "shr.b32 b0, b0, 1; \n" // b0 >>= 1 - "shr.b32 b1, b1, 1; \n" // shift into fp16 position - "add.u32 b0, b0, 0x20002000; \n" // b0.exp += 2**4-2**3 - // exponent compensate = 8 - "add.u32 b1, b1, 0x20002000; \n" // b1 += 8<<10 | 8<<10<<16 - "lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n" // out0 = b0|(0x80008000&a0) - "lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n" // (restore sign) - "}"; + +static SmallVector +Fp8E4M3FNUZ_to_Fp16_HW(Location loc, ConversionPatternRewriter &rewriter, + const SmallVector &v) { + return convert_val_Fp8_to_Fp16(loc, rewriter, v[0], v[1], "fp8"); +} + +static ConverterT +Fp8E4M3FNUZ_to_Fp16(int computeCapability) { + return computeCapability >= 300 ? Fp8E4M3FNUZ_to_Fp16_HW : Fp8E4M3FNUZ_to_Fp16_SW; +} #endif // Fp16 -> Fp8E4M3 (packed) #ifdef USE_ROCM -static Value convert_val_Fp16_to_Fp8E4M3FNUZ( +static Value Fp16_to_Fp8E4M3FNUZ_oneValue( Location loc, ConversionPatternRewriter &rewriter, Value v) { auto vi16 = bitcast(v, i16_ty); auto e10 = and_(vi16, int_val(16, 0x7C00)); @@ -749,33 +834,25 @@ static Value convert_val_Fp16_to_Fp8E4M3FNUZ( } static SmallVector -Fp16_to_Fp8E4M3FNUZ(Location loc, ConversionPatternRewriter &rewriter, +Fp16_to_Fp8E4M3FNUZ_SW(Location loc, ConversionPatternRewriter &rewriter, const SmallVector &v) { - SmallVector result(2); - result[0] = convert_val_Fp16_to_Fp8E4M3FNUZ(loc, rewriter, v[0]); - result[1] = convert_val_Fp16_to_Fp8E4M3FNUZ(loc, rewriter, v[1]); + result[0] = Fp16_to_Fp8E4M3FNUZ_oneValue(loc, rewriter, v[0]); + result[1] = Fp16_to_Fp8E4M3FNUZ_oneValue(loc, rewriter, v[1]); return result; } -#else -const std::string Fp16_to_Fp8E4M3FNUZ = - "{ \n" - ".reg .b32 a<2>, b<2>; \n" // see Fp8E4M3x4ToFp16x4 - "sub.u32 a0, $1, 0x20002000; \n" // a0 = input0 - 0x20002000 - // (compensate offset) - "sub.u32 a1, $2, 0x20002000; \n" // a1 = input1 - 0x20002000 - // (8 << 10 | 8 << 10 << 16) - "shl.b32 a0, a0, 1; \n" // a0 <<= 1 - "shl.b32 a1, a1, 1; \n" // shift into fp8e4 position - "lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n" // a0 &= 0x7fff7fff - "lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n" // (strip sign) - "add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080 - "add.u32 a1, a1, 0x00800080; \n" // (round to nearest) - "lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n" // b0 = a0|(0x80008000&in0) - "lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n" // (restore sign) - "prmt.b32 $0, b0, b1, 0x7531; \n" // output = b1b0 - "}"; + +static SmallVector +Fp16_to_Fp8E4M3FNUZ_HW(Location loc, ConversionPatternRewriter &rewriter, + const SmallVector &v) { + return convert_val_Fp16_to_Fp8(loc, rewriter, v[0], v[1], "fp8"); +} + +static ConverterT +Fp16_to_Fp8E4M3FNUZ(int computeCapability) { + return computeCapability >= 300 ? Fp16_to_Fp8E4M3FNUZ_HW : Fp16_to_Fp8E4M3FNUZ_SW; +} #endif // WARN: subnormal (0bs0000xxx) are not handled @@ -1144,10 +1221,6 @@ inline SmallVector packI32(const SmallVector &inValues, return outValues; } -typedef std::function(Location, ConversionPatternRewriter &, - const SmallVector &)> - ConverterT; - static ConverterT makeConverterFromPtx(const std::string &ptxAsm, Type inType, Type outType, const int inVecWidthBits = 32, @@ -1351,12 +1424,7 @@ struct FpToFpOpConversion ConversionPatternRewriter &rewriter, const Value &v) { #ifdef USE_ROCM - GCNBuilder builder; - auto &cvt = *builder.create("v_cvt_f32_f16"); - auto res = builder.newOperand("=v"); - auto operand = builder.newOperand(v, "v"); - cvt(res, operand); - return builder.launch(rewriter, loc, f32_ty, false); + return cvtFp16ToFp32(loc, rewriter, v); #else PTXBuilder builder; auto &cvt = *builder.create("cvt.f32.f16"); @@ -1402,12 +1470,7 @@ struct FpToFpOpConversion ConversionPatternRewriter &rewriter, const Value &v) { #ifdef USE_ROCM - GCNBuilder builder; - auto &cvt = *builder.create("v_cvt_f16_f32"); - auto res = builder.newOperand("=v"); - auto operand = builder.newOperand(v, "v"); - cvt(res, operand); - return builder.launch(rewriter, loc, f16_ty, false); + return cvtFp32ToFp16(loc, rewriter, v); #else PTXBuilder builder; auto &cvt = *builder.create("cvt.rn.f16.f32"); @@ -1420,7 +1483,11 @@ struct FpToFpOpConversion ConverterT getConversionFunc(Type srcTy, Type dstTy) const { auto F8E4M3B15TyID = TypeID::get(); +#ifdef USE_ROCM auto F8E4M3FNUZTyID = TypeID::get(); +#else + auto F8E4M3TyID = TypeID::get(); +#endif auto F8E4M3FNTyID = TypeID::get(); auto F8E5M2TyID = TypeID::get(); auto F8E5M2FNUZTyID = TypeID::get(); @@ -1436,37 +1503,37 @@ struct FpToFpOpConversion // F8 -> F16 {{F8E4M3B15TyID, F16TyID}, Fp8E4M3B15_to_Fp16}, {{F8E4M3FNTyID, F16TyID}, Fp8E4M3B15x4_to_Fp16}, - {{F8E4M3FNUZTyID, F16TyID}, Fp8E4M3FNUZ_to_Fp16}, #ifdef USE_ROCM + {{F8E4M3FNUZTyID, F16TyID}, Fp8E4M3FNUZ_to_Fp16(computeCapability)}, + {{F8E5M2FNUZTyID, F16TyID}, Fp8E5M2FNUZ_to_Fp16(computeCapability)}, {{F8E5M2TyID, F16TyID}, Fp8E5M2_to_Fp16}, - {{F8E5M2FNUZTyID, F16TyID}, Fp8E5M2FNUZ_to_Fp16}, #else {{F8E4M3TyID, F16TyID}, Fp8E4M3Nv_to_Fp16}, {{F8E5M2TyID, F16TyID}, Fp8E5M2_to_Fp16(computeCapability >= 90)}, #endif + // F16 -> F8 + {{F16TyID, F8E4M3FNTyID}, Fp16_to_Fp8E4M3B15x4}, #ifdef USE_ROCM {{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15}, -#else - {{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15(computeCapability >= 80)}, -#endif - {{F16TyID, F8E4M3FNTyID}, Fp16_to_Fp8E4M3B15x4}, - {{F16TyID, F8E4M3FNUZTyID}, Fp16_to_Fp8E4M3FNUZ}, -#ifdef USE_ROCM + {{F16TyID, F8E5M2FNUZTyID}, Fp16_to_Fp8E5M2FNUZ(computeCapability)}, + {{F16TyID, F8E4M3FNUZTyID}, Fp16_to_Fp8E4M3FNUZ(computeCapability)}, {{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2}, - {{F16TyID, F8E5M2FNUZTyID}, Fp16_to_Fp8E5M2FNUZ}, #else + {{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15(computeCapability >= 80)}, {{F16TyID, F8E4M3TyID}, Fp16_to_Fp8E4M3Nv}, {{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2(computeCapability >= 90)}, #endif - // F8 -> BF16 + + // F8 -> BF16 #ifdef USE_ROCM - {{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16}, + {{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16}, #else - {{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16(computeCapability >= 90)}, + {{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16(computeCapability >= 90)}, {{F8E4M3TyID, BF16TyID}, Fp8E4M3Nv_to_Bf16}, #endif - // BF16 -> F8 + + // BF16 -> F8 #ifdef USE_ROCM {{BF16TyID, F8E5M2TyID}, Bf16_to_Fp8E5M2}, #else @@ -1477,6 +1544,16 @@ struct FpToFpOpConversion {{F32TyID, F8E5M2TyID}, Fp32_to_Fp8E5M2}, #endif }; + + std::pair key = {srcTy.getTypeID(), dstTy.getTypeID()}; + if (srcMap.count(key) == 0) { + llvm::errs() << "Unsupported conversion from " << srcTy << " to " << dstTy + << "\n"; + llvm_unreachable(""); + } +#ifdef USE_ROCM + return srcMap.lookup(key); +#else int inVecWidthBits = 32; int outVecWidthBits = 32; if (srcTy.isFloat8E4M3FNUZ() || @@ -1490,15 +1567,6 @@ struct FpToFpOpConversion outVecWidthBits = 16; } - std::pair key = {srcTy.getTypeID(), dstTy.getTypeID()}; - if (srcMap.count(key) == 0) { - llvm::errs() << "Unsupported conversion from " << srcTy << " to " << dstTy - << "\n"; - llvm_unreachable(""); - } -#ifdef USE_ROCM - return srcMap.lookup(key); -#else if (computeCapability < 90 && (srcTy.isFloat8E4M3FNUZ() || dstTy.isFloat8E4M3FNUZ())) { llvm::errs() << "Conversion from/to f8e4m3nv is only supported on " @@ -1523,14 +1591,24 @@ struct FpToFpOpConversion size_t numElements = 4; if (srcElementType.isFloat8E4M3FNUZ() || dstElementType.isFloat8E4M3FNUZ() || +#ifdef USE_ROCM + srcElementType.isFloat8E5M2FNUZ() || + dstElementType.isFloat8E5M2FNUZ()) +#else (computeCapability >= 90 && - (srcElementType.isFloat8E5M2() || dstElementType.isFloat8E5M2()))) { + (srcElementType.isFloat8E5M2() || dstElementType.isFloat8E5M2()))) +#endif + { numElements = 2; } bool useFP16IntermediateSrc = +#ifdef USE_ROCM + srcElementType.isF32(); +#else srcElementType.isF32() && !(computeCapability >= 90 && (dstElementType.isFloat8E4M3FNUZ() || dstElementType.isFloat8E5M2())); +#endif bool isDstFP32 = dstElementType.isF32(); auto cvtFunc = getConversionFunc(useFP16IntermediateSrc ? f16_ty : srcElementType, diff --git a/python/test/unit/language/test_core_amd.py b/python/test/unit/language/test_core_amd.py index 18bd831f3164..bdc3ad6ca908 100644 --- a/python/test/unit/language/test_core_amd.py +++ b/python/test/unit/language/test_core_amd.py @@ -1075,15 +1075,16 @@ def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): if TORCH_HAS_FP8E4B8: tl_to_torch_types[tl.float8e4b8] = torch.float8_e4m3fnuz -@triton.jit -def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): - offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - input = tl.load(input_ptr + offsets, mask=mask) - output = input - tl.store(output_ptr + offsets, output, mask=mask) def gen_input(M, N, d_type, seed, device='cuda'): + @triton.jit + def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + input = tl.load(input_ptr + offsets, mask=mask) + output = input + tl.store(output_ptr + offsets, output, mask=mask) + torch.manual_seed(seed) torch.cuda.manual_seed(seed) if d_type == tl.float16: @@ -1246,7 +1247,8 @@ def matmul(a, b, c_type): def test_gemm_amd_fp8_inputs(M, N, K, a_type, b_type, out_dtype, device = 'cuda'): check_type_supported(out_dtype, device) - if triton.language.semantic.gpu_matrix_core_version() != 3: + backend = triton.common.backend.get_backend("hip") + if backend.get_matrix_core_version() != 3: pytest.skip("fp8 data type is not available on hardware") @triton.jit @@ -1630,7 +1632,7 @@ def kernel(X, stride_xm, stride_xn, ('float16', 'float16'), ('float16', 'float32'), ('float32', 'float32')]] - if triton.language.semantic.gpu_matrix_core_version() == 0 else + if triton.common.backend.get_backend("hip").get_matrix_core_version() == 0 else # MFMA Test Dot tests [(*shape, 2, False, False, epilogue, allow_tf32, in_dtype, out_dtype, non_k_dim) for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)] @@ -1881,7 +1883,8 @@ def kernel(X, stride_xm, stride_xk, # added atol, to loose precision for float16xfloat16->float32 case np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) if torch.version.hip is not None: - if triton.language.semantic.gpu_matrix_core_version() > 0: + backend = triton.common.backend.get_backend("hip") + if backend.get_matrix_core_version() > 0: ttgir = pgm.asm['ttgir'] if non_k_dim == 16: assert "#triton_gpu.mfma<{nonKDim = 16" in ttgir @@ -1890,9 +1893,9 @@ def kernel(X, stride_xm, stride_xk, assert "#triton_gpu.mfma<{nonKDim = 32" in ttgir assert "#triton_gpu.mfma<{nonKDim = 16" not in ttgir gcn = pgm.asm['amdgcn'] - if triton.language.semantic.gpu_matrix_core_version() == 3 and effective_in_dtype == tl.float8e5b16: + if backend.get_matrix_core_version() == 3 and effective_in_dtype == tl.float8e5b16: assert "v_mfma_f32_32x32x16_bf8_bf8" in gcn or "v_mfma_f32_16x16x32_bf8_bf8" in gcn - if triton.language.semantic.gpu_matrix_core_version() == 3 and effective_in_dtype == tl.float8e4b8: + if backend.get_matrix_core_version() == 3 and effective_in_dtype == tl.float8e4b8: assert "v_mfma_f32_32x32x16_fp8_fp8" in gcn or "v_mfma_f32_16x16x32_fp8_fp8" in gcn return # make sure ld/st are vectorized @@ -2727,7 +2730,7 @@ def test_dot_mfma_vector_load(vec_size, swizzle, transposeA, transposeB): if transposeA and not transposeB: pytest.skip() - if triton.language.semantic.gpu_matrix_core_version() == 0: + if triton.common.backend.get_backend("hip").get_matrix_core_version() == 0: pytest.skip("mfma is not available on hardware") # source code for following ttgir: @@ -2817,7 +2820,8 @@ def test_dot_mfma_vector_load(vec_size, swizzle, transposeA, transposeB): kernel = triton.compile(f.name, device_type="hip", cc=capabilities) import triton.language.semantic as sem - if torch.version.hip is not None and sem.gpu_matrix_core_version() > 0: + # if torch.version.hip is not None and sem.gpu_matrix_core_version() > 0: + if torch.version.hip is not None and backend.get_matrix_core_version() > 0: kernel[(1, 1, 1)](x_tri, y_tri, z_tri) np.testing.assert_allclose(z_np, to_numpy(z_tri), rtol=0.01, atol=1e-3) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index b849ad2653d1..7fcc7554f821 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -68,8 +68,7 @@ def ttir_compute_capability_rewrite(mod, target): if _is_cuda(target): pm.add_rewrite_tensor_pointer_pass(target.capability, False) elif is_hip(): - capability = 90 - pm.add_rewrite_tensor_pointer_pass(capability, True) + pm.add_rewrite_tensor_pointer_pass(target["capability"], True) else: assert(False, "unsupported target") pm.run(mod) @@ -121,14 +120,14 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, pm.add_tritongpu_accelerate_matmul_pass(capability) # TODO change interface of accelerate_matmul_pass if is_hip(): - matrix_core_version = gpu_matrix_core_version() + matrix_core_version = target["matrix_core_version"] matrix_inst_size = matrix_inst_type pm.add_tritonamdgpu_accelerate_matmul_pass(matrix_core_version, matrix_inst_size) pm.add_tritongpu_remove_layout_conversions_pass() if optimize_epilogue: pm.add_tritongpu_optimize_epilogue_pass() pm.add_tritongpu_optimize_dot_operands_pass() - if num_stages == 0 and is_hip() and gpu_matrix_core_version() != 0: + if num_stages == 0 and is_hip() and target["matrix_core_version"] != 0: pm.add_tritongpu_stream_pipeline_pass() pm.add_canonicalizer_pass() ws_enabled = False @@ -192,7 +191,7 @@ def ttgir_to_llir(mod, extern_libs, target, tma_infos, waves_per_eu=0): if _is_cuda(target): return translate_triton_gpu_to_llvmir(mod, target.capability, tma_infos, runtime.TARGET.NVVM, waves_per_eu) else: - return translate_triton_gpu_to_llvmir(mod, 0, TMAInfos(), runtime.TARGET.ROCDL, waves_per_eu) + return translate_triton_gpu_to_llvmir(mod, target["capability"], TMAInfos(), runtime.TARGET.ROCDL, waves_per_eu) # PTX translation @@ -351,8 +350,6 @@ def is_hip(): raise ImportError("Triton requires PyTorch to be installed") return torch.version.hip is not None -from ..language.semantic import gpu_matrix_core_version - def get_cuda_capability(capability): if capability is None: device = get_current_device() diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index dea0787139bc..6372a3374b36 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1288,32 +1288,6 @@ def is_hip(): raise ImportError("Triton requires PyTorch to be installed") return torch.version.hip is not None - -def gpu_matrix_core_version() -> int: - """ Determine matrix core type available on current GPU. - - 0 means no tensor cores are available - 1 corresponds to MFMA in CDNA 1 architecture - 2 corresponds to MFMA in CDNA 2 architecture - 3 corresponds to MFMA in CDNA 3 architecture - """ - - if not is_hip(): - return 0 - arch_info = _triton.get_arch_info() - gfx_arch_details = re.search('amd.*', arch_info) - if gfx_arch_details is None: - return 0 - gfx_arch_details = gfx_arch_details.group(0).strip().split('--') - gpu_name = gfx_arch_details[1].split(':')[0] - if gpu_name in ['gfx908']: - return 1 - if gpu_name in ['gfx90a']: - return 2 - if gpu_name in ['gfx940', 'gfx941', 'gfx942']: - return 3 - return 0 - def mfma_supported_granularity(m, n, k) -> bool: # todo make this gran_type matrix element type sensitive for gran_type in [(32, 8), (16, 16)]: @@ -1326,8 +1300,8 @@ def mfma_supported_granularity(m, n, k) -> bool: return True return False -def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool: - matrix_core_version = gpu_matrix_core_version() +def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty, target) -> bool: + matrix_core_version = target["matrix_core_version"] if matrix_core_version not in [1, 2, 3]: return False if not mfma_supported_granularity(M, N ,K): @@ -1343,10 +1317,18 @@ def dot(lhs: tl.tensor, builder: ir.builder) -> tl.tensor: def assert_dtypes_valid(lhs_dtype, rhs_dtype, target): # Checks for non-cuda archs - if not _is_cuda(target): + if is_hip(): assert lhs.dtype == rhs.dtype or (lhs.type.scalar.is_fp8() and rhs.type.scalar.is_fp16()) or \ (lhs.type.scalar.is_fp16() and rhs.type.scalar.is_fp8()) or (lhs.type.scalar.is_fp8() and rhs.type.scalar.is_fp8()), \ f"First input ({lhs.dtype}) and second input ({rhs.dtype}) must have the same dtype!" + if lhs.type.scalar.is_fp8() and rhs.type.scalar.is_fp8(): + assert lhs.type.scalar.is_fp8e4b8() or lhs.type.scalar.is_fp8e5b16() or lhs.type.scalar.is_fp8e5(),\ + f"Only hip fp8 or f8e5 types are accepted for both inputs of fp8" + assert rhs.type.scalar.is_fp8e4b8() or rhs.type.scalar.is_fp8e5b16() or rhs.type.scalar.is_fp8e5(),\ + f"Only hip fp8 or f8e5 types are accepted for both inputs of fp8" + return + + if not _is_cuda(target): return # Checks for cuda archs @@ -1381,13 +1363,18 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, target): # hip for now converts fp8 to fp16 for mixed input if is_hip(): - fp8_supported = gpu_matrix_core_version() == 3 + target = builder.target + assert "matrix_core_version" in target + fp8_supported = target["matrix_core_version"] == 3 + # gfx940 data type + lhs_hip_fp8 = lhs.type.scalar.is_fp8e4b8() or lhs.type.scalar.is_fp8e5b16() + rhs_hip_fp8 = rhs.type.scalar.is_fp8e4b8() or rhs.type.scalar.is_fp8e5b16() lhs_fp8 = lhs.type.scalar.is_fp8() rhs_fp8 = rhs.type.scalar.is_fp8() - supported_fp8_dot = fp8_supported and lhs_fp8 and rhs_fp8 - if not supported_fp8_dot and lhs_fp8: + supported_fp8_dot = fp8_supported and lhs_hip_fp8 and rhs_hip_fp8 + if (not supported_fp8_dot) and lhs_fp8: lhs = cast(lhs, tl.float16, builder) - if not supported_fp8_dot and rhs_fp8: + if (not supported_fp8_dot) and rhs_fp8: rhs = cast(rhs, tl.float16, builder) if lhs.type.scalar.is_int(): @@ -1409,7 +1396,7 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, target): N = rhs.type.shape[1] # Cast operands of types f16 and i8 for configurations where FMA only supported. - if is_hip() and not mfma_supported(M, N, lhs.type.shape[1], allow_tf32, ret_scalar_ty): + if is_hip() and not mfma_supported(M, N, lhs.type.shape[1], allow_tf32, ret_scalar_ty, builder.target): # max_num_imprecise_acc does not yet apply to hip if is_hip(): max_num_imprecise_acc = 0 @@ -1426,7 +1413,7 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, target): ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32, max_num_imprecise_acc), ret_ty) return cast(ret, ret_scalar_ty, builder) - if is_hip() and mfma_supported(M, N, lhs.type.shape[1], allow_tf32, ret_scalar_ty) and ret_scalar_ty.primitive_bitwidth <= 32: + if is_hip() and mfma_supported(M, N, lhs.type.shape[1], allow_tf32, ret_scalar_ty, builder.target) and ret_scalar_ty.primitive_bitwidth <= 32: # max_num_imprecise_acc does not yet apply to hip if is_hip(): max_num_imprecise_acc = 0 diff --git a/python/triton/third_party/hip/hip_backend.py b/python/triton/third_party/hip/hip_backend.py index 53a0c1d32ef9..324a7df2c0af 100644 --- a/python/triton/third_party/hip/hip_backend.py +++ b/python/triton/third_party/hip/hip_backend.py @@ -273,6 +273,30 @@ def get_amdgcn_bitcode_paths(gfx_arch: str): return amdgcn_bitcode_paths +def gpu_matrix_core_version() -> int: + """ Determine matrix core type available on current GPU. + + 0 means no tensor cores are available + 1 corresponds to MFMA in CDNA 1 architecture + 2 corresponds to MFMA in CDNA 2 architecture + 3 corresponds to MFMA in CDNA 3 architecture + """ + + arch_info = _triton.get_arch_info() + gfx_arch_details = re.search('amd.*', arch_info) + if gfx_arch_details is None: + return 0 + gfx_arch_details = gfx_arch_details.group(0).strip().split('--') + gpu_name = gfx_arch_details[1].split(':')[0] + if gpu_name in ['gfx908']: + return 1 + if gpu_name in ['gfx90a']: + return 2 + if gpu_name in ['gfx940', 'gfx941', 'gfx942']: + return 3 + return 0 + + def get_amdgpu_arch_fulldetails(): # print("get_amdgpu_arch_fulldetails") """ @@ -294,7 +318,11 @@ def get_amdgpu_arch_fulldetails(): if gfx_arch is None: raise RuntimeError('gfx_arch is None (not specified)') - return {"gfx_triple": arch_triple, "gfx_arch": gfx_arch, "gfx_features": arch_features} + mat_core_ver = gpu_matrix_core_version() + capability = gpu_matrix_core_version() * 100 + + return {"gfx_triple": arch_triple, "gfx_arch": gfx_arch, "gfx_features": arch_features,\ + "capability": capability, "matrix_core_version": mat_core_ver} except BaseException: return None @@ -487,3 +515,6 @@ def get_num_warps(self, module): return _triton.get_num_warps(module) else: return _triton.get_num_warps(module) + + def get_matrix_core_version(self): + return gpu_matrix_core_version() \ No newline at end of file