Skip to content

Commit 50e9c56

Browse files
smessmerfacebook-github-bot
authored andcommitted
Move Scalar and ScalarType to c10/core
Summary: Pull Request resolved: pytorch#14022 Reviewed By: ezyang Differential Revision: D13015236 fbshipit-source-id: 92aac4e342d85f75a31837b2943fa5b80f0c35c9
1 parent 3fca4bd commit 50e9c56

13 files changed

+340
-326
lines changed

aten/src/ATen/ScalarOps.h

+5-4
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,17 @@
33
#include "ATen/core/Scalar.h"
44
#include "ATen/Tensor.h"
55

6-
namespace at {
6+
// This is in the c10 namespace because we use ADL to find the functions in it.
7+
namespace c10 {
78

89
// FIXME: this should be (and was) Scalar::toTensor, but there is currently no way
910
// to implement this without going through Derived Types (which are not part of core).
10-
inline Tensor scalar_to_tensor(Scalar s) {
11+
inline at::Tensor scalar_to_tensor(Scalar s) {
1112
if (s.isFloatingPoint()) {
12-
return CPU(kDouble).scalarTensor(s);
13+
return at::CPU(kDouble).scalarTensor(s);
1314
} else {
1415
AT_ASSERT(s.isIntegral());
15-
return CPU(kLong).scalarTensor(s);
16+
return at::CPU(kLong).scalarTensor(s);
1617
}
1718
}
1819

aten/src/ATen/core/Backend.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -165,4 +165,4 @@ static inline const char* toString(Backend b) {
165165
}
166166
}
167167

168-
} // namespace at
168+
} // namespace c10

aten/src/ATen/core/Scalar.h

+1-103
Original file line numberDiff line numberDiff line change
@@ -1,103 +1 @@
1-
#pragma once
2-
3-
#include <assert.h>
4-
#include <stdint.h>
5-
#include <stdexcept>
6-
#include <string>
7-
#include <utility>
8-
9-
#include "ATen/core/ATenGeneral.h"
10-
#include "ATen/core/ScalarType.h"
11-
#include "ATen/core/Half.h"
12-
13-
namespace at {
14-
15-
class Tensor;
16-
17-
class CAFFE2_API Scalar {
18-
public:
19-
Scalar() : Scalar(int64_t(0)) {}
20-
21-
#define DEFINE_IMPLICIT_CTOR(type,name,member) \
22-
Scalar(type vv) \
23-
: tag(Tag::HAS_##member) { \
24-
v . member = convert<decltype(v.member),type>(vv); \
25-
}
26-
// We can't set v in the initializer list using the
27-
// syntax v{ .member = ... } because it doesn't work on MSVC
28-
29-
AT_FORALL_SCALAR_TYPES(DEFINE_IMPLICIT_CTOR)
30-
31-
#undef DEFINE_IMPLICIT_CTOR
32-
33-
#define DEFINE_IMPLICIT_COMPLEX_CTOR(type, name, member) \
34-
Scalar(type vv) : tag(Tag::HAS_##member) { \
35-
v.member[0] = c10::convert<double>(vv.real()); \
36-
v.member[1] = c10::convert<double>(vv.imag()); \
37-
}
38-
39-
DEFINE_IMPLICIT_COMPLEX_CTOR(at::ComplexHalf,ComplexHalf,z)
40-
DEFINE_IMPLICIT_COMPLEX_CTOR(std::complex<float>,ComplexFloat,z)
41-
DEFINE_IMPLICIT_COMPLEX_CTOR(std::complex<double>,ComplexDouble,z)
42-
43-
#undef DEFINE_IMPLICIT_COMPLEX_CTOR
44-
45-
#define DEFINE_ACCESSOR(type,name,member) \
46-
type to##name () const { \
47-
if (Tag::HAS_d == tag) { \
48-
return checked_convert<type, double>(v.d, #type); \
49-
} else if (Tag::HAS_z == tag) { \
50-
return checked_convert<type, std::complex<double>>({v.z[0], v.z[1]}, #type); \
51-
} else { \
52-
return checked_convert<type, int64_t>(v.i, #type); \
53-
} \
54-
}
55-
56-
// TODO: Support ComplexHalf accessor
57-
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_ACCESSOR)
58-
59-
//also support scalar.to<int64_t>();
60-
template<typename T>
61-
T to();
62-
63-
#undef DEFINE_ACCESSOR
64-
bool isFloatingPoint() const {
65-
return Tag::HAS_d == tag;
66-
}
67-
bool isIntegral() const {
68-
return Tag::HAS_i == tag;
69-
}
70-
bool isComplex() const {
71-
return Tag::HAS_z == tag;
72-
}
73-
74-
Scalar operator-() const;
75-
76-
private:
77-
enum class Tag { HAS_d, HAS_i, HAS_z };
78-
Tag tag;
79-
union {
80-
double d;
81-
int64_t i;
82-
// Can't do put std::complex in the union, because it triggers
83-
// an nvcc bug:
84-
// error: designator may not specify a non-POD subobject
85-
double z[2];
86-
} v;
87-
friend struct Type;
88-
};
89-
90-
// define the scalar.to<int64_t>() specializations
91-
template<typename T>
92-
inline T Scalar::to() {
93-
throw std::runtime_error("to() cast to unexpected type.");
94-
}
95-
96-
#define DEFINE_TO(T,name,_) \
97-
template<> \
98-
inline T Scalar::to<T>() { \
99-
return to##name(); \
100-
}
101-
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_TO)
102-
#undef DEFINE_TO
103-
}
1+
#include <c10/core/Scalar.h>

aten/src/ATen/core/ScalarType.h

+1-209
Original file line numberDiff line numberDiff line change
@@ -1,209 +1 @@
1-
#pragma once
2-
3-
#include <c10/util/ArrayRef.h>
4-
#include "ATen/core/Half.h"
5-
#include <c10/util/typeid.h>
6-
7-
#include <cstdint>
8-
#include <iostream>
9-
#include <complex>
10-
11-
namespace at {
12-
13-
// NB: Order matters for this macro; it is relied upon in
14-
// _promoteTypesLookup and the serialization format.
15-
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \
16-
_(uint8_t,Byte,i) /* 0 */ \
17-
_(int8_t,Char,i) /* 1 */ \
18-
_(int16_t,Short,i) /* 2 */ \
19-
_(int,Int,i) /* 3 */ \
20-
_(int64_t,Long,i) /* 4 */ \
21-
_(at::Half,Half,d) /* 5 */ \
22-
_(float,Float,d) /* 6 */ \
23-
_(double,Double,d) /* 7 */ \
24-
_(at::ComplexHalf,ComplexHalf,z) /* 8 */ \
25-
_(std::complex<float>,ComplexFloat,z) /* 9 */ \
26-
_(std::complex<double>,ComplexDouble,z) /* 10 */
27-
28-
// If you want to support ComplexHalf for real, replace occurrences
29-
// of this macro with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX. But
30-
// beware: convert() doesn't work for all the conversions you need...
31-
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(_) \
32-
_(uint8_t,Byte,i) \
33-
_(int8_t,Char,i) \
34-
_(int16_t,Short,i) \
35-
_(int,Int,i) \
36-
_(int64_t,Long,i) \
37-
_(at::Half,Half,d) \
38-
_(float,Float,d) \
39-
_(double,Double,d) \
40-
_(std::complex<float>,ComplexFloat,z) \
41-
_(std::complex<double>,ComplexDouble,z)
42-
43-
#define AT_FORALL_SCALAR_TYPES(_) \
44-
_(uint8_t,Byte,i) \
45-
_(int8_t,Char,i) \
46-
_(int16_t,Short,i) \
47-
_(int,Int,i) \
48-
_(int64_t,Long,i) \
49-
_(at::Half,Half,d) \
50-
_(float,Float,d) \
51-
_(double,Double,d)
52-
53-
#define AT_FORALL_SCALAR_TYPES_EXCEPT_HALF(_) \
54-
_(uint8_t,Byte,i) \
55-
_(int8_t,Char,i) \
56-
_(int16_t,Short,i) \
57-
_(int,Int,i) \
58-
_(int64_t,Long,i) \
59-
_(float,Float,d) \
60-
_(double,Double,d)
61-
62-
enum class ScalarType : int8_t {
63-
#define DEFINE_ENUM(_1,n,_2) \
64-
n,
65-
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_ENUM)
66-
#undef DEFINE_ENUM
67-
Undefined,
68-
NumOptions
69-
};
70-
71-
static inline DataType scalarTypeToDataType(ScalarType scalar_type) {
72-
#define DEFINE_CASE(ctype, name, _) \
73-
case ScalarType::name: \
74-
return caffe2::TypeIdentifier::Get<ctype>();
75-
76-
switch(scalar_type) {
77-
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
78-
case ScalarType::Undefined: return DataType::uninitialized();
79-
default: AT_ERROR("Unrecognized Scalartype ", scalar_type, " (please report this error)");
80-
}
81-
#undef DEFINE_CASE
82-
}
83-
84-
static inline caffe2::TypeMeta scalarTypeToTypeMeta(ScalarType scalar_type) {
85-
#define DEFINE_CASE(ctype,name,_) \
86-
case ScalarType:: name : return caffe2::TypeMeta::Make<ctype>();
87-
88-
switch(scalar_type) {
89-
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
90-
case ScalarType::Undefined: return caffe2::TypeMeta();
91-
default: AT_ERROR("Unrecognized Scalartype ", scalar_type, " (please report this error)");
92-
}
93-
#undef DEFINE_CASE
94-
}
95-
96-
static inline ScalarType typeMetaToScalarType(caffe2::TypeMeta dtype) {
97-
#define DEFINE_IF(ctype, name, _) \
98-
if (dtype == caffe2::TypeMeta::Make<ctype>()) { \
99-
return ScalarType::name; \
100-
}
101-
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_IF)
102-
#undef DEFINE_IF
103-
if (dtype == caffe2::TypeMeta()) {
104-
return ScalarType::Undefined;
105-
}
106-
AT_ERROR("Unsupported TypeMeta in ATen: ", dtype, " (please report this error)");
107-
}
108-
109-
static inline bool operator==(ScalarType t, caffe2::TypeMeta m) {
110-
return typeMetaToScalarType(m) == t;
111-
}
112-
113-
static inline bool operator==(caffe2::TypeMeta m, ScalarType t) {
114-
return typeMetaToScalarType(m) == t;
115-
}
116-
117-
#define DEFINE_CONSTANT(_,name,_2) \
118-
constexpr ScalarType k##name = ScalarType::name;
119-
120-
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CONSTANT)
121-
#undef DEFINE_CONSTANT
122-
123-
static inline const char * toString(ScalarType t) {
124-
#define DEFINE_CASE(_,name,_2) \
125-
case ScalarType:: name : return #name;
126-
127-
switch(t) {
128-
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
129-
default:
130-
return "UNKNOWN_SCALAR";
131-
}
132-
#undef DEFINE_CASE
133-
}
134-
135-
static inline size_t elementSize(ScalarType t) {
136-
#define CASE_ELEMENTSIZE_CASE(ctype,name,_2) \
137-
case ScalarType:: name : return sizeof(ctype);
138-
139-
switch(t) {
140-
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(CASE_ELEMENTSIZE_CASE)
141-
default:
142-
AT_ERROR("Unknown ScalarType");
143-
}
144-
#undef CASE_ELEMENTSIZE_CASE
145-
}
146-
147-
static inline bool isIntegralType(ScalarType t) {
148-
return (t == ScalarType::Byte ||
149-
t == ScalarType::Char ||
150-
t == ScalarType::Int ||
151-
t == ScalarType::Long ||
152-
t == ScalarType::Short);
153-
}
154-
155-
static inline bool isFloatingType(ScalarType t) {
156-
return (t == ScalarType::Double ||
157-
t == ScalarType::Float ||
158-
t == ScalarType::Half);
159-
}
160-
161-
static inline bool isComplexType(ScalarType t) {
162-
return (t == ScalarType::ComplexHalf ||
163-
t == ScalarType::ComplexFloat ||
164-
t == ScalarType::ComplexDouble);
165-
}
166-
167-
static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
168-
// This is generated according to NumPy's promote_types
169-
constexpr auto u1 = ScalarType::Byte;
170-
constexpr auto i1 = ScalarType::Char;
171-
constexpr auto i2 = ScalarType::Short;
172-
constexpr auto i4 = ScalarType::Int;
173-
constexpr auto i8 = ScalarType::Long;
174-
constexpr auto f2 = ScalarType::Half;
175-
constexpr auto f4 = ScalarType::Float;
176-
constexpr auto f8 = ScalarType::Double;
177-
constexpr auto ud = ScalarType::Undefined;
178-
if (a == ud || b == ud) {
179-
return ScalarType::Undefined;
180-
}
181-
if (isComplexType(a) || isComplexType(b)) {
182-
AT_ERROR("promoteTypes with complex numbers is not handled yet; figure out what the correct rules should be");
183-
}
184-
static constexpr ScalarType _promoteTypesLookup
185-
[static_cast<int>(ScalarType::NumOptions)]
186-
[static_cast<int>(ScalarType::NumOptions)] = {
187-
/* u1 i1 i2 i4 i8 f2 f4 f8 */
188-
/* u1 */ { u1, i2, i2, i4, i8, f2, f4, f8 },
189-
/* i1 */ { i2, i1, i2, i4, i8, f2, f4, f8 },
190-
/* i2 */ { i2, i2, i2, i4, i8, f2, f4, f8 },
191-
/* i4 */ { i4, i4, i4, i4, i8, f2, f4, f8 },
192-
/* i8 */ { i8, i8, i8, i8, i8, f2, f4, f8 },
193-
/* f2 */ { f2, f2, f2, f2, f2, f2, f4, f8 },
194-
/* f4 */ { f4, f4, f4, f4, f4, f4, f4, f8 },
195-
/* f8 */ { f8, f8, f8, f8, f8, f8, f8, f8 },
196-
};
197-
return _promoteTypesLookup[static_cast<int>(a)][static_cast<int>(b)];
198-
}
199-
200-
class Tensor;
201-
typedef ArrayRef<Tensor> TensorList;
202-
203-
inline std::ostream& operator<<(
204-
std::ostream& stream,
205-
at::ScalarType scalar_type) {
206-
return stream << toString(scalar_type);
207-
}
208-
209-
} // namespace at
1+
#include <c10/core/ScalarType.h>

aten/src/ATen/core/TensorImpl.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,10 @@ class DeviceOption;
4242

4343
}
4444

45-
namespace at {
45+
namespace c10 {
4646
class Scalar;
47+
}
48+
namespace at {
4749
struct Type;
4850
struct Storage;
4951
class Tensor;

aten/src/ATen/core/TensorMethods.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -1293,7 +1293,7 @@ inline bool is_sparse(Tensor self) {
12931293
"expected scalar type ", \
12941294
#name, \
12951295
" but found ", \
1296-
at::toString(type().scalarType())); \
1296+
c10::toString(type().scalarType())); \
12971297
return static_cast<T*>(this->data_ptr()); \
12981298
}
12991299

aten/src/ATen/templates/NativeFunctions.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
#include <tuple>
1414
#include <vector>
1515

16+
namespace c10 {
17+
class Scalar;
18+
}
1619
namespace at {
1720
struct Generator;
18-
class Scalar;
1921
class Tensor;
2022
struct Type;
2123
} // namespace at

aten/src/ATen/templates/TensorMethods.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ inline bool is_sparse(Tensor self) {
106106
"expected scalar type ", \
107107
#name, \
108108
" but found ", \
109-
at::toString(type().scalarType())); \
109+
c10::toString(type().scalarType())); \
110110
return static_cast<T*>(this->data_ptr()); \
111111
}
112112

0 commit comments

Comments
 (0)