Skip to content

Commit

Permalink
fp8 type support (#357)
Browse files Browse the repository at this point in the history
* add two fp8 data types `tl.float8e4b8` and `tl.float8e5b16` to triton.
* add SW type conversion between `tl.float8e4b8/tl.float8e5b16` and `fp16`
* change flashattention to support fp8 in q/k.
  • Loading branch information
scxiao authored Nov 2, 2023
1 parent 38f9136 commit 79bebc4
Show file tree
Hide file tree
Showing 11 changed files with 443 additions and 129 deletions.
2 changes: 1 addition & 1 deletion include/triton/Dialect/Triton/IR/TritonTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class TritonTypeDef<string name, string _mnemonic>
}

// Floating-point Type
def TT_Float : AnyTypeOf<[F8E4M3FNUZ, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2, F16, BF16, F32, F64], "floating-point">;
def TT_Float : AnyTypeOf<[F8E4M3FNUZ, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">;
def TT_FloatTensor : TensorOf<[TT_Float]>;
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;

Expand Down
247 changes: 189 additions & 58 deletions lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,6 @@ Fp16_to_Fp8E5M2(Location loc, ConversionPatternRewriter &rewriter,

Value a0 = bitcast(fp16x2Vec0, i32_ty);
Value a1 = bitcast(fp16x2Vec1, i32_ty);
Value sign0 = and_(i32_ty, a0, i32_val(0x80008000));
Value sign1 = and_(i32_ty, a1, i32_val(0x80008000));

a0 = and_(i32_ty, a0, i32_val(0x7fff7fff));
a1 = and_(i32_ty, a1, i32_val(0x7fff7fff));
a0 = add(i32_ty, a0, i32_val(0x00800080));
a1 = add(i32_ty, a1, i32_val(0x00800080));

a0 = or_(i32_ty, a0, sign0);
a1 = or_(i32_ty, a1, sign1);

auto fp8x4VecTy = vec_ty(i8_ty, 4);
a0 = bitcast(a0, fp8x4VecTy);
Expand All @@ -54,6 +44,57 @@ const std::string Fp16_to_Fp8E5M2 =
"}";
#endif

#ifdef USE_ROCM
static Value convert_val_Fp16_to_Fp8E5M2FNUZ(
Location loc, ConversionPatternRewriter &rewriter, Value v) {
auto vi16 = bitcast(v, i16_ty);
auto e = and_(i16_ty, vi16, int_val(16, 0x7C00));
auto sign = and_(i16_ty, vi16, int_val(16, 0x8000));

// normal value
auto a = and_(i16_ty, vi16, int_val(16, 0x7FFFF));
auto a1 = add(i16_ty, a, int_val(16, 0x0400));
auto o1 = or_(i16_ty, a1, sign);

// subnormal value, e is 0
auto m = and_(i16_ty, vi16, int_val(16, 0x03FF));
auto m2 = shl(m, int_val(16, 1));
auto o2 = or_(i16_ty, sign, or_(i16_ty, int_val(16, 1), m2));

auto e_is_zero = icmp_eq(e, int_val(16, 0));
auto e_is_all1 = icmp_eq(e, int_val(16, 0x7C00));

auto ot = select(e_is_zero, o2, o1);
auto o = select(e_is_all1, vi16, ot);
auto fp8x2VecTy = vec_ty(i8_ty, 2);
auto res = bitcast(o, fp8x2VecTy);

return extract_element(i8_ty, res, i32_val(1));
}

static SmallVector<Value>
Fp16_to_Fp8E5M2FNUZ(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
SmallVector<Value> result(4);
result[0] = convert_val_Fp16_to_Fp8E5M2FNUZ(loc, rewriter, v[0]);
result[1] = convert_val_Fp16_to_Fp8E5M2FNUZ(loc, rewriter, v[1]);
result[2] = convert_val_Fp16_to_Fp8E5M2FNUZ(loc, rewriter, v[2]);
result[3] = convert_val_Fp16_to_Fp8E5M2FNUZ(loc, rewriter, v[3]);

return result;
}
#else
const std::string Fp16_to_Fp8E5M2FNUZ =
"{ \n"
".reg .b32 a<2>; \n"
"and.b32 a0, $1, 0xfffefffe; \n" // a0 &= 0xfffefffe
"and.b32 a1, $2, 0xfffefffe; \n" // (strip lowest bit)
"add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080
"add.u32 a1, a1, 0x00800080; \n" // (round to nearest)
"prmt.b32 $0, a0, a1, 0x7531; \n\t" // output = a1a0
"}";
#endif

#ifdef USE_ROCM
static SmallVector<Value>
Fp8E5M2_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
Expand Down Expand Up @@ -89,6 +130,61 @@ const std::string Fp8E5M2_to_Fp16 = "{ \n"
"}";
#endif

#ifdef USE_ROCM

static Value convert_val_Fp8E5M2FNUZ_to_Fp16(
Location loc, ConversionPatternRewriter &rewriter, Value v) {
auto fp8x2VecTy = vec_ty(i8_ty, 2);
Value a = undef(fp8x2VecTy);
a = insert_element(fp8x2VecTy, a, int_val(8, 0), i32_val(0));
a = insert_element(fp8x2VecTy, a, v, i32_val(1));
a = bitcast(a, i16_ty);

auto e = and_(i16_ty, a, int_val(16, 0x7C00));
auto m = and_(i16_ty, a, int_val(16, 0x0300));
auto sign = and_(i16_ty, a, int_val(16, 0x8000));

// check whether all exponents are zeros
auto e_is_zero = icmp_eq(e, int_val(16, 0x0));

// case 1, e is zero, need to move m right by 1 bit
auto m1 = lshr(i16_ty, m, int_val(16, 1));
auto o0 = or_(i16_ty, sign, m1);

// case 2, e is nonzero, sub exponent by 1
auto e1 = sub(i16_ty, e, int_val(16, 0x0400));

auto e_is_one = icmp_eq(e, int_val(16, 0x0400));
auto m2 = add(i16_ty, m1, int_val(16, 0x0200));

auto o1 = or_(i16_ty, sign, or_(i16_ty, m, e1));
auto o2 = or_(i16_ty, sign, m2);

auto o12 = select(e_is_one, o2, o1);
auto o = select(e_is_zero, o0, o12);

return bitcast(o, f16_ty);
}

static SmallVector<Value>
Fp8E5M2FNUZ_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {

SmallVector<Value> result(4);
result[0] = convert_val_Fp8E5M2FNUZ_to_Fp16(loc, rewriter, v[0]);
result[1] = convert_val_Fp8E5M2FNUZ_to_Fp16(loc, rewriter, v[1]);
result[2] = convert_val_Fp8E5M2FNUZ_to_Fp16(loc, rewriter, v[2]);
result[3] = convert_val_Fp8E5M2FNUZ_to_Fp16(loc, rewriter, v[3]);

return result;
}
#else
const std::string Fp8E5M2FNUZ_to_Fp16 = "{ \n"
"prmt.b32 $0, 0, $2, 0x5140; \n\t"
"prmt.b32 $1, 0, $2, 0x7362; \n\t"
"}";
#endif

#ifdef USE_ROCM
static SmallVector<Value>
Fp8E5M2_to_Bf16(Location loc, ConversionPatternRewriter &rewriter,
Expand Down Expand Up @@ -510,36 +606,50 @@ const std::string Fp16_to_Fp8E4M3B15x4 =
// does not handle denormals and has
// more than a single NaN values.

// Fp8E4M3 -> Fp16 (packed)
#ifdef USE_ROCM
static Value convert_val_Fp8E4M3FNUZ_to_Fp16(
Location loc, ConversionPatternRewriter &rewriter, Value v) {
auto fp8x2VecTy = vec_ty(i8_ty, 2);
Value a = undef(fp8x2VecTy);
a = insert_element(fp8x2VecTy, a, int_val(8, 0), i32_val(0));
a = insert_element(fp8x2VecTy, a, v, i32_val(1));
a = bitcast(a, i16_ty);

auto e_mask = int_val(16, 0x7A00);
auto e = and_(i16_ty, a, e_mask);

auto m = and_(i16_ty, a, int_val(16, 0x0700));
auto sign = and_(i16_ty, a, int_val(16, 0x8000));

// check whether all exponents are zeros
auto e_is_zero = icmp_eq(e, int_val(16, 0x0));
auto b = and_(i16_ty, a, int_val(16, 0x7FFF));
auto b1 = lshr(i16_ty, b, int_val(16, 1));

// case 1, e is nonzero, add exponent by 6
auto o0v = add(i16_ty, b1, int_val(16, 0x0C00));
auto o0 = or_(i16_ty, o0v, sign);

// case 2, e is nonzero, add exponent by 7
auto o1v = add(i16_ty, b1, int_val(16, 0x1C00));
auto o1 = or_(i16_ty, o1v, sign);

auto io = select(e_is_zero, o0, o1);
return bitcast(io, f16_ty);
}

// Fp8E4M3FNUZ -> Fp16 (packed)
static SmallVector<Value>
Fp8E4M3_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
Fp8E4M3FNUZ_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
auto fp8x4VecTy = vec_ty(i8_ty, 4);
Value a0 = undef(fp8x4VecTy);
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(0));
a0 = insert_element(fp8x4VecTy, a0, v[0], i32_val(1));
a0 = insert_element(fp8x4VecTy, a0, int_val(8,0), i32_val(2));
a0 = insert_element(fp8x4VecTy, a0, v[1], i32_val(3));
a0 = bitcast(a0, i32_ty);

Value b0 = and_(i32_ty, a0, i32_val(0x7fff7fff));

b0 = lshr(i32_ty, b0, i32_val(1));

b0 = add(i32_ty, b0, i32_val(0x20002000));

b0 = or_( i32_ty, b0, and_(i32_ty, a0, i32_val(0x80008000)) );

auto fp16x2VecTy = vec_ty(f16_ty, 2);
auto fp16x2Vec0 = bitcast(b0, fp16x2VecTy);
SmallVector<Value> result(2);
result[0] = convert_val_Fp8E4M3FNUZ_to_Fp16(loc, rewriter, v[0]);
result[1] = convert_val_Fp8E4M3FNUZ_to_Fp16(loc, rewriter, v[1]);

return { extract_element(f16_ty, fp16x2Vec0, i32_val(0)),
extract_element(f16_ty, fp16x2Vec0, i32_val(1))
};
return result;
}
#else
const std::string Fp8E4M3_to_Fp16 =
const std::string Fp8E4M3FNUZ_to_Fp16 =
"{ \n"
".reg .b32 a<2>, b<2>; \n" // if input = 0xf1f2f3f4
"prmt.b32 a0, 0, $2, 0x5040; \n" // a0 = 0xf300f400
Expand All @@ -558,32 +668,50 @@ const std::string Fp8E4M3_to_Fp16 =

// Fp16 -> Fp8E4M3 (packed)
#ifdef USE_ROCM
static Value convert_val_Fp16_to_Fp8E4M3FNUZ(
Location loc, ConversionPatternRewriter &rewriter, Value v) {
auto vi16 = bitcast(v, i16_ty);
auto e10 = and_(vi16, int_val(16, 0x7C00));
auto e = lshr(i16_ty, e10, int_val(16, 10));

auto s = and_(i16_ty, vi16, int_val(16, 0x8000));

auto m7 = and_(i16_ty, vi16, int_val(16, 0x0380));
auto m = shl(i16_ty, m7, int_val(16, 1));

// three cases:
// 1) e > 21 --> e = 1111,
// 2) e <= 7 ---> e = 0,
// 3) others, normal conversion
auto e1 = int_val(16, 0x7800);
auto e2 = int_val(16, 0x0);
auto e31 = sub(i16_ty, e10, int_val(16, 0x1C00));
auto e3 = shl(i16_ty, e31, int_val(16, 1));

auto c13 = icmp_sgt(e, int_val(16, 21));
auto e13 = select(c13, e1, e3);
auto c23 = icmp_sle(e, int_val(16, 7));
auto re = select(c23, e2, e13);

auto r = or_(i16_ty, s, or_(i16_ty, re, m));
auto fp8x2VecTy = vec_ty(i8_ty, 2);
auto res = bitcast(r, fp8x2VecTy);

return extract_element(i8_ty, res, i32_val(1));
}

static SmallVector<Value>
Fp16_to_Fp8E4M3(Location loc, ConversionPatternRewriter &rewriter,
Fp16_to_Fp8E4M3FNUZ(Location loc, ConversionPatternRewriter &rewriter,
const SmallVector<Value> &v) {
auto fp16x2VecTy = vec_ty(f16_ty, 2);
Value fp16x2Vec0 = undef(fp16x2VecTy);

fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[0], i32_val(0));
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[1], i32_val(1));

fp16x2Vec0 = bitcast(fp16x2Vec0, i32_ty);
fp16x2Vec0 = sub(i32_ty, fp16x2Vec0, i32_val(0x20002000));

Value a0 = shl(i32_ty, fp16x2Vec0, i32_val(1));
a0 = and_(i32_ty, a0, i32_val(0x7fff7fff));
a0 = add(i32_ty, a0, i32_val(0x00800080));
Value b0 = or_( i32_ty, and_(i32_ty, fp16x2Vec0, i32_val(0x80008000)), a0 );
SmallVector<Value> result(2);
result[0] = convert_val_Fp16_to_Fp8E4M3FNUZ(loc, rewriter, v[0]);
result[1] = convert_val_Fp16_to_Fp8E4M3FNUZ(loc, rewriter, v[1]);

auto fp8x4VecTy = vec_ty(i8_ty, 4);
b0 = bitcast(b0, fp8x4VecTy);

return {extract_element(i8_ty, b0, i32_val(1)),
extract_element(i8_ty, b0, i32_val(3))
};
return result;
}
#else
const std::string Fp16_to_Fp8E4M3 =
const std::string Fp16_to_Fp8E4M3FNUZ =
"{ \n"
".reg .b32 a<2>, b<2>; \n" // see Fp8E4M3x4ToFp16x4
"sub.u32 a0, $1, 0x20002000; \n" // a0 = input0 - 0x20002000
Expand Down Expand Up @@ -1215,9 +1343,10 @@ struct FpToFpOpConversion

ConverterT getConversionFunc(Type srcTy, Type dstTy) const {
auto F8E4M3B15TyID = TypeID::get<mlir::Float8E4M3B11FNUZType>();
auto F8E4M3TyID = TypeID::get<mlir::Float8E4M3FNUZType>();
auto F8E5M2TyID = TypeID::get<mlir::Float8E5M2Type>();
auto F8E4M3FNUZTyID = TypeID::get<mlir::Float8E4M3FNUZType>();
auto F8E4M3FNTyID = TypeID::get<mlir::Float8E4M3FNType>();
auto F8E5M2TyID = TypeID::get<mlir::Float8E5M2Type>();
auto F8E5M2FNUZTyID = TypeID::get<mlir::Float8E5M2FNUZType>();
auto F16TyID = TypeID::get<mlir::Float16Type>();
auto BF16TyID = TypeID::get<mlir::BFloat16Type>();
auto F32TyID = TypeID::get<mlir::Float32Type>();
Expand All @@ -1230,17 +1359,19 @@ struct FpToFpOpConversion
// F8 -> F16
{{F8E4M3B15TyID, F16TyID}, Fp8E4M3B15_to_Fp16},
{{F8E4M3FNTyID, F16TyID}, Fp8E4M3B15x4_to_Fp16},
{{F8E4M3TyID, F16TyID}, Fp8E4M3_to_Fp16},
{{F8E4M3FNUZTyID, F16TyID}, Fp8E4M3FNUZ_to_Fp16},
{{F8E5M2TyID, F16TyID}, Fp8E5M2_to_Fp16},
{{F8E5M2FNUZTyID, F16TyID}, Fp8E5M2FNUZ_to_Fp16},
// F16 -> F8
#ifdef USE_ROCM
{{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15},
#else
{{F16TyID, F8E4M3B15TyID}, Fp16_to_Fp8E4M3B15(computeCapability >= 80)},
#endif
{{F16TyID, F8E4M3FNTyID}, Fp16_to_Fp8E4M3B15x4},
{{F16TyID, F8E4M3TyID}, Fp16_to_Fp8E4M3},
{{F16TyID, F8E4M3FNUZTyID}, Fp16_to_Fp8E4M3FNUZ},
{{F16TyID, F8E5M2TyID}, Fp16_to_Fp8E5M2},
{{F16TyID, F8E5M2FNUZTyID}, Fp16_to_Fp8E5M2FNUZ},
// F8 -> BF16
{{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16},
// BF16 -> F8
Expand Down
3 changes: 3 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
addConversion([&](mlir::Float8E5M2Type type) -> std::optional<Type> {
return IntegerType::get(type.getContext(), 8);
});
addConversion([&](mlir::Float8E5M2FNUZType type) -> std::optional<Type> {
return IntegerType::get(type.getContext(), 8);
});
// Internally store bfloat16 as int16
addConversion([&](BFloat16Type type) -> std::optional<Type> {
return IntegerType::get(type.getContext(), 16);
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter,
} // namespace LLVM

bool isF8(Type eType) {
return eType.isFloat8E5M2FNUZ() or eType.isFloat8E4M3FNUZ() or
return eType.isFloat8E4M3FNUZ() or eType.isFloat8E4M3FN() or
eType.isFloat8E5M2() or eType.isFloat8E5M2FNUZ();
}

Expand Down
8 changes: 8 additions & 0 deletions python/src/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,10 @@ void init_triton_ir(py::module &&m) {
[](TritonOpBuilder &self) -> mlir::Type {
return self.getBuilder().getType<mlir::Float8E4M3FNUZType>();
})
.def("get_fp8e4b8_ty",
[](TritonOpBuilder &self) -> mlir::Type {
return self.getBuilder().getType<mlir::Float8E4M3FNUZType>();
})
.def("get_fp8e4b15_ty",
[](TritonOpBuilder &self) -> mlir::Type {
// TODO: upstream FP8E4B15 into MLIR, or find a way to externally
Expand All @@ -827,6 +831,10 @@ void init_triton_ir(py::module &&m) {
[](TritonOpBuilder &self) -> mlir::Type {
return self.getBuilder().getType<mlir::Float8E5M2Type>();
})
.def("get_fp8e5b16_ty",
[](TritonOpBuilder &self) -> mlir::Type {
return self.getBuilder().getType<mlir::Float8E5M2FNUZType>();
})
.def("get_half_ty",
[](TritonOpBuilder &self) -> mlir::Type {
return self.getBuilder().getF16Type();
Expand Down
Loading

0 comments on commit 79bebc4

Please sign in to comment.