Skip to content

Commit

Permalink
Merge pull request #79 from HazyResearch/fp8
Browse files Browse the repository at this point in the history
torch scaled
  • Loading branch information
simran-arora authored Jan 4, 2025
2 parents de5cf82 + b108138 commit 0cd5947
Show file tree
Hide file tree
Showing 14 changed files with 1,749 additions and 10 deletions.
233 changes: 231 additions & 2 deletions include/ops/group/wgmma/base/64x128.impl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ struct base<T_D, T_AB, 128, trans_a, trans_b> {
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) ||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) ||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e4m3>) ||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e5m2>),
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e5m2>) ||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e4m3>) ||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e5m2>),
"Invalid type combination for WGMMA."
);
static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option");
Expand Down Expand Up @@ -284,6 +286,119 @@ struct base<T_D, T_AB, 128, trans_a, trans_b> {
"n"(scale_b)
);
}
// ----- FP8,FP8 -> FP16 ----- //
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e4m3>) {
asm volatile (
"{\n"
".reg .pred p;\n" \
"setp.ne.b32 p, %37, 0;\n" \
"wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e4m3 " \
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, " \
"{%32, %33, %34, %35}, " \
"%36, " \
"p, 1, %38;\n" \
"}\n"
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b

: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3]),
"+r"(*(uint32_t*)&dst.tiles[0][1].data[0]),
"+r"(*(uint32_t*)&dst.tiles[0][1].data[1]),
"+r"(*(uint32_t*)&dst.tiles[0][1].data[2]),
"+r"(*(uint32_t*)&dst.tiles[0][1].data[3]),
"+r"(*(uint32_t*)&dst.tiles[0][2].data[0]),
"+r"(*(uint32_t*)&dst.tiles[0][2].data[1]),
"+r"(*(uint32_t*)&dst.tiles[0][2].data[2]),
"+r"(*(uint32_t*)&dst.tiles[0][2].data[3]),
"+r"(*(uint32_t*)&dst.tiles[0][3].data[0]),
"+r"(*(uint32_t*)&dst.tiles[0][3].data[1]),
"+r"(*(uint32_t*)&dst.tiles[0][3].data[2]),
"+r"(*(uint32_t*)&dst.tiles[0][3].data[3]),
"+r"(*(uint32_t*)&dst.tiles[0][4].data[0]),
"+r"(*(uint32_t*)&dst.tiles[0][4].data[1]),
"+r"(*(uint32_t*)&dst.tiles[0][4].data[2]),
"+r"(*(uint32_t*)&dst.tiles[0][4].data[3]),
"+r"(*(uint32_t*)&dst.tiles[0][5].data[0]),
"+r"(*(uint32_t*)&dst.tiles[0][5].data[1]),
"+r"(*(uint32_t*)&dst.tiles[0][5].data[2]),
"+r"(*(uint32_t*)&dst.tiles[0][5].data[3]),
"+r"(*(uint32_t*)&dst.tiles[0][6].data[0]),
"+r"(*(uint32_t*)&dst.tiles[0][6].data[1]),
"+r"(*(uint32_t*)&dst.tiles[0][6].data[2]),
"+r"(*(uint32_t*)&dst.tiles[0][6].data[3]),
"+r"(*(uint32_t*)&dst.tiles[0][7].data[0]),
"+r"(*(uint32_t*)&dst.tiles[0][7].data[1]),
"+r"(*(uint32_t*)&dst.tiles[0][7].data[2]),
"+r"(*(uint32_t*)&dst.tiles[0][7].data[3])

: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),

"l"(b_st_desc),
"r"(scale_d),
// "n"(trans_b),
"n"(scale_b)
);
}
// ----- FP8,FP8 -> FP16 ----- //
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e5m2>) {
asm volatile (
"{\n"
".reg .pred p;\n" \
"setp.ne.b32 p, %37, 0;\n" \
"wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e5m2 " \
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, " \
"{%32, %33, %34, %35}, " \
"%36, " \
"p, 1, %38;\n" \
"}\n"
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b

: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3]),
"+r"(*(uint32_t*)&dst.tiles[0][1].data[0]),
"+r"(*(uint32_t*)&dst.tiles[0][1].data[1]),
"+r"(*(uint32_t*)&dst.tiles[0][1].data[2]),
"+r"(*(uint32_t*)&dst.tiles[0][1].data[3]),
"+r"(*(uint32_t*)&dst.tiles[0][2].data[0]),
"+r"(*(uint32_t*)&dst.tiles[0][2].data[1]),
"+r"(*(uint32_t*)&dst.tiles[0][2].data[2]),
"+r"(*(uint32_t*)&dst.tiles[0][2].data[3]),
"+r"(*(uint32_t*)&dst.tiles[0][3].data[0]),
"+r"(*(uint32_t*)&dst.tiles[0][3].data[1]),
"+r"(*(uint32_t*)&dst.tiles[0][3].data[2]),
"+r"(*(uint32_t*)&dst.tiles[0][3].data[3]),
"+r"(*(uint32_t*)&dst.tiles[0][4].data[0]),
"+r"(*(uint32_t*)&dst.tiles[0][4].data[1]),
"+r"(*(uint32_t*)&dst.tiles[0][4].data[2]),
"+r"(*(uint32_t*)&dst.tiles[0][4].data[3]),
"+r"(*(uint32_t*)&dst.tiles[0][5].data[0]),
"+r"(*(uint32_t*)&dst.tiles[0][5].data[1]),
"+r"(*(uint32_t*)&dst.tiles[0][5].data[2]),
"+r"(*(uint32_t*)&dst.tiles[0][5].data[3]),
"+r"(*(uint32_t*)&dst.tiles[0][6].data[0]),
"+r"(*(uint32_t*)&dst.tiles[0][6].data[1]),
"+r"(*(uint32_t*)&dst.tiles[0][6].data[2]),
"+r"(*(uint32_t*)&dst.tiles[0][6].data[3]),
"+r"(*(uint32_t*)&dst.tiles[0][7].data[0]),
"+r"(*(uint32_t*)&dst.tiles[0][7].data[1]),
"+r"(*(uint32_t*)&dst.tiles[0][7].data[2]),
"+r"(*(uint32_t*)&dst.tiles[0][7].data[3])

: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),

"l"(b_st_desc),
"r"(scale_d),
// "n"(trans_b),
"n"(scale_b)
);
}

}
template<int scale_b=1> __device__ static inline void st_st(
rt<T_D, 16, 128, ducks::rt_layout::row> &dst,
Expand All @@ -296,7 +411,9 @@ struct base<T_D, T_AB, 128, trans_a, trans_b> {
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) ||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) ||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e4m3>) ||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e5m2>),
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e5m2>) ||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e4m3>) ||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e5m2>),
"Invalid type combination for WGMMA."
);
static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option");
Expand Down Expand Up @@ -580,5 +697,117 @@ struct base<T_D, T_AB, 128, trans_a, trans_b> {
"n"(scale_b)
);
}
// ----- FP8,FP8 -> FP16 ----- //
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e4m3>) {
asm volatile (
"{\n"
".reg .pred p;\n" \
"setp.ne.b32 p, %34, 0;\n" \
"wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e4m3 " \
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, " \
"%32, " \
"%33, " \
"p, 1, %35;\n" \
"}\n"
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b

: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3]),
"+r"(*(uint32_t*)&dst.tiles[0][1].data[0]),
"+r"(*(uint32_t*)&dst.tiles[0][1].data[1]),
"+r"(*(uint32_t*)&dst.tiles[0][1].data[2]),
"+r"(*(uint32_t*)&dst.tiles[0][1].data[3]),
"+r"(*(uint32_t*)&dst.tiles[0][2].data[0]),
"+r"(*(uint32_t*)&dst.tiles[0][2].data[1]),
"+r"(*(uint32_t*)&dst.tiles[0][2].data[2]),
"+r"(*(uint32_t*)&dst.tiles[0][2].data[3]),
"+r"(*(uint32_t*)&dst.tiles[0][3].data[0]),
"+r"(*(uint32_t*)&dst.tiles[0][3].data[1]),
"+r"(*(uint32_t*)&dst.tiles[0][3].data[2]),
"+r"(*(uint32_t*)&dst.tiles[0][3].data[3]),
"+r"(*(uint32_t*)&dst.tiles[0][4].data[0]),
"+r"(*(uint32_t*)&dst.tiles[0][4].data[1]),
"+r"(*(uint32_t*)&dst.tiles[0][4].data[2]),
"+r"(*(uint32_t*)&dst.tiles[0][4].data[3]),
"+r"(*(uint32_t*)&dst.tiles[0][5].data[0]),
"+r"(*(uint32_t*)&dst.tiles[0][5].data[1]),
"+r"(*(uint32_t*)&dst.tiles[0][5].data[2]),
"+r"(*(uint32_t*)&dst.tiles[0][5].data[3]),
"+r"(*(uint32_t*)&dst.tiles[0][6].data[0]),
"+r"(*(uint32_t*)&dst.tiles[0][6].data[1]),
"+r"(*(uint32_t*)&dst.tiles[0][6].data[2]),
"+r"(*(uint32_t*)&dst.tiles[0][6].data[3]),
"+r"(*(uint32_t*)&dst.tiles[0][7].data[0]),
"+r"(*(uint32_t*)&dst.tiles[0][7].data[1]),
"+r"(*(uint32_t*)&dst.tiles[0][7].data[2]),
"+r"(*(uint32_t*)&dst.tiles[0][7].data[3])

: "l"(a_st_desc),
"l"(b_st_desc),

"r"(scale_d),
// "n"(trans_a),
// "n"(trans_b),
"n"(scale_b)
);
}
// ----- FP8,FP8 -> FP16 ----- //
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e5m2>) {
asm volatile (
"{\n"
".reg .pred p;\n" \
"setp.ne.b32 p, %34, 0;\n" \
"wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e5m2 " \
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, " \
"%32, " \
"%33, " \
"p, 1, %35;\n" \
"}\n"
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b

: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3]),
"+r"(*(uint32_t*)&dst.tiles[0][1].data[0]),
"+r"(*(uint32_t*)&dst.tiles[0][1].data[1]),
"+r"(*(uint32_t*)&dst.tiles[0][1].data[2]),
"+r"(*(uint32_t*)&dst.tiles[0][1].data[3]),
"+r"(*(uint32_t*)&dst.tiles[0][2].data[0]),
"+r"(*(uint32_t*)&dst.tiles[0][2].data[1]),
"+r"(*(uint32_t*)&dst.tiles[0][2].data[2]),
"+r"(*(uint32_t*)&dst.tiles[0][2].data[3]),
"+r"(*(uint32_t*)&dst.tiles[0][3].data[0]),
"+r"(*(uint32_t*)&dst.tiles[0][3].data[1]),
"+r"(*(uint32_t*)&dst.tiles[0][3].data[2]),
"+r"(*(uint32_t*)&dst.tiles[0][3].data[3]),
"+r"(*(uint32_t*)&dst.tiles[0][4].data[0]),
"+r"(*(uint32_t*)&dst.tiles[0][4].data[1]),
"+r"(*(uint32_t*)&dst.tiles[0][4].data[2]),
"+r"(*(uint32_t*)&dst.tiles[0][4].data[3]),
"+r"(*(uint32_t*)&dst.tiles[0][5].data[0]),
"+r"(*(uint32_t*)&dst.tiles[0][5].data[1]),
"+r"(*(uint32_t*)&dst.tiles[0][5].data[2]),
"+r"(*(uint32_t*)&dst.tiles[0][5].data[3]),
"+r"(*(uint32_t*)&dst.tiles[0][6].data[0]),
"+r"(*(uint32_t*)&dst.tiles[0][6].data[1]),
"+r"(*(uint32_t*)&dst.tiles[0][6].data[2]),
"+r"(*(uint32_t*)&dst.tiles[0][6].data[3]),
"+r"(*(uint32_t*)&dst.tiles[0][7].data[0]),
"+r"(*(uint32_t*)&dst.tiles[0][7].data[1]),
"+r"(*(uint32_t*)&dst.tiles[0][7].data[2]),
"+r"(*(uint32_t*)&dst.tiles[0][7].data[3])

: "l"(a_st_desc),
"l"(b_st_desc),

"r"(scale_d),
// "n"(trans_a),
// "n"(trans_b),
"n"(scale_b)
);
}
}
};
Loading

0 comments on commit 0cd5947

Please sign in to comment.