Skip to content

Commit

Permalink
Fix conversions
Browse files Browse the repository at this point in the history
  • Loading branch information
siuwuncheung committed Dec 6, 2024
1 parent 47e6335 commit cdf38c8
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 25 deletions.
6 changes: 6 additions & 0 deletions src/MGmol.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
#include "MGmol_prototypes.h"
#include <iostream>

#ifdef MGMOL_HAS_LIBROM
#include "librom.h"
#endif // MGMOL_HAS_LIBROM

// inline double one(const double r){ return 1.; }

inline double linear(const double r) { return 1. - r; }
Expand Down Expand Up @@ -359,6 +363,8 @@ class MGmol : public MGmolInterface
}

#ifdef MGMOL_HAS_LIBROM
const CAROM::Matrix* orbitals_to_carom_matrix(const OrbitalsType& orbitals);
void carom_matrix_to_orbitals(const CAROM::Matrix* Psi, OrbitalsType& orbitals);
int save_orbital_snapshot(std::string snapshot_dir, OrbitalsType& orbitals);
void project_orbital(std::string snapshot_dir, int rdim, OrbitalsType& orbitals);
#endif
Expand Down
59 changes: 34 additions & 25 deletions src/rom.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,36 @@
#include <fstream>
#include <sys/stat.h>

// Save the wavefunction snapshots
template <class OrbitalsType>
const CAROM::Matrix* MGmol<OrbitalsType>::orbitals_to_carom_matrix(const OrbitalsType& orbitals)
{
const int dim = orbitals.getLocNumpt();
const int num_orbitals = orbitals.chromatic_number();

CAROM::Options svd_options(dim, num_orbitals, 1);
CAROM::BasisGenerator basis_generator(svd_options, false, "foo");

for (int i = 0; i < num_orbitals; ++i)
basis_generator.takeSample(orbitals.getPsi(i));

return basis_generator.getSnapshotMatrix();
}

template <class OrbitalsType>
void MGmol<OrbitalsType>::carom_matrix_to_orbitals(const CAROM::Matrix* Psi, OrbitalsType& orbitals)
{
Control& ct = *(Control::instance());
Mesh* mesh = Mesh::instance();
pb::GridFunc<ORBDTYPE> gf_psi(mesh->grid(), ct.bcWF[0], ct.bcWF[1], ct.bcWF[2]);
CAROM::Vector psi;
for (int i = 0; i < Psi->numColumns(); ++i)
{
Psi->getColumn(i, psi);
gf_psi.assign(psi.getData());
orbitals.setPsi(gf_psi, i);
}
}

template <class OrbitalsType>
int MGmol<OrbitalsType>::save_orbital_snapshot(std::string file_path, OrbitalsType& orbitals)
{
Expand Down Expand Up @@ -60,35 +89,15 @@ int MGmol<OrbitalsType>::save_orbital_snapshot(std::string file_path, OrbitalsTy
template <class OrbitalsType>
void MGmol<OrbitalsType>::project_orbital(std::string file_path, int rdim, OrbitalsType& orbitals)
{
const int dim = orbitals.getLocNumpt();
const int totalSamples = orbitals.chromatic_number();

CAROM::Options svd_options(dim, totalSamples, 1);
CAROM::BasisGenerator basis_generator(svd_options, false, "foo");

for (int i = 0; i < totalSamples; ++i)
basis_generator.takeSample(orbitals.getPsi(i));
const CAROM::Matrix* orbital_snapshots = basis_generator.getSnapshotMatrix();
const CAROM::Matrix* Psi = orbitals_to_carom_matrix(orbitals);

CAROM::BasisReader reader(file_path);
CAROM::Matrix* orbital_basis = reader.getSpatialBasis(rdim);

CAROM::Matrix* proj_orbital_coeff = orbital_basis->transposeMult(orbital_snapshots);
CAROM::Matrix* proj_orbital_snapshots = orbital_basis->mult(proj_orbital_coeff);
CAROM::Matrix* Psi_reduced = orbital_basis->transposeMult(Psi);
CAROM::Matrix* Psi_projected = orbital_basis->mult(Psi_reduced);

Control& ct = *(Control::instance());
Mesh* mesh = Mesh::instance();
pb::GridFunc<ORBDTYPE> gf_psi(mesh->grid(), ct.bcWF[0], ct.bcWF[1], ct.bcWF[2]);
CAROM::Vector snapshot, proj_snapshot;
for (int i = 0; i < totalSamples; ++i)
{
orbital_snapshots->getColumn(i, snapshot);
proj_orbital_snapshots->getColumn(i, proj_snapshot);
gf_psi.assign(proj_snapshot.getData());
orbitals.setPsi(gf_psi, i);
snapshot -= proj_snapshot;
std::cout << "Error for orbital " << i << " = " << snapshot.norm() << std::endl;
}
carom_matrix_to_orbitals(Psi_projected, orbitals);
}

template class MGmol<LocGridOrbitals>;
Expand Down

0 comments on commit cdf38c8

Please sign in to comment.