Skip to content

Commit 1e5c64a

Browse files
authored
feat: added economic QR (#903)
* feat: added economic QR * fix memory alloc in gesvdjBatched
1 parent 5ee372a commit 1e5c64a

File tree

6 files changed

+606
-7
lines changed

6 files changed

+606
-7
lines changed

docs_input/api/linalg/decomp/qr.rst

+21-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ Perform a QR decomposition.
88
.. doxygenfunction:: qr
99

1010
.. note::
11-
This function is currently not supported with host-based executors (CPU)
11+
This function is currently not supported with host-based executors (CPU), and performs a full QR
12+
decomposition of a tensor `A` with shape `... x m x n`, where `Q` is shaped `... x m x m` and `R`
13+
is shaped `... x m x n`.
1214

1315
Examples
1416
~~~~~~~~
@@ -19,12 +21,27 @@ Examples
1921
:end-before: example-end qr-test-1
2022
:dedent:
2123

24+
.. doxygenfunction:: qr_econ
25+
26+
.. note::
27+
This function is currently not supported with host-based executors (CPU). It returns an economic
28+
QR decomposition, where `Q/R` are shaped `m x k` and `k x n` respectively, where `k = min(m, n)`.
29+
This is useful when `m >> n` to save memory and computation time.
30+
31+
Examples
32+
~~~~~~~~
33+
34+
.. literalinclude:: ../../../../test/00_solver/QREcon.cu
35+
:language: cpp
36+
:start-after: example-begin qr-econ-test-1
37+
:end-before: example-end qr-econ-test-1
38+
:dedent:
2239

2340
.. doxygenfunction:: qr_solver
2441

2542
.. note::
2643
This function does not return `Q` explicitly as it only runs :literal:`geqrf` from LAPACK/cuSolver.
27-
For full `Q/R`, use :literal:`qr_solver` on a CUDA executor.
44+
For full or economic `Q/R`, use :literal:`qr` or :literal:`qr_econ` on a CUDA executor.
2845

2946
Examples
3047
~~~~~~~~
@@ -33,4 +50,5 @@ Examples
3350
:language: cpp
3451
:start-after: example-begin qr_solver-test-1
3552
:end-before: example-end qr_solver-test-1
36-
:dedent:
53+
:dedent:
54+

include/matx/operators/qr.h

+72
Original file line numberDiff line numberDiff line change
@@ -184,4 +184,76 @@ __MATX_INLINE__ auto qr_solver(const OpA &a) {
184184
return detail::SolverQROp(a);
185185
}
186186

187+
188+
namespace detail {
189+
template<typename OpA>
190+
class EconQROp : public BaseOp<EconQROp<OpA>>
191+
{
192+
private:
193+
typename detail::base_type_t<OpA> a_;
194+
195+
public:
196+
using matxop = bool;
197+
using value_type = typename OpA::value_type;
198+
using matx_transform_op = bool;
199+
using qr_solver_xform_op = bool;
200+
201+
__MATX_INLINE__ std::string str() const { return "qr_econ()"; }
202+
__MATX_INLINE__ EconQROp(const OpA &a) : a_(a) { }
203+
204+
// This should never be called
205+
template <typename... Is>
206+
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) operator()(Is... indices) const = delete;
207+
208+
template <typename Out, typename Executor>
209+
void Exec(Out &&out, Executor &&ex) {
210+
static_assert(cuda::std::tuple_size_v<remove_cvref_t<Out>> == 3, "Must use mtie with 2 outputs on qr_econ(). ie: (mtie(Q, R) = qr_econ(A))");
211+
212+
qr_econ_impl(cuda::std::get<0>(out), cuda::std::get<1>(out), a_, ex);
213+
}
214+
215+
static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
216+
{
217+
return OpA::Rank();
218+
}
219+
220+
template <typename ShapeType, typename Executor>
221+
__MATX_INLINE__ void PreRun([[maybe_unused]] ShapeType &&shape, Executor &&ex) noexcept
222+
{
223+
if constexpr (is_matx_op<OpA>()) {
224+
a_.PreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
225+
}
226+
}
227+
228+
// Size is not relevant in qr_solver() since there are multiple return values and it
229+
// is not allowed to be called in larger expressions
230+
constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size(int dim) const
231+
{
232+
return a_.Size(dim);
233+
}
234+
235+
};
236+
}
237+
238+
/**
239+
* Perform an economic QR decomposition on a matrix using cuSolver.
240+
*
241+
* If rank > 2, operations are batched.
242+
*
243+
* @tparam OpA
244+
* Data type of input a tensor or operator
245+
*
246+
* @param a
247+
* Input tensor or operator of shape `... x m x n`
248+
*
249+
* @return
250+
* Operator that produces QR outputs.
251+
* - **Q** - Of shape `... x m x min(m, n)`, the reduced orthonormal basis for the span of A.
252+
* - **R** - Upper triangular matrix of shape `... x min(m, n) x n`.
253+
*/
254+
template<typename OpA>
255+
__MATX_INLINE__ auto qr_econ(const OpA &a) {
256+
return detail::EconQROp(a);
187257
}
258+
259+
}

0 commit comments

Comments
 (0)