Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Category for manipulating categories using boolean operators #218

Merged
merged 6 commits into from
May 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
417 changes: 417 additions & 0 deletions sisl/_category.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions sisl/geom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,7 @@
from .nanotube import *
from .special import *
from .bilayer import *
from .category import *


__all__ = [s for s in dir() if not s.startswith('_')]
3 changes: 3 additions & 0 deletions sisl/geom/category/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .base import *
from ._neighbours import *
from ._kind import *
84 changes: 84 additions & 0 deletions sisl/geom/category/_kind.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import numpy as np

from sisl._internal import set_module, singledispatchmethod
from sisl._help import isiterable
from .base import AtomCategory, NullCategory, _sanitize_loop


__all__ = ["AtomZ", "AtomOdd", "AtomEven"]


@set_module("sisl.geom")
class AtomZ(AtomCategory):
r""" Classify atoms based on atomic number

Parameters
----------
Z : int or array_like
atomic number match for several values this is equivalent to AND
"""
__slots__ = ("_Z",)

def __init__(self, Z):
if isiterable(Z):
self._Z = set(Z)
else:
self._Z = set([Z])
# using a sorted list ensures that we can compare
super().__init__(f"Z={self._Z}")

@_sanitize_loop
def categorize(self, geometry, atoms=None):
# _sanitize_loop will ensure that atoms will always be an integer
if geometry.atoms.Z[atoms] in self._Z:
return self
return NullCategory()

def __eq__(self, other):
if isinstance(other, (list, tuple, np.ndarray)):
# this *should* use the dispatch method for different
# classes
return super().__eq__(other)

eq = self.__class__ is other.__class__
if eq:
return self._Z == other._Z
return False


@set_module("sisl.geom")
class AtomOdd(AtomCategory):
r""" Classify atoms based on indices (odd in this case)"""
__slots__ = []

def __init__(self):
super().__init__("odd")

@_sanitize_loop
def categorize(self, geometry, atoms=None):
# _sanitize_loop will ensure that atoms will always be an integer
if atoms % 2 == 1:
return self
return NullClass()

def __eq__(self, other):
return self.__class__ is other.__class__


@set_module("sisl.geom")
class AtomEven(AtomCategory):
r""" Classify atoms based on indices (even in this case)"""
__slots__ = []

def __init__(self):
super().__init__("even")

@_sanitize_loop
def categorize(self, geometry, atoms):
# _sanitize_loop will ensure that atoms will always be an integer
if atoms % 2 == 0:
return self
return NullClass()

def __eq__(self, other):
return self.__class__ is other.__class__
92 changes: 92 additions & 0 deletions sisl/geom/category/_neighbours.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from collections import namedtuple

from sisl._internal import set_module
from .base import AtomCategory, NullCategory, _sanitize_loop


__all__ = ["AtomNeighbours"]


@set_module("sisl.geom")
class AtomNeighbours(AtomCategory):
r""" Classify atoms based on number of neighbours


Parameters
----------
min : int, optional
minimum number of neighbours
max : int
maximum number of neighbours
neigh_cat : Category
a category the neighbour must be in to be counted
"""
__slots__ = ("_min", "_max", "_in")

def __init__(self, *args, **kwargs):
if len(args) > 0:
if isinstance(args[-1], AtomCategory):
kwargs["neigh_cat"] = args.pop()

self._min = 0
self._max = 2 ** 31

if len(args) == 1:
self._min = args[0]
self._max = args[0]

elif len(args) == 2:
self._min = args[0]
self._max = args[1]

if "min" in kwargs:
self._min = kwargs.pop("min")
if "max" in kwargs:
self._max = kwargs.pop("max")

if self._min == self._max:
name = f"={self._max}"
elif self._max == 2 ** 31:
name = f" ∈ [{self._min};∞["
else:
name = f" ∈ [{self._min};{self._max}]"

self._in = kwargs.get("neigh_cat", None)

# Determine name. If there are requirements for the neighbours
# then the name changes
if self._in is None:
self.set_name(f"neighbours{name}")
else:
self.set_name(f"neighbours({self._in}){name}")

@_sanitize_loop
def categorize(self, geometry, atoms=None):
""" Check that number of neighbours are matching """
idx, rij = geometry.close(atoms, R=(0.01, geometry.atoms[atoms].maxR()), ret_rij=True)
idx, rij = idx[1], rij[1]
if len(idx) < self._min:
return NullCategory()

# Check if we have a condition
if not self._in is None:
# Get category of neighbours
cat = self._in.categorize(geometry, geometry.asc2uc(idx))
idx1, rij1 = [], []
for i in range(len(idx)):
if not isinstance(cat[i], NullCategory):
idx1.append(idx[i])
rij1.append(rij[i])
idx, rij = idx1, rij1
n = len(idx)
if self._min <= n and n <= self._max:
return self
return NullCategory()

def __eq__(self, other):
eq = self.__class__ is other.__class__
if eq:
return (self._min == other._min and
self._max == other._max and
self._in == other._in)
return False
24 changes: 24 additions & 0 deletions sisl/geom/category/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from functools import wraps
from sisl._internal import set_module
from sisl._category import Category, NullCategory
from sisl.geometry import AtomCategory


__all__ = ["NullCategory", "AtomCategory"]


def _sanitize_loop(func):
@wraps(func)
def loop_func(self, geometry, atoms=None):
if atoms is None:
return [func(self, geometry, ia) for ia in geometry]
# extract based on atoms selection
atoms = geometry._sanitize_atom(atoms)
if atoms.ndim == 0:
return func(self, geometry, atoms)
return [func(self, geometry, ia) for ia in atoms]
return loop_func

#class AtomCategory(Category)
# is defined in sisl/geometry.py since it is required in
# that instance.
1 change: 1 addition & 0 deletions sisl/geom/category/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
""" tests for sisl/geom/category """
2 changes: 1 addition & 1 deletion sisl/geom/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
""" test init for $d """
""" tests for sisl/geom """
3 changes: 3 additions & 0 deletions sisl/geom/tests/test_geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import numpy as np


pytestmark = [pytest.mark.geom]


def test_basis():
a = sc(2.52, Atom['Fe'])
a = bcc(2.52, Atom['Fe'])
Expand Down
20 changes: 20 additions & 0 deletions sisl/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,19 @@
from .atom import Atom, Atoms
from .shape import Shape, Sphere, Cube
from ._namedindex import NamedIndex
from ._category import Category


__all__ = ['Geometry', 'sgeom']


# It needs to be here otherwise we can't use it in these routines
# Note how we are overwriting the module
@set_module("sisl.geom")
class AtomCategory(Category):
__slots__ = tuple()


@set_module("sisl")
class Geometry(SuperCellChild):
""" Holds atomic information, coordinates, species, lattice vectors
Expand Down Expand Up @@ -312,6 +320,18 @@ def _(self, atom):
def _(self, atom):
return (self.atoms.specie == self.atoms.index(atom)).nonzero()[0]

@_sanitize_atom.register(AtomCategory)
def _(self, atom):
# First do categorization
cat = atom.categorize(self)
def m(cat):
for ia, c in enumerate(cat):
if c == None:
pass
else:
yield ia
return _a.fromiteri(m(cat))

def _sanitize_orb(self, orbital):
""" Converts an `orbital` to index under given inputs

Expand Down
2 changes: 1 addition & 1 deletion sisl/io/gulp/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
""" test init for $d """
""" tests for sisl/io/gulp """
2 changes: 1 addition & 1 deletion sisl/io/siesta/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
""" test init for $d """
""" tests for sisl/io/siesta """
2 changes: 1 addition & 1 deletion sisl/io/tbtrans/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
""" test init for $d """
""" tests for sisl/io/tbtrans """
2 changes: 1 addition & 1 deletion sisl/io/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
""" test init for $d """
""" tests for sisl/io """
2 changes: 1 addition & 1 deletion sisl/io/vasp/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
""" test init for $d """
""" tests for sisl/io/vasp """
2 changes: 1 addition & 1 deletion sisl/io/wannier90/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
""" test init for $d """
""" tests for sisl/io/wannier90 """
2 changes: 1 addition & 1 deletion sisl/linalg/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
""" test init for $d """
""" tests for sisl/linalg """
2 changes: 1 addition & 1 deletion sisl/mixing/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
""" test init for $d """
""" tests for sisl/mixing """
2 changes: 1 addition & 1 deletion sisl/physics/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
""" test init for $d """
""" tests for sisl/physics """
2 changes: 1 addition & 1 deletion sisl/shape/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
""" test init for $d """
""" tests for sisl/shape """
2 changes: 1 addition & 1 deletion sisl/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
""" test init for $d """
""" tests for sisl """
2 changes: 1 addition & 1 deletion sisl/unit/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
""" test init for $d """
""" tests for sisl/unit """
2 changes: 1 addition & 1 deletion sisl/utils/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
""" test init for $d """
""" tests for sisl/utils """