Skip to content

Commit 746500d

Browse files
Revert "[cuDNN] Enable cuDNN Frontend v8 API by Default (pytorch#84948)"
This reverts commit 427e0a6. Reverted pytorch#84948 on behalf of https://github.com/malfet due to Broke SM86 sanity
1 parent 2cfc4cb commit 746500d

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

aten/src/ATen/native/ConvUtils.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,11 @@ namespace {
6666
}
6767

6868
static inline bool cudnnv8_enabled_check_debug() {
69-
static bool cudnnv8_flag = c10::utils::check_env("TORCH_CUDNN_V8_API_DISABLED") != true;
69+
static bool cudnnv8_flag = c10::utils::check_env("TORCH_CUDNN_V8_API_ENABLED") == true;
7070
static bool cudnnv8_debug = c10::utils::check_env("TORCH_CUDNN_V8_API_DEBUG") == true;
7171
static uint8_t cudnnv8_debugcount = 0;
7272
if (cudnnv8_debug == 1 && cudnnv8_debugcount < 10) {
73-
TORCH_WARN("TORCH_CUDNN_V8_DEBUG ON, V8 ON: ", cudnnv8_flag, " TORCH_CUDNN_USE_HEURISTIC_MODE B: ", cudnnv8_heuristic_mode_b);
73+
TORCH_WARN("TORCH_CUDNN_V8_DEBUG ON, V8_FLAG: ", cudnnv8_flag, " TORCH_CUDNN_USE_HEURISTIC_MODE B: ", cudnnv8_heuristic_mode_b);
7474
cudnnv8_debugcount++;
7575
}
7676
return cudnnv8_flag == 1;

aten/src/ATen/native/cudnn/ConvShared.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ void raw_cudnn_convolution_add_relu_fallback_out(
113113

114114
#if HAS_CUDNN_V8()
115115
// v7 functions are preserved here to allow for runtime switching to v7
116-
// (e.g., TORCH_CUDNN_V8_API_DISABLED=1).
116+
// (e.g., TORCH_CUDNN_V8_API_ENABLED=0).
117117
// Note that v7 forward/backward out can have different behavior from the v8
118118
// versions, as v7 explicitly splits large tensors as a 32-bit indexing
119119
// workaround whereas v8 expects cuDNN to handle large tensors.

test/test_cuda.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -2894,10 +2894,10 @@ def test_autocast_torch_bf16(self):
28942894
op, args = op_with_args[0], op_with_args[1]
28952895
if len(op_with_args) == 3:
28962896
skip_test = op_with_args[2] # TEST_WITH_ROCM
2897-
should_error_from_cudnn = 'cudnn' in op and \
2898-
('TORCH_CUDNN_V8_API_DISABLED' in os.environ and
2899-
int(os.environ['TORCH_CUDNN_V8_API_DISABLED']) or
2900-
torch.cuda.get_device_capability() < (8, 0))
2897+
should_error_from_cudnn = 'cudnn' in op and not\
2898+
('TORCH_CUDNN_V8_API_ENABLED' in os.environ and
2899+
int(os.environ['TORCH_CUDNN_V8_API_ENABLED']) and
2900+
torch.cuda.get_device_capability() >= (8, 0))
29012901
should_error_from_not_implemented = should_error_from_cudnn or 'prelu' in op or 'thnn' in op \
29022902
or 'fused' in op or 'gru' in op or op == '_thnn_fused_lstm_cell' or op == 'lstm_cell'
29032903
if not skip_test:

0 commit comments

Comments
 (0)