Skip to content

Commit 0cfbbce

Browse files
Dmytro Dzhulgakovfacebook-github-bot
Dmytro Dzhulgakov
authored andcommitted
Change Tensor::CopyFrom to a simple double dispatch (pytorch#14268)
Summary: Pull Request resolved: pytorch#14268 Removes the need for Context in Tensor by doing simple dispatch for CopyBytes. It'd eventually be subsumed by Roy Li's changes of proper copy_ op, but before that is done, let's get a clear logic of how copies are implemented and clean up some craft in CopyFrom implementation. Note, that with these changes, one can probably can get rid of Context::CopyFromCPU/CopyToCPU, but it's a matter for follow up diffs. This diff doesn't change the API of Tensor yet, but relies on the fact that passing `Context` to CopyFrom makes copy async if the device is CUDA and doesn't have any effect otherwise (that's how Context methods are implemented). This doesn't change semantics of copy async implementation - as before it blindly calls cudaMemcpyAsync which probably means that it can be misused if invoked separately outside of operator body. I'll leave it for the follow up copy_ unification. For Extend() we always do async copy - it makes sense as it's an in-place device-device operation and only any further op would be observable. Note: there are now three ways of invoking copy in C2 code - templated CopyBytes, virtual CopyFromCPU/etc, and double-dispatch free method here. Hopefully we can get rid of the second one. Also, please advise whether it's c10-worthy :) Reviewed By: ezyang Differential Revision: D13117987 fbshipit-source-id: a6772d6dcf3effaf06717da3a656fc9873b310b5
1 parent f80d34a commit 0cfbbce

File tree

11 files changed

+306
-89
lines changed

11 files changed

+306
-89
lines changed

aten/src/ATen/core/TensorImpl.h

Lines changed: 46 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -950,48 +950,35 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
950950
if (data_type_.copy()) {
951951
AT_ASSERTM(
952952
device_type() == ::at::DeviceType::CPU,
953-
"In CopyFrom source and dest tensors must both be CPU for meta copy, "
954-
"but dest tensor was ", device_type());
953+
"In CopyFrom source and dest tensors must both be CPU for "
954+
"non-POD copy, but dest tensor was ",
955+
device_type());
955956
AT_ASSERTM(
956957
src.device_type() == ::at::DeviceType::CPU,
957-
"In CopyFrom source and dest tensors must both be CPU for meta copy, "
958-
"but src tensor was ", src.device_type());
958+
"In CopyFrom source and dest tensors must both be CPU for "
959+
"non-POD copy, but src tensor was ",
960+
src.device_type());
959961
data_type_.copy()(src.data(), raw_mutable_data(data_type_), numel());
960962
} else {
961963
// The following copy uses the current (thread local) stream for copying
962-
// and also takes the current GPU id previously set through CUDA API
963-
// as we don't invoke SwitchToDevice anywhere
964-
// TODO: this logic is overly complex and can be replaced with simple
965-
// dispatch based on two device types
964+
// and also takes the GPU id from the device() field passed in.
966965
//
967-
// We'll need to use a non-CPU context to perform the copy if
968-
// one of the context is not CPU since only non-CPU context
969-
// knows how to copy between CPU and that context
970-
if (src.device_type() != ::at::DeviceType::CPU || device_type() == ::at::DeviceType::CPU) {
971-
if (!context) {
972-
CreateContext(src.GetDevice())
973-
->CopyBytesToDevice(
974-
numel() * itemsize(),
975-
src.data(),
976-
raw_mutable_data(data_type_),
977-
device_type());
978-
} else {
979-
AT_ASSERTM(
980-
context->device_type() == src.device_type(),
981-
"Type for provided context does not match the type of source");
982-
context->CopyBytesToDevice(
983-
numel() * itemsize(), src.data(), raw_mutable_data(data_type_), device_type());
984-
}
985-
} else {
986-
// In case source context is CPU, and target context is non-CPU
987-
// We'll have to create a Context from target and perform the
988-
// copy using that context
989-
CreateContext(GetDevice())
990-
->CopyBytesFromCPU(
991-
numel() * itemsize(),
992-
src.data(),
993-
raw_mutable_data(data_type_));
994-
}
966+
// TODO: Potentially more enforcements are necessary to avoid accidental
967+
// switch to sync copy if the currently set device is wrong.
968+
//
969+
// Specifically, we might need to switch to a different context device
970+
// here explicitly to avoid relying on user synchronizing things
971+
// properly.
972+
//
973+
// note: raw_mutable_data initializes device here
974+
void* new_data = raw_mutable_data(data_type_);
975+
at::CopyBytes(
976+
numel() * itemsize(),
977+
src.data(),
978+
src.device(),
979+
new_data,
980+
device(),
981+
context != nullptr);
995982
}
996983
}
997984
}
@@ -1037,8 +1024,29 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
10371024
auto* newData = raw_mutable_data(data_type_);
10381025
AT_ASSERTM(
10391026
context != nullptr, "Context must be provided to Extend the tensor");
1040-
context->CopyItemsSameDevice(
1041-
data_type_, oldSize, oldData.get(), newData);
1027+
if (data_type_.copy()) {
1028+
AT_ASSERTM(
1029+
device_type() == ::at::DeviceType::CPU,
1030+
"non-POD types work only on CPU");
1031+
data_type_.copy()(oldData.get(), newData, oldSize);
1032+
} else {
1033+
// The following copy uses the current (thread local) stream for copying
1034+
// and also takes the GPU id from the device() field passed in.
1035+
//
1036+
// TODO: Potentially more enforcements are necessary to avoid accidental
1037+
// switch to sync copy if the currently set device is wrong.
1038+
//
1039+
// Specifically, we might need to switch to a different context device
1040+
// here explicitly to avoid relying on user synchronizing things
1041+
// properly.
1042+
at::CopyBytes(
1043+
oldSize * itemsize(),
1044+
oldData.get(),
1045+
device(),
1046+
newData,
1047+
device(),
1048+
true); // non-blocking
1049+
}
10421050
reserved_ = true;
10431051
sizes_ = newDims;
10441052
numel_ = newNumel;

aten/src/ATen/core/context_base.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include <ATen/core/context_base.h>
22

3+
#include <c10/util/Logging.h>
4+
35
namespace at {
46

57
C10_DEFINE_TYPED_REGISTRY(
@@ -9,6 +11,49 @@ C10_DEFINE_TYPED_REGISTRY(
911
std::unique_ptr,
1012
at::Device);
1113

14+
// First dimension of the array is `bool async`: 0 is sync,
15+
// 1 is async (non-blocking)
16+
static CopyBytesFunction g_copy_bytes[2][COMPILE_TIME_MAX_DEVICE_TYPES]
17+
[COMPILE_TIME_MAX_DEVICE_TYPES];
18+
19+
_CopyBytesFunctionRegisterer::_CopyBytesFunctionRegisterer(
20+
DeviceType fromType,
21+
DeviceType toType,
22+
CopyBytesFunction func_sync,
23+
CopyBytesFunction func_async) {
24+
auto from = static_cast<int>(fromType);
25+
auto to = static_cast<int>(toType);
26+
if (!func_async) {
27+
// default to the sync function
28+
func_async = func_sync;
29+
}
30+
CHECK(
31+
g_copy_bytes[0][from][to] == nullptr &&
32+
g_copy_bytes[1][from][to] == nullptr)
33+
<< "Duplicate registration for device type pair "
34+
<< c10::DeviceTypeName(fromType) << ", " << c10::DeviceTypeName(toType);
35+
g_copy_bytes[0][from][to] = func_sync;
36+
g_copy_bytes[1][from][to] = func_async;
37+
}
38+
39+
void CopyBytes(
40+
size_t nbytes,
41+
const void* src,
42+
Device src_device,
43+
void* dst,
44+
Device dst_device,
45+
bool async) {
46+
auto ptr = g_copy_bytes[async ? 1 : 0][static_cast<int>(src_device.type())]
47+
[static_cast<int>(dst_device.type())];
48+
CAFFE_ENFORCE(
49+
ptr,
50+
"No function found for copying from ",
51+
c10::DeviceTypeName(src_device.type()),
52+
" to ",
53+
c10::DeviceTypeName(dst_device.type()));
54+
ptr(nbytes, src, src_device, dst, dst_device);
55+
}
56+
1257
} // namespace at
1358

1459
namespace caffe2 {

aten/src/ATen/core/context_base.h

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -63,25 +63,6 @@ class CAFFE2_API BaseContext {
6363

6464
virtual void CopyBytesToCPU(size_t nbytes, const void* src, void* dst) = 0;
6565

66-
virtual void CopyBytesToDevice(
67-
size_t nbytes,
68-
const void* src,
69-
void* dst,
70-
DeviceType type) {
71-
if (type == DeviceType::CPU) {
72-
CopyBytesToCPU(nbytes, src, dst);
73-
} else if (type == device_type()) {
74-
CopyBytesSameDevice(nbytes, src, dst);
75-
} else {
76-
AT_ERROR(
77-
"CopyBytesToDevice can only copy to CPU or between same "
78-
"device. Can't copy from: ",
79-
device_type(),
80-
" to",
81-
type);
82-
}
83-
}
84-
8566
template <typename T>
8667
inline void CopySameDevice(size_t n, const T* src, T* dst) {
8768
static_assert(
@@ -175,9 +156,41 @@ inline std::unique_ptr<at::BaseContext> CreateContext(
175156

176157
} // namespace at
177158

159+
// TODO: move it to a separate file in c10 if possible
160+
namespace at {
161+
162+
using CopyBytesFunction = void (*)(
163+
size_t nbytes,
164+
const void* src,
165+
Device src_device,
166+
void* dst,
167+
Device dst_device);
168+
169+
struct CAFFE2_API _CopyBytesFunctionRegisterer {
170+
_CopyBytesFunctionRegisterer(
171+
DeviceType from,
172+
DeviceType to,
173+
CopyBytesFunction func_sync,
174+
CopyBytesFunction func_async = nullptr);
175+
};
176+
177+
#define REGISTER_COPY_BYTES_FUNCTION(from, to, ...) \
178+
namespace { \
179+
static _CopyBytesFunctionRegisterer C10_ANONYMOUS_VARIABLE( \
180+
g_copy_function)(from, to, __VA_ARGS__); \
181+
}
182+
183+
CAFFE2_API void CopyBytes(
184+
size_t nbytes,
185+
const void* src,
186+
Device src_device,
187+
void* dst,
188+
Device dst_device,
189+
bool async);
190+
} // namespace at
191+
178192
namespace caffe2 {
179193

180194
using at::BaseContext;
181195
using at::CreateContext;
182-
183196
} // namespace caffe2

c10/DeviceType.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ enum class DeviceType : int16_t {
2929
ONLY_FOR_TEST = 20901, // This device type is only for test.
3030
};
3131

32+
// define explicit int constant
33+
constexpr int COMPILE_TIME_MAX_DEVICE_TYPES =
34+
static_cast<int>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES);
3235
C10_API std::string DeviceTypeName(
3336
DeviceType d,
3437
bool lower_case = false);

caffe2/core/blob_serialization.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,8 @@ void TensorSerializer::Serialize(
222222
const TensorProto::DataType data_type = TypeMetaToDataType(input.dtype());
223223
proto.set_data_type(data_type);
224224
StoreDeviceDetail(input, &proto);
225+
// TODO: use DeviceGuard here instead of context and employ explicit sync
226+
// copy
225227
auto uniq_ptr = CreateContext(input.GetDevice());
226228
// A lot of copypaste is error prone. Should we create a macro for this?
227229
switch (data_type) {

caffe2/core/context.cc

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,6 @@
55
#include <process.h>
66
#endif
77

8-
namespace at {
9-
10-
REGISTER_CONTEXT(DeviceType::CPU, caffe2::CPUContext);
11-
} // namespace at
128
namespace caffe2 {
139

1410
uint32_t RandomNumberSeed() {
@@ -28,4 +24,41 @@ uint32_t RandomNumberSeed() {
2824
kPrime2 * tv_sec + kPrime3 * tv_usec;
2925
}
3026

27+
namespace {
28+
inline void CopyBytesImpl(size_t nbytes, const void* src, void* dst) {
29+
if (nbytes == 0) {
30+
return;
31+
}
32+
CAFFE_ENFORCE(src);
33+
CAFFE_ENFORCE(dst);
34+
memcpy(dst, src, nbytes);
35+
}
36+
37+
void CopyBytesWrapper(
38+
size_t nbytes,
39+
const void* src,
40+
Device src_device,
41+
void* dst,
42+
Device dst_device) {
43+
CopyBytesImpl(nbytes, src, dst);
44+
}
45+
} // namespace
46+
47+
void CPUContext::CopyBytesSameDevice(
48+
size_t nbytes,
49+
const void* src,
50+
void* dst) {
51+
CopyBytesImpl(nbytes, src, dst);
52+
}
53+
3154
} // namespace caffe2
55+
56+
namespace at {
57+
58+
REGISTER_CONTEXT(DeviceType::CPU, caffe2::CPUContext);
59+
60+
REGISTER_COPY_BYTES_FUNCTION(
61+
DeviceType::CPU,
62+
DeviceType::CPU,
63+
caffe2::CopyBytesWrapper);
64+
} // namespace at

caffe2/core/context.h

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,7 @@ class CAFFE2_API CPUContext final : public BaseContext {
7979
return GetCPUAllocator()->allocate(nbytes);
8080
}
8181

82-
void CopyBytesSameDevice(size_t nbytes, const void* src, void* dst) override {
83-
if (nbytes == 0) {
84-
return;
85-
}
86-
CAFFE_ENFORCE(src);
87-
CAFFE_ENFORCE(dst);
88-
memcpy(dst, src, nbytes);
89-
}
82+
void CopyBytesSameDevice(size_t nbytes, const void* src, void* dst) override;
9083

9184
void CopyBytesFromCPU(size_t nbytes, const void* src, void* dst) override {
9285
CopyBytesSameDevice(nbytes, src, dst);

0 commit comments

Comments
 (0)