Skip to content

Commit

Permalink
Revert "removed graph jastrow"
Browse files Browse the repository at this point in the history
This reverts commit 46e191b.
  • Loading branch information
NicoRenaud committed Dec 12, 2023
1 parent 926a51e commit 43f466b
Show file tree
Hide file tree
Showing 7 changed files with 809 additions and 0 deletions.
4 changes: 4 additions & 0 deletions qmctorch/wavefunction/jastrows/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .jastrow_graph import JastrowFactorGraph as JastrowFactor
from .mgcn.mgcn_predictor import MGCNPredictor

__all__ = ["JastrowFactor", "MGCNPredictor"]
45 changes: 45 additions & 0 deletions qmctorch/wavefunction/jastrows/graph/elec_elec_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import dgl
import torch


def ElecElecGraph(nelec, nup):
"""Create the elec-elec graph
Args:
nelec (int): total number of electrons
nup (int): numpber of spin up electrons
Returns:
[dgl.DGLGraph]: DGL graph
"""
edges = get_elec_elec_edges(nelec)
graph = dgl.graph(edges)
graph.ndata["node_types"] = get_elec_elec_ndata(nelec, nup)
return graph


def get_elec_elec_edges(nelec):
"""Compute the edge index of the electron-electron graph."""
ee_edges = ([], [])
for i in range(nelec - 1):
for j in range(i + 1, nelec):
ee_edges[0].append(i)
ee_edges[1].append(j)

ee_edges[0].append(j)
ee_edges[1].append(i)

return ee_edges


def get_elec_elec_ndata(nelec, nup):
"""Compute the node data of the elec-elec graph"""

ee_ndata = []
for i in range(nelec):
if i < nup:
ee_ndata.append(0)
else:
ee_ndata.append(1)

return torch.LongTensor(ee_ndata)
83 changes: 83 additions & 0 deletions qmctorch/wavefunction/jastrows/graph/elec_nuc_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import dgl
import torch
from mendeleev import element


def ElecNucGraph(natoms, atom_types, atomic_features, nelec, nup):
"""Create the elec-nuc graph
Args:
nelec (int): total number of electrons
nup (int): numpber of spin up electrons
Returns:
[dgl.DGLGraph]: DGL graph
"""
edges = get_elec_nuc_edges(natoms, nelec)
graph = dgl.graph(edges)
graph.ndata["node_types"] = get_elec_nuc_ndata(
natoms, atom_types, atomic_features, nelec, nup
)
return graph


def get_elec_nuc_edges(natoms, nelec):
"""Compute the edge index of the electron-nuclei graph."""
en_edges = ([], [])
for i in range(natoms):
for j in range(nelec):
en_edges[0].append(i)
en_edges[1].append(natoms + j)

en_edges[0].append(natoms + j)
en_edges[1].append(i)

# for i in range(natoms-1):
# for j in range(i+1, natoms):
# en_edges[0].append(i)
# en_edges[1].append(j)
return en_edges


def get_elec_nuc_ndata(natoms, atom_types, atomic_features, nelec, nup):
"""Compute the node data of the elec-elec graph"""

en_ndata = []
embed_number = 0
atom_dict = {}

for i in range(natoms):
if atom_types[i] not in atom_dict:
atom_dict[atom_types[i]] = embed_number
en_ndata.append(embed_number)
embed_number += 1
else:
en_ndata.append(atom_dict[atom_types[i]])

# feat = get_atomic_features(atom_types[i], atomic_features)
# feat.append(0) # spin
# en_ndata.append(feat)

for i in range(nelec):
# feat = get_atomic_features(None, atomic_features)
if i < nup:
en_ndata.append(embed_number)
else:
en_ndata.append(embed_number + 1)

return torch.LongTensor(en_ndata)


def get_atomic_features(atom_type, atomic_features):
"""Get the atomic features requested."""
if atom_type is not None:
data = element(atom_type)
feat = [getattr(data, feat) for feat in atomic_features]
else:
feat = []
for atf in atomic_features:
if atf == "atomic_number":
feat.append(-1)
else:
feat.append(0)
return feat
Loading

0 comments on commit 43f466b

Please sign in to comment.