@@ -1056,22 +1056,52 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
1056
1056
}
1057
1057
1058
1058
if (IsRocm (gpu_version_)) {
1059
- if (a_type == F8E5M2FNUZ && b_type == F8E5M2FNUZ) {
1060
- VLOG (1 )
1061
- << " Failed to rewrite " << instr->ToShortString ()
1062
- << " into FP8 Custom Call. The element type of one of the operands "
1063
- " must be F8E4M3FNUZ." ;
1064
- return false ;
1059
+ TF_ASSIGN_OR_RETURN (auto rocm_compute_capability,
1060
+ GetRocmComputeCapability (gpu_version_));
1061
+ if (rocm_compute_capability.has_ocp_fp8_support ()) {
1062
+ if (a_type == F8E5M2 && b_type == F8E5M2) {
1063
+ VLOG (1 ) << " Failed to rewrite " << instr->ToShortString ()
1064
+ << " into FP8 Custom Call. For "
1065
+ << rocm_compute_capability.gfx_version ()
1066
+ << " arch, one of the input types must be F8E4M3FN, but got "
1067
+ << PrimitiveType_Name (a_type) << " and "
1068
+ << PrimitiveType_Name (b_type);
1069
+ return false ;
1070
+ }
1071
+ if ((a_type != F8E5M2 && a_type != F8E4M3FN) ||
1072
+ (b_type != F8E5M2 && b_type != F8E4M3FN)) {
1073
+ VLOG (1 )
1074
+ << " Failed to rewrite " << instr->ToShortString ()
1075
+ << " into FP8 Custom Call. For "
1076
+ << rocm_compute_capability.gfx_version ()
1077
+ << " arch, the input types must be F8E5M2 or F8E4M3FN, but got "
1078
+ << PrimitiveType_Name (a_type) << " and "
1079
+ << PrimitiveType_Name (b_type);
1080
+ return false ;
1081
+ }
1065
1082
}
1066
- if ((a_type != F8E5M2FNUZ && a_type != F8E4M3FNUZ) ||
1067
- (b_type != F8E5M2FNUZ && b_type != F8E4M3FNUZ)) {
1068
- VLOG (1 )
1069
- << " Failed to rewrite " << instr->ToShortString ()
1070
- << " into FP8 Custom Call. The input types must be F8E5M2FNUZ or "
1071
- " F8E4M3FNUZ, but got "
1072
- << PrimitiveType_Name (a_type) << " and "
1073
- << PrimitiveType_Name (b_type);
1074
- return false ;
1083
+ if (rocm_compute_capability.has_nanoo_fp8_support ()) {
1084
+ if (a_type == F8E5M2FNUZ && b_type == F8E5M2FNUZ) {
1085
+ VLOG (1 )
1086
+ << " Failed to rewrite " << instr->ToShortString ()
1087
+ << " into FP8 Custom Call. For "
1088
+ << rocm_compute_capability.gfx_version ()
1089
+ << " arch, one of the input types must be F8E4M3FNUZ, but got "
1090
+ << PrimitiveType_Name (a_type) << " and "
1091
+ << PrimitiveType_Name (b_type);
1092
+ return false ;
1093
+ }
1094
+ if ((a_type != F8E5M2FNUZ && a_type != F8E4M3FNUZ) ||
1095
+ (b_type != F8E5M2FNUZ && b_type != F8E4M3FNUZ)) {
1096
+ VLOG (1 ) << " Failed to rewrite " << instr->ToShortString ()
1097
+ << " into FP8 Custom Call. For "
1098
+ << rocm_compute_capability.gfx_version ()
1099
+ << " arch, the input types must be F8E5M2FNUZ or F8E4M3FNUZ, "
1100
+ " but got "
1101
+ << PrimitiveType_Name (a_type) << " and "
1102
+ << PrimitiveType_Name (b_type);
1103
+ return false ;
1104
+ }
1075
1105
}
1076
1106
}
1077
1107
@@ -1112,25 +1142,56 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
1112
1142
}
1113
1143
1114
1144
PrimitiveType d_type = instr->shape ().element_type ();
1115
- bool supported_d_type = (d_type == BF16 || d_type == F16 || d_type == F32);
1116
- if (IsCuda (gpu_version_) && (d_type == F8E4M3FN || d_type == F8E5M2)) {
1117
- supported_d_type = true ;
1118
- }
1119
- if (IsRocm (gpu_version_) &&
1120
- toolkit_version_ >= stream_executor::SemanticVersion{6 , 2 , 0 } &&
1121
- (d_type == F8E4M3FNUZ || d_type == F8E5M2FNUZ)) {
1122
- supported_d_type = true ;
1145
+ std::unordered_set<PrimitiveType> supported_d_types = {BF16, F16, F32};
1146
+ if (IsCuda (gpu_version_)) {
1147
+ supported_d_types.insert (F8E4M3FN);
1148
+ supported_d_types.insert (F8E5M2);
1149
+ if (supported_d_types.find (d_type) == supported_d_types.end ()) {
1150
+ VLOG (1 ) << " Failed to rewrite " << instr->ToShortString ()
1151
+ << " into FP8 Custom Call. Output type must be "
1152
+ " F8E4M3FN, F8E5M2, BF16, F16 or F32, but got "
1153
+ << PrimitiveType_Name (d_type);
1154
+ return false ;
1155
+ }
1123
1156
}
1124
- if (!supported_d_type) {
1125
- VLOG (1 ) << " Failed to rewrite " << instr->ToShortString ()
1126
- << " into FP8 Custom Call. Output element type must be "
1127
- << (IsCuda (gpu_version_) ? " F8E4M3FN, F8E5M2, BF16, F16 or F32. "
1128
- : toolkit_version_ >=
1129
- stream_executor::SemanticVersion{6 , 2 , 0 }
1130
- ? " F8E4M3FNUZ, F8E5M2FNUZ, BF16, F16 or F32. "
1131
- : " BF16, F16 or F32. " )
1132
- << " Actual element type is " << PrimitiveType_Name (d_type);
1133
- return false ;
1157
+ if (IsRocm (gpu_version_)) {
1158
+ if (toolkit_version_ < stream_executor::SemanticVersion{6 , 2 , 0 }) {
1159
+ if (supported_d_types.find (d_type) == supported_d_types.end ()) {
1160
+ VLOG (1 ) << " Failed to rewrite " << instr->ToShortString ()
1161
+ << " into FP8 Custom Call. For ROCm version < 6.2, output "
1162
+ " type must be BF16, F16 or F32, but got "
1163
+ << PrimitiveType_Name (d_type);
1164
+ return false ;
1165
+ }
1166
+ }
1167
+ TF_ASSIGN_OR_RETURN (auto rocm_compute_capability,
1168
+ GetRocmComputeCapability (gpu_version_));
1169
+ if (rocm_compute_capability.has_ocp_fp8_support ()) {
1170
+ supported_d_types.insert (F8E4M3FN);
1171
+ supported_d_types.insert (F8E5M2);
1172
+ if (supported_d_types.find (d_type) == supported_d_types.end ()) {
1173
+ VLOG (1 ) << " Failed to rewrite " << instr->ToShortString ()
1174
+ << " into FP8 Custom Call. For "
1175
+ << rocm_compute_capability.gfx_version ()
1176
+ << " arch output type must be F8E4M3FN, F8E5M2, BF16, F16 or "
1177
+ " F32, but got "
1178
+ << PrimitiveType_Name (d_type);
1179
+ return false ;
1180
+ }
1181
+ }
1182
+ if (rocm_compute_capability.has_nanoo_fp8_support ()) {
1183
+ supported_d_types.insert (F8E4M3FNUZ);
1184
+ supported_d_types.insert (F8E5M2FNUZ);
1185
+ if (supported_d_types.find (d_type) == supported_d_types.end ()) {
1186
+ VLOG (1 ) << " Failed to rewrite " << instr->ToShortString ()
1187
+ << " into FP8 Custom Call. For "
1188
+ << rocm_compute_capability.gfx_version ()
1189
+ << " arch output type must be F8E4M3FNUZ, F8E5M2FNUZ, BF16, "
1190
+ " F16 or F32, but got "
1191
+ << PrimitiveType_Name (d_type);
1192
+ return false ;
1193
+ }
1194
+ }
1134
1195
}
1135
1196
1136
1197
// Each operand must have exactly one contracting and one non-contracting
@@ -1322,6 +1383,10 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
1322
1383
HloInstruction *d_scale, HloInstruction *clamp_lower,
1323
1384
HloInstruction *clamp_upper,
1324
1385
bool mult_scale = false ) {
1386
+ // TODO: add ROCm support to this fusion pattern
1387
+ if (IsRocm (gpu_version_)) {
1388
+ return absl::OkStatus ();
1389
+ }
1325
1390
// Verify the data types and the operands of clamp.
1326
1391
if (instr->shape ().element_type () == F8E4M3FN) {
1327
1392
if (!clamp_lower->literal ().IsAllFloat (static_cast <float >(
@@ -2073,38 +2138,70 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
2073
2138
return true ;
2074
2139
}
2075
2140
const TypeCombinations supported_hipblas_type_combinations = {
2076
- // FP8 types:
2077
- {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E4M3FNUZ,
2078
- PrimitiveType::F8E4M3FNUZ, DataType::kBF16 },
2079
- {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E4M3FNUZ,
2080
- PrimitiveType::F8E4M3FNUZ, DataType::kF8E4M3FNUZ },
2081
- {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E4M3FNUZ,
2082
- PrimitiveType::F8E4M3FNUZ, DataType::kHalf },
2083
- {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E4M3FNUZ,
2084
- PrimitiveType::F8E4M3FNUZ, DataType::kFloat },
2085
-
2086
- {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E4M3FNUZ,
2087
- PrimitiveType::F8E5M2FNUZ, DataType::kBF16 },
2088
- {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E4M3FNUZ,
2089
- PrimitiveType::F8E5M2FNUZ, DataType::kF8E4M3FNUZ },
2090
- {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E4M3FNUZ,
2091
- PrimitiveType::F8E5M2FNUZ, DataType::kF8E5M2FNUZ },
2092
- {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E4M3FNUZ,
2093
- PrimitiveType::F8E5M2FNUZ, DataType::kHalf },
2094
- {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E4M3FNUZ,
2095
- PrimitiveType::F8E5M2FNUZ, DataType::kFloat },
2096
-
2097
- {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E5M2FNUZ,
2098
- PrimitiveType::F8E4M3FNUZ, DataType::kBF16 },
2099
- {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E5M2FNUZ,
2100
- PrimitiveType::F8E4M3FNUZ, DataType::kF8E4M3FNUZ },
2101
- {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E5M2FNUZ,
2102
- PrimitiveType::F8E4M3FNUZ, DataType::kF8E5M2FNUZ },
2103
- {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E5M2FNUZ,
2104
- PrimitiveType::F8E4M3FNUZ, DataType::kHalf },
2105
- {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E5M2FNUZ,
2106
- PrimitiveType::F8E4M3FNUZ, DataType::kFloat },
2107
- };
2141
+ // OCP FP8 types:
2142
+ {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E4M3FN,
2143
+ PrimitiveType::F8E4M3FN, DataType::kBF16 },
2144
+ {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E4M3FN,
2145
+ PrimitiveType::F8E4M3FN, DataType::kF8E4M3FN },
2146
+ {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E4M3FN,
2147
+ PrimitiveType::F8E4M3FN, DataType::kHalf },
2148
+ {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E4M3FN,
2149
+ PrimitiveType::F8E4M3FN, DataType::kFloat },
2150
+
2151
+ {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E4M3FN,
2152
+ PrimitiveType::F8E5M2, DataType::kBF16 },
2153
+ {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E4M3FN,
2154
+ PrimitiveType::F8E5M2, DataType::kF8E4M3FN },
2155
+ {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E4M3FN,
2156
+ PrimitiveType::F8E5M2, DataType::kF8E5M2 },
2157
+ {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E4M3FN,
2158
+ PrimitiveType::F8E5M2, DataType::kHalf },
2159
+ {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E4M3FN,
2160
+ PrimitiveType::F8E5M2, DataType::kFloat },
2161
+
2162
+ {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E5M2,
2163
+ PrimitiveType::F8E4M3FN, DataType::kBF16 },
2164
+ {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E5M2,
2165
+ PrimitiveType::F8E4M3FN, DataType::kF8E4M3FN },
2166
+ {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E5M2,
2167
+ PrimitiveType::F8E4M3FN, DataType::kF8E5M2 },
2168
+ {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E5M2,
2169
+ PrimitiveType::F8E4M3FN, DataType::kHalf },
2170
+ {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E5M2,
2171
+ PrimitiveType::F8E4M3FN, DataType::kFloat },
2172
+
2173
+ // NANOO FP8 types:
2174
+ {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E4M3FNUZ,
2175
+ PrimitiveType::F8E4M3FNUZ, DataType::kBF16 },
2176
+ {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E4M3FNUZ,
2177
+ PrimitiveType::F8E4M3FNUZ, DataType::kF8E4M3FNUZ },
2178
+ {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E4M3FNUZ,
2179
+ PrimitiveType::F8E4M3FNUZ, DataType::kHalf },
2180
+ {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E4M3FNUZ,
2181
+ PrimitiveType::F8E4M3FNUZ, DataType::kFloat },
2182
+
2183
+ {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E4M3FNUZ,
2184
+ PrimitiveType::F8E5M2FNUZ, DataType::kBF16 },
2185
+ {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E4M3FNUZ,
2186
+ PrimitiveType::F8E5M2FNUZ, DataType::kF8E4M3FNUZ },
2187
+ {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E4M3FNUZ,
2188
+ PrimitiveType::F8E5M2FNUZ, DataType::kF8E5M2FNUZ },
2189
+ {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E4M3FNUZ,
2190
+ PrimitiveType::F8E5M2FNUZ, DataType::kHalf },
2191
+ {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E4M3FNUZ,
2192
+ PrimitiveType::F8E5M2FNUZ, DataType::kFloat },
2193
+
2194
+ {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E5M2FNUZ,
2195
+ PrimitiveType::F8E4M3FNUZ, DataType::kBF16 },
2196
+ {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E5M2FNUZ,
2197
+ PrimitiveType::F8E4M3FNUZ, DataType::kF8E4M3FNUZ },
2198
+ {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E5M2FNUZ,
2199
+ PrimitiveType::F8E4M3FNUZ, DataType::kF8E5M2FNUZ },
2200
+ {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E5M2FNUZ,
2201
+ PrimitiveType::F8E4M3FNUZ, DataType::kHalf },
2202
+ {ComputationType::kF32 , DataType::kFloat , PrimitiveType::F8E5M2FNUZ,
2203
+ PrimitiveType::F8E4M3FNUZ, DataType::kFloat },
2204
+ };
2108
2205
if (IsRocm (gpu_version_) &&
2109
2206
absl::c_linear_search (supported_hipblas_type_combinations,
2110
2207
std::tuple{compute_type, scale_type, a_dtype,
0 commit comments