|
1 |
| -#pragma once |
2 |
| - |
3 |
| -#include <c10/util/ArrayRef.h> |
4 |
| -#include "ATen/core/Half.h" |
5 |
| -#include <c10/util/typeid.h> |
6 |
| - |
7 |
| -#include <cstdint> |
8 |
| -#include <iostream> |
9 |
| -#include <complex> |
10 |
| - |
11 |
| -namespace at { |
12 |
| - |
13 |
| -// NB: Order matters for this macro; it is relied upon in |
14 |
| -// _promoteTypesLookup and the serialization format. |
15 |
| -#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \ |
16 |
| -_(uint8_t,Byte,i) /* 0 */ \ |
17 |
| -_(int8_t,Char,i) /* 1 */ \ |
18 |
| -_(int16_t,Short,i) /* 2 */ \ |
19 |
| -_(int,Int,i) /* 3 */ \ |
20 |
| -_(int64_t,Long,i) /* 4 */ \ |
21 |
| -_(at::Half,Half,d) /* 5 */ \ |
22 |
| -_(float,Float,d) /* 6 */ \ |
23 |
| -_(double,Double,d) /* 7 */ \ |
24 |
| -_(at::ComplexHalf,ComplexHalf,z) /* 8 */ \ |
25 |
| -_(std::complex<float>,ComplexFloat,z) /* 9 */ \ |
26 |
| -_(std::complex<double>,ComplexDouble,z) /* 10 */ |
27 |
| - |
28 |
| -// If you want to support ComplexHalf for real, replace occurrences |
29 |
| -// of this macro with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX. But |
30 |
| -// beware: convert() doesn't work for all the conversions you need... |
31 |
| -#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(_) \ |
32 |
| -_(uint8_t,Byte,i) \ |
33 |
| -_(int8_t,Char,i) \ |
34 |
| -_(int16_t,Short,i) \ |
35 |
| -_(int,Int,i) \ |
36 |
| -_(int64_t,Long,i) \ |
37 |
| -_(at::Half,Half,d) \ |
38 |
| -_(float,Float,d) \ |
39 |
| -_(double,Double,d) \ |
40 |
| -_(std::complex<float>,ComplexFloat,z) \ |
41 |
| -_(std::complex<double>,ComplexDouble,z) |
42 |
| - |
43 |
| -#define AT_FORALL_SCALAR_TYPES(_) \ |
44 |
| -_(uint8_t,Byte,i) \ |
45 |
| -_(int8_t,Char,i) \ |
46 |
| -_(int16_t,Short,i) \ |
47 |
| -_(int,Int,i) \ |
48 |
| -_(int64_t,Long,i) \ |
49 |
| -_(at::Half,Half,d) \ |
50 |
| -_(float,Float,d) \ |
51 |
| -_(double,Double,d) |
52 |
| - |
53 |
| -#define AT_FORALL_SCALAR_TYPES_EXCEPT_HALF(_) \ |
54 |
| -_(uint8_t,Byte,i) \ |
55 |
| -_(int8_t,Char,i) \ |
56 |
| -_(int16_t,Short,i) \ |
57 |
| -_(int,Int,i) \ |
58 |
| -_(int64_t,Long,i) \ |
59 |
| -_(float,Float,d) \ |
60 |
| -_(double,Double,d) |
61 |
| - |
62 |
| -enum class ScalarType : int8_t { |
63 |
| -#define DEFINE_ENUM(_1,n,_2) \ |
64 |
| - n, |
65 |
| - AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_ENUM) |
66 |
| -#undef DEFINE_ENUM |
67 |
| - Undefined, |
68 |
| - NumOptions |
69 |
| -}; |
70 |
| - |
71 |
| -static inline DataType scalarTypeToDataType(ScalarType scalar_type) { |
72 |
| -#define DEFINE_CASE(ctype, name, _) \ |
73 |
| - case ScalarType::name: \ |
74 |
| - return caffe2::TypeIdentifier::Get<ctype>(); |
75 |
| - |
76 |
| - switch(scalar_type) { |
77 |
| - AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE) |
78 |
| - case ScalarType::Undefined: return DataType::uninitialized(); |
79 |
| - default: AT_ERROR("Unrecognized Scalartype ", scalar_type, " (please report this error)"); |
80 |
| - } |
81 |
| -#undef DEFINE_CASE |
82 |
| -} |
83 |
| - |
84 |
| -static inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) { |
85 |
| -#define DEFINE_CASE(ctype,name,_) \ |
86 |
| - case ScalarType:: name : return caffe2::TypeMeta::Make<ctype>(); |
87 |
| - |
88 |
| - switch(scalar_type) { |
89 |
| - AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE) |
90 |
| - case ScalarType::Undefined: return caffe2::TypeMeta(); |
91 |
| - default: AT_ERROR("Unrecognized Scalartype ", scalar_type, " (please report this error)"); |
92 |
| - } |
93 |
| -#undef DEFINE_CASE |
94 |
| -} |
95 |
| - |
96 |
| -static inline ScalarType typeMetaToScalarType(caffe2::TypeMeta dtype) { |
97 |
| -#define DEFINE_IF(ctype, name, _) \ |
98 |
| - if (dtype == caffe2::TypeMeta::Make<ctype>()) { \ |
99 |
| - return ScalarType::name; \ |
100 |
| - } |
101 |
| - AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_IF) |
102 |
| -#undef DEFINE_IF |
103 |
| - if (dtype == caffe2::TypeMeta()) { |
104 |
| - return ScalarType::Undefined; |
105 |
| - } |
106 |
| - AT_ERROR("Unsupported TypeMeta in ATen: ", dtype, " (please report this error)"); |
107 |
| -} |
108 |
| - |
109 |
| -static inline bool operator==(ScalarType t, caffe2::TypeMeta m) { |
110 |
| - return typeMetaToScalarType(m) == t; |
111 |
| -} |
112 |
| - |
113 |
| -static inline bool operator==(caffe2::TypeMeta m, ScalarType t) { |
114 |
| - return typeMetaToScalarType(m) == t; |
115 |
| -} |
116 |
| - |
117 |
| -#define DEFINE_CONSTANT(_,name,_2) \ |
118 |
| -constexpr ScalarType k##name = ScalarType::name; |
119 |
| - |
120 |
| -AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CONSTANT) |
121 |
| -#undef DEFINE_CONSTANT |
122 |
| - |
123 |
| -static inline const char * toString(ScalarType t) { |
124 |
| -#define DEFINE_CASE(_,name,_2) \ |
125 |
| - case ScalarType:: name : return #name; |
126 |
| - |
127 |
| - switch(t) { |
128 |
| - AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE) |
129 |
| - default: |
130 |
| - return "UNKNOWN_SCALAR"; |
131 |
| - } |
132 |
| -#undef DEFINE_CASE |
133 |
| -} |
134 |
| - |
135 |
| -static inline size_t elementSize(ScalarType t) { |
136 |
| -#define CASE_ELEMENTSIZE_CASE(ctype,name,_2) \ |
137 |
| - case ScalarType:: name : return sizeof(ctype); |
138 |
| - |
139 |
| - switch(t) { |
140 |
| - AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(CASE_ELEMENTSIZE_CASE) |
141 |
| - default: |
142 |
| - AT_ERROR("Unknown ScalarType"); |
143 |
| - } |
144 |
| -#undef CASE_ELEMENTSIZE_CASE |
145 |
| -} |
146 |
| - |
147 |
| -static inline bool isIntegralType(ScalarType t) { |
148 |
| - return (t == ScalarType::Byte || |
149 |
| - t == ScalarType::Char || |
150 |
| - t == ScalarType::Int || |
151 |
| - t == ScalarType::Long || |
152 |
| - t == ScalarType::Short); |
153 |
| -} |
154 |
| - |
155 |
| -static inline bool isFloatingType(ScalarType t) { |
156 |
| - return (t == ScalarType::Double || |
157 |
| - t == ScalarType::Float || |
158 |
| - t == ScalarType::Half); |
159 |
| -} |
160 |
| - |
161 |
| -static inline bool isComplexType(ScalarType t) { |
162 |
| - return (t == ScalarType::ComplexHalf || |
163 |
| - t == ScalarType::ComplexFloat || |
164 |
| - t == ScalarType::ComplexDouble); |
165 |
| -} |
166 |
| - |
167 |
| -static inline ScalarType promoteTypes(ScalarType a, ScalarType b) { |
168 |
| - // This is generated according to NumPy's promote_types |
169 |
| - constexpr auto u1 = ScalarType::Byte; |
170 |
| - constexpr auto i1 = ScalarType::Char; |
171 |
| - constexpr auto i2 = ScalarType::Short; |
172 |
| - constexpr auto i4 = ScalarType::Int; |
173 |
| - constexpr auto i8 = ScalarType::Long; |
174 |
| - constexpr auto f2 = ScalarType::Half; |
175 |
| - constexpr auto f4 = ScalarType::Float; |
176 |
| - constexpr auto f8 = ScalarType::Double; |
177 |
| - constexpr auto ud = ScalarType::Undefined; |
178 |
| - if (a == ud || b == ud) { |
179 |
| - return ScalarType::Undefined; |
180 |
| - } |
181 |
| - if (isComplexType(a) || isComplexType(b)) { |
182 |
| - AT_ERROR("promoteTypes with complex numbers is not handled yet; figure out what the correct rules should be"); |
183 |
| - } |
184 |
| - static constexpr ScalarType _promoteTypesLookup |
185 |
| - [static_cast<int>(ScalarType::NumOptions)] |
186 |
| - [static_cast<int>(ScalarType::NumOptions)] = { |
187 |
| - /* u1 i1 i2 i4 i8 f2 f4 f8 */ |
188 |
| - /* u1 */ { u1, i2, i2, i4, i8, f2, f4, f8 }, |
189 |
| - /* i1 */ { i2, i1, i2, i4, i8, f2, f4, f8 }, |
190 |
| - /* i2 */ { i2, i2, i2, i4, i8, f2, f4, f8 }, |
191 |
| - /* i4 */ { i4, i4, i4, i4, i8, f2, f4, f8 }, |
192 |
| - /* i8 */ { i8, i8, i8, i8, i8, f2, f4, f8 }, |
193 |
| - /* f2 */ { f2, f2, f2, f2, f2, f2, f4, f8 }, |
194 |
| - /* f4 */ { f4, f4, f4, f4, f4, f4, f4, f8 }, |
195 |
| - /* f8 */ { f8, f8, f8, f8, f8, f8, f8, f8 }, |
196 |
| - }; |
197 |
| - return _promoteTypesLookup[static_cast<int>(a)][static_cast<int>(b)]; |
198 |
| -} |
199 |
| - |
200 |
| -class Tensor; |
201 |
| -typedef ArrayRef<Tensor> TensorList; |
202 |
| - |
203 |
| -inline std::ostream& operator<<( |
204 |
| - std::ostream& stream, |
205 |
| - at::ScalarType scalar_type) { |
206 |
| - return stream << toString(scalar_type); |
207 |
| -} |
208 |
| - |
209 |
| -} // namespace at |
| 1 | +#include <c10/core/ScalarType.h> |
0 commit comments