forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Repeat.cpp
47 lines (39 loc) · 1.56 KB
/
Repeat.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#include <ATen/ATen.h>
#include <ATen/native/Repeat.h>
#include <ATen/Parallel.h>
static void compute_cpu(int64_t *repeat_ptr, int64_t *cumsum_ptr, int64_t *result_ptr, int64_t size) {
at::parallel_for(0, size, 1, [&](int64_t i_begin, int64_t i_end) {
for(int64_t i = i_begin; i < i_end; i++) {
int64_t end = cumsum_ptr[i];
int64_t size = repeat_ptr[i];
int64_t start = end - size;
for(int64_t j = start; j < end; j++) {
result_ptr[j] = i;
}
}
});
}
namespace at { namespace native {
Tensor repeat_interleave_cpu(const Tensor &repeat) {
return repeat_interleave_common<compute_cpu>(repeat);
}
Tensor repeat_interleave(const Tensor &self, const Tensor &repeats, c10::optional<int64_t> dim) {
Tensor input = self;
if(!dim) {
input = self.flatten();
dim = 0;
}
Tensor repeats_ = repeats;
if (repeats.dim() == 0 || (repeats.dim() == 1 && repeats.size(0) == 1)) {
repeats_ = repeats.reshape({1}).expand({input.size(dim.value())});
} else if (repeats.dim() == 1) {
TORCH_CHECK(repeats.size(0) == input.size(dim.value()), "repeats must have the same size as input along dim")
} else {
AT_ERROR("repeats must be 0-dim or 1-dim tensor");
}
return input.index_select(dim.value(), at::repeat_interleave(repeats_));
}
Tensor repeat_interleave(const Tensor &self, int64_t repeats, c10::optional<int64_t> dim) {
return at::native::repeat_interleave(self, at::tensor({repeats}, self.options().dtype(kLong)), dim);
}
}}