A collection of GPU-friendly and neural-network-friendly scalable Quantum Monte Carlo (QMC) implementations in JAX.
Currently supported functionalities:
- Diffusion Monte Carlo (DMC)
- Spin Symmetry Enforcement
JaQMC can be installed via the supplied setup.py file.
pip3 install -e .
JaQMC is modularizely designed for easier integration with various Neural Network based Quantum Monte Carlo (NNQMC) projects.
The functionalities are developed in jaqmc
module, while we provide a number
of scripts integrating with different NNQMC projects in example
directory.
The fixed-node DMC implementation introduced in Towards the ground state of molecules via diffusion Monte Carlo on neural networks
See DMC section for more details.
The spin symmetry enforced solution introduced in Symmetry enforced solution of the many-body Schrödinger equation with deep neural network
See Spin Symmetry section for more details.
The fixed-node diffusion Monte Carlo (FNDMC) implementation here has a simple interface. In the simplest case, it requires only a (real-valued) trial wavefunction, taking in a dim-3N electron configuration and producing two outputs: the sign of the wavefunction value and the logarithm of its absolute value. In more sophisticated cases, users can also provide the implementation of local energy and quantum force, for instance, when ECP is considered.
Several examples are provided integrating with neural-network-based trial wavefunctions. The DMC related config can be found in the examples/dmc_config.py
.
See here for instructions on how to play with those config or flags.
Please first install FermiNet following instructions in https://github.com/deepmind/ferminet. Then train FermiNet for your favorite atom / molecule and generate a checkpoint to be reused in DMC as the trial wavefunction.
python3 examples/dmc/ferminet/run.py --config $YOUR_FERMINET_CONFIG_FILE --config.log.save_path $YOUR_FERMINET_CKPT_DIRECTORY --dmc_config.iterations 100 --dmc_config.fix_size --dmc_config.block_size 10 --dmc_config.log.save_path $YOUR_DMC_CKPT_DIRECTORY
Please first install LapNet following instructions in https://github.com/bytedance/lapnet. Then train LapNet for your favorite atom / molecule and generate a checkpoint to be reused in DMC as the trial wavefunction.
python3 examples/dmc/lapnet/run.py --config $YOUR_LAPNET_CONFIG_FILE --config.log.save_path $YOUR_LAPNET_CKPT_DIRECTORY --dmc_config.iterations 100 --dmc_config.fix_size --dmc_config.block_size 10 --dmc_config.log.save_path $YOUR_DMC_CKPT_DIRECTORY
Please first install DeepErwin following instructions in https://mdsunivie.github.io/deeperwin/. Then train DeepErwin for your favorite atom / molecule and generate a checkpoint to be reused in DMC as the trial wavefunction.
python3 examples/dmc/deeperwin/run.py --deeperwin_ckpt $YOUR_DEEPERVIN_CKPT_FILE --dmc_config.iterations 100 --dmc_config.fix_size --dmc_config.block_size 10 --dmc_config.log.save_path $YOUR_DMC_CKPT_DIRECTORY
The entry point for DMC integration is the run
function in jaqmc/dmc/dmc.py
, which is quite heavily commented.
Basically you only need to construct your favorite trial wavefunction in JAX, then simply pass it to this run
function and it should work smoothly.
Please don't hesitate to file an issue if you need help to integrate with your favorite (JAX-implemented) trial wavefunction.
Note that our DMC implementation is "multi-node calculation ready" in the sense that if you initialize the distributed JAX runtime on a multi-node cluster, then our DMC implementation can do multi-node calculation correctly, i.e. aggregation across different computing nodes. See here for instructions on initialization of the distributed JAX runtime.
The data at each checkpoint step will be stored in the specified path (namely $YOUR_DMC_CKPT_DIRECTORY
in the examples above) with the naming pattern
dmc_data_{step}.tgz
which contains a csv file with the metric produced from each DMC step up to the checkpoint step. The columns of the metric file are
- step: The step index in DMC
- estimator: The mixed estimator calculated at each step, calculated and smoothed within a certain time window.
- offset: The energy offset used to update DMC walker weights.
- average: The local energy weighted average calculated at each DMC step.
- num_walkers: The total number of walkers across all the computing nodes.
- old_walkers: The number of walkers got rejected for too many times in the process.
- total_weight: The total weight of all walkers across all the computing nodes.
- acceptance_ratio: The acceptence ratio of the acceptence-rejection action.
- effective_time_step: The effective time step
- num_cutoff_updated, num_cutoff_orig: Debug related, indicating the number of outliers in terms of local energy.
We enforce the spin symmetry with two steps:
- Set the spin magnetic spin number to be the target spin value
$s_z = s$ , by setting the number of spin-up and spin-down electrons in the input of the neural network wavefunction. - Integrate
$\hat{S}_+$ penalty into the loss function to enforce the spin symmetry.
We implement loss
module in JaQMC for that purpose.
For each component of loss, such as VMC energy and spin related penalties, we build a
factory method to produce losses with the same interface:
class Loss(Protocol):
def __call__(self,
params: ParamTree,
func_state: BaseFuncState,
key: chex.PRNGKey,
data: jnp.ndarray) -> Tuple[jnp.ndarray, Tuple[BaseFuncState, BaseAuxData]]:
"""
Args:
params: network parameters.
func_state: function state passed to the loss function to control its behavior.
key: JAX PRNG state.
data: QMC walkers with electronic configuration to evaluate.
Returns:
(loss value, (updated func_state, auxillary data)
"""
This loss interface works well with KFAC optimizer. It is also flexible enough to work with optimizers in optax, SPRING and etc.
We also provide user-facing entry points in jaqmc/loss/factory.py
.
One for building func_state
, one of the inputs to the loss function, and
another one for building the loss function.
def build_func_state(step=None) -> FuncState:
'''
Helper function to create parent FuncState from actual data.
'''
......
Please first install LapNet following instructions in https://github.com/bytedance/lapnet.
To simulate singlet state for Oxygen atom with LapNet and spin symmetry enforced, simply turn on loss_config.enforce_spin.with_spin
flag
as follows.
python3 $JAQMC_PATH/examples/loss/lapnet/run.py --config $JAQMC_PATH/examples/loss/lapnet/atom_spin_state.py:O,0
--loss_config.enforce_spin.with_spin --config.$OTHER_LAPNET_CONFIGS --loss_config.enforce_spin.$OTHER_SPIN_CONFIGS
Note that this example script is by no means "production-ready". It is just a
show case on how to integrate the loss
module with exisiting NNQMC projects.
For instance, it's not including the pretrain phase since it has nothing to do
with the loss
module.
If you use certain functionalities of JaQMC in your work, please consider citing the corresponding papers.
@article{ren2023towards,
title={Towards the ground state of molecules via diffusion Monte Carlo on neural networks},
author={Ren, Weiluo and Fu, Weizhong and Wu, Xiaojie and Chen, Ji},
journal={Nature Communications},
volume={14},
number={1},
pages={1860},
year={2023},
publisher={Nature Publishing Group UK London}
}
@article{li2024symmetry,
title={Symmetry enforced solution of the many-body Schr$\backslash$" odinger equation with deep neural network},
author={Li, Zhe and Lu, Zixiang and Li, Ruichen and Wen, Xuelan and Li, Xiang and Wang, Liwei and Chen, Ji and Ren, Weiluo},
journal={arXiv preprint arXiv:2406.01222},
year={2024}
}