From be1314ce51c2552d84b7632180aec50c32f4d564 Mon Sep 17 00:00:00 2001 From: Pol Febrer Date: Mon, 15 Jul 2024 23:13:18 +0200 Subject: [PATCH] Allow learning partial matrices --- src/e3nn_matrix/data/basis.py | 30 +++- src/e3nn_matrix/data/configuration.py | 61 +++++++- src/e3nn_matrix/data/irreps_tools.py | 3 + src/e3nn_matrix/data/matrices/basis_matrix.py | 8 +- src/e3nn_matrix/data/node_feats.py | 132 ++++++++++++++++++ src/e3nn_matrix/data/processing.py | 100 ++++++++++++- src/e3nn_matrix/data/sparse.py | 26 +++- src/e3nn_matrix/data/table.py | 36 +++-- src/e3nn_matrix/data/tests/test_basis.py | 29 ++-- src/e3nn_matrix/models/mace/models.py | 18 ++- src/e3nn_matrix/tools/lightning/callbacks.py | 79 ++++++++++- src/e3nn_matrix/tools/lightning/cli.py | 20 +-- src/e3nn_matrix/tools/lightning/data.py | 9 +- src/e3nn_matrix/tools/lightning/model.py | 14 +- .../tools/lightning/models/mace.py | 7 +- src/e3nn_matrix/tools/server/extrapolation.py | 26 +++- src/e3nn_matrix/tools/server/server_app.py | 17 +-- src/e3nn_matrix/torch/conftest.py | 19 ++- src/e3nn_matrix/torch/data.py | 14 +- src/e3nn_matrix/torch/dataset.py | 2 +- src/e3nn_matrix/torch/load.py | 1 + src/e3nn_matrix/torch/modules/basis_matrix.py | 56 ++++++-- .../torch/modules/tests/test_basis_matrix.py | 12 +- 23 files changed, 611 insertions(+), 108 deletions(-) create mode 100644 src/e3nn_matrix/data/node_feats.py diff --git a/src/e3nn_matrix/data/basis.py b/src/e3nn_matrix/data/basis.py index 74f6b16..8590464 100644 --- a/src/e3nn_matrix/data/basis.py +++ b/src/e3nn_matrix/data/basis.py @@ -51,6 +51,9 @@ class PointBasis: irreps : o3.Irreps Irreps of the basis. E.g. ``o3.Irreps("3x0e + 2x1o")`` for a basis with 3 l=0 functions and 2 sets of l=1 functions. + + ``o3.Irreps("")``, the default value, means that this point + has no basis functions. R : Union[float, np.ndarray] The reach of the basis. If a float, the same reach is used for all functions. @@ -75,17 +78,17 @@ class PointBasis: # Let's create a basis with 3 l=0 functions and 2 sets of l=1 functions. # The convention for spherical harmonics will be the standard one. # We call this type of basis set "A", and functions have a reach of 5. - basis = PointBasis("A", "spherical", o3.Irreps("3x0e + 2x1o"), 5) + basis = PointBasis("A", R=5, irreps=o3.Irreps("3x0e + 2x1o"), basis_convention="spherical") # Same but with a different reach for l=0 (R=5) and l=1 functions (R=3). - basis = PointBasis("A", "spherical", o3.Irreps("3x0e + 2x1o"), np.array([5, 5, 5, 3, 3, 3, 3, 3, 3])) + basis = PointBasis("A", R=np.array([5, 5, 5, 3, 3, 3, 3, 3, 3]), irreps=o3.Irreps("3x0e + 2x1o"), basis_convention="spherical" ) """ type: Union[str, int] - basis_convention: BasisConvention - irreps: o3.Irreps R: Union[float, np.ndarray] + irreps: o3.Irreps = o3.Irreps("") + basis_convention: BasisConvention = "spherical" def __post_init__(self): assert isinstance(self.R, Number) or ( @@ -141,7 +144,7 @@ def from_sisl_atom( type=atom.Z, basis_convention=basis_convention, irreps=get_atom_irreps(atom), - R=atom.R, + R=atom.R if atom.no != 0 else atom.R[0], ) def to_sisl_atom(self, Z: int = 1) -> "sisl.Atom": @@ -155,6 +158,9 @@ def to_sisl_atom(self, Z: int = 1) -> "sisl.Atom": import sisl + if self.basis_size == 0: + return NoBasisAtom(Z=Z, R=self.R) + orbitals = [] R = ( @@ -174,3 +180,17 @@ def to_sisl_atom(self, Z: int = 1) -> "sisl.Atom": i += 1 return sisl.Atom(Z=Z, orbitals=orbitals) + +class NoBasisAtom(sisl.Atom): + """Placeholder for atoms without orbitals. + + This should no longer be needed once sisl allows atoms with 0 orbitals.""" + + @property + def no(self): + return 0 + + @property + def q0(self): + return np.array([]) + \ No newline at end of file diff --git a/src/e3nn_matrix/data/configuration.py b/src/e3nn_matrix/data/configuration.py index e11e5d0..aa55fd2 100644 --- a/src/e3nn_matrix/data/configuration.py +++ b/src/e3nn_matrix/data/configuration.py @@ -21,7 +21,7 @@ import numpy as np import sisl -from .basis import PointBasis +from .basis import PointBasis, NoBasisAtom from .matrices import OrbitalMatrix, BasisMatrix, get_matrix_cls from .sparse import csr_to_block_dict @@ -233,7 +233,11 @@ def from_matrix( # sparse structure. matrix_cls = get_matrix_cls(matrix.__class__) matrix_block = csr_to_block_dict( - matrix._csr, matrix.atoms, nsc=matrix.nsc, matrix_cls=matrix_cls + matrix._csr, + matrix.atoms, + nsc=matrix.nsc, + matrix_cls=matrix_cls, + geometry_atoms=geometry.atoms, ) kwargs["matrix"] = matrix_block @@ -244,6 +248,7 @@ def from_matrix( def from_run( cls, runfilepath: Union[str, Path], + geometry_path: Optional[Union[str, Path]] = None, out_matrix: Optional[PhysicsMatrixType] = None, basis: Optional[sisl.Atoms] = None, ) -> "OrbitalConfiguration": @@ -290,6 +295,35 @@ def _read_geometry(main_input, basis): return geometry + def _copy_basis( + original: sisl.Geometry, geometry: sisl.Geometry, notfound_ok=False + ) -> sisl.Geometry: + import warnings + + new_geometry = geometry.copy() + + for atom in geometry.atoms.atom: + for basis_atom in original.atoms: + if basis_atom.tag == atom.tag: + break + else: + if not notfound_ok: + raise ValueError(f"Couldn't find atom {atom} in the basis") + basis_atom = NoBasisAtom(atom.Z, tag=atom.tag) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + new_geometry.atoms.replace_atom(atom, basis_atom) + + return new_geometry + + if isinstance(main_input, sisl.io.fdfSileSiesta): + type_of_run = main_input.get("MD.TypeOfRun") + if type_of_run == "qmmm": + pipe_file = main_input.get("QMMM.Driver.QMRegionFile") + #geometry_path = main_input.file.parent / (pipe_file.split(".")[0] + ".last.pdb") + geometry_path = main_input.file.parent / (pipe_file.split(".")[0] + ".XV") + if out_matrix is not None: # Get the method to read the desired matrix and read it read = getattr(main_input, f"read_{out_matrix}") @@ -299,11 +333,32 @@ def _read_geometry(main_input, basis): else: matrix = read() + kwargs = {} + if geometry_path is not None: + # If we have a geometry path, we will read the geometry from there. + #from ase.io import read + + kwargs["geometry"] = sisl.Geometry.read(geometry_path) + kwargs["geometry"] = _copy_basis( + matrix.geometry, kwargs["geometry"], notfound_ok=True + ) + + metadata["geometry"] = kwargs.get("geometry", matrix.geometry) + # Now build the OrbitalConfiguration object using this matrix. - return cls.from_matrix(matrix=matrix, metadata=metadata) + return cls.from_matrix(matrix=matrix, metadata=metadata, **kwargs) else: # We have no matrix to read, we will just read the geometry. geometry = _read_geometry(main_input, basis) + if geometry_path is not None: + # If we have a geometry path, we will read the geometry from there. + from ase.io import read + + new_geometry = sisl.Geometry.new(read(geometry_path)) + geometry = _copy_basis(geometry, new_geometry, notfound_ok=True) + + metadata["geometry"] = geometry + # And build the OrbitalConfiguration object using this geometry. return cls.from_geometry(geometry=geometry, metadata=metadata) diff --git a/src/e3nn_matrix/data/irreps_tools.py b/src/e3nn_matrix/data/irreps_tools.py index 39d5e5f..07a00a3 100644 --- a/src/e3nn_matrix/data/irreps_tools.py +++ b/src/e3nn_matrix/data/irreps_tools.py @@ -26,6 +26,9 @@ def get_atom_irreps(atom: sisl.Atom): the basis irreps. """ + if atom.no == 0: + return o3.Irreps("") + atom_irreps = [] # Array that stores the number of orbitals for each l. diff --git a/src/e3nn_matrix/data/matrices/basis_matrix.py b/src/e3nn_matrix/data/matrices/basis_matrix.py index 8b6f7b6..dffb2ad 100644 --- a/src/e3nn_matrix/data/matrices/basis_matrix.py +++ b/src/e3nn_matrix/data/matrices/basis_matrix.py @@ -64,9 +64,14 @@ def to_flat_nodes_and_edges( blocks = [ (self.block_dict[i, i, 0] - point_matrices[point_types[i]]).flatten() for i in order + if self.basis_count[i] > 0 ] else: - blocks = [self.block_dict[i, i, 0].flatten() for i in order] + blocks = [ + self.block_dict[i, i, 0].flatten() + for i in order + if self.basis_count[i] > 0 + ] node_values = np.concatenate(blocks) @@ -74,6 +79,7 @@ def to_flat_nodes_and_edges( blocks = [ self.block_dict[edge[0], edge[1], sc_shift].flatten() for edge, sc_shift in zip(edge_index.transpose(), edge_sc_shifts) + if self.basis_count[edge[0]] > 0 and self.basis_count[edge[1]] > 0 ] edge_values = np.concatenate(blocks) diff --git a/src/e3nn_matrix/data/node_feats.py b/src/e3nn_matrix/data/node_feats.py new file mode 100644 index 0000000..9100a8f --- /dev/null +++ b/src/e3nn_matrix/data/node_feats.py @@ -0,0 +1,132 @@ +from typing import Callable, Optional + +import numpy as np + +class NodeFeature: + + def __new__(cls, config, data_processor): + return cls.get_feature(config, data_processor) + + registry = {} + + def __init_subclass__(cls) -> None: + NodeFeature.registry[cls.__name__] = cls + + @staticmethod + def get_feature(config: dict, data_processor) -> np.ndarray: + raise NotImplementedError + + @staticmethod + def get_irreps(data_processor): + raise NotImplementedError + +class OneHotZ(NodeFeature): + + @staticmethod + def get_irreps(basis_table): + from e3nn import o3 + return o3.Irreps([(len(basis_table), (0, 1))]) + + @staticmethod + def get_feature(config, data_processor): + indices = data_processor.get_point_types(config) + return data_processor.one_hot_encode(indices) + +class WaterDipole(NodeFeature): + + @staticmethod + def get_irreps(basis_table): + from e3nn import o3 + return o3.Irreps("1x1o") + + @staticmethod + def get_feature(config, data_processor): + + n_atoms = len(config.positions) + + z_dipole = np.array([0.0, 0.0, 0.0]) + for position, point_type in zip(config.positions, config.point_types): + if point_type == 8 or point_type == 1: + + z_dipole[2] += position[2]*(-2 if point_type == 8 else 1) + + z_dipole = data_processor.cartesian_to_basis(z_dipole) / 30 + z_dipole = np.tile(z_dipole, n_atoms).reshape(n_atoms, 3) + + return z_dipole + +class WaterDipoleInv(NodeFeature): + + @staticmethod + def get_irreps(basis_table): + from e3nn import o3 + return o3.Irreps("1x0e") + + @staticmethod + def get_feature(config, data_processor): + + n_atoms = len(config.positions) + + z_dipole = np.array([0.0]) + for position, point_type in zip(config.positions, config.point_types): + if point_type == 8 or point_type == 1: + + z_dipole[0] += position[2]*(-2 if point_type == 8 else 1) + + z_dipole = np.tile(z_dipole, n_atoms).reshape(n_atoms, 1) + + return z_dipole / 30 + +class Nothing(NodeFeature): + + @staticmethod + def get_irreps(basis_table): + from e3nn import o3 + return o3.Irreps("1x0e") + + @staticmethod + def get_feature(config, data_processor): + + n_atoms = len(config.positions) + + z_dipole = np.array([0.0]) + + z_dipole = np.tile(z_dipole, n_atoms).reshape(n_atoms, 1) + + return z_dipole + +class NothingVector(NodeFeature): + + @staticmethod + def get_irreps(basis_table): + from e3nn import o3 + return o3.Irreps("1x1o") + + @staticmethod + def get_feature(config, data_processor): + + n_atoms = len(config.positions) + + z_dipole = np.array([0.0, 0.0, 0.0]) + + z_dipole = np.tile(z_dipole, n_atoms).reshape(n_atoms, 3) + + return z_dipole + +class One(NodeFeature): + + @staticmethod + def get_irreps(basis_table): + from e3nn import o3 + return o3.Irreps("1x0e") + + @staticmethod + def get_feature(config, data_processor): + + n_atoms = len(config.positions) + + z_dipole = np.array([1.0]) + + z_dipole = np.tile(z_dipole, n_atoms).reshape(n_atoms, 1) + + return z_dipole \ No newline at end of file diff --git a/src/e3nn_matrix/data/processing.py b/src/e3nn_matrix/data/processing.py index 8b5e130..330a647 100644 --- a/src/e3nn_matrix/data/processing.py +++ b/src/e3nn_matrix/data/processing.py @@ -5,7 +5,7 @@ """ from __future__ import annotations -from typing import Optional, Tuple, Union, Dict, Any, Callable, Sequence, Generator +from typing import Optional, Tuple, Union, Dict, Any, Callable, Sequence, Generator, List from functools import cached_property from pathlib import Path import dataclasses @@ -58,6 +58,8 @@ class MatrixDataProcessor: symmetric_matrix: bool = False sub_point_matrix: bool = True out_matrix: Optional[PhysicsMatrixType] = None + node_attr_getters: List[Any] = dataclasses.field(default_factory=list) + def copy(self, **kwargs): """Create a copy of the object with the given attributes replaced.""" @@ -81,6 +83,28 @@ def get_config_kwargs(self, obj: Any) -> Dict[str, Any]: else: return {} + def torch_predict(self, torch_model, geometry: sisl.Geometry): + import torch + + from ..torch import BasisMatrixTorchData + + with torch.no_grad(): + # USE THE MODEL + # First, we need to process the input data, to get inputs as the model expects. + input_data = BasisMatrixTorchData.new( + geometry, data_processor=self, labels=False + ) + + # Then, we run the model. + out = torch_model(input_data) + + # And finally, we convert the output to a matrix. + matrix = self.matrix_from_data( + input_data, predictions=out + ) + + return matrix + def matrix_from_data( self, data: BasisMatrixData, @@ -638,6 +662,10 @@ def get_labels_from_types_and_edges( return point_labels, edge_labels + def get_node_attrs(self, config: BasisConfiguration) -> np.ndarray: + """Returns the initial features of nodes.""" + return np.concatenate([getter(config, self) for getter in self.node_attr_getters], axis=1) + def one_hot_encode(self, point_types: np.ndarray) -> np.ndarray: """One hot encodes a vector of point types. @@ -793,6 +821,11 @@ def labels_to_sparse_orbital( threshold=threshold, ) + # Remove atoms with no basis. + for i, point_basis in enumerate(self.basis_table.basis): + if point_basis.basis_size == 0: + matrix = matrix.remove(unique_atoms[i]) + return matrix @@ -959,6 +992,9 @@ class BasisMatrixData: postprocess outputs, for example. It includes the data processor. """ + _node_attr_keys = ("node_attrs", "positions", "point_types") + _edge_attr_keys = ("edge_types", "shifts", "neigh_isc") + num_nodes: Optional[int] edge_index: np.ndarray neigh_isc: np.ndarray @@ -973,6 +1009,8 @@ class BasisMatrixData: point_types: np.ndarray edge_types: np.ndarray edge_type_nlabels: np.ndarray + labels_point_filter: np.ndarray + labels_edge_filter: np.ndarray metadata: Dict[str, Any] def __init__( @@ -987,6 +1025,8 @@ def __init__( nsc: Optional[np.ndarray] = None, # [3,] point_labels: Optional[np.ndarray] = None, # [total_point_elements] edge_labels: Optional[np.ndarray] = None, # [total_edge_elements] + labels_point_filter: Optional[np.ndarray] = None, # [n_point_labels] + labels_edge_filter: Optional[np.ndarray] = None, # [n_edge_labels] point_types: Optional[np.ndarray] = None, # [n_nodes] edge_types: Optional[np.ndarray] = None, # [n_edges] edge_type_nlabels: Optional[np.ndarray] = None, # [n_edge_types] @@ -1003,6 +1043,8 @@ def __init__( nsc=nsc, point_labels=point_labels, edge_labels=edge_labels, + labels_point_filter=labels_point_filter, + labels_edge_filter=labels_edge_filter, point_types=point_types, edge_types=edge_types, edge_type_nlabels=edge_type_nlabels, @@ -1024,6 +1066,8 @@ def _sanitize_data( nsc: Optional[np.ndarray] = None, point_labels: Optional[np.ndarray] = None, # [total_point_elements] edge_labels: Optional[np.ndarray] = None, # [total_edge_elements] + labels_point_filter: Optional[np.ndarray] = None, # [total_point_elements] + labels_edge_filter: Optional[np.ndarray] = None, # [total_edge_elements] point_types: Optional[np.ndarray] = None, # [n_nodes] edge_types: Optional[np.ndarray] = None, # [n_edges] edge_type_nlabels: Optional[np.ndarray] = None, # [n_edge_types] @@ -1113,7 +1157,7 @@ def from_config( cls, config: BasisConfiguration, data_processor: MatrixDataProcessor, nsc=None ) -> "BasisMatrixData": indices = data_processor.get_point_types(config) - one_hot = data_processor.one_hot_encode(indices) + node_attrs = data_processor.get_node_attrs(config) # Search for the neighbors. We use the max radius of each atom as cutoff for looking over neighbors. # This means that two atoms ij are neighbors if they have some overlap between their orbitals. That is @@ -1172,7 +1216,7 @@ def from_config( return cls( edge_index=edge_index, neigh_isc=neigh_isc, - node_attrs=one_hot, + node_attrs=node_attrs, positions=config.positions, shifts=shifts, cell=config.cell if config.cell is not None else None, @@ -1218,3 +1262,53 @@ def to_sparse_orbital_matrix(self, threshold: float = 1e-8) -> sisl.SparseOrbita arrays = self.numpy_arrays() return data_processor.labels_to_sparse_orbital(arrays, threshold=threshold) + + def node_types_subgraph(self, node_types: np.ndarray) -> "BasisMatrixData": + """Returns a subgraph with only the nodes of the given types. + + If the BasisMatrixData has labels (i.e. a matrix), this function will + raise an error because we don't support filtering labels yet. + + Parameters + ---------- + node_types : + Array with the node types to keep. + """ + # Initialize the data dictionary, removing the num_nodes and n_edges keys + # which should be recomputed on init. Also, pass the data processor as an argument. + new_data = {**self._data} + new_data.pop("num_nodes") + new_data.pop("n_edges") + new_data["metadata"] = new_data["metadata"].copy() + new_data["data_processor"] = new_data["metadata"].pop("data_processor", None) + + # Filtering point labels and edge labels is complicated, we don't support it yet + if "point_labels" in new_data or "edge_labels" in new_data: + raise ValueError("point_labels and edge_labels are not supported yet") + + # Find the indices of the nodes that belong to the requested types + mask = np.isin(self.point_types, node_types) + # And the edge indices for edges between nodes that we will keep + edge_mask = np.all(mask[self.edge_index], axis=0) + + # Filter node attributes + for k in self._node_attr_keys: + if new_data.get(k) is not None: + new_data[k] = new_data[k][mask] + + # Filter edge indices + new_data["edge_index"] = new_data["edge_index"][:, edge_mask] + + # Filter edge attributes + for k in self._edge_attr_keys: + if new_data.get(k) is not None: + new_data[k] = new_data[k][edge_mask] + + # Set nlabels to 0 for edge types that are not present anymore + new_data["edge_type_nlabels"] = copy(new_data["edge_type_nlabels"]) + u_edge_types = abs(new_data["edge_types"]).unique() + for i in range(new_data["edge_type_nlabels"].shape[1]): + if i not in u_edge_types: + new_data["edge_type_nlabels"][:, i] = 0 + + return self.__class__(**new_data) diff --git a/src/e3nn_matrix/data/sparse.py b/src/e3nn_matrix/data/sparse.py index a98a2bb..396684a 100644 --- a/src/e3nn_matrix/data/sparse.py +++ b/src/e3nn_matrix/data/sparse.py @@ -3,8 +3,6 @@ Different sparse representations of a matrix are required during the different steps of a typical workflow using ``e3nn_matrix``. """ - -from dataclasses import dataclass from typing import Dict, Tuple, Type, Optional import itertools @@ -23,13 +21,25 @@ def csr_to_block_dict( spmat: sisl.SparseCSR, atoms: sisl.Atoms, nsc: np.ndarray, + geometry_atoms: Optional[sisl.Atoms] = None, matrix_cls: Type[OrbitalMatrix] = OrbitalMatrix, ) -> OrbitalMatrix: - """ - Creates a OrbitalMatrix object from a SparseCSR matrix - In the block dictionary of the OrbitalMatrix: - Each key is a 2-tuple of atom indices - each value is the corresponding block of the orbital matrix as a dense numpy ndarray + """Creates a OrbitalMatrix object from a SparseCSR matrix + + Parameters + ---------- + spmat : + The sparse matrix to convert to a block dictionary. + atoms : + The atoms object for the matrix, containing orbital information. + nsc : + The auxiliary supercell size. + matrix_cls : + Matrix class to initialize. + geometry_atoms : + The atoms object for the full geometry. This allows the matrix to contain + atoms without any orbital. Geometry atoms should contain the matrix atoms + first and then the orbital-less atoms. """ orbitals = atoms.orbitals @@ -42,6 +52,8 @@ def csr_to_block_dict( n_atoms=len(atoms.specie), ) + orbitals = geometry_atoms.orbitals if geometry_atoms is not None else atoms.orbitals + return matrix_cls(block_dict=block_dict, nsc=nsc, orbital_count=orbitals) diff --git a/src/e3nn_matrix/data/table.py b/src/e3nn_matrix/data/table.py index 0a0d799..9eb09e1 100644 --- a/src/e3nn_matrix/data/table.py +++ b/src/e3nn_matrix/data/table.py @@ -24,8 +24,7 @@ import sisl from .basis import PointBasis, BasisConvention, get_change_of_basis - - + class BasisTableWithEdges: """Stores the unique types of points in the system, with their basis and the possible edges. @@ -132,6 +131,7 @@ class BasisTableWithEdges: def __init__( self, basis: Sequence[PointBasis], get_point_matrix: Optional[Callable] = None ): + self._init_args = {"atoms": basis, "get_point_matrix": get_point_matrix} self.basis = list(basis) self.types = [point_basis.type for point_basis in self.basis] @@ -377,15 +377,24 @@ class AtomicTableWithEdges(BasisTableWithEdges): def __init__(self, atoms: Sequence[sisl.Atom]): from .matrices.physics.density_matrix import get_atomic_DM - self.atoms = list(atoms) + self.atoms = list([atom if isinstance(atom, sisl.Atom) else atom.to_sisl_atom(Z=atom.type) for atom in atoms]) - basis = [PointBasis.from_sisl_atom(atom) for atom in self.atoms] + basis = [ + PointBasis.from_sisl_atom(atom) + if not isinstance(atom, PointBasis) + else atom + for atom in atoms + ] super().__init__(basis=basis, get_point_matrix=None) + self._init_args = {"atoms": atoms} # Get the point matrix for each type. This is the matrix that a point would # have if it was the only one in the system, and it depends only on the type. - self.point_matrix = [get_atomic_DM(atom) for atom in self.atoms] + self.point_matrix = [ + get_atomic_DM(atom) if not isinstance(atom, PointBasis) else None + for atom in self.atoms + ] self.file_names = None self.file_contents = None @@ -414,7 +423,10 @@ def atomic_DM(self): @classmethod def from_basis_dir( - cls, basis_dir: str, basis_ext: str = "ion.xml" + cls, + basis_dir: str, + basis_ext: str = "ion.xml", + no_basis_atoms: Optional[dict] = None, ) -> "AtomicTableWithEdges": """Generates a table from a directory containing basis files. @@ -427,11 +439,13 @@ def from_basis_dir( """ basis_path = Path(basis_dir) - return cls.from_basis_glob(basis_path.glob(f"*.{basis_ext}")) + return cls.from_basis_glob( + basis_path.glob(f"*.{basis_ext}"), no_basis_atoms=no_basis_atoms + ) @classmethod def from_basis_glob( - cls, basis_glob: Union[str, Generator] + cls, basis_glob: Union[str, Generator], no_basis_atoms: Optional[dict] = None ) -> "AtomicTableWithEdges": """Generates a table from basis files that match a glob pattern. @@ -453,6 +467,12 @@ def from_basis_glob( # file_contents.append(f.read()) basis.append(sisl.get_sile(basis_file).read_basis()) + if no_basis_atoms is not None: + for k, v in no_basis_atoms.items(): + basis.append( + PointBasis(k, R=v["R"], basis_convention="siesta_spherical") + ) + obj = cls(basis) # obj.file_names = file_names # obj.file_contents = file_contents diff --git a/src/e3nn_matrix/data/tests/test_basis.py b/src/e3nn_matrix/data/tests/test_basis.py index b1e0bef..eac9805 100644 --- a/src/e3nn_matrix/data/tests/test_basis.py +++ b/src/e3nn_matrix/data/tests/test_basis.py @@ -9,25 +9,36 @@ def test_simplest(): - basis = PointBasis("A", "spherical", o3.Irreps("3x0e + 2x1o"), 5) + basis = PointBasis( + "A", basis_convention="spherical", irreps=o3.Irreps("3x0e + 2x1o"), R=5 + ) def test_siesta_convention(): - basis = PointBasis("A", "siesta_spherical", o3.Irreps("3x0e + 2x1o"), 5) + basis = PointBasis( + "A", basis_convention="siesta_spherical", irreps=o3.Irreps("3x0e + 2x1o"), R=5 + ) + + +def test_no_basis(): + basis = PointBasis("A", R=5) def test_multiple_R(): basis = PointBasis( "A", - "spherical", - o3.Irreps("3x0e + 2x1o"), - np.array([5, 5, 5, 3, 3, 3, 3, 3, 3]), + basis_convention="spherical", + irreps=o3.Irreps("3x0e + 2x1o"), + R=np.array([5, 5, 5, 3, 3, 3, 3, 3, 3]), ) # Wrong number of Rs with pytest.raises(AssertionError): basis = PointBasis( - "A", "spherical", o3.Irreps("3x0e + 2x1o"), np.array([5, 5, 5, 3, 3]) + "A", + basis_convention="spherical", + irreps=o3.Irreps("3x0e + 2x1o"), + R=np.array([5, 5, 5, 3, 3]), ) @@ -46,9 +57,9 @@ def test_from_sisl_atom(): def test_to_sisl_atom(): basis = PointBasis( "A", - "siesta_spherical", - o3.Irreps("3x0e + 2x1o"), - np.array([5, 5, 5, 3, 3, 3, 3, 3, 3]), + basis_convention="siesta_spherical", + irreps=o3.Irreps("3x0e + 2x1o"), + R=np.array([5, 5, 5, 3, 3, 3, 3, 3, 3]), ) atom = basis.to_sisl_atom() diff --git a/src/e3nn_matrix/models/mace/models.py b/src/e3nn_matrix/models/mace/models.py index 21c2319..74c0601 100644 --- a/src/e3nn_matrix/models/mace/models.py +++ b/src/e3nn_matrix/models/mace/models.py @@ -1,6 +1,6 @@ """Variant of the MACE model using the orbital matrix readouts.""" -from typing import Any, Dict, Type, Sequence +from typing import Any, Dict, Type, Sequence, Optional import torch from e3nn import o3 @@ -40,12 +40,16 @@ def __init__( node_block_readout: Type[NodeBlock], edge_block_readout: Type[EdgeBlock], only_last_readout: bool, + node_attr_irreps: Optional[o3.Irreps] = None, ): super().__init__() self.r_max = r_max + self.num_elements = num_elements # Embedding - node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) - node_feats_irreps = o3.Irreps([(hidden_irreps.count(o3.Irrep(0, 1)), (0, 1))]) + if node_attr_irreps is None: + node_attr_irreps = o3.Irreps([(num_elements, (0, 1))]) + node_feats_irreps = o3.Irreps([(hidden_irreps.count(ir.ir), ir.ir) for ir in node_attr_irreps]) + self.node_embedding = LinearNodeEmbeddingBlock( irreps_in=node_attr_irreps, irreps_out=node_feats_irreps ) @@ -145,7 +149,11 @@ def __init__( self.readouts.append(readout) - def forward(self, data: BasisMatrixTorchData, training=False) -> Dict[str, Any]: + def forward( + self, + data: BasisMatrixTorchData, + training=False, + ) -> Dict[str, Any]: # Setup # This is only if we want to compute matrix gradients. For now, we don't. # data.positions.requires_grad = True @@ -181,7 +189,7 @@ def forward(self, data: BasisMatrixTorchData, training=False) -> Dict[str, Any]: ) node_feats = product( - node_feats=node_feats, sc=sc, node_attrs=data["node_attrs"] + node_feats=node_feats, sc=sc, node_attrs=data["node_attrs"][:, :self.num_elements] ) if readout is None: diff --git a/src/e3nn_matrix/tools/lightning/callbacks.py b/src/e3nn_matrix/tools/lightning/callbacks.py index ce6884e..26bfb7a 100644 --- a/src/e3nn_matrix/tools/lightning/callbacks.py +++ b/src/e3nn_matrix/tools/lightning/callbacks.py @@ -25,19 +25,43 @@ from e3nn_matrix.tools.viz import plot_basis_matrix -class MatrixWriter(BasePredictionWriter): +class MatrixWriter(Callback): """Callback to write predicted matrices to disk.""" + + def __init__( + self, + output_file: str, + splits: Sequence = [ + "train", + "val", + "test", + "predict" + ], # I don't know why, but Sequence[str] breaks the lightning CLI + ): + super().__init__() + + splits = [ + "train", + "val", + "test", + "predict" + ] - def __init__(self, output_file: str, write_interval: str = "batch"): - super().__init__(write_interval) + if splits in ["train", "val", "test", "predict"]: + splits = [splits] + elif isinstance(splits, str): + raise ValueError(f"Invalid value for splits: {splits}") + + self.splits = splits self.output_file = output_file + self.out_is_absolute = Path(output_file).is_absolute() - def write_on_batch_end( + def _on_batch_end( self, + split: str, trainer: "pl.Trainer", pl_module: "pl.LightningModule", prediction: Dict, - batch_indices: Sequence[int], batch: Any, batch_idx: int, dataloader_idx: int, @@ -50,12 +74,53 @@ def write_on_batch_end( # Loop through structures in the batch for matrix_data in matrix_iter: + sparse_orbital_matrix = matrix_data.to_sparse_orbital_matrix() + # Get the path from which this structure was read. path = matrix_data.metadata["path"] - sparse_orbital_matrix = matrix_data.to_sparse_orbital_matrix() + out_file = Path(self.output_file.replace("$name$", path.parent.name)) + if not self.out_is_absolute: + out_file = path.parent / out_file + + if not out_file.parent.exists(): + out_file.parent.mkdir(parents=True) # And write the matrix to it. - sparse_orbital_matrix.write(path.parent / self.output_file) + sparse_orbital_matrix.write(out_file) + + + def on_train_batch_end( + self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None + ): + if "train" in self.splits: + self._on_batch_end( + "train", trainer, pl_module, outputs, batch, batch_idx, dataloader_idx + ) + + def on_validation_batch_end( + self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None + ): + if "val" in self.splits: + self._on_batch_end( + "val", trainer, pl_module, outputs, batch, batch_idx, dataloader_idx + ) + + def on_test_batch_end( + self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None + ): + if "test" in self.splits: + self._on_batch_end( + "test", trainer, pl_module, outputs, batch, batch_idx, dataloader_idx + ) + + def on_predict_batch_end( + self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None + ): + if "predict" in self.splits: + self._on_batch_end( + "predict", trainer, pl_module, outputs, batch, batch_idx, dataloader_idx + ) + class SamplewiseMetricsLogger(Callback): diff --git a/src/e3nn_matrix/tools/lightning/cli.py b/src/e3nn_matrix/tools/lightning/cli.py index 8723723..ae0cad3 100644 --- a/src/e3nn_matrix/tools/lightning/cli.py +++ b/src/e3nn_matrix/tools/lightning/cli.py @@ -13,6 +13,7 @@ ) import torch from jsonargparse import Namespace +from jsonargparse._typehints import ActionTypeHint from e3nn_matrix.torch.load import sanitize_checkpoint @@ -34,7 +35,9 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser): parser.link_arguments("data.root_dir", "model.root_dir") parser.link_arguments("data.basis_files", "model.basis_files") parser.link_arguments("data.basis_table", "model.basis_table") + parser.link_arguments("data.no_basis", "model.no_basis") parser.link_arguments("data.symmetric_matrix", "model.symmetric_matrix") + parser.link_arguments("data.initial_node_feats", "model.initial_node_feats") defaults = {} # Set logger defaults based on environment variables @@ -113,14 +116,15 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No # arguments are linked and therefore "data.x" might not exist because it is linked # to "model.x". That's why we need to set the defaults one by one inside a try/except # block (I found no way to check if the argument is defined in the parser). - for k in defaults: - for subkey in defaults[k]: - try: - subcommand_parser.set_defaults( - {f"{k}.{subkey}": defaults[k][subkey]} - ) - except KeyError: - pass + with ActionTypeHint.allow_default_instance_context(): + for k in defaults: + for subkey in defaults[k]: + try: + subcommand_parser.set_defaults( + {f"{k}.{subkey}": defaults[k][subkey]} + ) + except KeyError: + pass # We have set all the right defaults now! So we can reparse the arguments. super().parse_arguments(parser, args) diff --git a/src/e3nn_matrix/tools/lightning/data.py b/src/e3nn_matrix/tools/lightning/data.py index 13576c1..aa3232b 100644 --- a/src/e3nn_matrix/tools/lightning/data.py +++ b/src/e3nn_matrix/tools/lightning/data.py @@ -1,4 +1,5 @@ """Data loading for pytorch_lightning workflows.""" +from typing import Type, Union, List import json import math @@ -18,6 +19,7 @@ from e3nn_matrix.data.configuration import PhysicsMatrixType from e3nn_matrix.data.table import BasisTableWithEdges, AtomicTableWithEdges from e3nn_matrix.data.processing import MatrixDataProcessor +from e3nn_matrix.data.node_feats import NodeFeature from e3nn_matrix.torch.data import BasisMatrixTorchData from e3nn_matrix.torch.dataset import ( BasisMatrixDataset, @@ -31,6 +33,7 @@ def __init__( self, out_matrix: Optional[PhysicsMatrixType] = None, basis_files: Optional[str] = None, + no_basis: Optional[dict] = None, basis_table: Optional[BasisTableWithEdges] = None, root_dir: str = ".", train_runs: Optional[str] = None, @@ -45,6 +48,7 @@ def __init__( copy_root_to_tmp: bool = False, store_in_memory: bool = False, rotating_pool_size: Optional[int] = None, + initial_node_feats: str = "OneHotZ", ): """ @@ -83,9 +87,11 @@ def __init__( self.root_dir = root_dir self.basis_files = basis_files + self.no_basis = no_basis self.basis_table = basis_table self.out_matrix: Optional[PhysicsMatrixType] = out_matrix self.symmetric_matrix = symmetric_matrix + self.initial_node_feats = [NodeFeature.registry[k] for k in initial_node_feats.split(" ")] self.train_runs = train_runs self.val_runs = val_runs @@ -127,7 +133,7 @@ def setup(self, stage: str): # Read the basis from the basis files provided. assert self.basis_files is not None self.basis_table = AtomicTableWithEdges.from_basis_glob( - Path(root).glob(self.basis_files) + Path(root).glob(self.basis_files), no_basis_atoms=self.no_basis ) # Initialize the data. @@ -151,6 +157,7 @@ def setup(self, stage: str): out_matrix=self.out_matrix, symmetric_matrix=self.symmetric_matrix, sub_point_matrix=self.sub_point_matrix, + node_attr_getters=self.initial_node_feats, ) # Set the paths for each split diff --git a/src/e3nn_matrix/tools/lightning/model.py b/src/e3nn_matrix/tools/lightning/model.py index a41de50..c54ae2d 100644 --- a/src/e3nn_matrix/tools/lightning/model.py +++ b/src/e3nn_matrix/tools/lightning/model.py @@ -1,9 +1,11 @@ """Wrapping of raw models to use them in pytorch_lightning.""" from pathlib import Path -from typing import Type, Union +from typing import Type, Union, Optional import warnings +from e3nn import o3 + import pytorch_lightning as pl import torch @@ -11,6 +13,7 @@ from e3nn_matrix.data.metrics import OrbitalMatrixMetric, block_type_mse from e3nn_matrix.data.table import BasisTableWithEdges, AtomicTableWithEdges from e3nn_matrix.torch.load import sanitize_checkpoint +from e3nn_matrix.data.node_feats import NodeFeature from e3nn_matrix import __version__ @@ -27,7 +30,9 @@ def __init__( root_dir: str = ".", basis_files: Union[str, None] = None, basis_table: Union[BasisTableWithEdges, None] = None, + no_basis: Optional[dict] = None, loss: Type[OrbitalMatrixMetric] = block_type_mse, + initial_node_feats: str = "OneHotZ", **kwargs, ): super().__init__() @@ -39,10 +44,15 @@ def __init__( self.basis_table = None else: self.basis_table = AtomicTableWithEdges.from_basis_glob( - Path(root_dir).glob(basis_files) + Path(root_dir).glob(basis_files), no_basis_atoms=no_basis ) else: self.basis_table = basis_table + + self.initial_node_feats = [NodeFeature.registry[k] for k in initial_node_feats.split(" ")] + self.initial_node_feats_irreps = sum([f.get_irreps(self.basis_table) for f in self.initial_node_feats], o3.Irreps()).simplify() + + print(self.initial_node_feats_irreps) self.loss_fn = loss() diff --git a/src/e3nn_matrix/tools/lightning/models/mace.py b/src/e3nn_matrix/tools/lightning/models/mace.py index cb0785b..a0421c3 100644 --- a/src/e3nn_matrix/tools/lightning/models/mace.py +++ b/src/e3nn_matrix/tools/lightning/models/mace.py @@ -1,4 +1,4 @@ -from typing import Type, Union +from typing import Type, Union, Optional from e3nn import o3 import torch @@ -30,6 +30,7 @@ def __init__( root_dir: str = ".", basis_files: Union[str, None] = None, basis_table: Union[BasisTableWithEdges, None] = None, + no_basis: Optional[dict] = None, # r_max: float=3.0, num_bessel: int = 10, num_polynomial_cutoff: int = 3, @@ -45,6 +46,7 @@ def __init__( avg_num_neighbors: float = 1.0, # atomic_numbers: List[int], correlation: int = 1, + initial_node_feats: str = "OneHotZ", # unique_atoms: Sequence[sisl.Atom], matrix_readout: Type[MACEBasisMatrixReadout] = MACEBasisMatrixReadout, symmetric_matrix: bool = False, @@ -60,7 +62,9 @@ def __init__( root_dir=root_dir, basis_files=basis_files, basis_table=basis_table, + no_basis=no_basis, loss=loss, + initial_node_feats=initial_node_feats, model_cls=OrbitalMatrixMACE, ) self.save_hyperparameters() @@ -89,6 +93,7 @@ def __init__( node_block_readout=node_block_readout, edge_block_readout=edge_block_readout, only_last_readout=only_last_readout, + node_attr_irreps=self.initial_node_feats_irreps, ) def configure_optimizers(self): diff --git a/src/e3nn_matrix/tools/server/extrapolation.py b/src/e3nn_matrix/tools/server/extrapolation.py index bc04b8f..af017bf 100644 --- a/src/e3nn_matrix/tools/server/extrapolation.py +++ b/src/e3nn_matrix/tools/server/extrapolation.py @@ -386,7 +386,12 @@ def add_next_geometry(self, geometry): next_config = OrbitalConfiguration.from_geometry( geometry, metadata={"geometry": geometry} ) - self.configs.append(next_config) + self.add_next_config(next_config) + + def add_next_config(self, config): + print("ADDING CONFIG") + print(config.metadata["geometry"]) + self.configs.append(config) def add_last_matrix_ref(self, matrix_ref): self.last_matrix_ref = matrix_ref @@ -539,6 +544,7 @@ def write_extrapolate_from_series( descriptor_order: int = 2, node_rcond: float = 1e-6, edge_rcond: float = 1e-6, + m_0: Optional[str] = None, ): this_series = time_series[series] @@ -576,26 +582,36 @@ def write_extrapolate_from_series( if this_series.last_matrix_ref is not None: m = m + this_series.last_matrix_ref + if m_0 is not None: + mat_0 = sisl.get_sile(m_0).read_density_matrix(geometry=m.geometry) + m = mat_0 + m + m.write(out) @app.get("/add_step") def add_geometry(path: str, series: int = 0, matrix_ref: Union[str, None] = None): - geometry = sisl.get_sile(path).read_geometry(output=True) - time_series[series].add_next_geometry(geometry) + + config = OrbitalConfiguration.from_run(path) + time_series[series].add_next_config(config) if matrix_ref is not None: - time_series[series].add_last_matrix_ref(matrix_refs[matrix_ref](geometry)) + time_series[series].add_last_matrix_ref(matrix_refs[matrix_ref](config.metadata["geometry"])) # geometry_xv = sisl.get_sile(Path(path).parent / "siesta.XV").read_geometry() # print("XV", np.allclose(geometry.xyz, geometry_xv.xyz)) @app.get("/add_matrix") - def add_matrix(path: str, series: int = 0): + def add_matrix(path: str, series: int = 0, m_0: Optional[str] = None): this_series = time_series[series] sile = sisl.get_sile(path) matrix = getattr(sile, f"read_{this_series.processor.out_matrix}")() + + if m_0 is not None: + mat_0 = getattr(sisl.get_sile(m_0), f"read_{this_series.processor.out_matrix}")() + matrix = matrix - mat_0 + this_series.add_last_matrix(matrix) return app diff --git a/src/e3nn_matrix/tools/server/server_app.py b/src/e3nn_matrix/tools/server/server_app.py index 2037850..27c25d3 100644 --- a/src/e3nn_matrix/tools/server/server_app.py +++ b/src/e3nn_matrix/tools/server/server_app.py @@ -228,22 +228,7 @@ async def model_files_info(model_name: ModelName): } def predict_from_geometry(model, geometry): - with torch.no_grad(): - # USE THE MODEL - # First, we need to process the input data, to get inputs as the model expects. - input_data = BasisMatrixTorchData.new( - geometry, data_processor=model["data_processor"], labels=False - ) - - # Then, we run the model. - out = model["prediction_function"](input_data) - - # And finally, we convert the output to a matrix. - matrix = model["data_processor"].matrix_from_data( - input_data, predictions=out - ) - - return matrix + return model["data_processor"].torch_predict(model["prediction_function"], geometry) @api.post("/models/{model_name}/predict", response_class=FileResponse) async def predict( diff --git a/src/e3nn_matrix/torch/conftest.py b/src/e3nn_matrix/torch/conftest.py index 82ce021..4d20597 100644 --- a/src/e3nn_matrix/torch/conftest.py +++ b/src/e3nn_matrix/torch/conftest.py @@ -18,17 +18,26 @@ from e3nn_matrix.torch.modules import BasisMatrixReadout -@pytest.fixture(scope="module", params=[True, False]) -def long_A_basis(request): +@pytest.fixture(scope="module", params=["normal", "long_A", "nobasis_A"]) +def basis_type(request): return request.param @pytest.fixture(scope="module") -def ABA_basis_configuration(long_A_basis): +def ABA_basis_configuration(basis_type): """Dummy basis configuration with""" - point_1 = PointBasis("A", "spherical", o3.Irreps("0e"), R=5 if long_A_basis else 2) - point_2 = PointBasis("B", "spherical", o3.Irreps("1o"), R=5) + if basis_type == "nobasis_A": + point_1 = PointBasis("A", R=5) + else: + point_1 = PointBasis( + "A", + R=5 if basis_type == "long_A" else 2, + irreps=o3.Irreps("0e"), + basis_convention="spherical", + ) + + point_2 = PointBasis("B", R=5, irreps=o3.Irreps("1o"), basis_convention="spherical") positions = np.array([[0, 0, 0], [3.0, 0, 0], [5.0, 0, 0]]) diff --git a/src/e3nn_matrix/torch/data.py b/src/e3nn_matrix/torch/data.py index f288836..b26451a 100644 --- a/src/e3nn_matrix/torch/data.py +++ b/src/e3nn_matrix/torch/data.py @@ -49,7 +49,9 @@ def __getitem__(self, key: str) -> Any: return Data.__getitem__(self, key) def process_input_array(self, key: str, array: np.ndarray) -> Any: - if issubclass(array.dtype.type, float): + if isinstance(array, torch.Tensor): + return array + elif issubclass(array.dtype.type, float): return torch.tensor(array, dtype=torch.get_default_dtype()) else: return torch.tensor(array) @@ -59,3 +61,13 @@ def ensure_numpy(self, array: torch.Tensor) -> np.ndarray: return array.numpy(force=True) else: return np.array(array) + + def is_node_attr(self, key: str) -> bool: + return key in self._node_attr_keys + + def is_edge_attr(self, key: str) -> bool: + return key in self._edge_attr_keys + + @property + def _data(self): + return {**self._store} diff --git a/src/e3nn_matrix/torch/dataset.py b/src/e3nn_matrix/torch/dataset.py index f6d8af6..83dd95f 100644 --- a/src/e3nn_matrix/torch/dataset.py +++ b/src/e3nn_matrix/torch/dataset.py @@ -11,7 +11,7 @@ from ..data import BasisConfiguration from ..data.processing import MatrixDataProcessor -from .data import BasisMatrixTorchData +from .data import BasisMatrixTorchData class BasisMatrixDataset(torch.utils.data.Dataset): diff --git a/src/e3nn_matrix/torch/load.py b/src/e3nn_matrix/torch/load.py index 40cbb7a..bd46085 100644 --- a/src/e3nn_matrix/torch/load.py +++ b/src/e3nn_matrix/torch/load.py @@ -99,6 +99,7 @@ def load_from_lit_ckpt( sub_point_matrix=ckpt["datamodule_hyper_parameters"]["sub_point_matrix"], symmetric_matrix=ckpt["datamodule_hyper_parameters"]["symmetric_matrix"], basis_table=ckpt["basis_table"], + node_attr_getters=model.initial_node_feats, ) return model, data_processor diff --git a/src/e3nn_matrix/torch/modules/basis_matrix.py b/src/e3nn_matrix/torch/modules/basis_matrix.py index aafa072..37ddd49 100644 --- a/src/e3nn_matrix/torch/modules/basis_matrix.py +++ b/src/e3nn_matrix/torch/modules/basis_matrix.py @@ -257,13 +257,17 @@ def _init_self_interactions(self, basis_irreps, **kwargs) -> List[torch.nn.Modul self_interactions = [] for point_type_irreps in basis_irreps: - self_interactions.append( - MatrixBlock( - i_irreps=point_type_irreps, - j_irreps=point_type_irreps, - **kwargs, + if point_type_irreps.dim == 0: + # The point type has no basis functions + self_interactions.append(None) + else: + self_interactions.append( + MatrixBlock( + i_irreps=point_type_irreps, + j_irreps=point_type_irreps, + **kwargs, + ) ) - ) return self_interactions @@ -285,12 +289,16 @@ def _init_interactions( perms.append((-edge_type, neigh_type, point_type)) for signed_edge_type, point_i, point_j in perms: - interactions[point_i, point_j, signed_edge_type] = MatrixBlock( - i_irreps=basis_irreps[point_i], - j_irreps=basis_irreps[point_j], - symm_transpose=neigh_type == point_type, - **kwargs, - ) + if basis_irreps[point_i].dim == 0 or basis_irreps[point_j].dim == 0: + # One of the involved point types has no basis functions + interactions[point_i, point_j, signed_edge_type] = None + else: + interactions[point_i, point_j, signed_edge_type] = MatrixBlock( + i_irreps=basis_irreps[point_i], + j_irreps=basis_irreps[point_j], + symm_transpose=neigh_type == point_type, + **kwargs, + ) return interactions @@ -307,6 +315,11 @@ def summary(self) -> str: s += "Node operations:" for i, x in enumerate(self.self_interactions): point = self._unique_basis[i] + + if x is None: + s += f"\n ({point.type}) No basis functions." + continue + s = ( s + f"\n ({point.type}) {str(x.operation.__class__.__name__)}: ({point.irreps})^2 -> {x._irreps_out}" @@ -322,6 +335,10 @@ def summary(self) -> str: point = self._unique_basis[point_type] neigh = self._unique_basis[neigh_type] + if x is None: + s += f"\n ({point.type}, {neigh.type}) No basis functions." + continue + s = ( s + f"\n ({point.type}, {neigh.type}) {str(x.operation.__class__.__name__)}: ({point.irreps}) x ({neigh.irreps}) -> {x._irreps_out}." @@ -456,17 +473,24 @@ def forward( return (node_labels, edge_labels) def _forward_self_interactions( - self, node_types: torch.Tensor, node_kwargs, global_kwargs + self, + node_types: torch.Tensor, + node_kwargs, + global_kwargs, ) -> torch.Tensor: # Allocate a list where we will store the outputs of all node blocks. n_nodes = len(node_types) - node_labels = [None] * n_nodes + node_labels = [torch.tensor([], device=node_types.device)] * n_nodes # Call each unique self interaction function with only the features # of nodes that correspond to that type. for node_type, func in enumerate(self.self_interactions): + if func is None: + continue + # Select the features for nodes of this type mask = node_types == node_type + # Quick exit if there are no features of this type if not mask.any(): continue @@ -514,6 +538,10 @@ def _forward_interactions( # Call each unique interaction function with only the features # of edges that correspond to that type. for module_key, func in self.interactions.items(): + if func is None: + # Case where one of the point types has no basis functions. + continue + # The key of the module is the a tuple (int, int, int) converted to a string. point_type, neigh_type, edge_type = map(int, module_key[1:-1].split(",")) diff --git a/src/e3nn_matrix/torch/modules/tests/test_basis_matrix.py b/src/e3nn_matrix/torch/modules/tests/test_basis_matrix.py index 24d50b6..7c3f345 100644 --- a/src/e3nn_matrix/torch/modules/tests/test_basis_matrix.py +++ b/src/e3nn_matrix/torch/modules/tests/test_basis_matrix.py @@ -43,7 +43,7 @@ def test_irreps_in(ABA_basis_configuration: BasisConfiguration): assert str(readout) == str(readout2) -def test_readout(ABA_basis_configuration: BasisConfiguration, long_A_basis: bool): +def test_readout(ABA_basis_configuration: BasisConfiguration, basis_type: str): config = ABA_basis_configuration basis = ABA_basis_configuration.basis @@ -89,12 +89,12 @@ def test_readout(ABA_basis_configuration: BasisConfiguration, long_A_basis: bool ) assert isinstance(matrix, csr_matrix) - assert matrix.shape == (5, 5) - assert matrix.nnz == 25 if long_A_basis else 23 + assert matrix.shape == (5, 5) if basis_type != "nobasis_A" else (3, 3) + assert matrix.nnz == {"normal": 23, "long_A": 25, "nobasis_A": 9}[basis_type] def test_readout_filtering( - ABA_basis_configuration: BasisConfiguration, long_A_basis: bool + ABA_basis_configuration: BasisConfiguration, basis_type: str ): config = ABA_basis_configuration basis = ABA_basis_configuration.basis @@ -182,5 +182,5 @@ def forward(self, node_types, **kwargs): ) assert isinstance(matrix, csr_matrix) - assert matrix.shape == (5, 5) - assert matrix.nnz == 25 if long_A_basis else 23 + assert matrix.shape == (5, 5) if basis_type != "nobasis_A" else (3, 3) + assert matrix.nnz == {"normal": 23, "long_A": 25, "nobasis_A": 9}[basis_type]