diff --git a/src/interface/run.hpp b/src/interface/run.hpp index 6fe36231..4a52edc2 100644 --- a/src/interface/run.hpp +++ b/src/interface/run.hpp @@ -61,6 +61,16 @@ These are the options available: Python example: `pinq.run.real_time()` +- Shell: `run resume` + Python: `run.resume()` + + Restarts the run of a stopped simulation from a checkpoint. For the + moment this only works for real-time simulations. + + Shell example: `inq run resume` + Python example: `pinq.run.resume()` + + )""""; } @@ -107,6 +117,49 @@ These are the options available: res.save(input::environment::global().comm(), ".inq/default_results_real_time"); } + + static void resume() { + + auto dirname = std::string(".inq/default_checkpoint/"); + + std::optional run_type; + utils::load_optional(dirname + "/type", run_type); + + if(not run_type.has_value()) { + actions::error(input::environment::global().comm(), "Cannot resume run, a checkpoint was not found."); + } + + if(*run_type != "real-time") { + actions::error(input::environment::global().comm(), "Unknown checkpoint type '" + *run_type + "'."); + } + + auto ions = systems::ions::load(".inq/default_ions"); + auto bz = ionic::brillouin(systems::ions::load(".inq/default_ions"), input::kpoints::gamma()); + + try { bz = ionic::brillouin::load(".inq/default_brillouin"); } + catch(...) { + bz.save(input::environment::global().comm(), ".inq/default_brillouin"); + } + + systems::electrons electrons(ions, options::electrons::load(".inq/default_electrons_options"), bz); + + if(not electrons.try_load(dirname + "/real-time/orbitals")) { + actions::error(input::environment::global().comm(), "Cannot load the restart electron orbitals.\n The checkpoint must be corrupt"); + } + + auto opts = options::real_time::load(".inq/default_real_time_options"); + + auto res = real_time::results::load(dirname + "/real-time/observables"); + + res.obs = opts.observables_container(); + + real_time::propagate(ions, electrons, [&res](auto obs){ res(obs); }, + options::theory::load(".inq/default_theory"), opts, perturbations::blend::load(".inq/default_perturbations"), + /* start_step = */ res.total_steps); + res.save(input::environment::global().comm(), ".inq/default_results_real_time"); + + } + template void command(ArgsType const & args, bool quiet) const { @@ -122,6 +175,11 @@ These are the options available: real_time(); actions::normal_exit(); } + + if(args.size() == 1 and args[0] == "resume") { + resume(); + actions::normal_exit(); + } actions::error(input::environment::global().comm(), "Invalid syntax in the 'run' command"); } @@ -135,6 +193,7 @@ These are the options available: auto sub = module.def_submodule(name(), help()); sub.def("ground_state", &ground_state); sub.def("real_time", &real_time); + sub.def("resume", &real_time); } #endif diff --git a/src/real_time/propagate.hpp b/src/real_time/propagate.hpp index 0591cc30..a1898f58 100644 --- a/src/real_time/propagate.hpp +++ b/src/real_time/propagate.hpp @@ -28,13 +28,19 @@ namespace inq { namespace real_time { template -void propagate(systems::ions & ions, systems::electrons & electrons, ProcessFunction func, const options::theory & inter, const options::real_time & opts, Perturbation const & pert = {}){ +void propagate(systems::ions & ions, systems::electrons & electrons, ProcessFunction func, const options::theory & inter, + const options::real_time & opts, Perturbation const & pert = {}, int const start_step = 0){ + + assert(start_step >= 0); + CALI_CXX_MARK_FUNCTION; auto console = electrons.logger(); ionic::propagator::runtime ion_propagator{opts.ion_dynamics_value()}; - + + if(start_step > 0) assert(ion_propagator.static_ions()); //restart doesn't work with moving ions for now + const double dt = opts.dt(); const int numsteps = opts.num_steps(); @@ -43,10 +49,13 @@ void propagate(systems::ions & ions, systems::electrons & electrons, ProcessFunc std::string(" time step = {} atomictime ({:.2f} as)\n") + std::string(" number of steps = {}\n") + std::string(" propagation time = {} atomictime ({:.2f} fs)"), dt, dt/0.041341373, numsteps, numsteps*dt, numsteps*dt/41.341373); + if(start_step > 0) console->trace("restarting propagation from step {}", start_step); console->trace("\n{}", pert); } - + + if(start_step == 0) { for(auto & phi : electrons.kpin()) pert.zero_step(phi); + } electrons.spin_density() = observables::density::calculate(electrons); @@ -56,8 +65,8 @@ void propagate(systems::ions & ions, systems::electrons & electrons, ProcessFunc hamiltonian::energy energy; sc.update_ionic_fields(electrons.states_comm(), ions, electrons.atomic_pot()); - sc.update_hamiltonian(ham, energy, electrons.spin_density(), /* time = */ 0.0); - + sc.update_hamiltonian(ham, energy, electrons.spin_density(), /* time = */ start_step*dt); + ham.exchange().update(electrons); energy.calculate(ham, electrons); @@ -69,13 +78,13 @@ void propagate(systems::ions & ions, systems::electrons & electrons, ProcessFunc auto current = vector3{0.0, 0.0, 0.0}; if(sc.has_induced_vector_potential()) current = observables::current(ions, electrons, ham); - func(real_time::viewables{false, 0, 0.0, ions, electrons, energy, forces, ham, pert}); + if(start_step == 0) func(real_time::viewables{false, start_step, start_step*dt, ions, electrons, energy, forces, ham, pert}); if(console) console->trace("starting real-time propagation"); - if(console) console->info("step {:9d} : t = {:9.3f} e = {:.12f}", 0, 0.0, energy.total()); + if(console) console->info("step {:9d} : t = {:9.3f} e = {:.12f}", start_step, start_step*dt, energy.total()); auto iter_start_time = std::chrono::high_resolution_clock::now(); - for(int istep = 0; istep < numsteps; istep++){ + for(int istep = start_step; istep < numsteps; istep++){ CALI_CXX_MARK_SCOPE("time_step"); switch(opts.propagator()){ diff --git a/src/real_time/results.hpp b/src/real_time/results.hpp index 91c7ce96..4aa8285a 100644 --- a/src/real_time/results.hpp +++ b/src/real_time/results.hpp @@ -52,8 +52,17 @@ class results { if(not observables.every(500)) return; - save(observables.electrons().full_comm(), ".inq/default_checkpoint/observables"); - observables.electrons().save(".inq/default_checkpoint/orbitals"); + auto dirname = std::string(".inq/default_checkpoint/"); + auto error_message = "INQ error: Cannot save checkpoint '" + dirname + "'."; + + //remove the type file and only create at the end, to avoid partially written checkpoints + if(observables.electrons().full_comm().root()) std::filesystem::remove(dirname + "/type"); + observables.electrons().full_comm().barrier(); + + save(observables.electrons().full_comm(), dirname + "real-time/observables"); + observables.electrons().save(dirname + "real-time/orbitals"); + + utils::save_value(observables.electrons().full_comm(), dirname + "/type", std::string("real-time"), error_message); } template