Skip to content

Commit e4c59c4

Browse files
authored
Fixed issue with operator types used as both lvalue/rvalue no assigning (#655)
1 parent 0c00b59 commit e4c59c4

14 files changed

+137
-7
lines changed

include/matx/operators/collapse.h

+14
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ namespace matx
5151
using scalar_type = typename T1::scalar_type;
5252
using shape_type = index_t;
5353
using matxoplvalue = bool;
54+
using self_type = LCollapseOp<DIM, T1>;
5455

5556
__MATX_INLINE__ std::string str() const { return "lcollapse<" + std::to_string(DIM) + ">(" + op_.str() + ")"; }
5657
__MATX_INLINE__ LCollapseOp(const T1 op) : op_(op)
@@ -132,6 +133,12 @@ namespace matx
132133
return op_.Size(DIM + dim - 1);
133134
}
134135

136+
~LCollapseOp() = default;
137+
LCollapseOp(const LCollapseOp &rhs) = default;
138+
__MATX_INLINE__ auto operator=(const self_type &rhs) {
139+
return set(*this, rhs);
140+
}
141+
135142
template<typename R>
136143
__MATX_INLINE__ auto operator=(const R &rhs) {
137144
if constexpr (is_matx_transform_op<R>()) {
@@ -199,6 +206,7 @@ namespace matx
199206
using scalar_type = typename T1::scalar_type;
200207
using shape_type = index_t;
201208
using matxlvalue = bool;
209+
using self_type = RCollapseOp<DIM, T1>;
202210

203211
__MATX_INLINE__ std::string str() const { return "rcollapse<" + std::to_string(DIM) + ">(" + op_.str() + ")"; }
204212

@@ -281,6 +289,12 @@ namespace matx
281289
return op_.Size(dim);
282290
}
283291

292+
~RCollapseOp() = default;
293+
RCollapseOp(const RCollapseOp &rhs) = default;
294+
__MATX_INLINE__ auto operator=(const self_type &rhs) {
295+
return set(*this, rhs);
296+
}
297+
284298
template<typename R>
285299
__MATX_INLINE__ auto operator=(const R &rhs) {
286300
if constexpr (is_matx_transform_op<R>()) {

include/matx/operators/concat.h

+7
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ namespace matx
5151
{
5252
using first_type = cuda::std::tuple_element_t<0, cuda::std::tuple<Ts...>>;
5353
using first_value_type = typename first_type::scalar_type;
54+
using self_type = ConcatOp<Ts...>;
5455

5556
static constexpr int RANK = first_type::Rank();
5657

@@ -165,6 +166,12 @@ namespace matx
165166
return cuda::std::get<0>(ops_).Size(dim);
166167
}
167168

169+
~ConcatOp() = default;
170+
ConcatOp(const ConcatOp &rhs) = default;
171+
__MATX_INLINE__ auto operator=(const self_type &rhs) {
172+
return set(*this, rhs);
173+
}
174+
168175
template<typename R>
169176
__MATX_INLINE__ auto operator=(const R &rhs) {
170177
if constexpr (is_matx_transform_op<R>()) {

include/matx/operators/overlap.h

+8-1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ namespace matx
4848
public:
4949
using scalar_type = typename T::scalar_type;
5050
using shape_type = index_t;
51+
using self_type = OverlapOp<DIM, T>;
5152

5253
private:
5354
typename base_type<T>::type op_;
@@ -118,7 +119,13 @@ namespace matx
118119
if constexpr (is_matx_op<T>()) {
119120
op_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
120121
}
121-
}
122+
}
123+
124+
~OverlapOp() = default;
125+
OverlapOp(const OverlapOp &rhs) = default;
126+
__MATX_INLINE__ auto operator=(const self_type &rhs) {
127+
return set(*this, rhs);
128+
}
122129

123130
template<typename R>
124131
__MATX_INLINE__ auto operator=(const R &rhs) {

include/matx/operators/permute.h

+8-1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ namespace matx
4747
{
4848
public:
4949
using scalar_type = typename T::scalar_type;
50+
using self_type = PermuteOp<T>;
5051

5152
private:
5253
typename base_type<T>::type op_;
@@ -133,7 +134,13 @@ namespace matx
133134
if constexpr (is_matx_op<T>()) {
134135
op_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
135136
}
136-
}
137+
}
138+
139+
~PermuteOp() = default;
140+
PermuteOp(const PermuteOp &rhs) = default;
141+
__MATX_INLINE__ auto operator=(const self_type &rhs) {
142+
return set(*this, rhs);
143+
}
137144

138145
template<typename R>
139146
__MATX_INLINE__ auto operator=(const R &rhs) {

include/matx/operators/remap.h

+8-1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ namespace matx
5757
using scalar_type = typename T::scalar_type;
5858
using shape_type = std::conditional_t<has_shape_type_v<T>, typename T::shape_type, index_t>;
5959
using index_type = typename IdxType::scalar_type;
60+
using self_type = RemapOp<DIM, T, IdxType>;
6061
static_assert(std::is_integral<index_type>::value, "RemapOp: Type for index operator must be integral");
6162
static_assert(IdxType::Rank() <= 1, "RemapOp: Rank of index operator must be 0 or 1");
6263
static_assert(DIM<T::Rank(), "RemapOp: DIM must be less than Rank of tensor");
@@ -134,7 +135,13 @@ namespace matx
134135
if constexpr (is_matx_op<T>()) {
135136
op_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
136137
}
137-
}
138+
}
139+
140+
~RemapOp() = default;
141+
RemapOp(const RemapOp &rhs) = default;
142+
__MATX_INLINE__ auto operator=(const self_type &rhs) {
143+
return set(*this, rhs);
144+
}
138145

139146
template<typename R>
140147
__MATX_INLINE__ auto operator=(const R &rhs) {

include/matx/operators/reshape.h

+8-1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ namespace matx
5656
public:
5757
using matxop = bool;
5858
using matxoplvalue = bool;
59+
using self_type = ReshapeOp<RANK, T, ShapeType>;
5960

6061
__MATX_INLINE__ std::string str() const { return "reshape(" + op_.str() + ")"; }
6162

@@ -149,7 +150,13 @@ namespace matx
149150
if constexpr (is_matx_op<T>()) {
150151
op_.PostRun(std::forward<S2>(shape), std::forward<Executor>(ex));
151152
}
152-
}
153+
}
154+
155+
~ReshapeOp() = default;
156+
ReshapeOp(const ReshapeOp &rhs) = default;
157+
__MATX_INLINE__ auto operator=(const self_type &rhs) {
158+
return set(*this, rhs);
159+
}
153160

154161
template<typename R>
155162
__MATX_INLINE__ auto operator=(const R &rhs) {

include/matx/operators/reverse.h

+8-1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ namespace matx
5757
using matxop = bool;
5858
using matxoplvalue = bool;
5959
using scalar_type = typename T1::scalar_type;
60+
using self_type = ReverseOp<DIM, T1>;
6061

6162
__MATX_INLINE__ std::string str() const { return "reverse(" + op_.str() + ")"; }
6263

@@ -112,7 +113,13 @@ namespace matx
112113
if constexpr (is_matx_op<T1>()) {
113114
op_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
114115
}
115-
}
116+
}
117+
118+
~ReverseOp() = default;
119+
ReverseOp(const ReverseOp &rhs) = default;
120+
__MATX_INLINE__ auto operator=(const self_type &rhs) {
121+
return set(*this, rhs);
122+
}
116123

117124
template<typename R>
118125
__MATX_INLINE__ auto operator=(const R &rhs) {

include/matx/operators/set.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ class set : public BaseOp<set<T, Op>> {
6868
public:
6969
// Type specifier for reflection on class
7070
using scalar_type = typename T::scalar_type;
71-
using shape_type = std::conditional_t<has_shape_type_v<T>, typename T::shape_type, index_t>;
7271
using tensor_type = T;
7372
using op_type = Op;
7473
using matx_setop = bool;
@@ -136,7 +135,9 @@ class set : public BaseOp<set<T, Op>> {
136135
return r;
137136
}
138137
}
139-
__MATX_DEVICE__ __MATX_HOST__ inline decltype(auto) operator()(cuda::std::array<shape_type, T::Rank()> idx) const noexcept
138+
139+
template <typename ShapeType>
140+
__MATX_DEVICE__ __MATX_HOST__ inline decltype(auto) operator()(cuda::std::array<ShapeType, T::Rank()> idx) const noexcept
140141
{
141142
auto res = cuda::std::apply([&](auto &&...args) {
142143
return _internal_mapply(args...);

include/matx/operators/shift.h

+7
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ namespace matx
6161
using matxop = bool;
6262
using matxoplvalue = bool;
6363
using scalar_type = typename T1::scalar_type;
64+
using self_type = ShiftOp<DIM, T1, T2>;
6465

6566
__MATX_INLINE__ std::string str() const { return "shift(" + op_.str() + ")"; }
6667

@@ -131,6 +132,12 @@ namespace matx
131132
return detail::matx_max(size1,size2);
132133
}
133134

135+
~ShiftOp() = default;
136+
ShiftOp(const ShiftOp &rhs) = default;
137+
__MATX_INLINE__ auto operator=(const self_type &rhs) {
138+
return set(*this, rhs);
139+
}
140+
134141
template<typename R>
135142
__MATX_INLINE__ auto operator=(const R &rhs) {
136143
if constexpr (is_matx_transform_op<R>()) {

include/matx/operators/slice.h

+7
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ namespace matx
4848
public:
4949
using scalar_type = typename T::scalar_type;
5050
using shape_type = index_t;
51+
using self_type = SliceOp<DIM, T>;
5152

5253
private:
5354
typename base_type<T>::type op_;
@@ -158,6 +159,12 @@ namespace matx
158159
return sizes_[dim];
159160
}
160161

162+
~SliceOp() = default;
163+
SliceOp(const SliceOp &rhs) = default;
164+
__MATX_INLINE__ auto operator=(const self_type &rhs) {
165+
return set(*this, rhs);
166+
}
167+
161168
template<typename R>
162169
__MATX_INLINE__ auto operator=(const R &rhs) {
163170
if constexpr (is_matx_transform_op<R>()) {

include/matx/operators/stack.h

+7
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ namespace matx
5050
{
5151
using first_type = cuda::std::tuple_element_t<0, cuda::std::tuple<Ts...>>;
5252
using first_value_type = typename first_type::scalar_type;
53+
using self_type = StackOp<Ts...>;
5354

5455
static constexpr int RANK = first_type::Rank();
5556

@@ -181,6 +182,12 @@ namespace matx
181182
}
182183
}
183184

185+
~StackOp() = default;
186+
StackOp(const StackOp &rhs) = default;
187+
__MATX_INLINE__ auto operator=(const self_type &rhs) {
188+
return set(*this, rhs);
189+
}
190+
184191
template<typename R>
185192
__MATX_INLINE__ auto operator=(const R &rhs) {
186193
if constexpr (is_matx_transform_op<R>()) {

include/matx/operators/transpose.h

+7
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ namespace detail {
5757
using matx_transform_op = bool;
5858
using matxoplvalue = bool;
5959
using transpose_xform_op = bool;
60+
using self_type = TransposeMatrixOp<OpA>;
6061

6162
__MATX_INLINE__ std::string str() const { return "transpose_matrix(" + get_type_str(a_) + ")"; }
6263
__MATX_INLINE__ TransposeMatrixOp(OpA a) : a_(a) {
@@ -121,6 +122,12 @@ namespace detail {
121122
return out_dims_[dim];
122123
}
123124

125+
~TransposeMatrixOp() = default;
126+
TransposeMatrixOp(const TransposeMatrixOp &rhs) = default;
127+
__MATX_INLINE__ auto operator=(const self_type &rhs) {
128+
return set(*this, rhs);
129+
}
130+
124131
template<typename R>
125132
__MATX_INLINE__ auto operator=(const R &rhs) {
126133
if constexpr (is_matx_transform_op<R>()) {

include/matx/operators/updownsample.h

+7
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ namespace matx
5454
using matxop = bool;
5555
using matxoplvalue = bool;
5656
using scalar_type = typename T::scalar_type;
57+
using self_type = UpsampleOp<T>;
5758

5859
static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
5960
{
@@ -109,6 +110,12 @@ namespace matx
109110
}
110111
}
111112

113+
~UpsampleOp() = default;
114+
UpsampleOp(const UpsampleOp &rhs) = default;
115+
__MATX_INLINE__ auto operator=(const self_type &rhs) {
116+
return set(*this, rhs);
117+
}
118+
112119
template<typename R> __MATX_INLINE__ auto operator=(const R &rhs) { return set(*this, rhs); }
113120
};
114121
}

test/00_operators/OperatorTests.cu

+38
Original file line numberDiff line numberDiff line change
@@ -1589,6 +1589,44 @@ TYPED_TEST(OperatorTestsNumericAllExecs, RemapOp)
15891589
}
15901590
}
15911591

1592+
{
1593+
// Remap as both LHS and RHS
1594+
auto in = make_tensor<TestType>({4,4,4});
1595+
auto out = make_tensor<TestType>({4,4,4});
1596+
TestType c = GenerateData<TestType>();
1597+
for (int i = 0; i < in.Size(0); i++){
1598+
for (int j = 0; j < in.Size(1); j++){
1599+
for (int k = 0; k < in.Size(2); k++){
1600+
in(i,j,k) = c;
1601+
}
1602+
}
1603+
}
1604+
1605+
auto map1 = matx::make_tensor<int>({2});
1606+
auto map2 = matx::make_tensor<int>({2});
1607+
map1(0) = 1;
1608+
map1(1) = 2;
1609+
map2(0) = 0;
1610+
map2(1) = 1;
1611+
1612+
(out = static_cast<TestType>(0)).run(exec);
1613+
(matx::remap<2>(out, map2) = matx::remap<2>(in, map1)).run(exec);
1614+
exec.sync();
1615+
1616+
for (int i = 0; i < in.Size(0); i++){
1617+
for (int j = 0; j < in.Size(1); j++){
1618+
for (int k = 0; k < in.Size(2); k++){
1619+
if (k > 1) {
1620+
ASSERT_EQ(out(i,j,k), (TestType)0);
1621+
}
1622+
else {
1623+
ASSERT_EQ(out(i,j,k), in(i,j,k));
1624+
}
1625+
}
1626+
}
1627+
}
1628+
}
1629+
15921630
MATX_EXIT_HANDLER();
15931631
}
15941632

0 commit comments

Comments
 (0)