diff --git a/hopper/benchmark_flash_attention_fp8.py b/hopper/benchmark_flash_attention_fp8.py index 8490d00b3..48fe7b1b2 100644 --- a/hopper/benchmark_flash_attention_fp8.py +++ b/hopper/benchmark_flash_attention_fp8.py @@ -44,7 +44,7 @@ def convert_to_cudnn_type(torch_type): return cudnn.data_type.INT64 elif torch_type == torch.float8_e4m3fn: return cudnn.data_type.FP8_E4M3 - elif torch_type == torch.float8_e4m3fn: + elif torch_type == torch.float8_e5m2: return cudnn.data_type.FP8_E5M2 else: raise ValueError("Unsupported tensor data type.")