forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMemoryOverlap.cpp
105 lines (89 loc) · 3.89 KB
/
MemoryOverlap.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#include <ATen/MemoryOverlap.h>
#include <ATen/core/TensorBase.h>
#include <c10/core/Layout.h>
#include <c10/util/irange.h>
namespace at {
MemOverlap has_internal_overlap(const TensorBase& tensor) {
return has_internal_overlap(tensor.unsafeGetTensorImpl());
}
MemOverlap has_internal_overlap(TensorImpl* t) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(t->layout() == kStrided);
if (t->is_non_overlapping_and_dense()) {
return MemOverlap::No;
}
auto strides = t->sym_strides();
auto sizes = t->sym_sizes();
for (const auto i : c10::irange(strides.size())) {
// NB: The size oblivious test is written very carefully here. When
// unbacked SymInts are involved, we should try to conservatively report
// if memory overlap /could/ happen under some setting of unbacked
// SymInts. Thus, if I have u0 size, we should assume that this has > 1
// elements (first expression), but if I have a u0 stride, I should NOT
// assume that it is not zero (second expression)
if (TORCH_GUARD_SIZE_OBLIVIOUS(sizes[i].sym_gt(1)) && strides[i] == 0) {
return MemOverlap::Yes;
}
}
return MemOverlap::TooHard;
}
void assert_no_internal_overlap(const TensorBase& t) {
assert_no_internal_overlap(t.unsafeGetTensorImpl());
}
void assert_no_internal_overlap(TensorImpl* t) {
TORCH_CHECK(has_internal_overlap(t) != MemOverlap::Yes,
"unsupported operation: more than one element of the written-to tensor "
"refers to a single memory location. Please clone() the tensor before "
"performing the operation.");
}
MemOverlapStatus get_overlap_status(const TensorBase& a, const TensorBase& b) {
return get_overlap_status(a.unsafeGetTensorImpl(), b.unsafeGetTensorImpl());
}
MemOverlapStatus get_overlap_status(const TensorImpl* a, const TensorImpl* b) {
if (a == b) return MemOverlapStatus::Full;
if (a->numel() == 0 || b->numel() == 0) {
return MemOverlapStatus::No;
}
if (!a->is_non_overlapping_and_dense() || !b->is_non_overlapping_and_dense()) {
return MemOverlapStatus::TooHard;
}
// Test for storage equality, rather than pointer equality.
// This reduces precision, but if people are aliasing the
// same pointer across multiple storages there are many
// similar situations (e.g., storage().data() == storage().data()+1)
// which we will miss.
auto a_storage = a->unsafe_storage();
if (a_storage && a_storage.is_alias_of(b->unsafe_storage())) {
const auto a_begin = static_cast<const char*>(a->data());
const auto a_end = a_begin + a->numel() * a->itemsize();
const auto b_begin = static_cast<const char*>(b->data());
const auto b_end = b_begin + b->numel() * b->itemsize();
if (a_begin == b_begin && a_end == b_end) {
return (a->strides() == b->strides()) ?
MemOverlapStatus::Full : MemOverlapStatus::Partial;
}
if (a_begin < b_end && b_begin < a_end) {
return MemOverlapStatus::Partial;
}
}
return MemOverlapStatus::No;
}
void assert_no_partial_overlap(const TensorBase& a, const TensorBase& b) {
assert_no_partial_overlap(a.unsafeGetTensorImpl(), b.unsafeGetTensorImpl());
}
void assert_no_partial_overlap(TensorImpl* a, TensorImpl* b) {
TORCH_CHECK(get_overlap_status(a, b) != MemOverlapStatus::Partial,
"unsupported operation: some elements of the input tensor and "
"the written-to tensor refer to a single memory location. "
"Please clone() the tensor before performing the operation.");
}
void assert_no_overlap(const TensorBase& a, const TensorBase& b) {
assert_no_overlap(a.unsafeGetTensorImpl(), b.unsafeGetTensorImpl());
}
void assert_no_overlap(TensorImpl* a, TensorImpl* b) {
const auto lap = get_overlap_status(a, b);
TORCH_CHECK(lap != MemOverlapStatus::Partial && lap != MemOverlapStatus::Full,
"unsupported operation: some elements of the input tensor and "
"the written-to tensor refer to a single memory location. "
"Please clone() the tensor before performing the operation.");
}
}