Skip to content

Commit

Permalink
Merge pull request #9 from BIG-MAP/partial_matrix
Browse files Browse the repository at this point in the history
Added functionality to be able to learn partial matrices
  • Loading branch information
pfebrer authored Jul 15, 2024
2 parents cea5b9e + be1314c commit 3ac5be8
Show file tree
Hide file tree
Showing 23 changed files with 611 additions and 108 deletions.
30 changes: 25 additions & 5 deletions src/e3nn_matrix/data/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ class PointBasis:
irreps : o3.Irreps
Irreps of the basis. E.g. ``o3.Irreps("3x0e + 2x1o")``
for a basis with 3 l=0 functions and 2 sets of l=1 functions.
``o3.Irreps("")``, the default value, means that this point
has no basis functions.
R : Union[float, np.ndarray]
The reach of the basis.
If a float, the same reach is used for all functions.
Expand All @@ -75,17 +78,17 @@ class PointBasis:
# Let's create a basis with 3 l=0 functions and 2 sets of l=1 functions.
# The convention for spherical harmonics will be the standard one.
# We call this type of basis set "A", and functions have a reach of 5.
basis = PointBasis("A", "spherical", o3.Irreps("3x0e + 2x1o"), 5)
basis = PointBasis("A", R=5, irreps=o3.Irreps("3x0e + 2x1o"), basis_convention="spherical")
# Same but with a different reach for l=0 (R=5) and l=1 functions (R=3).
basis = PointBasis("A", "spherical", o3.Irreps("3x0e + 2x1o"), np.array([5, 5, 5, 3, 3, 3, 3, 3, 3]))
basis = PointBasis("A", R=np.array([5, 5, 5, 3, 3, 3, 3, 3, 3]), irreps=o3.Irreps("3x0e + 2x1o"), basis_convention="spherical" )
"""

type: Union[str, int]
basis_convention: BasisConvention
irreps: o3.Irreps
R: Union[float, np.ndarray]
irreps: o3.Irreps = o3.Irreps("")
basis_convention: BasisConvention = "spherical"

def __post_init__(self):
assert isinstance(self.R, Number) or (
Expand Down Expand Up @@ -141,7 +144,7 @@ def from_sisl_atom(
type=atom.Z,
basis_convention=basis_convention,
irreps=get_atom_irreps(atom),
R=atom.R,
R=atom.R if atom.no != 0 else atom.R[0],
)

def to_sisl_atom(self, Z: int = 1) -> "sisl.Atom":
Expand All @@ -155,6 +158,9 @@ def to_sisl_atom(self, Z: int = 1) -> "sisl.Atom":

import sisl

if self.basis_size == 0:
return NoBasisAtom(Z=Z, R=self.R)

orbitals = []

R = (
Expand All @@ -174,3 +180,17 @@ def to_sisl_atom(self, Z: int = 1) -> "sisl.Atom":
i += 1

return sisl.Atom(Z=Z, orbitals=orbitals)

class NoBasisAtom(sisl.Atom):
"""Placeholder for atoms without orbitals.
This should no longer be needed once sisl allows atoms with 0 orbitals."""

@property
def no(self):
return 0

@property
def q0(self):
return np.array([])

61 changes: 58 additions & 3 deletions src/e3nn_matrix/data/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import numpy as np
import sisl

from .basis import PointBasis
from .basis import PointBasis, NoBasisAtom
from .matrices import OrbitalMatrix, BasisMatrix, get_matrix_cls
from .sparse import csr_to_block_dict

Expand Down Expand Up @@ -233,7 +233,11 @@ def from_matrix(
# sparse structure.
matrix_cls = get_matrix_cls(matrix.__class__)
matrix_block = csr_to_block_dict(
matrix._csr, matrix.atoms, nsc=matrix.nsc, matrix_cls=matrix_cls
matrix._csr,
matrix.atoms,
nsc=matrix.nsc,
matrix_cls=matrix_cls,
geometry_atoms=geometry.atoms,
)

kwargs["matrix"] = matrix_block
Expand All @@ -244,6 +248,7 @@ def from_matrix(
def from_run(
cls,
runfilepath: Union[str, Path],
geometry_path: Optional[Union[str, Path]] = None,
out_matrix: Optional[PhysicsMatrixType] = None,
basis: Optional[sisl.Atoms] = None,
) -> "OrbitalConfiguration":
Expand Down Expand Up @@ -290,6 +295,35 @@ def _read_geometry(main_input, basis):

return geometry

def _copy_basis(
original: sisl.Geometry, geometry: sisl.Geometry, notfound_ok=False
) -> sisl.Geometry:
import warnings

new_geometry = geometry.copy()

for atom in geometry.atoms.atom:
for basis_atom in original.atoms:
if basis_atom.tag == atom.tag:
break
else:
if not notfound_ok:
raise ValueError(f"Couldn't find atom {atom} in the basis")
basis_atom = NoBasisAtom(atom.Z, tag=atom.tag)

with warnings.catch_warnings():
warnings.simplefilter("ignore")
new_geometry.atoms.replace_atom(atom, basis_atom)

return new_geometry

if isinstance(main_input, sisl.io.fdfSileSiesta):
type_of_run = main_input.get("MD.TypeOfRun")
if type_of_run == "qmmm":
pipe_file = main_input.get("QMMM.Driver.QMRegionFile")
#geometry_path = main_input.file.parent / (pipe_file.split(".")[0] + ".last.pdb")
geometry_path = main_input.file.parent / (pipe_file.split(".")[0] + ".XV")

if out_matrix is not None:
# Get the method to read the desired matrix and read it
read = getattr(main_input, f"read_{out_matrix}")
Expand All @@ -299,11 +333,32 @@ def _read_geometry(main_input, basis):
else:
matrix = read()

kwargs = {}
if geometry_path is not None:
# If we have a geometry path, we will read the geometry from there.
#from ase.io import read

kwargs["geometry"] = sisl.Geometry.read(geometry_path)
kwargs["geometry"] = _copy_basis(
matrix.geometry, kwargs["geometry"], notfound_ok=True
)

metadata["geometry"] = kwargs.get("geometry", matrix.geometry)

# Now build the OrbitalConfiguration object using this matrix.
return cls.from_matrix(matrix=matrix, metadata=metadata)
return cls.from_matrix(matrix=matrix, metadata=metadata, **kwargs)
else:
# We have no matrix to read, we will just read the geometry.
geometry = _read_geometry(main_input, basis)

if geometry_path is not None:
# If we have a geometry path, we will read the geometry from there.
from ase.io import read

new_geometry = sisl.Geometry.new(read(geometry_path))
geometry = _copy_basis(geometry, new_geometry, notfound_ok=True)

metadata["geometry"] = geometry

# And build the OrbitalConfiguration object using this geometry.
return cls.from_geometry(geometry=geometry, metadata=metadata)
3 changes: 3 additions & 0 deletions src/e3nn_matrix/data/irreps_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ def get_atom_irreps(atom: sisl.Atom):
the basis irreps.
"""

if atom.no == 0:
return o3.Irreps("")

atom_irreps = []

# Array that stores the number of orbitals for each l.
Expand Down
8 changes: 7 additions & 1 deletion src/e3nn_matrix/data/matrices/basis_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,22 @@ def to_flat_nodes_and_edges(
blocks = [
(self.block_dict[i, i, 0] - point_matrices[point_types[i]]).flatten()
for i in order
if self.basis_count[i] > 0
]
else:
blocks = [self.block_dict[i, i, 0].flatten() for i in order]
blocks = [
self.block_dict[i, i, 0].flatten()
for i in order
if self.basis_count[i] > 0
]

node_values = np.concatenate(blocks)

assert edge_index.shape[0] == 2, "edge_index is assumed to be [2, n_edges]"
blocks = [
self.block_dict[edge[0], edge[1], sc_shift].flatten()
for edge, sc_shift in zip(edge_index.transpose(), edge_sc_shifts)
if self.basis_count[edge[0]] > 0 and self.basis_count[edge[1]] > 0
]

edge_values = np.concatenate(blocks)
Expand Down
132 changes: 132 additions & 0 deletions src/e3nn_matrix/data/node_feats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from typing import Callable, Optional

import numpy as np

class NodeFeature:

def __new__(cls, config, data_processor):
return cls.get_feature(config, data_processor)

registry = {}

def __init_subclass__(cls) -> None:
NodeFeature.registry[cls.__name__] = cls

@staticmethod
def get_feature(config: dict, data_processor) -> np.ndarray:
raise NotImplementedError

@staticmethod
def get_irreps(data_processor):
raise NotImplementedError

class OneHotZ(NodeFeature):

@staticmethod
def get_irreps(basis_table):
from e3nn import o3
return o3.Irreps([(len(basis_table), (0, 1))])

@staticmethod
def get_feature(config, data_processor):
indices = data_processor.get_point_types(config)
return data_processor.one_hot_encode(indices)

class WaterDipole(NodeFeature):

@staticmethod
def get_irreps(basis_table):
from e3nn import o3
return o3.Irreps("1x1o")

@staticmethod
def get_feature(config, data_processor):

n_atoms = len(config.positions)

z_dipole = np.array([0.0, 0.0, 0.0])
for position, point_type in zip(config.positions, config.point_types):
if point_type == 8 or point_type == 1:

z_dipole[2] += position[2]*(-2 if point_type == 8 else 1)

z_dipole = data_processor.cartesian_to_basis(z_dipole) / 30
z_dipole = np.tile(z_dipole, n_atoms).reshape(n_atoms, 3)

return z_dipole

class WaterDipoleInv(NodeFeature):

@staticmethod
def get_irreps(basis_table):
from e3nn import o3
return o3.Irreps("1x0e")

@staticmethod
def get_feature(config, data_processor):

n_atoms = len(config.positions)

z_dipole = np.array([0.0])
for position, point_type in zip(config.positions, config.point_types):
if point_type == 8 or point_type == 1:

z_dipole[0] += position[2]*(-2 if point_type == 8 else 1)

z_dipole = np.tile(z_dipole, n_atoms).reshape(n_atoms, 1)

return z_dipole / 30

class Nothing(NodeFeature):

@staticmethod
def get_irreps(basis_table):
from e3nn import o3
return o3.Irreps("1x0e")

@staticmethod
def get_feature(config, data_processor):

n_atoms = len(config.positions)

z_dipole = np.array([0.0])

z_dipole = np.tile(z_dipole, n_atoms).reshape(n_atoms, 1)

return z_dipole

class NothingVector(NodeFeature):

@staticmethod
def get_irreps(basis_table):
from e3nn import o3
return o3.Irreps("1x1o")

@staticmethod
def get_feature(config, data_processor):

n_atoms = len(config.positions)

z_dipole = np.array([0.0, 0.0, 0.0])

z_dipole = np.tile(z_dipole, n_atoms).reshape(n_atoms, 3)

return z_dipole

class One(NodeFeature):

@staticmethod
def get_irreps(basis_table):
from e3nn import o3
return o3.Irreps("1x0e")

@staticmethod
def get_feature(config, data_processor):

n_atoms = len(config.positions)

z_dipole = np.array([1.0])

z_dipole = np.tile(z_dipole, n_atoms).reshape(n_atoms, 1)

return z_dipole
Loading

0 comments on commit 3ac5be8

Please sign in to comment.