From 5a06f7ac94f28f1ff8b0bbc63f07dd5f52a19efa Mon Sep 17 00:00:00 2001 From: Nico Date: Thu, 30 Nov 2023 11:40:46 +0100 Subject: [PATCH] add test for orbital dependent jastrow --- .../test_slater_orbital_dependent_jastrow.py | 59 +++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 tests/wavefunction/test_slater_orbital_dependent_jastrow.py diff --git a/tests/wavefunction/test_slater_orbital_dependent_jastrow.py b/tests/wavefunction/test_slater_orbital_dependent_jastrow.py new file mode 100644 index 00000000..1392c61a --- /dev/null +++ b/tests/wavefunction/test_slater_orbital_dependent_jastrow.py @@ -0,0 +1,59 @@ +import unittest +import numpy as np +import torch + + +from .base_test_cases import BaseTestCases + +from qmctorch.scf import Molecule +from qmctorch.wavefunction.slater_jastrow import SlaterJastrow + +from qmctorch.wavefunction.jastrows.elec_elec.jastrow_factor_electron_electron import ( + JastrowFactorElectronElectron, +) +from qmctorch.wavefunction.jastrows.elec_elec.kernels import PadeJastrowKernel + + +from qmctorch.utils import set_torch_double_precision + + +torch.set_default_tensor_type(torch.DoubleTensor) + + +class TestSlaterJastrow(BaseTestCases.WaveFunctionBaseTest): + def setUp(self): + torch.manual_seed(101) + np.random.seed(101) + + set_torch_double_precision() + + # molecule + mol = Molecule( + atom="Li 0 0 0; H 0 0 3.14", + unit="bohr", + calculator="pyscf", + basis="sto-3g", + redo_scf=True, + ) + + # define jastrow factor + jastrow = JastrowFactorElectronElectron(mol, PadeJastrowKernel, orbital_dependent_kernel=True) + + self.wf = SlaterJastrow( + mol, + kinetic="auto", + include_all_mo=False, + configs="single_double(2,2)", + jastrow=jastrow, + backflow=None, + ) + + self.random_fc_weight = torch.rand(self.wf.fc.weight.shape) + self.wf.fc.weight.data = self.random_fc_weight + self.nbatch = 11 + self.pos = torch.Tensor(np.random.rand(self.nbatch, self.wf.nelec * 3)) + self.pos.requires_grad = True + + +if __name__ == "__main__": + unittest.main()