Skip to content

Commit 33858fb

Browse files
authored
merge ocp fp8 (#2931)
1 parent 25a5adf commit 33858fb

File tree

8 files changed

+396
-175
lines changed

8 files changed

+396
-175
lines changed

third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc

+15-1
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,22 @@ class GemmAutotuner {
168168
se::DeviceMemoryBase a_scale_buffer, b_scale_buffer, c_scale_buffer,
169169
d_scale_buffer, d_amax_buffer, bias_buffer, aux_buffer;
170170

171+
int64_t input_buffer_idx = 2; // lhs is at 0, rhs is at 1
171172
if (has_vector_bias) {
172-
bias_buffer = rz_buffers_.input_buffers().at(has_matrix_bias ? 3 : 2);
173+
if (has_matrix_bias) {
174+
input_buffer_idx++;
175+
}
176+
bias_buffer = rz_buffers_.input_buffers().at(input_buffer_idx++);
177+
}
178+
// In the current GemmRewriter design for FP8, the a/b scales remain active
179+
// even when they are not used. Consequently, we must inform the autotuner
180+
// so it can choose algorithms that properly support a/b scales.
181+
if (xla::primitive_util::IsF8Type(
182+
gemm->operand(0)->shape().element_type()) &&
183+
xla::primitive_util::IsF8Type(
184+
gemm->operand(1)->shape().element_type())) {
185+
a_scale_buffer = rz_buffers_.input_buffers().at(input_buffer_idx++);
186+
b_scale_buffer = rz_buffers_.input_buffers().at(input_buffer_idx++);
173187
}
174188
if (has_aux_output) {
175189
aux_buffer = rz_buffers_.output_buffers().at(1);

third_party/xla/xla/service/gpu/buffer_comparator.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ absl::StatusOr<bool> BufferComparator::CompareEqual(
187187
stream, current, expected};
188188

189189
switch (shape_.element_type()) {
190-
#if GOOGLE_CUDA // not available for ROCm yet..
190+
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60300
191191
case xla::F8E4M3FN:
192192
return CompareEqualParameterized<tsl::float8_e4m3fn, float>(
193193
"fp8_e4m3fn_comparison", buffer_comparator::fp8_e4m3fn_comparison(),
@@ -196,7 +196,7 @@ absl::StatusOr<bool> BufferComparator::CompareEqual(
196196
return CompareEqualParameterized<tsl::float8_e5m2, float>(
197197
"fp8_e5m2_comparison", buffer_comparator::fp8_e5m2_comparison(),
198198
params);
199-
#endif // GOOGLE_CUDA
199+
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60300
200200
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200
201201
case xla::F8E4M3FNUZ:
202202
return CompareEqualParameterized<tsl::float8_e4m3fnuz, float>(

third_party/xla/xla/service/gpu/buffer_comparator.cu.cc

+33-15
Original file line numberDiff line numberDiff line change
@@ -54,20 +54,29 @@ __device__ __inline__ float Canonicalize(float input) {
5454
return isnan(input) ? input : max(-65505.0f, min(input, 65505.0f));
5555
}
5656

57+
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60300
58+
__global__ void xla_fp8_e4m3fn_comparison(
5759
#if GOOGLE_CUDA
58-
__global__ void xla_fp8_e4m3fn_comparison(__nv_fp8_storage_t* buffer_a,
59-
__nv_fp8_storage_t* buffer_b,
60-
float rel_error_threshold,
61-
uint64_t buffer_length,
62-
int* mismatch_count) {
60+
__nv_fp8_storage_t* buffer_a, __nv_fp8_storage_t* buffer_b,
61+
#else // TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60300
62+
__hip_fp8_storage_t* buffer_a, __hip_fp8_storage_t* buffer_b,
63+
#endif
64+
float rel_error_threshold, uint64_t buffer_length, int* mismatch_count) {
6365
int idx = threadIdx.x + blockIdx.x * blockDim.x;
6466
if (idx >= buffer_length) return;
6567
// TODO(philipphack): Replace with direct conversion to float when this
6668
// functionality becomes available.
69+
#if GOOGLE_CUDA
6770
float elem_a =
6871
__half2float(__nv_cvt_fp8_to_halfraw(buffer_a[idx], __NV_E4M3));
6972
float elem_b =
7073
__half2float(__nv_cvt_fp8_to_halfraw(buffer_b[idx], __NV_E4M3));
74+
#else // TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60300
75+
float elem_a =
76+
__half2float(__hip_cvt_fp8_to_halfraw(buffer_a[idx], __HIP_E4M3));
77+
float elem_b =
78+
__half2float(__hip_cvt_fp8_to_halfraw(buffer_b[idx], __HIP_E4M3));
79+
#endif
7180
elem_a = Canonicalize(elem_a);
7281
elem_b = Canonicalize(elem_b);
7382
if (isnan(elem_a) && isnan(elem_b)) return;
@@ -78,19 +87,28 @@ __global__ void xla_fp8_e4m3fn_comparison(__nv_fp8_storage_t* buffer_a,
7887
atomicAdd(mismatch_count, 1);
7988
}
8089

81-
__global__ void xla_fp8_e5m2_comparison(__nv_fp8_storage_t* buffer_a,
82-
__nv_fp8_storage_t* buffer_b,
83-
float rel_error_threshold,
84-
uint64_t buffer_length,
85-
int* mismatch_count) {
90+
__global__ void xla_fp8_e5m2_comparison(
91+
#if GOOGLE_CUDA
92+
__nv_fp8_storage_t* buffer_a, __nv_fp8_storage_t* buffer_b,
93+
#else // TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60300
94+
__hip_fp8_storage_t* buffer_a, __hip_fp8_storage_t* buffer_b,
95+
#endif
96+
float rel_error_threshold, uint64_t buffer_length, int* mismatch_count) {
8697
int idx = threadIdx.x + blockIdx.x * blockDim.x;
8798
if (idx >= buffer_length) return;
88-
// TODO(philipphack): Replace with direct conversion to float when this
89-
// functionality becomes available.
99+
// TODO(philipphack): Replace with direct conversion to float when this
100+
// functionality becomes available.
101+
#if GOOGLE_CUDA
90102
float elem_a =
91103
__half2float(__nv_cvt_fp8_to_halfraw(buffer_a[idx], __NV_E5M2));
92104
float elem_b =
93105
__half2float(__nv_cvt_fp8_to_halfraw(buffer_b[idx], __NV_E5M2));
106+
#else // TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60300
107+
float elem_a =
108+
__half2float(__hip_cvt_fp8_to_halfraw(buffer_a[idx], __HIP_E5M2));
109+
float elem_b =
110+
__half2float(__hip_cvt_fp8_to_halfraw(buffer_b[idx], __HIP_E5M2));
111+
#endif
94112
elem_a = Canonicalize(elem_a);
95113
elem_b = Canonicalize(elem_b);
96114
if (isnan(elem_a) && isnan(elem_b)) return;
@@ -100,7 +118,7 @@ __global__ void xla_fp8_e5m2_comparison(__nv_fp8_storage_t* buffer_a,
100118
if (rel_error > rel_error_threshold || isnan(rel_error))
101119
atomicAdd(mismatch_count, 1);
102120
}
103-
#endif // GOOGLE_CUDA
121+
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60300
104122

105123
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200
106124

@@ -262,15 +280,15 @@ __global__ void xla_int32_comparison(int* buffer_a, int* buffer_b,
262280

263281
} // namespace
264282

265-
#if GOOGLE_CUDA
283+
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60300
266284
void* fp8_e4m3fn_comparison() {
267285
return reinterpret_cast<void*>(&xla_fp8_e4m3fn_comparison);
268286
}
269287

270288
void* fp8_e5m2_comparison() {
271289
return reinterpret_cast<void*>(&xla_fp8_e5m2_comparison);
272290
}
273-
#endif
291+
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60300
274292

275293
#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200
276294
void* fp8_e4m3fnuz_comparison() {

third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc

+162-65
Original file line numberDiff line numberDiff line change
@@ -1056,22 +1056,52 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
10561056
}
10571057

10581058
if (IsRocm(gpu_version_)) {
1059-
if (a_type == F8E5M2FNUZ && b_type == F8E5M2FNUZ) {
1060-
VLOG(1)
1061-
<< "Failed to rewrite " << instr->ToShortString()
1062-
<< " into FP8 Custom Call. The element type of one of the operands "
1063-
"must be F8E4M3FNUZ.";
1064-
return false;
1059+
TF_ASSIGN_OR_RETURN(auto rocm_compute_capability,
1060+
GetRocmComputeCapability(gpu_version_));
1061+
if (rocm_compute_capability.has_ocp_fp8_support()) {
1062+
if (a_type == F8E5M2 && b_type == F8E5M2) {
1063+
VLOG(1) << "Failed to rewrite " << instr->ToShortString()
1064+
<< " into FP8 Custom Call. For "
1065+
<< rocm_compute_capability.gfx_version()
1066+
<< " arch, one of the input types must be F8E4M3FN, but got "
1067+
<< PrimitiveType_Name(a_type) << " and "
1068+
<< PrimitiveType_Name(b_type);
1069+
return false;
1070+
}
1071+
if ((a_type != F8E5M2 && a_type != F8E4M3FN) ||
1072+
(b_type != F8E5M2 && b_type != F8E4M3FN)) {
1073+
VLOG(1)
1074+
<< "Failed to rewrite " << instr->ToShortString()
1075+
<< " into FP8 Custom Call. For "
1076+
<< rocm_compute_capability.gfx_version()
1077+
<< " arch, the input types must be F8E5M2 or F8E4M3FN, but got "
1078+
<< PrimitiveType_Name(a_type) << " and "
1079+
<< PrimitiveType_Name(b_type);
1080+
return false;
1081+
}
10651082
}
1066-
if ((a_type != F8E5M2FNUZ && a_type != F8E4M3FNUZ) ||
1067-
(b_type != F8E5M2FNUZ && b_type != F8E4M3FNUZ)) {
1068-
VLOG(1)
1069-
<< "Failed to rewrite " << instr->ToShortString()
1070-
<< " into FP8 Custom Call. The input types must be F8E5M2FNUZ or "
1071-
"F8E4M3FNUZ, but got "
1072-
<< PrimitiveType_Name(a_type) << " and "
1073-
<< PrimitiveType_Name(b_type);
1074-
return false;
1083+
if (rocm_compute_capability.has_nanoo_fp8_support()) {
1084+
if (a_type == F8E5M2FNUZ && b_type == F8E5M2FNUZ) {
1085+
VLOG(1)
1086+
<< "Failed to rewrite " << instr->ToShortString()
1087+
<< " into FP8 Custom Call. For "
1088+
<< rocm_compute_capability.gfx_version()
1089+
<< " arch, one of the input types must be F8E4M3FNUZ, but got "
1090+
<< PrimitiveType_Name(a_type) << " and "
1091+
<< PrimitiveType_Name(b_type);
1092+
return false;
1093+
}
1094+
if ((a_type != F8E5M2FNUZ && a_type != F8E4M3FNUZ) ||
1095+
(b_type != F8E5M2FNUZ && b_type != F8E4M3FNUZ)) {
1096+
VLOG(1) << "Failed to rewrite " << instr->ToShortString()
1097+
<< " into FP8 Custom Call. For "
1098+
<< rocm_compute_capability.gfx_version()
1099+
<< " arch, the input types must be F8E5M2FNUZ or F8E4M3FNUZ, "
1100+
"but got "
1101+
<< PrimitiveType_Name(a_type) << " and "
1102+
<< PrimitiveType_Name(b_type);
1103+
return false;
1104+
}
10751105
}
10761106
}
10771107

@@ -1112,25 +1142,56 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
11121142
}
11131143

11141144
PrimitiveType d_type = instr->shape().element_type();
1115-
bool supported_d_type = (d_type == BF16 || d_type == F16 || d_type == F32);
1116-
if (IsCuda(gpu_version_) && (d_type == F8E4M3FN || d_type == F8E5M2)) {
1117-
supported_d_type = true;
1118-
}
1119-
if (IsRocm(gpu_version_) &&
1120-
toolkit_version_ >= stream_executor::SemanticVersion{6, 2, 0} &&
1121-
(d_type == F8E4M3FNUZ || d_type == F8E5M2FNUZ)) {
1122-
supported_d_type = true;
1145+
std::unordered_set<PrimitiveType> supported_d_types = {BF16, F16, F32};
1146+
if (IsCuda(gpu_version_)) {
1147+
supported_d_types.insert(F8E4M3FN);
1148+
supported_d_types.insert(F8E5M2);
1149+
if (supported_d_types.find(d_type) == supported_d_types.end()) {
1150+
VLOG(1) << "Failed to rewrite " << instr->ToShortString()
1151+
<< " into FP8 Custom Call. Output type must be "
1152+
"F8E4M3FN, F8E5M2, BF16, F16 or F32, but got "
1153+
<< PrimitiveType_Name(d_type);
1154+
return false;
1155+
}
11231156
}
1124-
if (!supported_d_type) {
1125-
VLOG(1) << "Failed to rewrite " << instr->ToShortString()
1126-
<< " into FP8 Custom Call. Output element type must be "
1127-
<< (IsCuda(gpu_version_) ? "F8E4M3FN, F8E5M2, BF16, F16 or F32. "
1128-
: toolkit_version_ >=
1129-
stream_executor::SemanticVersion{6, 2, 0}
1130-
? "F8E4M3FNUZ, F8E5M2FNUZ, BF16, F16 or F32. "
1131-
: "BF16, F16 or F32. ")
1132-
<< "Actual element type is " << PrimitiveType_Name(d_type);
1133-
return false;
1157+
if (IsRocm(gpu_version_)) {
1158+
if (toolkit_version_ < stream_executor::SemanticVersion{6, 2, 0}) {
1159+
if (supported_d_types.find(d_type) == supported_d_types.end()) {
1160+
VLOG(1) << "Failed to rewrite " << instr->ToShortString()
1161+
<< " into FP8 Custom Call. For ROCm version < 6.2, output "
1162+
"type must be BF16, F16 or F32, but got "
1163+
<< PrimitiveType_Name(d_type);
1164+
return false;
1165+
}
1166+
}
1167+
TF_ASSIGN_OR_RETURN(auto rocm_compute_capability,
1168+
GetRocmComputeCapability(gpu_version_));
1169+
if (rocm_compute_capability.has_ocp_fp8_support()) {
1170+
supported_d_types.insert(F8E4M3FN);
1171+
supported_d_types.insert(F8E5M2);
1172+
if (supported_d_types.find(d_type) == supported_d_types.end()) {
1173+
VLOG(1) << "Failed to rewrite " << instr->ToShortString()
1174+
<< " into FP8 Custom Call. For "
1175+
<< rocm_compute_capability.gfx_version()
1176+
<< " arch output type must be F8E4M3FN, F8E5M2, BF16, F16 or "
1177+
"F32, but got "
1178+
<< PrimitiveType_Name(d_type);
1179+
return false;
1180+
}
1181+
}
1182+
if (rocm_compute_capability.has_nanoo_fp8_support()) {
1183+
supported_d_types.insert(F8E4M3FNUZ);
1184+
supported_d_types.insert(F8E5M2FNUZ);
1185+
if (supported_d_types.find(d_type) == supported_d_types.end()) {
1186+
VLOG(1) << "Failed to rewrite " << instr->ToShortString()
1187+
<< " into FP8 Custom Call. For "
1188+
<< rocm_compute_capability.gfx_version()
1189+
<< " arch output type must be F8E4M3FNUZ, F8E5M2FNUZ, BF16, "
1190+
"F16 or F32, but got "
1191+
<< PrimitiveType_Name(d_type);
1192+
return false;
1193+
}
1194+
}
11341195
}
11351196

11361197
// Each operand must have exactly one contracting and one non-contracting
@@ -1322,6 +1383,10 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
13221383
HloInstruction *d_scale, HloInstruction *clamp_lower,
13231384
HloInstruction *clamp_upper,
13241385
bool mult_scale = false) {
1386+
// TODO: add ROCm support to this fusion pattern
1387+
if (IsRocm(gpu_version_)) {
1388+
return absl::OkStatus();
1389+
}
13251390
// Verify the data types and the operands of clamp.
13261391
if (instr->shape().element_type() == F8E4M3FN) {
13271392
if (!clamp_lower->literal().IsAllFloat(static_cast<float>(
@@ -2073,38 +2138,70 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
20732138
return true;
20742139
}
20752140
const TypeCombinations supported_hipblas_type_combinations = {
2076-
// FP8 types:
2077-
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
2078-
PrimitiveType::F8E4M3FNUZ, DataType::kBF16},
2079-
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
2080-
PrimitiveType::F8E4M3FNUZ, DataType::kF8E4M3FNUZ},
2081-
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
2082-
PrimitiveType::F8E4M3FNUZ, DataType::kHalf},
2083-
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
2084-
PrimitiveType::F8E4M3FNUZ, DataType::kFloat},
2085-
2086-
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
2087-
PrimitiveType::F8E5M2FNUZ, DataType::kBF16},
2088-
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
2089-
PrimitiveType::F8E5M2FNUZ, DataType::kF8E4M3FNUZ},
2090-
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
2091-
PrimitiveType::F8E5M2FNUZ, DataType::kF8E5M2FNUZ},
2092-
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
2093-
PrimitiveType::F8E5M2FNUZ, DataType::kHalf},
2094-
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
2095-
PrimitiveType::F8E5M2FNUZ, DataType::kFloat},
2096-
2097-
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
2098-
PrimitiveType::F8E4M3FNUZ, DataType::kBF16},
2099-
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
2100-
PrimitiveType::F8E4M3FNUZ, DataType::kF8E4M3FNUZ},
2101-
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
2102-
PrimitiveType::F8E4M3FNUZ, DataType::kF8E5M2FNUZ},
2103-
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
2104-
PrimitiveType::F8E4M3FNUZ, DataType::kHalf},
2105-
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
2106-
PrimitiveType::F8E4M3FNUZ, DataType::kFloat},
2107-
};
2141+
// OCP FP8 types:
2142+
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
2143+
PrimitiveType::F8E4M3FN, DataType::kBF16},
2144+
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
2145+
PrimitiveType::F8E4M3FN, DataType::kF8E4M3FN},
2146+
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
2147+
PrimitiveType::F8E4M3FN, DataType::kHalf},
2148+
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
2149+
PrimitiveType::F8E4M3FN, DataType::kFloat},
2150+
2151+
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
2152+
PrimitiveType::F8E5M2, DataType::kBF16},
2153+
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
2154+
PrimitiveType::F8E5M2, DataType::kF8E4M3FN},
2155+
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
2156+
PrimitiveType::F8E5M2, DataType::kF8E5M2},
2157+
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
2158+
PrimitiveType::F8E5M2, DataType::kHalf},
2159+
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FN,
2160+
PrimitiveType::F8E5M2, DataType::kFloat},
2161+
2162+
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2,
2163+
PrimitiveType::F8E4M3FN, DataType::kBF16},
2164+
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2,
2165+
PrimitiveType::F8E4M3FN, DataType::kF8E4M3FN},
2166+
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2,
2167+
PrimitiveType::F8E4M3FN, DataType::kF8E5M2},
2168+
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2,
2169+
PrimitiveType::F8E4M3FN, DataType::kHalf},
2170+
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2,
2171+
PrimitiveType::F8E4M3FN, DataType::kFloat},
2172+
2173+
// NANOO FP8 types:
2174+
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
2175+
PrimitiveType::F8E4M3FNUZ, DataType::kBF16},
2176+
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
2177+
PrimitiveType::F8E4M3FNUZ, DataType::kF8E4M3FNUZ},
2178+
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
2179+
PrimitiveType::F8E4M3FNUZ, DataType::kHalf},
2180+
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
2181+
PrimitiveType::F8E4M3FNUZ, DataType::kFloat},
2182+
2183+
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
2184+
PrimitiveType::F8E5M2FNUZ, DataType::kBF16},
2185+
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
2186+
PrimitiveType::F8E5M2FNUZ, DataType::kF8E4M3FNUZ},
2187+
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
2188+
PrimitiveType::F8E5M2FNUZ, DataType::kF8E5M2FNUZ},
2189+
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
2190+
PrimitiveType::F8E5M2FNUZ, DataType::kHalf},
2191+
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E4M3FNUZ,
2192+
PrimitiveType::F8E5M2FNUZ, DataType::kFloat},
2193+
2194+
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
2195+
PrimitiveType::F8E4M3FNUZ, DataType::kBF16},
2196+
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
2197+
PrimitiveType::F8E4M3FNUZ, DataType::kF8E4M3FNUZ},
2198+
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
2199+
PrimitiveType::F8E4M3FNUZ, DataType::kF8E5M2FNUZ},
2200+
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
2201+
PrimitiveType::F8E4M3FNUZ, DataType::kHalf},
2202+
{ComputationType::kF32, DataType::kFloat, PrimitiveType::F8E5M2FNUZ,
2203+
PrimitiveType::F8E4M3FNUZ, DataType::kFloat},
2204+
};
21082205
if (IsRocm(gpu_version_) &&
21092206
absl::c_linear_search(supported_hipblas_type_combinations,
21102207
std::tuple{compute_type, scale_type, a_dtype,

0 commit comments

Comments
 (0)