diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index a222c9ce74c8f..c42e68ad5318e 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -3,6 +3,7 @@ #include #include +#include #include #include @@ -186,6 +187,9 @@ bool Context::userEnabledOverrideableSDP() const { static constexpr const auto cublas_config_var_name = "CUBLAS_WORKSPACE_CONFIG"; static constexpr const std::array cublas_deterministic_configs = {":4096:8", ":16:8"}; +#ifdef USE_ROCM +static constexpr const auto hipblaslt_allow_tf32 = "HIPBLASLT_ALLOW_TF32"; +#endif bool Context::checkCuBLASConfigDeterministic() { // If using CUDA 10.2 or greater, need to make sure CuBLAS workspace config @@ -237,10 +241,24 @@ void Context::setBenchmarkLimitCuDNN(int b) { } bool Context::allowTF32CuBLAS() const { +#ifdef USE_ROCM + const static auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32); + if (allow_tf32 != true) { + return false; + } +#endif return float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST; } void Context::setAllowTF32CuBLAS(bool b) { +#ifdef USE_ROCM + const static auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32); + if (allow_tf32 != true) { + C10_LOG_FIRST_N(INFO, 10) << "torch.backends.cuda.matmul.allow_tf32 is not supported on ROCm by default. " + << "Please set environment variable HIPBLASLT_ALLOW_TF32=1 to enable it."; + return; + } +#endif float32_matmul_precision = b ? at::Float32MatmulPrecision::HIGH : at::Float32MatmulPrecision::HIGHEST; } diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index 8a4ec2671dbe8..37c28bd2086ff 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -336,11 +336,9 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { computeType = CUBLAS_COMPUTE_64F; scaleType = CUDA_R_64F; } else if constexpr (std::is_same_v) { -#ifndef USE_ROCM if (at::globalContext().allowTF32CuBLAS()) { computeType = CUBLAS_COMPUTE_32F_FAST_TF32; } -#endif } else if constexpr (std::is_same_v>) { abcType = CUDA_C_64F; computeType = CUBLAS_COMPUTE_64F; @@ -1236,11 +1234,9 @@ void gemm_and_bias( computeType = CUBLAS_COMPUTE_64F; scaleType = CUDA_R_64F; } else if constexpr (std::is_same_v) { -#ifndef USE_ROCM if (at::globalContext().allowTF32CuBLAS()) { computeType = CUBLAS_COMPUTE_32F_FAST_TF32; } -#endif abcType = CUDA_R_32F; } else if constexpr (std::is_same_v) { abcType = CUDA_R_16F; diff --git a/test/dynamo/test_graph_region_tracker.py b/test/dynamo/test_graph_region_tracker.py index c701ede3d4c68..04962bc4f8f8c 100644 --- a/test/dynamo/test_graph_region_tracker.py +++ b/test/dynamo/test_graph_region_tracker.py @@ -1,5 +1,6 @@ # Owner(s): ["module: dynamo"] import contextlib +import os import torch import torch.fx @@ -213,6 +214,21 @@ def fn(x, y, z): ) def test_mismatched_global_state(self): + @contextlib.contextmanager + def _hip_allow_tf32(): + # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new + # and only for MI300+ + hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None) + os.environ["HIPBLASLT_ALLOW_TF32"] = "1" + + try: + yield + finally: + if hip_allow_tf32 is not None: + os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32 + else: + del os.environ["HIPBLASLT_ALLOW_TF32"] + def inner_fn(x, y): x1 = x * 1 y1 = y + 1 @@ -253,29 +269,31 @@ def set_default_dtype_bfloat16(): def reset_default_dtype(): torch.set_default_dtype(old_dtype) - for ctx in [ - lambda: torch.set_grad_enabled(False), - torch.autograd.grad_mode.inference_mode, - lambda: torch.autograd.graph.disable_saved_tensors_hooks( - "This is not supported" - ), - # lambda: torch.set_num_threads(2), : Unsupported - (set_default_dtype_bfloat16, reset_default_dtype), - ( - lambda: torch.use_deterministic_algorithms(True), - lambda: torch.use_deterministic_algorithms(False), - ), - # (lambda: torch.use_deterministic_algorithms(True, warn_only=True), - # lambda: torch.use_deterministic_algorithms(False)), : Unsupported - create_toggle_fns("allow_bf16_reduced_precision_reduction"), - create_toggle_fns("allow_fp16_reduced_precision_reduction"), - create_toggle_fns("allow_tf32"), - ]: - self.assertExpectedInline( - self.get_result(fn, torch.rand(10, 10), torch.ones(10, 20), ctx), - """[[['y1_2', 'sum_3', 'x1_2', 'o0'], ['y1_3', 'sum_4', 'x1_3', 'o2']], \ -[['y1', 'sum_1', 'x1', 'o4'], ['y1_1', 'sum_2', 'x1_1', 'o5']]]""", - ) + tf32_ctx = _hip_allow_tf32 if torch.version.hip else contextlib.nullcontext + with tf32_ctx(): + for ctx in [ + lambda: torch.set_grad_enabled(False), + torch.autograd.grad_mode.inference_mode, + lambda: torch.autograd.graph.disable_saved_tensors_hooks( + "This is not supported" + ), + # lambda: torch.set_num_threads(2), : Unsupported + (set_default_dtype_bfloat16, reset_default_dtype), + ( + lambda: torch.use_deterministic_algorithms(True), + lambda: torch.use_deterministic_algorithms(False), + ), + # (lambda: torch.use_deterministic_algorithms(True, warn_only=True), + # lambda: torch.use_deterministic_algorithms(False)), : Unsupported + create_toggle_fns("allow_bf16_reduced_precision_reduction"), + create_toggle_fns("allow_fp16_reduced_precision_reduction"), + create_toggle_fns("allow_tf32"), + ]: + self.assertExpectedInline( + self.get_result(fn, torch.rand(10, 10), torch.ones(10, 20), ctx), + """[[['x1_2', 'y1_2', 'sum_3', 'o0'], ['x1_3', 'y1_3', 'sum_4', 'o2']], \ +[['x1', 'y1', 'sum_1', 'o4'], ['x1_1', 'y1_1', 'sum_2', 'o5']]]""", + ) if __name__ == "__main__": diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 202158d5ed1c5..83ec332612aac 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -7962,24 +7962,43 @@ def write_state(state): def fn(x): return x + 1 - initial_state = read_state() - y = torch.randn(10) - try: - for round in range(3): - for i in range(len(initial_state)): - new_state = [False] * len(initial_state) - new_state[i] = True - write_state(new_state) - assert read_state() == new_state - last_state.clear() - fn(y) - assert last_state == new_state - if round == 0: - assert cnt == i + 1 - else: - assert cnt == len(initial_state) - finally: - write_state(initial_state) + import contextlib + + @contextlib.contextmanager + def _hip_allow_tf32(): + # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new + # and only for MI300+ + hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None) + os.environ["HIPBLASLT_ALLOW_TF32"] = "1" + + try: + yield + finally: + if hip_allow_tf32 is not None: + os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32 + else: + del os.environ["HIPBLASLT_ALLOW_TF32"] + + tf32_ctx = _hip_allow_tf32 if torch.version.hip else contextlib.nullcontext + with tf32_ctx(): + initial_state = read_state() + y = torch.randn(10) + try: + for round in range(3): + for i in range(len(initial_state)): + new_state = [False] * len(initial_state) + new_state[i] = True + write_state(new_state) + assert read_state() == new_state + last_state.clear() + fn(y) + assert last_state == new_state + if round == 0: + assert cnt == i + 1 + else: + assert cnt == len(initial_state) + finally: + write_state(initial_state) def test_grad_state_mutated(self): prior = torch.is_grad_enabled() diff --git a/test/test_cuda.py b/test/test_cuda.py index 58b255bf96089..e2741aebafcb4 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -485,7 +485,33 @@ def check_workspace_size(inp): torch._C._cuda_clearCublasWorkspaces() + @contextlib.contextmanager + def _hip_allow_tf32(self): + # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new + # and only for MI300+ + hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None) + os.environ["HIPBLASLT_ALLOW_TF32"] = "1" + + try: + yield + finally: + if hip_allow_tf32 is not None: + os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32 + else: + del os.environ["HIPBLASLT_ALLOW_TF32"] + def test_cublas_allow_tf32_get_set(self): + """ + We only turn on TF32 for MI300 with a special env var. This is because TF32 + is only available in MI300+ and is in experimental mode (hipblaslt support + is current WIP) + """ + tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext + + with tf32_ctx(): + self._test_cublas_allow_tf32_get_set_inner() + + def _test_cublas_allow_tf32_get_set_inner(self): skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int( os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"] ) @@ -500,6 +526,12 @@ def test_cublas_allow_tf32_get_set(self): torch.backends.cuda.matmul.allow_tf32 = orig def test_float32_matmul_precision_get_set(self): + tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext + + with tf32_ctx(): + self._test_float32_matmul_precision_get_set_inner() + + def _test_float32_matmul_precision_get_set_inner(self): orig = torch.get_float32_matmul_precision() skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int( os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"] @@ -511,6 +543,7 @@ def test_float32_matmul_precision_get_set(self): self.assertEqual(torch.get_float32_matmul_precision(), "highest") else: self.assertTrue(torch.backends.cuda.matmul.allow_tf32) + for p in ("medium", "high"): torch.set_float32_matmul_precision(p) self.assertEqual(torch.get_float32_matmul_precision(), p) diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py index bb9257079a798..d8952403201bb 100644 --- a/torch/utils/hipify/cuda_to_hip_mappings.py +++ b/torch/utils/hipify/cuda_to_hip_mappings.py @@ -7292,6 +7292,10 @@ "CUBLAS_COMPUTE_32F", ("HIPBLAS_COMPUTE_32F", CONV_MATH_FUNC, API_BLAS) ), + ( + "CUBLAS_COMPUTE_32F_FAST_TF32", + ("HIPBLAS_COMPUTE_32F_FAST_TF32", CONV_MATH_FUNC, API_BLAS) + ), ( "CUBLAS_COMPUTE_64F", ("HIPBLAS_COMPUTE_64F", CONV_MATH_FUNC, API_BLAS)