From 8559edab979e048da87d0ea3c99d61e43fc92d9d Mon Sep 17 00:00:00 2001 From: Nick Papior Date: Mon, 25 Nov 2024 14:55:08 +0100 Subject: [PATCH] streamlined projection options (#875) * streamlined projection options Still things to do, but for now this catches many things. One thing is that 'basis' projections are now called hadamard (the proper name of the operation). While 'basis' is still allowed it seems better to streamline a against a common name. Signed-off-by: Nick Papior --------- Signed-off-by: Nick Papior --- CHANGELOG.md | 4 + docs/api/typing.rst | 6 ++ src/sisl/__init__.py | 4 + src/sisl/physics/__init__.py | 1 + src/sisl/physics/_common.py | 41 ++++++++ src/sisl/physics/_feature.py | 15 +-- src/sisl/physics/_matrix_ddk.pyx | 2 +- src/sisl/physics/_matrix_dk.pyx | 2 +- src/sisl/physics/_matrix_k.pyx | 2 +- src/sisl/physics/electron.py | 81 +++++++++++++--- src/sisl/physics/hamiltonian.py | 2 +- src/sisl/physics/state.py | 107 +++++++++++++-------- src/sisl/physics/tests/test_electron.py | 5 +- src/sisl/physics/tests/test_hamiltonian.py | 10 +- src/sisl/physics/tests/test_state.py | 10 +- src/sisl/typing/_physics.py | 19 +++- 16 files changed, 223 insertions(+), 88 deletions(-) create mode 100644 src/sisl/physics/_common.py diff --git a/CHANGELOG.md b/CHANGELOG.md index fcef2fbd2e..1d7b72543c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,10 @@ we hit release version 1.0.0. import sisl sisl.geom.graphene +### Fixed + +- `projection` arguments of several functions has been streamlined + ## [0.15.2] - 2024-11-06 diff --git a/docs/api/typing.rst b/docs/api/typing.rst index 1106a75a34..5cabbc1a55 100644 --- a/docs/api/typing.rst +++ b/docs/api/typing.rst @@ -48,6 +48,12 @@ The typing types are shown below: LatticeLike LatticeOrGeometry LatticeOrGeometryLike + ProjectionTypeMatrix + ProjectionTypeTrace + ProjectionTypeDiag + ProjectionTypeHadamard + ProjectionTypeHadamardAtoms + ProjectionType SileLike SparseMatrix SparseMatrixExt diff --git a/src/sisl/__init__.py b/src/sisl/__init__.py index 51f5ed86cd..8b1f7b065b 100644 --- a/src/sisl/__init__.py +++ b/src/sisl/__init__.py @@ -219,6 +219,10 @@ def __getattr__(attr): import sisl.constant as constant return constant + if attr == "typing": + import sisl.typing as typing + + return typing raise AttributeError(f"module {__name__} has no attribute {attr}") diff --git a/src/sisl/physics/__init__.py b/src/sisl/physics/__init__.py index d9886eac7d..88f81ffe57 100644 --- a/src/sisl/physics/__init__.py +++ b/src/sisl/physics/__init__.py @@ -89,6 +89,7 @@ """ +from ._common import * from ._feature import * from .distribution import * from .sparse import * diff --git a/src/sisl/physics/_common.py b/src/sisl/physics/_common.py new file mode 100644 index 0000000000..56d94bfb55 --- /dev/null +++ b/src/sisl/physics/_common.py @@ -0,0 +1,41 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at https://mozilla.org/MPL/2.0/. +from __future__ import annotations + +from sisl.typing import GaugeType, ProjectionType + +__all__ = ["comply_gauge", "comply_projection"] + + +def comply_gauge(gauge: GaugeType) -> str: + """Comply the gauge to one of two words: atom | cell""" + return { + "R": "cell", + "cell": "cell", + "r": "atom", + "orbital": "atom", + "orbitals": "atom", + "atom": "atom", + "atoms": "atom", + }[gauge] + + +def comply_projection(projection: ProjectionType) -> str: + """Comply the projection to one of the allowed variants""" + return { + "matrix": "matrix", + "ij": "matrix", + "trace": "trace", + "sum": "trace", + "diagonal": "diagonal", + "diag": "diagonal", + "ii": "diagonal", + "hadamard": "hadamard", + "basis": "hadamard", + "orbital": "hadamard", + "orbitals": "hadamard", + "hadamard:atoms": "hadamard:atoms", + "atoms": "hadamard:atoms", + "atom": "hadamard:atoms", + }[projection] diff --git a/src/sisl/physics/_feature.py b/src/sisl/physics/_feature.py index fa96bd19dc..9ed2319acd 100644 --- a/src/sisl/physics/_feature.py +++ b/src/sisl/physics/_feature.py @@ -7,20 +7,7 @@ import numpy as np -__all__ = ["yield_manifolds", "comply_gauge"] - - -def comply_gauge(gauge: GaugeType) -> str: - """Comply the gauge to one of two words: atom | cell""" - return { - "R": "cell", - "cell": "cell", - "r": "atom", - "orbital": "atom", - "orbitals": "atom", - "atom": "atom", - "atoms": "atom", - }[gauge] +__all__ = ["yield_manifolds"] def yield_manifolds(values, atol: float = 0.1, axis: int = -1) -> Iterator[list]: diff --git a/src/sisl/physics/_matrix_ddk.pyx b/src/sisl/physics/_matrix_ddk.pyx index 1d8fb03e6e..d1fa01cc41 100644 --- a/src/sisl/physics/_matrix_ddk.pyx +++ b/src/sisl/physics/_matrix_ddk.pyx @@ -8,7 +8,7 @@ import numpy as np cimport numpy as np -from ._feature import comply_gauge +from ._common import comply_gauge from ._matrix_phase3 import * from ._matrix_phase3_nc import * from ._matrix_phase3_so import * diff --git a/src/sisl/physics/_matrix_dk.pyx b/src/sisl/physics/_matrix_dk.pyx index 67c0bb08fd..0523e7a8ba 100644 --- a/src/sisl/physics/_matrix_dk.pyx +++ b/src/sisl/physics/_matrix_dk.pyx @@ -8,7 +8,7 @@ import numpy as np cimport numpy as np -from ._feature import comply_gauge +from ._common import comply_gauge from ._matrix_phase3 import * from ._matrix_phase3_nc import * from ._matrix_phase3_so import * diff --git a/src/sisl/physics/_matrix_k.pyx b/src/sisl/physics/_matrix_k.pyx index ae82761aba..be8b80a451 100644 --- a/src/sisl/physics/_matrix_k.pyx +++ b/src/sisl/physics/_matrix_k.pyx @@ -7,7 +7,7 @@ import numpy as np cimport numpy as np -from ._feature import comply_gauge +from ._common import comply_gauge from ._matrix_phase import * from ._matrix_phase_nc import * from ._matrix_phase_nc_diag import * diff --git a/src/sisl/physics/electron.py b/src/sisl/physics/electron.py index 36ab7322d4..0610058ed6 100644 --- a/src/sisl/physics/electron.py +++ b/src/sisl/physics/electron.py @@ -88,7 +88,14 @@ progressbar, warn, ) -from sisl.typing import CartesianAxisStrLiteral +from sisl.physics._common import comply_projection +from sisl.typing import ( + CartesianAxisStrLiteral, + ProjectionType, + ProjectionTypeHadamard, + ProjectionTypeHadamardAtoms, +) +from sisl.typing._physics import ProjectionTypeDiag from sisl.utils.misc import direction if TYPE_CHECKING: @@ -329,7 +336,7 @@ def COP(E, eig, state, M, distribution="gaussian", atol: float = 1e-10): distribution : func or str, optional a function that accepts :math:`E-\epsilon` as argument and calculates the distribution function. - atol : float, optional + atol : tolerance value where the distribution should be above before considering an eigenstate to contribute to an energy point, a higher value means that more energy points are discarded and so the calculation @@ -437,7 +444,20 @@ def new_list(bools, tmp, we): @set_module("sisl.physics.electron") -def spin_moment(state, S=None, project: bool = False): +@deprecate_argument( + "project", + "projection", + "argument project has been deprecated in favor of projection", + "0.15", + "0.16", +) +def spin_moment( + state, + S=None, + projection: Union[ + ProjectionTypeTrace, ProjectionTypeDiag, ProjectionTypeHadamard, True, False + ] = "diagonal", +): r""" Spin magnetic moment (spin texture) and optionally orbitally resolved moments This calculation only makes sense for non-colinear calculations. @@ -458,7 +478,7 @@ def spin_moment(state, S=None, project: bool = False): \\ \mathbf{S}_\alpha^z &= \langle \psi_\alpha | \boldsymbol\sigma_z \mathbf S | \psi_\alpha \rangle - If `project` is true, the above will be the orbitally resolved quantities. + If `projection` is orbitals/basis/true, the above will be the orbitally resolved quantities. Parameters ---------- @@ -468,8 +488,8 @@ def spin_moment(state, S=None, project: bool = False): overlap matrix used in the :math:`\langle\psi|\mathbf S|\psi\rangle` calculation. If `None` the identity matrix is assumed. The overlap matrix should correspond to the system and :math:`\mathbf k` point the eigenvectors has been evaluated at. - project: bool, optional - whether the spin-moments will be orbitally resolved or not + projection: + how the projection should be done Notes ----- @@ -485,10 +505,15 @@ def spin_moment(state, S=None, project: bool = False): Returns ------- numpy.ndarray - spin moments per state with final dimension ``(3, state.shape[0])``, or ``(3, state.shape[0], state.shape[1]//2)`` if project is true + spin moments per state with final dimension ``(3, state.shape[0])``, or ``(3, + state.shape[0], state.shape[1]//2)`` if projection is orbitals/basis/true """ if state.ndim == 1: - return spin_moment(state.reshape(1, -1), S, project)[0] + return spin_moment(state.reshape(1, -1), S, projection)[0] + + if isinstance(projection, bool): + projection = "hadamard" if projection else "diagonal" + projection = comply_projection(projection) if S is None: S = _FakeMatrix(state.shape[1] // 2, state.shape[1] // 2) @@ -498,7 +523,7 @@ def spin_moment(state, S=None, project: bool = False): # see PDOS for details related to the spin-box calculations - if project: + if projection == "hadamard": s = empty( [3, state.shape[0], state.shape[1] // 2], dtype=state.real.dtype, @@ -514,7 +539,7 @@ def spin_moment(state, S=None, project: bool = False): s[0, i] = D1.real + D2.real s[1, i] = D2.imag - D1.imag - else: + elif projection == "diagonal": s = empty([3, state.shape[0]], dtype=state.real.dtype) # TODO consider doing this all in a few lines @@ -529,6 +554,20 @@ def spin_moment(state, S=None, project: bool = False): s[0, i] = D[1, 0].real + D[0, 1].real s[1, i] = D[0, 1].imag - D[1, 0].imag + elif projection == "trace": + s = empty([3], dtype=state.real.dtype) + + for i in range(len(state)): + cs = conj(state[i]).reshape(-1, 2) + Sstate = S @ state[i].reshape(-1, 2) + D = cs.T @ Sstate + s[2] = (D[0, 0].real - D[1, 1].real).sum() + s[0] = (D[1, 0].real + D[0, 1].real).sum() + s[1] = (D[0, 1].imag - D[1, 0].imag).sum() + + else: + raise ValueError(f"spin_moment got wrong 'projection' argument: {projection}.") + return s @@ -561,7 +600,7 @@ def spin_contamination(state_alpha, state_beta, S=None, sum: bool = True): have been evaluated at. sum: whether the spin-contamination should be summed for all states (a single number returned). - If false, a spin-contamination per state per spin-channel will be returned. + If sum, a spin-contamination per state per spin-channel will be returned. Notes ----- @@ -1671,7 +1710,12 @@ def Sk(self, format=None): "0.15", "0.16", ) - def norm2(self, projection: Literal["sum", "orbitals", "basis", "atoms"] = "sum"): + def norm2( + self, + projection: Union[ + ProjectionType, ProjectionTypeHadamard, ProjectionTypeHadamardAtoms + ] = "diagonal", + ): r"""Return a vector with the norm of each state :math:`\langle\psi|\mathbf S|\psi\rangle` :math:`\mathbf S` is the overlap matrix (or basis), for orthogonal basis @@ -1693,7 +1737,14 @@ def norm2(self, projection: Literal["sum", "orbitals", "basis", "atoms"] = "sum" """ return self.inner(matrix=self.Sk(), projection=projection) - def spin_moment(self, project=False): + @deprecate_argument( + "project", + "projection", + "argument project has been deprecated in favor of projection", + "0.15", + "0.16", + ) + def spin_moment(self, projection="diagonal"): r"""Calculate spin moment from the states This routine calls `~sisl.physics.electron.spin_moment` with appropriate arguments @@ -1703,10 +1754,10 @@ def spin_moment(self, project=False): Parameters ---------- - project : bool, optional + projection: whether the moments are orbitally resolved or not """ - return spin_moment(self.state, self.Sk(), project=project) + return spin_moment(self.state, self.Sk(), projection=projection) def wavefunction(self, grid, spinor=0, eta=None): r"""Expand the coefficients as the wavefunction on `grid` *as-is* diff --git a/src/sisl/physics/hamiltonian.py b/src/sisl/physics/hamiltonian.py index d9ad46b2fb..6a243dde60 100644 --- a/src/sisl/physics/hamiltonian.py +++ b/src/sisl/physics/hamiltonian.py @@ -9,7 +9,7 @@ from sisl._internal import set_module from sisl.typing import GaugeType -from ._feature import comply_gauge +from ._common import comply_gauge from .distribution import get_distribution from .electron import EigenstateElectron, EigenvalueElectron from .sparse import SparseOrbitalBZSpin diff --git a/src/sisl/physics/state.py b/src/sisl/physics/state.py index 4b0f8f0ba7..e20828c2a4 100644 --- a/src/sisl/physics/state.py +++ b/src/sisl/physics/state.py @@ -17,9 +17,10 @@ from sisl._internal import set_module from sisl.linalg import eigh_destroy from sisl.messages import deprecate_argument, warn -from sisl.typing import CartesianAxes, GaugeType +from sisl.typing import CartesianAxes, GaugeType, ProjectionType +from sisl.typing._physics import ProjectionTypeHadamard, ProjectionTypeHadamardAtoms -from ._feature import comply_gauge +from ._common import comply_gauge, comply_projection __all__ = ["degenerate_decouple", "Coefficient", "State", "StateC"] @@ -489,7 +490,12 @@ def norm(self): "0.15", "0.16", ) - def norm2(self, projection: Literal["sum", "atoms", "basis"] = "sum"): + def norm2( + self, + projection: Union[ + ProjectionType, ProjectionTypeHadamard, ProjectionTypeHadamardAtoms + ] = "diagonal", + ): r"""Return a vector with the norm of each state :math:`\langle\psi|\psi\rangle` Parameters @@ -549,7 +555,7 @@ def ipr(self, q: int = 2): order parameter for the IPR """ # This *has* to be a real value C * C^* == real - state_abs2 = self.norm2(projection="basis").real + state_abs2 = self.norm2(projection="hadamard").real assert q >= 2, f"{self.__class__.__name__}.ipr requires q>=2" # abs2 is already having the exponent 2 return (state_abs2**q).sum(-1) / state_abs2.sum(-1) ** q @@ -659,7 +665,9 @@ def inner( self, ket=None, matrix=None, - projection: Literal["diag", "atoms", "basis", "matrix"] = "diag", + projection: Union[ + ProjectionType, ProjectionTypeHadamard, ProjectionTypeHadamardAtoms + ] = "diagonal", ): r"""Calculate the inner product as :math:`\mathbf A_{ij} = \langle\psi_i|\mathbf M|\psi'_j\rangle` @@ -695,14 +703,17 @@ def inner( This can be used to sum specific sub-elements, return the diagonal, or the full matrix. - * ``diag`` only return the diagonal of the inner product - * ``matrix`` a matrix with diagonals and the off-diagonals - * ``basis`` only do inner products for individual states, but return them basis-resolved + * ``diagonal`` only return the diagonal of the inner product ('ii' elements) + * ``matrix`` a matrix with diagonals and the off-diagonals ('ij' elements) + * ``hadamard`` only do element wise products for the states (equivalent to + basis resolved inner-products) * ``atoms`` only do inner products for individual states, but return them atom-resolved Notes ----- - This does *not* take into account a possible overlap matrix when non-orthogonal basis sets are used. One have to add the overlap matrix in the `matrix` argument, if needed. + This does *not* take into account a possible overlap matrix when + non-orthogonal basis sets are used. + One have to add the overlap matrix in the `matrix` argument, if needed. Raises ------ @@ -758,19 +769,15 @@ def inner( f"{self.__class__.__name__}.inner requires the objects to have matching shapes bra @ M @ ket bra={self.shape}, M={M.shape}, ket={ket.shape[::-1]}" ) - projection = { - # temporary work-around for older codes where project/diag=T|F were allowed - True: "diag", - False: "matrix", - "sum": "diag", # still allowed here (for bypass options) - "atoms": "atom", # plural s allowed - "orbitals": "orbital", # still allowed here (for bypass options) - }.get(projection, projection) + if isinstance(projection, bool): + projection = "diagonal" if projection else "matrix" + projection = comply_projection(projection) - if projection in ("diag", "diagonal"): + if projection == "diagonal": if bra.shape[0] != ket.shape[0]: raise ValueError( - f"{self.__class__.__name__}.inner diagonal matrix product is non-square, please use diag=False or reduce number of vectors." + f"{self.__class__.__name__}.inner diagonal matrix product is " + "non-square, please use projection!=diagonal or reduce number of vectors." ) if ndim == 2: Aij = einsum("ij,ji->i", np.conj(bra), M @ ket.T) @@ -779,7 +786,7 @@ def inner( elif ndim == 0: Aij = einsum("ij,ij->i", np.conj(bra), ket) * M - elif projection in ("matrix", "none"): + elif projection == "matrix": if ndim == 2: Aij = np.conj(bra) @ (M @ ket.T) elif ndim == 1: @@ -787,34 +794,50 @@ def inner( elif ndim == 0: Aij = einsum("ij,kj->ik", np.conj(bra), ket) * M - elif projection in ("atom", "basis", "orbital"): + elif projection == "hadamard": if ndim == 2: Aij = np.conj(bra) * (M @ ket.T).T else: Aij = np.conj(bra) * ket * M - # Now do the projection - if projection == "atom": - # Now we need to convert it - geom = self._geometry() - if Aij.shape[1] == geom.no * 2: - # We have some kind of spin-configuration (hidden) - def mapper(atom): - return np.arange( - geom.firsto[atom] * 2, geom.firsto[atom + 1] * 2 - ) + elif projection == "hadamard:atoms": + if ndim == 2: + Aij = np.conj(bra) * (M @ ket.T).T + else: + Aij = np.conj(bra) * ket * M - elif Aij.shape[1] == geom.no: + # Now we need to convert it + geom = self._geometry() + if Aij.shape[1] == geom.no * 2: + # We have some kind of spin-configuration (hidden) + def mapper(atom): + return np.arange(geom.firsto[atom] * 2, geom.firsto[atom + 1] * 2) - def mapper(atom): - return np.arange(geom.firsto[atom], geom.firsto[atom + 1]) + elif Aij.shape[1] == geom.no: + + def mapper(atom): + return np.arange(geom.firsto[atom], geom.firsto[atom + 1]) + + else: + raise RuntimeError( + f"{self.__class__.__name__}.inner could not determine " + "the correct atom conversions." + ) + Aij = geom.apply(Aij, np.sum, mapper, axis=1) + + elif projection == "trace": + if bra.shape[0] != ket.shape[0]: + raise ValueError( + f"{self.__class__.__name__}.inner diagonal matrix product is " + "non-square, cannot do the trace." + ) + if ndim == 2: + Aij = einsum("ij,ji->i", np.conj(bra), M @ ket.T).sum() + elif ndim == 1: + Aij = einsum("ij,j,ij->i", np.conj(bra), M, ket).sum() + elif ndim == 0: + Aij = (einsum("ij,ij->i", np.conj(bra), ket) * M).sum() - else: - raise RuntimeError( - f"{self.__class__.__name__}.inner could not determine " - "the correct atom conversions." - ) - Aij = geom.apply(Aij, np.sum, mapper, axis=1) else: raise ValueError( f"{self.__class__.__name__}.inner got unknown argument 'projection'={projection}" @@ -919,8 +942,8 @@ def align_norm(self, other: State, ret_index: bool = False, inplace: bool = Fals -------- align_phase : rotate states such that their phases align """ - snorm = self.norm2(projection="basis").real - onorm = other.norm2(projection="basis").real + snorm = self.norm2(projection="hadamard").real + onorm = other.norm2(projection="hadamard").real # Now find new orderings show_warn = False diff --git a/src/sisl/physics/tests/test_electron.py b/src/sisl/physics/tests/test_electron.py index 5d51a2d64c..a0af01c9c3 100644 --- a/src/sisl/physics/tests/test_electron.py +++ b/src/sisl/physics/tests/test_electron.py @@ -28,12 +28,13 @@ def test_EigenstateElectron_norm2(): assert len(state) == H.no assert state.norm2()[0] == pytest.approx(1) assert state.norm2().shape == (H.no,) - for p in ("sum", "orbital", "atom"): + for p in ("diagonal", "orbital", "atom"): assert state.norm2(projection=p).sum() == pytest.approx(H.no) ns = 3 state3 = state.sub(range(ns)) - assert state3.norm2(projection="sum").shape == (ns,) + assert state3.norm2(projection="trace").ndim == 0 + assert state3.norm2(projection="diagonal").shape == (ns,) assert state3.norm2(projection="orbital").shape == (ns, H.no) assert state3.norm2(projection="atom").shape == (ns, H.na) assert state3.norm2(projection="atom").sum() == pytest.approx(ns) diff --git a/src/sisl/physics/tests/test_hamiltonian.py b/src/sisl/physics/tests/test_hamiltonian.py index 3831f9a9a6..e9d314648b 100644 --- a/src/sisl/physics/tests/test_hamiltonian.py +++ b/src/sisl/physics/tests/test_hamiltonian.py @@ -1252,7 +1252,7 @@ def dist(E, *args): es = H.eigenstate() PDOS = es.PDOS(E, dist)[..., 0] SM = es.spin_moment() - SMp = es.spin_moment(project=True) + SMp = es.spin_moment(projection=True) # now check with spin stuff pdos = es.inner().real @@ -1288,7 +1288,7 @@ def dist(E, *args): es = H.eigenstate() PDOS = es.PDOS(E, dist)[..., 0] SM = es.spin_moment() - SMp = es.spin_moment(project=True) + SMp = es.spin_moment(projection=True) # now check with spin stuff pdos = es.inner().real @@ -1636,7 +1636,7 @@ def test_non_colinear_orthogonal(self, setup, sisl_tolerance): assert np.allclose(sm[2], sm2) assert np.allclose(sm[2], sm3) - om = es.spin_moment(project=True) + om = es.spin_moment(projection=True) assert np.allclose(sm, om.sum(-1)) PDOS = es.PDOS(np.linspace(-1, 1, 21)) @@ -1710,7 +1710,7 @@ def test_non_colinear_non_orthogonal(self, sisl_tolerance): sm = es.spin_moment() - om = es.spin_moment(project=True) + om = es.spin_moment(projection=True) assert np.allclose(sm, om.sum(-1)) PDOS = es.PDOS(np.linspace(-1, 1, 21)) @@ -1795,7 +1795,7 @@ def test_spin_orbit_orthogonal(self, sisl_tolerance): assert np.allclose(sm[2], sm2) assert np.allclose(sm[2], sm3) - om = es.spin_moment(project=True) + om = es.spin_moment(projection=True) assert np.allclose(sm, om.sum(-1)) PDOS = es.PDOS(np.linspace(-1, 1, 21)) diff --git a/src/sisl/physics/tests/test_state.py b/src/sisl/physics/tests/test_state.py index e14bfce916..b715f6bc12 100644 --- a/src/sisl/physics/tests/test_state.py +++ b/src/sisl/physics/tests/test_state.py @@ -231,7 +231,8 @@ def test_state_inner_projections(): state = State(ar(n, g.no), parent=g) for projs, shape in ( - (("diag", "diagonal", "sum", True), (n,)), + (("diag", "diagonal", True), (n,)), + (("trace", "sum"), tuple()), (("matrix", False), (n, n)), (("basis", "orbitals", "orbital"), (n, g.no)), (("atoms", "atom"), (n, g.na)), @@ -248,9 +249,10 @@ def test_state_norm_projections(): assert state.shape[0] != state.shape[1] for projs, shape in ( - (("sum", True), (n,)), - (("basis", "orbitals", "orbital"), (n, g.no)), - (("atoms", "atom"), (n, g.na)), + (("diagonal", True), (n,)), + (("trace", "sum"), tuple()), + (("hadamard", "basis", "orbitals", "orbital"), (n, g.no)), + (("atoms", "atom", "hadamard:atoms"), (n, g.na)), ): for proj in projs: data = state.norm2(projection=proj) diff --git a/src/sisl/typing/_physics.py b/src/sisl/typing/_physics.py index 1e724bdd61..5b0639e742 100644 --- a/src/sisl/typing/_physics.py +++ b/src/sisl/typing/_physics.py @@ -3,8 +3,23 @@ # file, You can obtain one at https://mozilla.org/MPL/2.0/. from __future__ import annotations -from typing import Literal +from typing import Literal, Union -__all__ = ["GaugeType"] +__all__ = [ + "GaugeType", + "ProjectionType", + "ProjectionTypeTrace", + "ProjectionTypeDiag", + "ProjectionTypeMatrix", + "ProjectionTypeHadamard", + "ProjectionTypeHadamardAtoms", +] GaugeType = Literal["cell", "atom"] + +ProjectionTypeMatrix = Literal["matrix", "ij"] +ProjectionTypeTrace = Literal["trace", "sum"] +ProjectionTypeDiag = Literal["diagonal", "diag", "ii"] +ProjectionTypeHadamard = Literal["hadamard", "basis"] +ProjectionTypeHadamardAtoms = Literal["hadamard:atoms", "atoms"] +ProjectionType = Union[ProjectionTypeMatrix, ProjectionTypeDiag, ProjectionTypeTrace]