diff --git a/sisl/_category.py b/sisl/_category.py new file mode 100644 index 0000000000..964e41027c --- /dev/null +++ b/sisl/_category.py @@ -0,0 +1,417 @@ +from collections import defaultdict, ChainMap +from abc import ABCMeta, abstractmethod +from functools import lru_cache +import numpy as np + +from ._internal import set_module, singledispatchmethod + +__all__ = ["Category", "CompositeCategory", "NullCategory"] +__all__ += ["AndCategory", "OrCategory", "XOrCategory"] +__all__ += ["InstanceCache"] + + +class InstanceCache: + """ Wraps an instance to cache *all* results based on `functools.lru_cache` + + Parameters + ---------- + obj : object + the object to get cached results + lru_size : int or dict, optional + initial size of the lru cache. For integers this + is the default size of the cache, for a dictionary + it should return the ``maxsize`` argument for `functools.lru_cache` + no_cache : searchable (list or dict) + a list-like (or dictionary) for searching for + methods that don't require caches (e.g. small methods) + """ + + def __init__(self, obj, lru_size=1, no_cache=None): + self.__obj = obj + + # Handle user input for lru_size + if isinstance(lru_size, defaultdict): + # fine, user did everything good + self.__lru_size = lru_size + elif isinstance(lru_size, dict): + default = lru_size.pop("default", 1) + self.__lru_size = ChainMap(lru_size, defaultdict(lambda: default)) + else: + self.__lru_size = defaultdict(lambda: lru_size) + + if no_cache is None: + self.__no_cache = [] + else: + self.__no_cache = no_cache + + def __getattr__(self, name): + attr = getattr(self.__obj, name) + # Check if the attribute has the cached functionality + try: + attr.cache_info() + except AttributeError: + # Fix it and set it to this one + if name in self.__no_cache: + # We have to make it cacheable + maxsize = self.__lru_size[name] + if maxsize != 0: + attr = wraps(attr)(lru_cache(maxsize)(attr)) + + # offload the attribute to this class (to minimize overhead) + object.__setattr__(self, name, attr) + return attr + + +@set_module("sisl.category") +class Category(metaclass=ABCMeta): + r""" A category """ + __slots__ = ("_name", "_wrapper") + + def __init__(self, name=None): + if name is None: + self._name = self.__class__.__name__ + else: + self._name = name + + @property + def name(self): + r""" Name of category """ + return self._name + + def set_name(self, name): + r""" Override the name of the categorization """ + self._name = name + + @classmethod + def kw(cls, **kwargs): + """ Create categories based on keywords + + This will search through the inherited classes and + return and & category object for all keywords. + + Since this is a class method one should use this + on the base category class in the given section + of the code. + """ + + subcls = set() + work = [cls] + while work: + parent = work.pop() + for child in parent.__subclasses__(): + if child not in subcls: + subcls.add(child) + work.append(child) + + del work, parent, child + + # create dictionary look-up + subcls = {cl.__name__.lower(): cl for cl in subcls} + + def get_cat(cl, args): + if isinstance(args, dict): + return cl(**args) + return cl(args) + + # Now search keywords and create category + cat = None + for key, args in kwargs.items(): + lkey = key.lower() + found = '' + for name, cl in subcls.items(): + if name.endswith(lkey): + if found: + raise ValueError(f"{cls.__name__}.kw got a non-unique argument for category name:\n" + f" Searching for {name} and found matches {found} and {name}.") + found = name + if cat is None: + cat = get_cat(cl, args) + else: + cat = cat & get_cat(cl, args) + + return cat + + @abstractmethod + def categorize(self, *args, **kwargs): + r""" Do categorization """ + pass + + def __str__(self): + r""" String representation of the class (non-distinguishable between equivalent classifiers) """ + return self.name + + def __repr__(self): + r""" String representation of the class (non-distinguishable between equivalent classifiers) """ + return self.name + + @singledispatchmethod + def __eq__(self, other): + """ Comparison of two categories, they are compared by class-type """ + # This is not totally safe since composites *could* be generated + # in different sequences and result in the same boolean expression. + # This we do not check and thus are not fool proof... + # The exact action also depends on whether we are dealing with + # an And/Or/XOr operation.... + # I.e. something like + # (A & B & C) != (A & C & B) + # (A ^ B ^ C) != (C ^ A ^ B) + if isinstance(self, CompositeCategory): + if isinstance(other, CompositeCategory): + return (self.__class__ is other.__class__ and + (self.A == other.A and self.B == other.B or + self.A == other.B and self.B == other.A)) + # if neither is a compositecategory, then they cannot + # be the same category + return False + elif self.__class__ != other.__class__: + return False + return self == other + + @__eq__.register(list) + @__eq__.register(tuple) + @__eq__.register(np.ndarray) + def _(self, other): + return [self.__eq__(o) for o in other] + + def __ne__(self, other): + eq = self == other + if isinstance(eq, list): + return [not e for e in eq] + return eq + + # Implement logical operators to enable composition of sets + def __and__(self, other): + return AndCategory(self, other) + + def __or__(self, other): + return OrCategory(self, other) + + def __xor__(self, other): + return XOrCategory(self, other) + + def __invert__(self): + if isinstance(self, NotCategory): + return self._cat + return NotCategory(self) + + +@set_module("sisl.category") +class NullCategory(Category): + r""" Special Null class which always represents a classification not being *anything* """ + __slots__ = tuple() + + def __init__(self): + pass + + def categorize(self, *args, **kwargs): + return self + + @singledispatchmethod + def __eq__(self, other): + if other is None: + return True + return self.__class__ == other.__class__ + + @__eq__.register(list) + @__eq__.register(tuple) + @__eq__.register(np.ndarray) + def _(self, other): + return super().__eq__(other) + + @property + def name(self): + return "∅" + + +@set_module("sisl.category") +class NotCategory(Category): + """ A class returning the *opposite* of this class (NullCategory) if it is categorized as such """ + __slots__ = ("_cat",) + + def __init__(self, cat): + super().__init__() + if isinstance(cat, CompositeCategory): + self.set_name(f"~({cat})") + else: + self.set_name(f"~{cat}") + self._cat = cat + + def categorize(self, *args, **kwargs): + r""" Base method for queriyng whether an object is a certain category """ + cat = self._cat.categorize(*args, **kwargs) + + def check(cat): + if isinstance(cat, NullCategory): + return self + return NullCategory() + + if isinstance(cat, list): + return list(map(check, cat)) + return check(cat) + + @singledispatchmethod + def __eq__(self, other): + if isinstance(other, NotCategory): + return self._cat == other._cat + return False + + @__eq__.register(list) + @__eq__.register(tuple) + @__eq__.register(np.ndarray) + def _(self, other): + # this will call the list approach + return super().__eq__(other) + + +@set_module("sisl.category") +class CompositeCategory(Category): + """ A composite class consisting of two categories + + This should take 2 categories as arguments and a binary operator to define + how the categories are related. + + Parameters + ---------- + A : Category + the left hand side of the set operation + B : Category + the right hand side of the set operation + op : {_OR, _AND, _XOR} + the operator defining the sets relation + """ + __slots__ = ("A", "B") + + def __init__(self, A, B): + # To ensure we always get composite name + super().__init__() + self._name = None + self.A = A + self.B = B + + def categorize(self, *args, **kwargs): + r""" Base method for queriyng whether an object is a certain category """ + catA = self.A.categorize(*args, **kwargs) + catB = self.B.categorize(*args, **kwargs) + return catA, catB + + +def _composite_name(sep): + def name(self): + if not self._name is None: + return self._name + + # Name is unset, we simply return the other parts + if isinstance(self.A, CompositeCategory): + nameA = f"({self.A.name})" + else: + nameA = self.A.name + if isinstance(self.B, CompositeCategory): + nameB = f"({self.B.name})" + else: + nameB = self.B.name + + return f"{nameA} {sep} {nameB}" + + return property(name) + + +@set_module("sisl.category") +class OrCategory(CompositeCategory): + """ A class consisting of two categories + + This should take 2 categories as arguments and a binary operator to define + how the categories are related. + + Parameters + ---------- + A : Category + the left hand side of the set operation + B : Category + the right hand side of the set operation + """ + __slots__ = tuple() + + def categorize(self, *args, **kwargs): + r""" Base method for queriyng whether an object is a certain category """ + catA, catB = super().categorize(*args, **kwargs) + + def cmp(a, b): + if isinstance(a, NullCategory): + return b + return a + + if isinstance(catA, list): + return list(map(cmp, catA, catB)) + return cmp(catA, catB) + + name = _composite_name("|") + + +@set_module("sisl.category") +class AndCategory(CompositeCategory): + """ A class consisting of two categories + + This should take 2 categories as arguments and a binary operator to define + how the categories are related. + + Parameters + ---------- + A : Category + the left hand side of the set operation + B : Category + the right hand side of the set operation + """ + __slots__ = tuple() + + def categorize(self, *args, **kwargs): + r""" Base method for queriyng whether an object is a certain category """ + catA, catB = super().categorize(*args, **kwargs) + + def cmp(a, b): + if isinstance(a, NullCategory): + return a + elif isinstance(b, NullCategory): + return b + return self + + if isinstance(catA, list): + return list(map(cmp, catA, catB)) + return cmp(catA, catB) + + name = _composite_name("&") + + +@set_module("sisl.category") +class XOrCategory(CompositeCategory): + """ A class consisting of two categories + + This should take 2 categories as arguments and a binary operator to define + how the categories are related. + + Parameters + ---------- + A : Category + the left hand side of the set operation + B : Category + the right hand side of the set operation + """ + __slots__ = tuple() + + def categorize(self, *args, **kwargs): + r""" Base method for queriyng whether an object is a certain category """ + catA, catB = super().categorize(*args, **kwargs) + + def cmp(a, b): + if isinstance(a, NullCategory): + return b + elif isinstance(b, NullCategory): + return a + # both are not NullCategory, in which case nothing + # is exclusive, so we return the NullCategory + return NullCategory() + + if isinstance(catA, list): + return list(map(cmp, catA, catB)) + return cmp(catA, catB) + + name = _composite_name("⊕") diff --git a/sisl/geom/__init__.py b/sisl/geom/__init__.py index 37e0a84ea3..e1eedcb811 100644 --- a/sisl/geom/__init__.py +++ b/sisl/geom/__init__.py @@ -81,5 +81,7 @@ from .nanotube import * from .special import * from .bilayer import * +from .category import * + __all__ = [s for s in dir() if not s.startswith('_')] diff --git a/sisl/geom/category/__init__.py b/sisl/geom/category/__init__.py new file mode 100644 index 0000000000..3484ca6bf5 --- /dev/null +++ b/sisl/geom/category/__init__.py @@ -0,0 +1,3 @@ +from .base import * +from ._neighbours import * +from ._kind import * diff --git a/sisl/geom/category/_kind.py b/sisl/geom/category/_kind.py new file mode 100644 index 0000000000..368c274471 --- /dev/null +++ b/sisl/geom/category/_kind.py @@ -0,0 +1,84 @@ +import numpy as np + +from sisl._internal import set_module, singledispatchmethod +from sisl._help import isiterable +from .base import AtomCategory, NullCategory, _sanitize_loop + + +__all__ = ["AtomZ", "AtomOdd", "AtomEven"] + + +@set_module("sisl.geom") +class AtomZ(AtomCategory): + r""" Classify atoms based on atomic number + + Parameters + ---------- + Z : int or array_like + atomic number match for several values this is equivalent to AND + """ + __slots__ = ("_Z",) + + def __init__(self, Z): + if isiterable(Z): + self._Z = set(Z) + else: + self._Z = set([Z]) + # using a sorted list ensures that we can compare + super().__init__(f"Z={self._Z}") + + @_sanitize_loop + def categorize(self, geometry, atoms=None): + # _sanitize_loop will ensure that atoms will always be an integer + if geometry.atoms.Z[atoms] in self._Z: + return self + return NullCategory() + + def __eq__(self, other): + if isinstance(other, (list, tuple, np.ndarray)): + # this *should* use the dispatch method for different + # classes + return super().__eq__(other) + + eq = self.__class__ is other.__class__ + if eq: + return self._Z == other._Z + return False + + +@set_module("sisl.geom") +class AtomOdd(AtomCategory): + r""" Classify atoms based on indices (odd in this case)""" + __slots__ = [] + + def __init__(self): + super().__init__("odd") + + @_sanitize_loop + def categorize(self, geometry, atoms=None): + # _sanitize_loop will ensure that atoms will always be an integer + if atoms % 2 == 1: + return self + return NullClass() + + def __eq__(self, other): + return self.__class__ is other.__class__ + + +@set_module("sisl.geom") +class AtomEven(AtomCategory): + r""" Classify atoms based on indices (even in this case)""" + __slots__ = [] + + def __init__(self): + super().__init__("even") + + @_sanitize_loop + def categorize(self, geometry, atoms): + # _sanitize_loop will ensure that atoms will always be an integer + if atoms % 2 == 0: + return self + return NullClass() + + def __eq__(self, other): + return self.__class__ is other.__class__ diff --git a/sisl/geom/category/_neighbours.py b/sisl/geom/category/_neighbours.py new file mode 100644 index 0000000000..3a473e65a0 --- /dev/null +++ b/sisl/geom/category/_neighbours.py @@ -0,0 +1,92 @@ +from collections import namedtuple + +from sisl._internal import set_module +from .base import AtomCategory, NullCategory, _sanitize_loop + + +__all__ = ["AtomNeighbours"] + + +@set_module("sisl.geom") +class AtomNeighbours(AtomCategory): + r""" Classify atoms based on number of neighbours + + + Parameters + ---------- + min : int, optional + minimum number of neighbours + max : int + maximum number of neighbours + neigh_cat : Category + a category the neighbour must be in to be counted + """ + __slots__ = ("_min", "_max", "_in") + + def __init__(self, *args, **kwargs): + if len(args) > 0: + if isinstance(args[-1], AtomCategory): + kwargs["neigh_cat"] = args.pop() + + self._min = 0 + self._max = 2 ** 31 + + if len(args) == 1: + self._min = args[0] + self._max = args[0] + + elif len(args) == 2: + self._min = args[0] + self._max = args[1] + + if "min" in kwargs: + self._min = kwargs.pop("min") + if "max" in kwargs: + self._max = kwargs.pop("max") + + if self._min == self._max: + name = f"={self._max}" + elif self._max == 2 ** 31: + name = f" ∈ [{self._min};∞[" + else: + name = f" ∈ [{self._min};{self._max}]" + + self._in = kwargs.get("neigh_cat", None) + + # Determine name. If there are requirements for the neighbours + # then the name changes + if self._in is None: + self.set_name(f"neighbours{name}") + else: + self.set_name(f"neighbours({self._in}){name}") + + @_sanitize_loop + def categorize(self, geometry, atoms=None): + """ Check that number of neighbours are matching """ + idx, rij = geometry.close(atoms, R=(0.01, geometry.atoms[atoms].maxR()), ret_rij=True) + idx, rij = idx[1], rij[1] + if len(idx) < self._min: + return NullCategory() + + # Check if we have a condition + if not self._in is None: + # Get category of neighbours + cat = self._in.categorize(geometry, geometry.asc2uc(idx)) + idx1, rij1 = [], [] + for i in range(len(idx)): + if not isinstance(cat[i], NullCategory): + idx1.append(idx[i]) + rij1.append(rij[i]) + idx, rij = idx1, rij1 + n = len(idx) + if self._min <= n and n <= self._max: + return self + return NullCategory() + + def __eq__(self, other): + eq = self.__class__ is other.__class__ + if eq: + return (self._min == other._min and + self._max == other._max and + self._in == other._in) + return False diff --git a/sisl/geom/category/base.py b/sisl/geom/category/base.py new file mode 100644 index 0000000000..1c9c62c6e7 --- /dev/null +++ b/sisl/geom/category/base.py @@ -0,0 +1,24 @@ +from functools import wraps +from sisl._internal import set_module +from sisl._category import Category, NullCategory +from sisl.geometry import AtomCategory + + +__all__ = ["NullCategory", "AtomCategory"] + + +def _sanitize_loop(func): + @wraps(func) + def loop_func(self, geometry, atoms=None): + if atoms is None: + return [func(self, geometry, ia) for ia in geometry] + # extract based on atoms selection + atoms = geometry._sanitize_atom(atoms) + if atoms.ndim == 0: + return func(self, geometry, atoms) + return [func(self, geometry, ia) for ia in atoms] + return loop_func + +#class AtomCategory(Category) +# is defined in sisl/geometry.py since it is required in +# that instance. diff --git a/sisl/geom/category/tests/__init__.py b/sisl/geom/category/tests/__init__.py new file mode 100644 index 0000000000..21a3aae7b5 --- /dev/null +++ b/sisl/geom/category/tests/__init__.py @@ -0,0 +1 @@ +""" tests for sisl/geom/category """ diff --git a/sisl/geom/tests/__init__.py b/sisl/geom/tests/__init__.py index cefab91711..379cf9aafc 100644 --- a/sisl/geom/tests/__init__.py +++ b/sisl/geom/tests/__init__.py @@ -1 +1 @@ -""" test init for $d """ +""" tests for sisl/geom """ diff --git a/sisl/geom/tests/test_geom.py b/sisl/geom/tests/test_geom.py index 09dae1bc71..d6ddd701c9 100644 --- a/sisl/geom/tests/test_geom.py +++ b/sisl/geom/tests/test_geom.py @@ -7,6 +7,9 @@ import numpy as np +pytestmark = [pytest.mark.geom] + + def test_basis(): a = sc(2.52, Atom['Fe']) a = bcc(2.52, Atom['Fe']) diff --git a/sisl/geometry.py b/sisl/geometry.py index 3e8dcda94a..5bf93075fc 100644 --- a/sisl/geometry.py +++ b/sisl/geometry.py @@ -27,11 +27,19 @@ from .atom import Atom, Atoms from .shape import Shape, Sphere, Cube from ._namedindex import NamedIndex +from ._category import Category __all__ = ['Geometry', 'sgeom'] +# It needs to be here otherwise we can't use it in these routines +# Note how we are overwriting the module +@set_module("sisl.geom") +class AtomCategory(Category): + __slots__ = tuple() + + @set_module("sisl") class Geometry(SuperCellChild): """ Holds atomic information, coordinates, species, lattice vectors @@ -312,6 +320,18 @@ def _(self, atom): def _(self, atom): return (self.atoms.specie == self.atoms.index(atom)).nonzero()[0] + @_sanitize_atom.register(AtomCategory) + def _(self, atom): + # First do categorization + cat = atom.categorize(self) + def m(cat): + for ia, c in enumerate(cat): + if c == None: + pass + else: + yield ia + return _a.fromiteri(m(cat)) + def _sanitize_orb(self, orbital): """ Converts an `orbital` to index under given inputs diff --git a/sisl/io/gulp/tests/__init__.py b/sisl/io/gulp/tests/__init__.py index cefab91711..4ba813e40e 100644 --- a/sisl/io/gulp/tests/__init__.py +++ b/sisl/io/gulp/tests/__init__.py @@ -1 +1 @@ -""" test init for $d """ +""" tests for sisl/io/gulp """ diff --git a/sisl/io/siesta/tests/__init__.py b/sisl/io/siesta/tests/__init__.py index cefab91711..2beb5280e3 100644 --- a/sisl/io/siesta/tests/__init__.py +++ b/sisl/io/siesta/tests/__init__.py @@ -1 +1 @@ -""" test init for $d """ +""" tests for sisl/io/siesta """ diff --git a/sisl/io/tbtrans/tests/__init__.py b/sisl/io/tbtrans/tests/__init__.py index cefab91711..96fd180be3 100644 --- a/sisl/io/tbtrans/tests/__init__.py +++ b/sisl/io/tbtrans/tests/__init__.py @@ -1 +1 @@ -""" test init for $d """ +""" tests for sisl/io/tbtrans """ diff --git a/sisl/io/tests/__init__.py b/sisl/io/tests/__init__.py index cefab91711..7fcaa3e72f 100644 --- a/sisl/io/tests/__init__.py +++ b/sisl/io/tests/__init__.py @@ -1 +1 @@ -""" test init for $d """ +""" tests for sisl/io """ diff --git a/sisl/io/vasp/tests/__init__.py b/sisl/io/vasp/tests/__init__.py index cefab91711..379ec9ee4d 100644 --- a/sisl/io/vasp/tests/__init__.py +++ b/sisl/io/vasp/tests/__init__.py @@ -1 +1 @@ -""" test init for $d """ +""" tests for sisl/io/vasp """ diff --git a/sisl/io/wannier90/tests/__init__.py b/sisl/io/wannier90/tests/__init__.py index cefab91711..48feb4b321 100644 --- a/sisl/io/wannier90/tests/__init__.py +++ b/sisl/io/wannier90/tests/__init__.py @@ -1 +1 @@ -""" test init for $d """ +""" tests for sisl/io/wannier90 """ diff --git a/sisl/linalg/tests/__init__.py b/sisl/linalg/tests/__init__.py index cefab91711..8c47eb61bc 100644 --- a/sisl/linalg/tests/__init__.py +++ b/sisl/linalg/tests/__init__.py @@ -1 +1 @@ -""" test init for $d """ +""" tests for sisl/linalg """ diff --git a/sisl/mixing/tests/__init__.py b/sisl/mixing/tests/__init__.py index cefab91711..287b7b003e 100644 --- a/sisl/mixing/tests/__init__.py +++ b/sisl/mixing/tests/__init__.py @@ -1 +1 @@ -""" test init for $d """ +""" tests for sisl/mixing """ diff --git a/sisl/physics/tests/__init__.py b/sisl/physics/tests/__init__.py index cefab91711..0cf59d6b50 100644 --- a/sisl/physics/tests/__init__.py +++ b/sisl/physics/tests/__init__.py @@ -1 +1 @@ -""" test init for $d """ +""" tests for sisl/physics """ diff --git a/sisl/shape/tests/__init__.py b/sisl/shape/tests/__init__.py index cefab91711..713e853944 100644 --- a/sisl/shape/tests/__init__.py +++ b/sisl/shape/tests/__init__.py @@ -1 +1 @@ -""" test init for $d """ +""" tests for sisl/shape """ diff --git a/sisl/tests/__init__.py b/sisl/tests/__init__.py index cefab91711..f6bab9a668 100644 --- a/sisl/tests/__init__.py +++ b/sisl/tests/__init__.py @@ -1 +1 @@ -""" test init for $d """ +""" tests for sisl """ diff --git a/sisl/unit/tests/__init__.py b/sisl/unit/tests/__init__.py index cefab91711..e750f24293 100644 --- a/sisl/unit/tests/__init__.py +++ b/sisl/unit/tests/__init__.py @@ -1 +1 @@ -""" test init for $d """ +""" tests for sisl/unit """ diff --git a/sisl/utils/tests/__init__.py b/sisl/utils/tests/__init__.py index cefab91711..45d506d685 100644 --- a/sisl/utils/tests/__init__.py +++ b/sisl/utils/tests/__init__.py @@ -1 +1 @@ -""" test init for $d """ +""" tests for sisl/utils """