-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
7 changed files
with
809 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .jastrow_graph import JastrowFactorGraph as JastrowFactor | ||
from .mgcn.mgcn_predictor import MGCNPredictor | ||
|
||
__all__ = ["JastrowFactor", "MGCNPredictor"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import dgl | ||
import torch | ||
|
||
|
||
def ElecElecGraph(nelec, nup): | ||
"""Create the elec-elec graph | ||
Args: | ||
nelec (int): total number of electrons | ||
nup (int): numpber of spin up electrons | ||
Returns: | ||
[dgl.DGLGraph]: DGL graph | ||
""" | ||
edges = get_elec_elec_edges(nelec) | ||
graph = dgl.graph(edges) | ||
graph.ndata["node_types"] = get_elec_elec_ndata(nelec, nup) | ||
return graph | ||
|
||
|
||
def get_elec_elec_edges(nelec): | ||
"""Compute the edge index of the electron-electron graph.""" | ||
ee_edges = ([], []) | ||
for i in range(nelec - 1): | ||
for j in range(i + 1, nelec): | ||
ee_edges[0].append(i) | ||
ee_edges[1].append(j) | ||
|
||
ee_edges[0].append(j) | ||
ee_edges[1].append(i) | ||
|
||
return ee_edges | ||
|
||
|
||
def get_elec_elec_ndata(nelec, nup): | ||
"""Compute the node data of the elec-elec graph""" | ||
|
||
ee_ndata = [] | ||
for i in range(nelec): | ||
if i < nup: | ||
ee_ndata.append(0) | ||
else: | ||
ee_ndata.append(1) | ||
|
||
return torch.LongTensor(ee_ndata) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import dgl | ||
import torch | ||
from mendeleev import element | ||
|
||
|
||
def ElecNucGraph(natoms, atom_types, atomic_features, nelec, nup): | ||
"""Create the elec-nuc graph | ||
Args: | ||
nelec (int): total number of electrons | ||
nup (int): numpber of spin up electrons | ||
Returns: | ||
[dgl.DGLGraph]: DGL graph | ||
""" | ||
edges = get_elec_nuc_edges(natoms, nelec) | ||
graph = dgl.graph(edges) | ||
graph.ndata["node_types"] = get_elec_nuc_ndata( | ||
natoms, atom_types, atomic_features, nelec, nup | ||
) | ||
return graph | ||
|
||
|
||
def get_elec_nuc_edges(natoms, nelec): | ||
"""Compute the edge index of the electron-nuclei graph.""" | ||
en_edges = ([], []) | ||
for i in range(natoms): | ||
for j in range(nelec): | ||
en_edges[0].append(i) | ||
en_edges[1].append(natoms + j) | ||
|
||
en_edges[0].append(natoms + j) | ||
en_edges[1].append(i) | ||
|
||
# for i in range(natoms-1): | ||
# for j in range(i+1, natoms): | ||
# en_edges[0].append(i) | ||
# en_edges[1].append(j) | ||
return en_edges | ||
|
||
|
||
def get_elec_nuc_ndata(natoms, atom_types, atomic_features, nelec, nup): | ||
"""Compute the node data of the elec-elec graph""" | ||
|
||
en_ndata = [] | ||
embed_number = 0 | ||
atom_dict = {} | ||
|
||
for i in range(natoms): | ||
if atom_types[i] not in atom_dict: | ||
atom_dict[atom_types[i]] = embed_number | ||
en_ndata.append(embed_number) | ||
embed_number += 1 | ||
else: | ||
en_ndata.append(atom_dict[atom_types[i]]) | ||
|
||
# feat = get_atomic_features(atom_types[i], atomic_features) | ||
# feat.append(0) # spin | ||
# en_ndata.append(feat) | ||
|
||
for i in range(nelec): | ||
# feat = get_atomic_features(None, atomic_features) | ||
if i < nup: | ||
en_ndata.append(embed_number) | ||
else: | ||
en_ndata.append(embed_number + 1) | ||
|
||
return torch.LongTensor(en_ndata) | ||
|
||
|
||
def get_atomic_features(atom_type, atomic_features): | ||
"""Get the atomic features requested.""" | ||
if atom_type is not None: | ||
data = element(atom_type) | ||
feat = [getattr(data, feat) for feat in atomic_features] | ||
else: | ||
feat = [] | ||
for atf in atomic_features: | ||
if atf == "atomic_number": | ||
feat.append(-1) | ||
else: | ||
feat.append(0) | ||
return feat |
Oops, something went wrong.