Skip to content

Commit

Permalink
add test for orbital dependent jastrow
Browse files Browse the repository at this point in the history
  • Loading branch information
NicoRenaud committed Nov 30, 2023
1 parent 271cd33 commit 5a06f7a
Showing 1 changed file with 59 additions and 0 deletions.
59 changes: 59 additions & 0 deletions tests/wavefunction/test_slater_orbital_dependent_jastrow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import unittest
import numpy as np
import torch


from .base_test_cases import BaseTestCases

from qmctorch.scf import Molecule
from qmctorch.wavefunction.slater_jastrow import SlaterJastrow

from qmctorch.wavefunction.jastrows.elec_elec.jastrow_factor_electron_electron import (
JastrowFactorElectronElectron,
)
from qmctorch.wavefunction.jastrows.elec_elec.kernels import PadeJastrowKernel


from qmctorch.utils import set_torch_double_precision


torch.set_default_tensor_type(torch.DoubleTensor)


class TestSlaterJastrow(BaseTestCases.WaveFunctionBaseTest):
def setUp(self):
torch.manual_seed(101)
np.random.seed(101)

set_torch_double_precision()

# molecule
mol = Molecule(
atom="Li 0 0 0; H 0 0 3.14",
unit="bohr",
calculator="pyscf",
basis="sto-3g",
redo_scf=True,
)

# define jastrow factor
jastrow = JastrowFactorElectronElectron(mol, PadeJastrowKernel, orbital_dependent_kernel=True)

self.wf = SlaterJastrow(
mol,
kinetic="auto",
include_all_mo=False,
configs="single_double(2,2)",
jastrow=jastrow,
backflow=None,
)

self.random_fc_weight = torch.rand(self.wf.fc.weight.shape)
self.wf.fc.weight.data = self.random_fc_weight
self.nbatch = 11
self.pos = torch.Tensor(np.random.rand(self.nbatch, self.wf.nelec * 3))
self.pos.requires_grad = True


if __name__ == "__main__":
unittest.main()

0 comments on commit 5a06f7a

Please sign in to comment.