forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathAllocator.h
164 lines (147 loc) · 4.84 KB
/
Allocator.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#pragma once
#include <stddef.h>
#include <memory>
#include <c10/Device.h>
#include <ATen/core/UniqueVoidPtr.h>
#include <c10/util/Exception.h>
namespace at {
// A DataPtr is a unique pointer (with an attached deleter and some
// context for the deleter) to some memory, which also records what
// device is for its data.
//
// nullptr DataPtrs can still have a nontrivial device; this allows
// us to treat zero-size allocations uniformly with non-zero allocations.
//
class DataPtr {
private:
detail::UniqueVoidPtr ptr_;
Device device_;
public:
// Choice of CPU here is arbitrary; if there's an "undefined" device
// we could use that too
DataPtr() : ptr_(), device_(DeviceType::CPU) {}
DataPtr(void* data, Device device) : ptr_(data), device_(device) {}
DataPtr(void* data, void* ctx, DeleterFnPtr ctx_deleter, Device device)
: ptr_(data, ctx, ctx_deleter), device_(device) {}
void* operator->() const {
return ptr_.get();
}
void clear() {
ptr_.clear();
}
void* get() const {
return ptr_.get();
}
void* get_context() const {
return ptr_.get_context();
}
void* release_context() {
return ptr_.release_context();
}
std::unique_ptr<void, DeleterFnPtr>&& move_context() {
return ptr_.move_context();
}
operator bool() const {
return static_cast<bool>(ptr_);
}
template <typename T>
T* cast_context(DeleterFnPtr expected_deleter) const {
return ptr_.cast_context<T>(expected_deleter);
}
DeleterFnPtr get_deleter() const {
return ptr_.get_deleter();
}
Device device() const {
return device_;
}
};
// NB: Device is NOT tested for here; a CUDA nullptr is as much a nullptr as a
// CPU nullptr
inline bool operator==(const at::DataPtr& dp, std::nullptr_t) noexcept {
return !dp;
}
inline bool operator==(std::nullptr_t, const at::DataPtr& dp) noexcept {
return !dp;
}
inline bool operator!=(const at::DataPtr& dp, std::nullptr_t) noexcept {
return dp;
}
inline bool operator!=(std::nullptr_t, const at::DataPtr& dp) noexcept {
return dp;
}
// Note [raw_allocate/raw_deallocate and Thrust]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Thrust's support for custom allocators requires us to write something
// like this:
//
// class ThrustAllocator {
// char* allocate(size_t);
// void deallocate(char*, size_t);
// };
//
// This is not good for our unique_ptr based allocator interface, as
// there is no way to get to the context when we free.
//
// However, in some cases the context is exactly the same as
// the data pointer. In this case, we can support the "raw"
// allocate and deallocate interface. This is what
// raw_deleter signifies. By default, it returns a nullptr, which means that
// the raw interface is not implemented. Be sure to implement it whenever
// possible, or the raw interface will incorrectly reported as unsupported,
// when it is actually possible.
struct Allocator {
virtual ~Allocator() {}
virtual at::DataPtr allocate(size_t n) const = 0;
// If this returns a non nullptr, it means that allocate()
// is guaranteed to return a unique_ptr with this deleter attached;
// it means the rawAllocate and rawDeallocate APIs are safe to use.
// This function MUST always return the same BoundDeleter.
virtual DeleterFnPtr raw_deleter() const {
return nullptr;
}
void* raw_allocate(size_t n) {
auto dptr = allocate(n);
AT_ASSERT(dptr.get() == dptr.get_context());
return dptr.release_context();
}
void raw_deallocate(void* ptr) {
auto d = raw_deleter();
AT_ASSERT(d);
d(ptr);
}
};
// Question: is this still needed?
struct CAFFE2_API InefficientStdFunctionContext {
std::unique_ptr<void, std::function<void(void*)>> ptr_;
InefficientStdFunctionContext(
std::unique_ptr<void, std::function<void(void*)>>&& ptr)
: ptr_(std::move(ptr)) {}
static at::DataPtr makeDataPtr(
void* ptr,
const std::function<void(void*)>& deleter,
Device device);
};
} // namespace at
namespace caffe2 {
/** Set the allocator for DeviceType `t`. The passed in allocator pointer is
* expected to have static lifetime; this function does NOT take ownership
* of the raw pointer. (The reason for this is to prevent existing pointers
* to an allocator of a particular device from being invalidated when
* SetAllocator is called.)
*
* Also note that this is not thraed-safe, and we assume this function will
* only be called during initialization.
*/
CAFFE2_API void SetAllocator(at::DeviceType t, at::Allocator* alloc);
CAFFE2_API at::Allocator* GetAllocator(const at::DeviceType& t);
template <at::DeviceType t>
struct AllocatorRegisterer {
explicit AllocatorRegisterer(at::Allocator* alloc) {
SetAllocator(t, alloc);
}
};
#define REGISTER_ALLOCATOR(t, f) \
namespace { \
static AllocatorRegisterer<t> g_allocator_##d(f); \
}
} // namespace caffe2