Skip to content

Commit

Permalink
CONSOLIDATED COMMITS: Turn on TF32 for aten::mm and De-noise tf32 war…
Browse files Browse the repository at this point in the history
…nings

==========================================================================

[reland][attempt2][AMD] Turn on TF32 for aten::mm (pytorch#144145)

Summary:
pytorch#143549 was reverted due to some
internal/oss tooling issue. Relanding.

hipblaslt supports TF32, so adding the support.
Original PR pytorch#139869

Test Plan: CI

Differential Revision: D67785496

Pull Request resolved: pytorch#144145
Approved by: https://github.com/jianyuh

(cherry picked from commit 3d3a079)

[AMD] De-noise tf32 warnings (pytorch#144797)

Summary: This is way too noisy especially during unit tests. So just log once.

Test Plan: OSS CI. Tested on a unit test and now I only see one line (hard to notice :) ).

Differential Revision: D68167633

Pull Request resolved: pytorch#144797
Approved by: https://github.com/jianyuh, https://github.com/leitian, https://github.com/yoyoyocmu

(cherry picked from commit 6ba53a5)
  • Loading branch information
xw285cornell authored and jithunnair-amd committed Feb 19, 2025
1 parent 9e1ee51 commit 150721a
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 45 deletions.
18 changes: 18 additions & 0 deletions aten/src/ATen/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <ATen/Context.h>

#include <c10/core/CPUAllocator.h>
#include <c10/util/Logging.h>

#include <algorithm>
#include <array>
Expand Down Expand Up @@ -186,6 +187,9 @@ bool Context::userEnabledOverrideableSDP() const {

static constexpr const auto cublas_config_var_name = "CUBLAS_WORKSPACE_CONFIG";
static constexpr const std::array<const char*, 2> cublas_deterministic_configs = {":4096:8", ":16:8"};
#ifdef USE_ROCM
static constexpr const auto hipblaslt_allow_tf32 = "HIPBLASLT_ALLOW_TF32";
#endif

bool Context::checkCuBLASConfigDeterministic() {
// If using CUDA 10.2 or greater, need to make sure CuBLAS workspace config
Expand Down Expand Up @@ -237,10 +241,24 @@ void Context::setBenchmarkLimitCuDNN(int b) {
}

bool Context::allowTF32CuBLAS() const {
#ifdef USE_ROCM
const static auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32);
if (allow_tf32 != true) {
return false;
}
#endif
return float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST;
}

void Context::setAllowTF32CuBLAS(bool b) {
#ifdef USE_ROCM
const static auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32);
if (allow_tf32 != true) {
C10_LOG_FIRST_N(INFO, 10) << "torch.backends.cuda.matmul.allow_tf32 is not supported on ROCm by default. "
<< "Please set environment variable HIPBLASLT_ALLOW_TF32=1 to enable it.";
return;
}
#endif
float32_matmul_precision = b ? at::Float32MatmulPrecision::HIGH : at::Float32MatmulPrecision::HIGHEST;
}

Expand Down
4 changes: 0 additions & 4 deletions aten/src/ATen/cuda/CUDABlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,11 +336,9 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
computeType = CUBLAS_COMPUTE_64F;
scaleType = CUDA_R_64F;
} else if constexpr (std::is_same_v<Dtype, float>) {
#ifndef USE_ROCM
if (at::globalContext().allowTF32CuBLAS()) {
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
}
#endif
} else if constexpr (std::is_same_v<Dtype, c10::complex<double>>) {
abcType = CUDA_C_64F;
computeType = CUBLAS_COMPUTE_64F;
Expand Down Expand Up @@ -1236,11 +1234,9 @@ void gemm_and_bias(
computeType = CUBLAS_COMPUTE_64F;
scaleType = CUDA_R_64F;
} else if constexpr (std::is_same_v<Dtype, float>) {
#ifndef USE_ROCM
if (at::globalContext().allowTF32CuBLAS()) {
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
}
#endif
abcType = CUDA_R_32F;
} else if constexpr (std::is_same_v<Dtype, at::Half>) {
abcType = CUDA_R_16F;
Expand Down
64 changes: 41 additions & 23 deletions test/dynamo/test_graph_region_tracker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Owner(s): ["module: dynamo"]
import contextlib
import os

import torch
import torch.fx
Expand Down Expand Up @@ -213,6 +214,21 @@ def fn(x, y, z):
)

def test_mismatched_global_state(self):
@contextlib.contextmanager
def _hip_allow_tf32():
# for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
# and only for MI300+
hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
os.environ["HIPBLASLT_ALLOW_TF32"] = "1"

try:
yield
finally:
if hip_allow_tf32 is not None:
os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
else:
del os.environ["HIPBLASLT_ALLOW_TF32"]

def inner_fn(x, y):
x1 = x * 1
y1 = y + 1
Expand Down Expand Up @@ -253,29 +269,31 @@ def set_default_dtype_bfloat16():
def reset_default_dtype():
torch.set_default_dtype(old_dtype)

for ctx in [
lambda: torch.set_grad_enabled(False),
torch.autograd.grad_mode.inference_mode,
lambda: torch.autograd.graph.disable_saved_tensors_hooks(
"This is not supported"
),
# lambda: torch.set_num_threads(2), : Unsupported
(set_default_dtype_bfloat16, reset_default_dtype),
(
lambda: torch.use_deterministic_algorithms(True),
lambda: torch.use_deterministic_algorithms(False),
),
# (lambda: torch.use_deterministic_algorithms(True, warn_only=True),
# lambda: torch.use_deterministic_algorithms(False)), : Unsupported
create_toggle_fns("allow_bf16_reduced_precision_reduction"),
create_toggle_fns("allow_fp16_reduced_precision_reduction"),
create_toggle_fns("allow_tf32"),
]:
self.assertExpectedInline(
self.get_result(fn, torch.rand(10, 10), torch.ones(10, 20), ctx),
"""[[['y1_2', 'sum_3', 'x1_2', 'o0'], ['y1_3', 'sum_4', 'x1_3', 'o2']], \
[['y1', 'sum_1', 'x1', 'o4'], ['y1_1', 'sum_2', 'x1_1', 'o5']]]""",
)
tf32_ctx = _hip_allow_tf32 if torch.version.hip else contextlib.nullcontext
with tf32_ctx():
for ctx in [
lambda: torch.set_grad_enabled(False),
torch.autograd.grad_mode.inference_mode,
lambda: torch.autograd.graph.disable_saved_tensors_hooks(
"This is not supported"
),
# lambda: torch.set_num_threads(2), : Unsupported
(set_default_dtype_bfloat16, reset_default_dtype),
(
lambda: torch.use_deterministic_algorithms(True),
lambda: torch.use_deterministic_algorithms(False),
),
# (lambda: torch.use_deterministic_algorithms(True, warn_only=True),
# lambda: torch.use_deterministic_algorithms(False)), : Unsupported
create_toggle_fns("allow_bf16_reduced_precision_reduction"),
create_toggle_fns("allow_fp16_reduced_precision_reduction"),
create_toggle_fns("allow_tf32"),
]:
self.assertExpectedInline(
self.get_result(fn, torch.rand(10, 10), torch.ones(10, 20), ctx),
"""[[['x1_2', 'y1_2', 'sum_3', 'o0'], ['x1_3', 'y1_3', 'sum_4', 'o2']], \
[['x1', 'y1', 'sum_1', 'o4'], ['x1_1', 'y1_1', 'sum_2', 'o5']]]""",
)


if __name__ == "__main__":
Expand Down
55 changes: 37 additions & 18 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7962,24 +7962,43 @@ def write_state(state):
def fn(x):
return x + 1

initial_state = read_state()
y = torch.randn(10)
try:
for round in range(3):
for i in range(len(initial_state)):
new_state = [False] * len(initial_state)
new_state[i] = True
write_state(new_state)
assert read_state() == new_state
last_state.clear()
fn(y)
assert last_state == new_state
if round == 0:
assert cnt == i + 1
else:
assert cnt == len(initial_state)
finally:
write_state(initial_state)
import contextlib

@contextlib.contextmanager
def _hip_allow_tf32():
# for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
# and only for MI300+
hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
os.environ["HIPBLASLT_ALLOW_TF32"] = "1"

try:
yield
finally:
if hip_allow_tf32 is not None:
os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
else:
del os.environ["HIPBLASLT_ALLOW_TF32"]

tf32_ctx = _hip_allow_tf32 if torch.version.hip else contextlib.nullcontext
with tf32_ctx():
initial_state = read_state()
y = torch.randn(10)
try:
for round in range(3):
for i in range(len(initial_state)):
new_state = [False] * len(initial_state)
new_state[i] = True
write_state(new_state)
assert read_state() == new_state
last_state.clear()
fn(y)
assert last_state == new_state
if round == 0:
assert cnt == i + 1
else:
assert cnt == len(initial_state)
finally:
write_state(initial_state)

def test_grad_state_mutated(self):
prior = torch.is_grad_enabled()
Expand Down
33 changes: 33 additions & 0 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,33 @@ def check_workspace_size(inp):

torch._C._cuda_clearCublasWorkspaces()

@contextlib.contextmanager
def _hip_allow_tf32(self):
# for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
# and only for MI300+
hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
os.environ["HIPBLASLT_ALLOW_TF32"] = "1"

try:
yield
finally:
if hip_allow_tf32 is not None:
os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
else:
del os.environ["HIPBLASLT_ALLOW_TF32"]

def test_cublas_allow_tf32_get_set(self):
"""
We only turn on TF32 for MI300 with a special env var. This is because TF32
is only available in MI300+ and is in experimental mode (hipblaslt support
is current WIP)
"""
tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext

with tf32_ctx():
self._test_cublas_allow_tf32_get_set_inner()

def _test_cublas_allow_tf32_get_set_inner(self):
skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int(
os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"]
)
Expand All @@ -500,6 +526,12 @@ def test_cublas_allow_tf32_get_set(self):
torch.backends.cuda.matmul.allow_tf32 = orig

def test_float32_matmul_precision_get_set(self):
tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext

with tf32_ctx():
self._test_float32_matmul_precision_get_set_inner()

def _test_float32_matmul_precision_get_set_inner(self):
orig = torch.get_float32_matmul_precision()
skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int(
os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"]
Expand All @@ -511,6 +543,7 @@ def test_float32_matmul_precision_get_set(self):
self.assertEqual(torch.get_float32_matmul_precision(), "highest")
else:
self.assertTrue(torch.backends.cuda.matmul.allow_tf32)

for p in ("medium", "high"):
torch.set_float32_matmul_precision(p)
self.assertEqual(torch.get_float32_matmul_precision(), p)
Expand Down
4 changes: 4 additions & 0 deletions torch/utils/hipify/cuda_to_hip_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7292,6 +7292,10 @@
"CUBLAS_COMPUTE_32F",
("HIPBLAS_COMPUTE_32F", CONV_MATH_FUNC, API_BLAS)
),
(
"CUBLAS_COMPUTE_32F_FAST_TF32",
("HIPBLAS_COMPUTE_32F_FAST_TF32", CONV_MATH_FUNC, API_BLAS)
),
(
"CUBLAS_COMPUTE_64F",
("HIPBLAS_COMPUTE_64F", CONV_MATH_FUNC, API_BLAS)
Expand Down

0 comments on commit 150721a

Please sign in to comment.