Skip to content

Commit

Permalink
Merge branch 'restart_real_time_propagation' into 'master'
Browse files Browse the repository at this point in the history
Restart real time propagation

See merge request npneq/inq!1126
  • Loading branch information
xavierandrade committed Aug 30, 2024
2 parents 097c88b + cade8bd commit d915519
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 10 deletions.
59 changes: 59 additions & 0 deletions src/interface/run.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ These are the options available:
Python example: `pinq.run.real_time()`
- Shell: `run resume`
Python: `run.resume()`
Restarts the run of a stopped simulation from a checkpoint. For the
moment this only works for real-time simulations.
Shell example: `inq run resume`
Python example: `pinq.run.resume()`
)"""";
}

Expand Down Expand Up @@ -107,6 +117,49 @@ These are the options available:
res.save(input::environment::global().comm(), ".inq/default_results_real_time");

}

static void resume() {

auto dirname = std::string(".inq/default_checkpoint/");

std::optional<std::string> run_type;
utils::load_optional(dirname + "/type", run_type);

if(not run_type.has_value()) {
actions::error(input::environment::global().comm(), "Cannot resume run, a checkpoint was not found.");
}

if(*run_type != "real-time") {
actions::error(input::environment::global().comm(), "Unknown checkpoint type '" + *run_type + "'.");
}

auto ions = systems::ions::load(".inq/default_ions");
auto bz = ionic::brillouin(systems::ions::load(".inq/default_ions"), input::kpoints::gamma());

try { bz = ionic::brillouin::load(".inq/default_brillouin"); }
catch(...) {
bz.save(input::environment::global().comm(), ".inq/default_brillouin");
}

systems::electrons electrons(ions, options::electrons::load(".inq/default_electrons_options"), bz);

if(not electrons.try_load(dirname + "/real-time/orbitals")) {
actions::error(input::environment::global().comm(), "Cannot load the restart electron orbitals.\n The checkpoint must be corrupt");
}

auto opts = options::real_time::load(".inq/default_real_time_options");

auto res = real_time::results::load(dirname + "/real-time/observables");

res.obs = opts.observables_container();

real_time::propagate(ions, electrons, [&res](auto obs){ res(obs); },
options::theory::load(".inq/default_theory"), opts, perturbations::blend::load(".inq/default_perturbations"),
/* start_step = */ res.total_steps);
res.save(input::environment::global().comm(), ".inq/default_results_real_time");

}


template <typename ArgsType>
void command(ArgsType const & args, bool quiet) const {
Expand All @@ -122,6 +175,11 @@ These are the options available:
real_time();
actions::normal_exit();
}

if(args.size() == 1 and args[0] == "resume") {
resume();
actions::normal_exit();
}

actions::error(input::environment::global().comm(), "Invalid syntax in the 'run' command");
}
Expand All @@ -135,6 +193,7 @@ These are the options available:
auto sub = module.def_submodule(name(), help());
sub.def("ground_state", &ground_state);
sub.def("real_time", &real_time);
sub.def("resume", &real_time);

}
#endif
Expand Down
25 changes: 17 additions & 8 deletions src/real_time/propagate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,19 @@ namespace inq {
namespace real_time {

template <typename ProcessFunction, typename IonSubPropagator = ionic::propagator::fixed, typename Perturbation = perturbations::none>
void propagate(systems::ions & ions, systems::electrons & electrons, ProcessFunction func, const options::theory & inter, const options::real_time & opts, Perturbation const & pert = {}){
void propagate(systems::ions & ions, systems::electrons & electrons, ProcessFunction func, const options::theory & inter,
const options::real_time & opts, Perturbation const & pert = {}, int const start_step = 0){

assert(start_step >= 0);

CALI_CXX_MARK_FUNCTION;

auto console = electrons.logger();

ionic::propagator::runtime ion_propagator{opts.ion_dynamics_value()};


if(start_step > 0) assert(ion_propagator.static_ions()); //restart doesn't work with moving ions for now

const double dt = opts.dt();
const int numsteps = opts.num_steps();

Expand All @@ -43,10 +49,13 @@ void propagate(systems::ions & ions, systems::electrons & electrons, ProcessFunc
std::string(" time step = {} atomictime ({:.2f} as)\n") +
std::string(" number of steps = {}\n") +
std::string(" propagation time = {} atomictime ({:.2f} fs)"), dt, dt/0.041341373, numsteps, numsteps*dt, numsteps*dt/41.341373);
if(start_step > 0) console->trace("restarting propagation from step {}", start_step);
console->trace("\n{}", pert);
}


if(start_step == 0) {
for(auto & phi : electrons.kpin()) pert.zero_step(phi);
}

electrons.spin_density() = observables::density::calculate(electrons);

Expand All @@ -56,8 +65,8 @@ void propagate(systems::ions & ions, systems::electrons & electrons, ProcessFunc
hamiltonian::energy energy;

sc.update_ionic_fields(electrons.states_comm(), ions, electrons.atomic_pot());
sc.update_hamiltonian(ham, energy, electrons.spin_density(), /* time = */ 0.0);

sc.update_hamiltonian(ham, energy, electrons.spin_density(), /* time = */ start_step*dt);
ham.exchange().update(electrons);

energy.calculate(ham, electrons);
Expand All @@ -69,13 +78,13 @@ void propagate(systems::ions & ions, systems::electrons & electrons, ProcessFunc
auto current = vector3<double, covariant>{0.0, 0.0, 0.0};
if(sc.has_induced_vector_potential()) current = observables::current(ions, electrons, ham);

func(real_time::viewables{false, 0, 0.0, ions, electrons, energy, forces, ham, pert});
if(start_step == 0) func(real_time::viewables{false, start_step, start_step*dt, ions, electrons, energy, forces, ham, pert});

if(console) console->trace("starting real-time propagation");
if(console) console->info("step {:9d} : t = {:9.3f} e = {:.12f}", 0, 0.0, energy.total());
if(console) console->info("step {:9d} : t = {:9.3f} e = {:.12f}", start_step, start_step*dt, energy.total());

auto iter_start_time = std::chrono::high_resolution_clock::now();
for(int istep = 0; istep < numsteps; istep++){
for(int istep = start_step; istep < numsteps; istep++){
CALI_CXX_MARK_SCOPE("time_step");

switch(opts.propagator()){
Expand Down
13 changes: 11 additions & 2 deletions src/real_time/results.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,17 @@ class results {

if(not observables.every(500)) return;

save(observables.electrons().full_comm(), ".inq/default_checkpoint/observables");
observables.electrons().save(".inq/default_checkpoint/orbitals");
auto dirname = std::string(".inq/default_checkpoint/");
auto error_message = "INQ error: Cannot save checkpoint '" + dirname + "'.";

//remove the type file and only create at the end, to avoid partially written checkpoints
if(observables.electrons().full_comm().root()) std::filesystem::remove(dirname + "/type");
observables.electrons().full_comm().barrier();

save(observables.electrons().full_comm(), dirname + "real-time/observables");
observables.electrons().save(dirname + "real-time/orbitals");

utils::save_value(observables.electrons().full_comm(), dirname + "/type", std::string("real-time"), error_message);
}

template <typename Comm>
Expand Down

0 comments on commit d915519

Please sign in to comment.