Skip to content

Commit cf83d50

Browse files
yzhliuvinx13
authored andcommitted
[Codegen] remove fp16 function override for cuda (#4331)
* add volatile override back * [codegen] remove fp16 function override for cuda
1 parent b127dc7 commit cf83d50

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

src/codegen/codegen_cuda.cc

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,19 @@ std::string CodeGenCUDA::Finish() {
5858
<< "{\n return __hgt(__half(a), __half(b)) ? a : b;\n}\n";
5959
decl_stream << "__device__ half min(half a, half b)\n"
6060
<< "{\n return __hlt(__half(a), __half(b)) ? a : b;\n}\n";
61-
decl_stream << "__device__ half operator<="
62-
<< "(__half a, __half b)\n"
63-
<< "{\n return __hlt(a, b);\n}\n";
64-
decl_stream << "__device__ half operator+"
65-
<< "(__half a, __half &b)\n"
66-
<<"{\n return __hadd(a, b);\n}\n";
67-
decl_stream << "__device__ half operator*"
68-
<< "(__half a, __half b)\n"
69-
<< "{\n return __hmul(a, b);\n}\n";
61+
// FIXME(tvm-team): "volatile" is used to enable cross thread reduction,
62+
// which is needed by operations such as softmax.
63+
// However, volatile overloading is not supported in NVRTC and CUDA < 9.2.
64+
// We need to figure out a solution which can satisfy both scenario.
65+
// decl_stream << "__device__ half operator<="
66+
// << "(const volatile __half &a, const volatile __half &b)\n"
67+
// << "{\n return __hlt(a, b);\n}\n";
68+
// decl_stream << "__device__ half operator+"
69+
// << "(const volatile __half &a, const volatile __half &b)\n"
70+
// <<"{\n return __hadd(a, b);\n}\n";
71+
// decl_stream << "__device__ half operator*"
72+
// << "(const volatile __half &a, const volatile __half &b)\n"
73+
// << "{\n return __hmul(a, b);\n}\n";
7074
// otherwise simulate computation via float32
7175
decl_stream << "#else\n";
7276
decl_stream << _cuda_half_t_def;

src/codegen/literal/cuda_half_t.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
static constexpr const char* _cuda_half_t_def = R"(
2929
typedef unsigned short uint16_t;
3030
typedef unsigned char uint8_t;
31+
typedef signed char int8_t;
3132
typedef int int32_t;
3233
typedef unsigned long long uint64_t;
3334
typedef unsigned int uint32_t;
@@ -76,7 +77,7 @@ class TVM_ALIGNED(2) half {
7677
TVM_XINLINE explicit half(const uint8_t& value) { constructor(value); }
7778
TVM_XINLINE explicit half(const int32_t& value) { constructor(value); }
7879
TVM_XINLINE explicit half(const uint32_t& value) { constructor(value); }
79-
TVM_XINLINE explicit half(const int64_t& value) { constructor(value); }
80+
TVM_XINLINE explicit half(const long long& value) { constructor(value); }
8081
TVM_XINLINE explicit half(const uint64_t& value) { constructor(value); }
8182
8283
TVM_XINLINE operator float() const { \

0 commit comments

Comments
 (0)