Skip to content

Commit

Permalink
fix partial spectrum index
Browse files Browse the repository at this point in the history
  • Loading branch information
RMeli committed Dec 19, 2024
1 parent 972b5e4 commit 0b4be03
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 41 deletions.
4 changes: 2 additions & 2 deletions include/dlaf_c/eigensolver/eigensolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ DLAF_EXTERN_C void dlaf_pssyevd(const char uplo, const int n, float* a, const in
const int descz[9], int* info) DLAF_NOEXCEPT;

// @copydoc dlaf_pssyevd
/// @param eigenvalues_index_begin index of the first eigenvalue to compute (has to be 0)
/// @param eigenvalues_index_end index of the last eigenvalue to compute (exclusive)
/// @param eigenvalues_index_begin index of the first eigenvalue to compute (has to be 1)
/// @param eigenvalues_index_end index of the last eigenvalue to compute (inclusive)
DLAF_EXTERN_C void dlaf_pssyevd_partial_spectrum(
const char uplo, const int n, float* a, const int ia, const int ja, const int desca[9], float* w,
float* z, const int iz, const int jz, const int descz[9], const SizeType eigenvalues_index_begin,
Expand Down
12 changes: 6 additions & 6 deletions include/dlaf_c/eigensolver/gen_eigensolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ DLAF_EXTERN_C int dlaf_symmetric_generalized_eigensolver_s(
const struct DLAF_descriptor dlaf_descz) DLAF_NOEXCEPT;

/// @copydoc dlaf_symmetric_generalized_eigensolver_s
/// @param eigenvalues_index_begin index of the first eigenvalue to compute
/// @param eigenvalues_index_begin index of the first eigenvalue to compute (has to be 0)
/// @param eigenvalues_index_end index of the last eigenvalue to compute (exclusive)
DLAF_EXTERN_C int dlaf_symmetric_generalized_eigensolver_partial_spectrum_s(
const int dlaf_context, const char uplo, float* a, const struct DLAF_descriptor dlaf_desca, float* b,
Expand Down Expand Up @@ -116,7 +116,7 @@ DLAF_EXTERN_C int dlaf_symmetric_generalized_eigensolver_factorized_s(
const struct DLAF_descriptor dlaf_descz) DLAF_NOEXCEPT;

/// @copydoc dlaf_symmetric_generalized_eigensolver_factorized_s
/// @param eigenvalues_index_begin index of the first eigenvalue to compute
/// @param eigenvalues_index_begin index of the first eigenvalue to compute (has to be 0)
/// @param eigenvalues_index_end index of the last eigenvalue to compute (exclusive)
DLAF_EXTERN_C int dlaf_symmetric_generalized_eigensolver_partial_spectrum_factorized_s(
const int dlaf_context, const char uplo, float* a, const struct DLAF_descriptor dlaf_desca, float* b,
Expand Down Expand Up @@ -208,8 +208,8 @@ DLAF_EXTERN_C void dlaf_pssygvd(const char uplo, const int n, float* a, const in
const int descz[9], int* info) DLAF_NOEXCEPT;

/// @copydoc dlaf_pssygvd
/// @param eigenvalues_index_begin index of the first eigenvalue to compute
/// @param eigenvalues_index_end index of the last eigenvalue to compute (exclusive)
/// @param eigenvalues_index_begin index of the first eigenvalue to compute (has to be 1)
/// @param eigenvalues_index_end index of the last eigenvalue to compute (inclusive)
DLAF_EXTERN_C void dlaf_pssygvd_partial_spectrum(
const char uplo, const int n, float* a, const int ia, const int ja, const int desca[9], float* b,
const int ib, const int jb, const int descb[9], float* w, float* z, const int iz, const int jz,
Expand Down Expand Up @@ -303,8 +303,8 @@ DLAF_EXTERN_C void dlaf_pssygvd_factorized(const char uplo, const int n, float*
int* info) DLAF_NOEXCEPT;

/// @copydoc dlaf_pssygvx_factorized
/// @param eigenvalues_index_begin index of the first eigenvalue to compute
/// @param eigenvalues_index_end index of the last eigenvalue to compute (exclusive)
/// @param eigenvalues_index_begin index of the first eigenvalue to compute (has to be 1)
/// @param eigenvalues_index_end index of the last eigenvalue to compute (inclusive)
DLAF_EXTERN_C void dlaf_pssygvd_partial_spectrum_factorized(
const char uplo, const int n, float* a, const int ia, const int ja, const int desca[9], float* b,
const int ib, const int jb, const int descb[9], float* w, float* z, const int iz, const int jz,
Expand Down
8 changes: 4 additions & 4 deletions src/c_api/eigensolver/eigensolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ int dlaf_hermitian_eigensolver_partial_spectrum_z(
void dlaf_pssyevd(const char uplo, const int m, float* a, const int ia, const int ja, const int desca[9],
float* w, float* z, const int iz, const int jz, const int descz[9],
int* info) noexcept {
pxheevd<float>(uplo, m, a, ia, ja, desca, w, z, iz, jz, descz, 0l, m, *info);
pxheevd<float>(uplo, m, a, ia, ja, desca, w, z, iz, jz, descz, 1l, m, *info);
}

void dlaf_pssyevd_partial_spectrum(const char uplo, const int m, float* a, const int ia, const int ja,
Expand All @@ -97,7 +97,7 @@ void dlaf_pssyevd_partial_spectrum(const char uplo, const int m, float* a, const
void dlaf_pdsyevd(const char uplo, const int m, double* a, const int ia, const int ja,
const int desca[9], double* w, double* z, const int iz, const int jz,
const int descz[9], int* info) noexcept {
pxheevd<double>(uplo, m, a, ia, ja, desca, w, z, iz, jz, descz, 0l, m, *info);
pxheevd<double>(uplo, m, a, ia, ja, desca, w, z, iz, jz, descz, 1l, m, *info);
}

void dlaf_pdsyevd_partial_spectrum(const char uplo, const int m, double* a, const int ia, const int ja,
Expand All @@ -110,7 +110,7 @@ void dlaf_pdsyevd_partial_spectrum(const char uplo, const int m, double* a, cons
void dlaf_pcheevd(const char uplo, const int m, dlaf_complex_c* a, const int ia, const int ja,
const int desca[9], float* w, dlaf_complex_c* z, const int iz, const int jz,
const int descz[9], int* info) noexcept {
pxheevd<std::complex<float>>(uplo, m, a, ia, ja, desca, w, z, iz, jz, descz, 0l, m, *info);
pxheevd<std::complex<float>>(uplo, m, a, ia, ja, desca, w, z, iz, jz, descz, 1l, m, *info);
}

void dlaf_pcheevd_partial_spectrum(const char uplo, const int m, dlaf_complex_c* a, const int ia,
Expand All @@ -125,7 +125,7 @@ void dlaf_pcheevd_partial_spectrum(const char uplo, const int m, dlaf_complex_c*
void dlaf_pzheevd(const char uplo, const int m, dlaf_complex_z* a, const int ia, const int ja,
const int desca[9], double* w, dlaf_complex_z* z, const int iz, const int jz,
const int descz[9], int* info) noexcept {
pxheevd<std::complex<double>>(uplo, m, a, ia, ja, desca, w, z, iz, jz, descz, 0l, m, *info);
pxheevd<std::complex<double>>(uplo, m, a, ia, ja, desca, w, z, iz, jz, descz, 1l, m, *info);
}

void dlaf_pzheevd_partial_spectrum(const char uplo, const int m, dlaf_complex_z* a, const int ia,
Expand Down
7 changes: 6 additions & 1 deletion src/c_api/eigensolver/eigensolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,17 @@ void pxheevd(const char uplo, const int m, T* a, const int ia, const int ja, con
DLAF_ASSERT(ja == 1, ja);
DLAF_ASSERT(iz == 1, iz);
DLAF_ASSERT(iz == 1, iz);
DLAF_ASSERT(m > 0 ? eigenvalues_index_begin >= 1 : eigenvalues_index_begin == 1, m,
eigenvalues_index_begin);
DLAF_ASSERT(m > 0 ? eigenvalues_index_end <= m : eigenvalues_index_end == 0, m, eigenvalues_index_end);
DLAF_ASSERT(m > 0 ? eigenvalues_index_begin <= eigenvalues_index_end : true, m,
eigenvalues_index_begin, eigenvalues_index_end);

auto dlaf_desca = make_dlaf_descriptor(m, m, ia, ja, desca);
auto dlaf_descz = make_dlaf_descriptor(m, m, iz, jz, descz);

auto _info = hermitian_eigensolver(desca[1], uplo, a, dlaf_desca, w, z, dlaf_descz,
eigenvalues_index_begin, eigenvalues_index_end);
eigenvalues_index_begin - 1, eigenvalues_index_end);
info = _info;
}

Expand Down
16 changes: 8 additions & 8 deletions src/c_api/eigensolver/gen_eigensolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ int dlaf_hermitian_generalized_eigensolver_partial_spectrum_factorized_z(
void dlaf_pssygvd(const char uplo, const int m, float* a, const int ia, const int ja, const int desca[9],
float* b, const int ib, const int jb, const int descb[9], float* w, float* z,
const int iz, const int jz, const int descz[9], int* info) noexcept {
pxhegvd<float>(uplo, m, a, ia, ja, desca, b, ib, jb, descb, w, z, iz, jz, descz, 0, m, *info);
pxhegvd<float>(uplo, m, a, ia, ja, desca, b, ib, jb, descb, w, z, iz, jz, descz, 1, m, *info);
}

void dlaf_pssygvd_partial_spectrum(const char uplo, const int m, float* a, const int ia, const int ja,
Expand All @@ -191,7 +191,7 @@ void dlaf_pdsygvd(const char uplo, const int m, double* a, const int ia, const i
const int desca[9], double* b, const int ib, const int jb, const int descb[9],
double* w, double* z, const int iz, const int jz, const int descz[9],
int* info) noexcept {
pxhegvd<double>(uplo, m, a, ia, ja, desca, b, ib, jb, descb, w, z, iz, jz, descz, 0, m, *info);
pxhegvd<double>(uplo, m, a, ia, ja, desca, b, ib, jb, descb, w, z, iz, jz, descz, 1, m, *info);
}

void dlaf_pdsygvd_partial_spectrum(const char uplo, const int m, double* a, const int ia, const int ja,
Expand All @@ -207,7 +207,7 @@ void dlaf_pchegvd(const char uplo, const int m, dlaf_complex_c* a, const int ia,
const int desca[9], dlaf_complex_c* b, const int ib, const int jb, const int descb[9],
float* w, dlaf_complex_c* z, const int iz, const int jz, const int descz[9],
int* info) noexcept {
pxhegvd<std::complex<float>>(uplo, m, a, ia, ja, desca, b, ib, jb, descb, w, z, iz, jz, descz, 0, m,
pxhegvd<std::complex<float>>(uplo, m, a, ia, ja, desca, b, ib, jb, descb, w, z, iz, jz, descz, 1, m,
*info);
}

Expand All @@ -225,7 +225,7 @@ void dlaf_pzhegvd(const char uplo, const int m, dlaf_complex_z* a, const int ia,
const int desca[9], dlaf_complex_z* b, const int ib, const int jb, const int descb[9],
double* w, dlaf_complex_z* z, const int iz, const int jz, const int descz[9],
int* info) noexcept {
pxhegvd<std::complex<double>>(uplo, m, a, ia, ja, desca, b, ib, jb, descb, w, z, iz, jz, descz, 0, m,
pxhegvd<std::complex<double>>(uplo, m, a, ia, ja, desca, b, ib, jb, descb, w, z, iz, jz, descz, 1, m,
*info);
}

Expand All @@ -243,7 +243,7 @@ void dlaf_pssygvd_factorized(const char uplo, const int m, float* a, const int i
const int desca[9], float* b, const int ib, const int jb,
const int descb[9], float* w, float* z, const int iz, const int jz,
const int descz[9], int* info) noexcept {
pxhegvd_factorized<float>(uplo, m, a, ia, ja, desca, b, ib, jb, descb, w, z, iz, jz, descz, 0, m,
pxhegvd_factorized<float>(uplo, m, a, ia, ja, desca, b, ib, jb, descb, w, z, iz, jz, descz, 1, m,
*info);
}

Expand All @@ -261,7 +261,7 @@ void dlaf_pdsygvd_factorized(const char uplo, const int m, double* a, const int
const int desca[9], double* b, const int ib, const int jb,
const int descb[9], double* w, double* z, const int iz, const int jz,
const int descz[9], int* info) noexcept {
pxhegvd_factorized<double>(uplo, m, a, ia, ja, desca, b, ib, jb, descb, w, z, iz, jz, descz, 0, m,
pxhegvd_factorized<double>(uplo, m, a, ia, ja, desca, b, ib, jb, descb, w, z, iz, jz, descz, 1, m,
*info);
}

Expand All @@ -280,7 +280,7 @@ void dlaf_pchegvd_factorized(const char uplo, const int m, dlaf_complex_c* a, co
const int descb[9], float* w, dlaf_complex_c* z, const int iz, const int jz,
const int descz[9], int* info) noexcept {
pxhegvd_factorized<std::complex<float>>(uplo, m, a, ia, ja, desca, b, ib, jb, descb, w, z, iz, jz,
descz, 0, m, *info);
descz, 1, m, *info);
}

void dlaf_pchegvd_partial_spectrum_factorized(
Expand All @@ -297,7 +297,7 @@ void dlaf_pzhegvd_factorized(const char uplo, const int m, dlaf_complex_z* a, co
const int descb[9], double* w, dlaf_complex_z* z, const int iz,
const int jz, const int descz[9], int* info) noexcept {
pxhegvd_factorized<std::complex<double>>(uplo, m, a, ia, ja, desca, b, ib, jb, descb, w, z, iz, jz,
descz, 0, m, *info);
descz, 1, m, *info);
}

void dlaf_pzhegvd_partial_spectrum_factorized(
Expand Down
13 changes: 9 additions & 4 deletions src/c_api/eigensolver/gen_eigensolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,19 +124,24 @@ void pxhegvd(const char uplo, const int m, T* a, const int ia, const int ja, con
DLAF_ASSERT(jb == 1, jb);
DLAF_ASSERT(iz == 1, iz);
DLAF_ASSERT(iz == 1, iz);
DLAF_ASSERT(m > 0 ? eigenvalues_index_begin >= 1 : eigenvalues_index_begin == 1, m,
eigenvalues_index_begin);
DLAF_ASSERT(m > 0 ? eigenvalues_index_end <= m : eigenvalues_index_end == 0, m, eigenvalues_index_end);
DLAF_ASSERT(m > 0 ? eigenvalues_index_begin <= eigenvalues_index_end : true, m,
eigenvalues_index_begin, eigenvalues_index_end);

auto dlaf_desca = make_dlaf_descriptor(m, m, ia, ja, desca);
auto dlaf_descb = make_dlaf_descriptor(m, m, ib, jb, descb);
auto dlaf_descz = make_dlaf_descriptor(m, m, iz, jz, descz);

if (!factorized) {
info =
hermitian_generalized_eigensolver<T>(desca[1], uplo, a, dlaf_desca, b, dlaf_descb, w, z,
dlaf_descz, eigenvalues_index_begin, eigenvalues_index_end);
info = hermitian_generalized_eigensolver<T>(desca[1], uplo, a, dlaf_desca, b, dlaf_descb, w, z,
dlaf_descz, eigenvalues_index_begin - 1,
eigenvalues_index_end);
}
else {
info = hermitian_generalized_eigensolver_factorized<T>(desca[1], uplo, a, dlaf_desca, b, dlaf_descb,
w, z, dlaf_descz, eigenvalues_index_begin,
w, z, dlaf_descz, eigenvalues_index_begin - 1,
eigenvalues_index_end);
}
}
Expand Down
17 changes: 13 additions & 4 deletions test/unit/c_api/eigensolver/test_eigensolver_c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,34 +188,43 @@ void testEigensolver(const blas::Uplo uplo, const SizeType m, const SizeType mb,
int desc_a[] = {1, dlaf_context, (int) m, (int) m, (int) mb, (int) mb, 0, 0, lld_a};
int desc_z[] = {1, dlaf_context, (int) m, (int) m, (int) mb, (int) mb, 0, 0, lld_eigenvectors};
int info = -1;

// Treat special case when eval_idx_end is 0 for the C API
// The ScaLAPACK API uses base 1 indexing
const SizeType eval_idx_end_scalapack = m > 0 && eval_idx_end == 0 ? 1 : eval_idx_end;

if constexpr (std::is_same_v<T, double>) {
if (eigenvalues_index_end.has_value())
C_dlaf_pdsyevd_partial_spectrum(dlaf_uplo, (int) m, local_a_ptr, 1, 1, desc_a, eigenvalues_ptr,
local_eigenvectors_ptr, 1, 1, desc_z, 0, eval_idx_end, &info);
local_eigenvectors_ptr, 1, 1, desc_z, 1,
eval_idx_end_scalapack, &info);
else
C_dlaf_pdsyevd(dlaf_uplo, (int) m, local_a_ptr, 1, 1, desc_a, eigenvalues_ptr,
local_eigenvectors_ptr, 1, 1, desc_z, &info);
}
else if constexpr (std::is_same_v<T, float>) {
if (eigenvalues_index_end.has_value())
C_dlaf_pssyevd_partial_spectrum(dlaf_uplo, (int) m, local_a_ptr, 1, 1, desc_a, eigenvalues_ptr,
local_eigenvectors_ptr, 1, 1, desc_z, 0, eval_idx_end, &info);
local_eigenvectors_ptr, 1, 1, desc_z, 1,
eval_idx_end_scalapack, &info);
else
C_dlaf_pssyevd(dlaf_uplo, (int) m, local_a_ptr, 1, 1, desc_a, eigenvalues_ptr,
local_eigenvectors_ptr, 1, 1, desc_z, &info);
}
else if constexpr (std::is_same_v<T, std::complex<double>>) {
if (eigenvalues_index_end.has_value())
C_dlaf_pzheevd_partial_spectrum(dlaf_uplo, (int) m, local_a_ptr, 1, 1, desc_a, eigenvalues_ptr,
local_eigenvectors_ptr, 1, 1, desc_z, 0, eval_idx_end, &info);
local_eigenvectors_ptr, 1, 1, desc_z, 1,
eval_idx_end_scalapack, &info);
else
C_dlaf_pzheevd(dlaf_uplo, (int) m, local_a_ptr, 1, 1, desc_a, eigenvalues_ptr,
local_eigenvectors_ptr, 1, 1, desc_z, &info);
}
else if constexpr (std::is_same_v<T, std::complex<float>>) {
if (eigenvalues_index_end.has_value())
C_dlaf_pcheevd_partial_spectrum(dlaf_uplo, (int) m, local_a_ptr, 1, 1, desc_a, eigenvalues_ptr,
local_eigenvectors_ptr, 1, 1, desc_z, 0, eval_idx_end, &info);
local_eigenvectors_ptr, 1, 1, desc_z, 1,
eval_idx_end_scalapack, &info);
else
C_dlaf_pcheevd(dlaf_uplo, (int) m, local_a_ptr, 1, 1, desc_a, eigenvalues_ptr,
local_eigenvectors_ptr, 1, 1, desc_z, &info);
Expand Down
Loading

0 comments on commit 0b4be03

Please sign in to comment.