Skip to content

Commit 52d91c5

Browse files
authored
Revert broadcasting changes to concat (#890)
1 parent 925a0b5 commit 52d91c5

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

include/matx/operators/concat.h

+5-8
Original file line numberDiff line numberDiff line change
@@ -91,48 +91,45 @@ namespace matx
9191
}
9292
}
9393

94-
95-
96-
9794
template <int I = 0, int N>
9895
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto GetVal(cuda::std::array<index_t,RANK> &indices) const {
9996

10097
if constexpr ( I == N ) {
10198
// This should never happen
10299
return value_type{};
103-
// returning this to satisfy lvalue requirements
104100
} else {
105101
const auto &op = cuda::std::get<I>(ops_);
106102
auto idx = indices[axis_];
107103
auto size = op.Size(axis_);
108104
// If in range of this operator
109105
if(idx < size) {
110106
// evaluate operator
111-
return get_value(cuda::std::forward<decltype(op)>(op), indices);
107+
return cuda::std::apply(op, indices);
112108
} else {
113109
// otherwise remove this operator and recurse
114110
indices[axis_] -= size;
115111
return GetVal<I+1, N>(indices);
116112
}
117113
}
118114
}
119-
115+
116+
120117
template <int I = 0, int N>
121118
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) GetVal(cuda::std::array<index_t,RANK> &indices) {
122119

123120
if constexpr ( I == N ) {
124121
// This should never happen
125122
// returning this to satisfy lvalue requirements
126123
auto &op = cuda::std::get<I-1>(ops_);
127-
return get_value(cuda::std::forward<decltype(op)>(op), indices);
124+
return cuda::std::apply(op, indices);
128125
} else {
129126
auto &op = cuda::std::get<I>(ops_);
130127
auto idx = indices[axis_];
131128
auto size = op.Size(axis_);
132129
// If in range of this operator
133130
if(idx < size) {
134131
// evaluate operator
135-
return get_value(cuda::std::forward<decltype(op)>(op), indices);
132+
return cuda::std::apply(op, indices);
136133
} else {
137134
// otherwise remove this operator and recurse
138135
indices[axis_] -= size;

0 commit comments

Comments
 (0)