@@ -184,4 +184,76 @@ __MATX_INLINE__ auto qr_solver(const OpA &a) {
184
184
return detail::SolverQROp (a);
185
185
}
186
186
187
+
188
+ namespace detail {
189
+ template <typename OpA>
190
+ class EconQROp : public BaseOp <EconQROp<OpA>>
191
+ {
192
+ private:
193
+ typename detail::base_type_t <OpA> a_;
194
+
195
+ public:
196
+ using matxop = bool ;
197
+ using value_type = typename OpA::value_type;
198
+ using matx_transform_op = bool ;
199
+ using qr_solver_xform_op = bool ;
200
+
201
+ __MATX_INLINE__ std::string str () const { return " qr_econ()" ; }
202
+ __MATX_INLINE__ EconQROp (const OpA &a) : a_(a) { }
203
+
204
+ // This should never be called
205
+ template <typename ... Is>
206
+ __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype (auto ) operator()(Is... indices) const = delete;
207
+
208
+ template <typename Out, typename Executor>
209
+ void Exec (Out &&out, Executor &&ex) {
210
+ static_assert (cuda::std::tuple_size_v<remove_cvref_t <Out>> == 3 , " Must use mtie with 2 outputs on qr_econ(). ie: (mtie(Q, R) = qr_econ(A))" );
211
+
212
+ qr_econ_impl (cuda::std::get<0 >(out), cuda::std::get<1 >(out), a_, ex);
213
+ }
214
+
215
+ static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank ()
216
+ {
217
+ return OpA::Rank ();
218
+ }
219
+
220
+ template <typename ShapeType, typename Executor>
221
+ __MATX_INLINE__ void PreRun ([[maybe_unused]] ShapeType &&shape, Executor &&ex) noexcept
222
+ {
223
+ if constexpr (is_matx_op<OpA>()) {
224
+ a_.PreRun (std::forward<ShapeType>(shape), std::forward<Executor>(ex));
225
+ }
226
+ }
227
+
228
+ // Size is not relevant in qr_solver() since there are multiple return values and it
229
+ // is not allowed to be called in larger expressions
230
+ constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size (int dim) const
231
+ {
232
+ return a_.Size (dim);
233
+ }
234
+
235
+ };
236
+ }
237
+
238
+ /* *
239
+ * Perform an economic QR decomposition on a matrix using cuSolver.
240
+ *
241
+ * If rank > 2, operations are batched.
242
+ *
243
+ * @tparam OpA
244
+ * Data type of input a tensor or operator
245
+ *
246
+ * @param a
247
+ * Input tensor or operator of shape `... x m x n`
248
+ *
249
+ * @return
250
+ * Operator that produces QR outputs.
251
+ * - **Q** - Of shape `... x m x min(m, n)`, the reduced orthonormal basis for the span of A.
252
+ * - **R** - Upper triangular matrix of shape `... x min(m, n) x n`.
253
+ */
254
+ template <typename OpA>
255
+ __MATX_INLINE__ auto qr_econ (const OpA &a) {
256
+ return detail::EconQROp (a);
187
257
}
258
+
259
+ }
0 commit comments