Skip to content

Commit

Permalink
ci: add pyright checker (#215)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and pre-commit-ci[bot] authored May 25, 2024
1 parent 56eb530 commit e76b413
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 35 deletions.
27 changes: 27 additions & 0 deletions .github/workflows/pyright.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
on:
push:
branches:
- master
tags:
- v*
pull_request:
merge_group:

concurrency:
group: ${{ github.workflow }}-${{ github.ref || github.run_id }}
cancel-in-progress: true

name: Type checker
jobs:
test:
name: pyright
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.12'
- run: pipx run uv pip install --system -e .
- uses: jakebailey/pyright-action@v2
with:
version: 1.1.364
24 changes: 18 additions & 6 deletions mddatasetbuilder/datasetbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,17 @@
import tempfile
import time
from collections import Counter, defaultdict
from typing import List, Optional

import numpy as np
from ase.atoms import Atoms
from ase.data import atomic_numbers
from ase.io import write as write_xyz
from sklearn import preprocessing
from sklearn.cluster import MiniBatchKMeans

from ._logger import logger
from .detect import Detect
from .detect import Detect, DetectDump
from .utils import (
bytestolist,
listtobytes,
Expand Down Expand Up @@ -97,7 +99,7 @@ def __init__(
nproc=None,
pbc=True,
fragment=False,
errorfilename=None,
errorfilename: Optional[List[str]] = None,
errorlimit=0.0,
atom_pref=False,
):
Expand Down Expand Up @@ -302,13 +304,15 @@ def _writestepmatrix(self, item):
step, lines = item
results = []
if step in self.dstep:
assert isinstance(self.crddetector, DetectDump)
step_atoms, _ = self.crddetector.readcrd(lines)
for atoma in self.dstep[step]:
# atom ID starts from 1
distances = step_atoms.get_distances(
atoma - 1, range(len(step_atoms)), mic=True
)
cutoffatoms = step_atoms[distances < self.cutoff]
assert isinstance(cutoffatoms, Atoms)
symbols = cutoffatoms.get_chemical_symbols()
results.append(
(
Expand Down Expand Up @@ -364,7 +368,9 @@ def _clusterdatas(cls, X, n_clusters, n_each=1):
min_max_scaler = preprocessing.MinMaxScaler()
X = np.array(min_max_scaler.fit_transform(X))
clus = MiniBatchKMeans(
n_clusters=n_clusters, init_size=(min(3 * n_clusters, len(X))), n_init=3
n_clusters=n_clusters,
init_size=(min(3 * n_clusters, len(X))),
n_init=3, # type: ignore
)
labels = clus.fit_predict(X)
choosedidx = []
Expand Down Expand Up @@ -477,6 +483,7 @@ def _convertgjf(self, gjffilename, takenatomidindex, atoms_whole):
multiplicity_whole = sum(multiplicities) - len(takenatomidindex) + 1
multiplicity_whole_str = f"0 {multiplicity_whole}"
title = "\nGenerated by MDDatasetMaker (Author: Jinzhe Zeng)\n"
connect = None
if len(self.qmkeywords) > 1:
connect = "\n--link1--\n"
chk = [f"%chk={os.path.splitext(os.path.basename(gjffilename))[0]}.chk"]
Expand Down Expand Up @@ -505,6 +512,7 @@ def _convertgjf(self, gjffilename, takenatomidindex, atoms_whole):
]
)
for kw in itertools.islice(self.qmkeywords, 1, None):
assert connect is not None
buff.extend((connect, *chk, kw, title, f"0 {multiplicity_whole}", "\n"))
buff.append("\n")
with open(gjffilename, "w") as f:
Expand All @@ -530,10 +538,12 @@ def _writestepxyzfile(self, item):
results = 0
if step in self.dstep:
if len(lines) == 2:
assert isinstance(self.crddetector, DetectDump)
step_atoms, _ = self.crddetector.readcrd(lines[0])
molecules = self.bonddetector.readmolecule(lines[1])
molecules, _ = self.bonddetector.readmolecule(lines[1])
else:
molecules, step_atoms = self.bonddetector.readmolecule(lines)
assert step_atoms is not None
for atoma, trajatomfilename, itype, itotal in self.dstep[step]:
# update counter
folder = str(itotal // 1000).zfill(self.foldermaxlength)
Expand All @@ -555,9 +565,10 @@ def _writestepxyzfile(self, item):
idsum += len(mol_atomid)
idx = np.concatenate(takenatomids)
cutoffatoms = step_atoms[idx]
cutoffatoms[np.nonzero(idx == atoma - 1)[0][0]].tag = 1
assert isinstance(cutoffatoms, Atoms)
cutoffatoms[np.nonzero(idx == atoma - 1)[0][0]].tag = 1 # type: ignore
cutoffatoms.wrap(
center=step_atoms[atoma - 1].position
center=step_atoms[atoma - 1].position # type: ignore
/ cutoffatoms.get_cell_lengths_and_angles()[0:3],
pbc=cutoffatoms.get_pbc(),
)
Expand Down Expand Up @@ -632,6 +643,7 @@ def erroriter(self):
str
the line of model deviation
"""
assert self.errorfilename is not None
fns = must_be_list(self.errorfilename)
for fn in fns:
with open(fn) as f:
Expand Down
2 changes: 1 addition & 1 deletion mddatasetbuilder/deepmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _searchpath(self):
self._preparedeepmdforLOG, tqdm(logfiles, disable=None)
):
multi_systems.append(system)
multi_systems.to_deepmd_npy(self.deepmd_dir)
multi_systems.to_deepmd_npy(self.deepmd_dir) # type: ignore
for formula, system in multi_systems.systems.items():
self.system_paths.append(os.path.join(self.deepmd_dir, formula))
self.batch_size.append(
Expand Down
76 changes: 51 additions & 25 deletions mddatasetbuilder/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
from abc import ABCMeta, abstractmethod
from collections import defaultdict
from enum import Enum, auto
from typing import List, Optional, Tuple, Union, cast

import numpy as np
from ase import Atom, Atoms
from openbabel import openbabel

from .dps import dps as connectmolecule
from mddatasetbuilder.dps import dps as connectmolecule

Check warning on line 13 in mddatasetbuilder/detect.py

View workflow job for this annotation

GitHub Actions / pyright

Import "mddatasetbuilder.dps" could not be resolved from source (reportMissingModuleSource)


class Detect(metaclass=ABCMeta):
Expand All @@ -24,16 +25,16 @@ def __init__(self, filename, atomname, pbc, errorlimit=None, errorfilename=None)
self.steplinenum = self._readN()

@abstractmethod
def _readN(self):
def _readN(self) -> int:
pass

@abstractmethod
def readatombondtype(self, item):
def readatombondtype(self, item) -> Tuple[dict, int]:
"""Read bond types of atoms such as C1111."""
pass

@abstractmethod
def readmolecule(self, lines):
def readmolecule(self, lines) -> Tuple[List[List[int]], Optional[Atoms]]:
"""Read molecules."""
pass

Expand All @@ -54,6 +55,10 @@ class DetectBond(Detect):

def _readN(self):
"""Read bondfile N, which should be at very beginning."""
N = None
atomtype = None
stepaindex = None
stepbindex = None
# copy from reacnetgenerator on 2018-12-15
with open(
self.filename if isinstance(self.filename, str) else self.filename[0]
Expand All @@ -72,7 +77,10 @@ def _readN(self):
atomtype = np.zeros(N, dtype=int)
else:
s = line.split()
assert atomtype is not None
atomtype[int(s[0]) - 1] = int(s[1])
if stepaindex is None or stepbindex is None or N is None or atomtype is None:
raise RuntimeError("The bond file is not completed")
steplinenum = stepbindex - stepaindex
self._N = N
self.atomtype = atomtype
Expand Down Expand Up @@ -110,7 +118,7 @@ def readatombondtype(self, item):
)
return d, step

def readmolecule(self, lines):
def readmolecule(self, lines) -> Tuple[List[List[int]], Optional[Atoms]]:
"""Return molecules from lines.
Parameters
Expand All @@ -122,16 +130,19 @@ def readmolecule(self, lines):
-------
molecules: list
Indexes of atoms in molecules.
None
None
"""
# copy from reacnetgenerator on 2018-12-15
bond = [None] * self._N
bond: List[Optional[List[int]]] = [None] * self._N
for line in lines:
if line:
if not line.startswith("#"):
s = line.split()
bond[int(s[0]) - 1] = [int(x) - 1 for x in s[3 : 3 + int(s[2])]]
molecules = connectmolecule(bond)
return molecules
bond_ = cast(List[List[int]], bond)
molecules = connectmolecule(bond_)
return molecules, None


class DetectDump(Detect):
Expand All @@ -140,9 +151,14 @@ class DetectDump(Detect):
def _readN(self):
# copy from reacnetgenerator on 2018-12-15
iscompleted = False
N = None
atomtype = None
stepaindex = None
stepbindex = None
with open(
self.filename if isinstance(self.filename, str) else self.filename[0]
) as f:
linecontent = None
for index, line in enumerate(f):
if line.startswith("ITEM:"):
linecontent = self.LineType.linecontent(line)
Expand All @@ -154,7 +170,9 @@ def _readN(self):
self.yidx = keys.index("y") - 2
self.zidx = keys.index("z") - 2
else:
if linecontent == self.LineType.NUMBER:
if linecontent is None:
raise RuntimeError("No ITEM: in the dump file")
elif linecontent == self.LineType.NUMBER:
if iscompleted:
stepbindex = index
break
Expand All @@ -165,7 +183,10 @@ def _readN(self):
atomtype = np.zeros(N, dtype=int)
elif linecontent == self.LineType.ATOMS:
s = line.split()
assert atomtype is not None
atomtype[int(s[self.id_idx]) - 1] = int(s[self.tidx])
if stepaindex is None or stepbindex is None or N is None or atomtype is None:
raise RuntimeError("The dump file is not completed")
steplinenum = stepbindex - stepaindex
self._N = N
self.atomtype = atomtype
Expand All @@ -188,21 +209,23 @@ def readatombondtype(self, item):
the step index
"""
(step, lines), needlerror = item
if needlerror:
trajline, errorline = lines
lerror = np.fromstring(errorline, dtype=float, sep=" ")[7:]
lerror: Optional[Union[np.ndarray, List[float]]] = None
d = defaultdict(list)
step_atoms, ids = self.readcrd(lines)
if needlerror:
trajline, errorline = lines
lerror = np.fromstring(errorline, dtype=float, sep=" ")[7:]
lerror = [x for (y, x) in sorted(zip(ids, lerror))]
level = self._crd2bond(step_atoms, readlevel=True)
for i, (n, l) in enumerate(zip(self.atomnames, level)):
if not needlerror or lerror[i] > self.errorlimit:
if lerror is None or (
self.errorlimit is not None and lerror[i] > self.errorlimit
):
# Note that atom id starts from 1
d[pickle.dumps((n, sorted(l)))].append(i + 1)
return d, step

def readmolecule(self, lines):
def readmolecule(self, lines) -> Tuple[List[List[int]], Optional[Atoms]]:
"""Return molecules from lines.
Parameters
Expand Down Expand Up @@ -250,11 +273,11 @@ def _crd2bond(cls, step_atoms, readlevel):
mol.CloneData(uc)
mol.SetPeriodicMol()
mol.ConnectTheDots()
if not readlevel:
bond = [[] for i in range(atomnumber)]
else:
# when readlevel is False, bond is used to store connected atoms
# otherwise, bondlevel is used to store bond orders
bond = [[] for i in range(atomnumber)]
if readlevel:
mol.PerceiveBondOrders()
bondlevel = [[] for i in range(atomnumber)]
mol.EndModify()
for b in openbabel.OBMolBondIter(mol):
s1 = b.GetBeginAtom().GetId()
Expand All @@ -264,23 +287,26 @@ def _crd2bond(cls, step_atoms, readlevel):
bond[s2].append(s1)
else:
level = b.GetBondOrder()
bondlevel[s1].append(level)
bondlevel[s2].append(level)
return bondlevel if readlevel else bond
bond[s1].append(level)
bond[s2].append(level)
return bond

def readcrd(self, item):
def readcrd(self, item) -> Tuple[Atoms, List[int]]:
"""Only this function can read coordinates."""
lines = item
# box information
ss = []
step_atoms = []
ids = []
linecontent = None
for line in lines:
if line:
if line.startswith("ITEM:"):
linecontent = self.LineType.linecontent(line)
else:
if linecontent == self.LineType.ATOMS:
if linecontent is None:
raise RuntimeError("No ITEM: in the dump file")
elif linecontent == self.LineType.ATOMS:
s = line.split()
ids.append(int(s[self.id_idx]))
step_atoms.append(
Expand Down Expand Up @@ -314,8 +340,8 @@ def readcrd(self, item):
[[xhi - xlo, 0.0, 0.0], [xy, yhi - ylo, 0.0], [xz, yz, zhi - zlo]]
)
# sort by ID
step_atoms = [x for (y, x) in sorted(zip(ids, step_atoms))]
step_atoms = Atoms(step_atoms, cell=boxsize, pbc=self.pbc)
step_atoms_ = [x for (y, x) in sorted(zip(ids, step_atoms))]
step_atoms = Atoms(step_atoms_, cell=boxsize, pbc=self.pbc)
return step_atoms, ids

class LineType(Enum):
Expand Down
1 change: 1 addition & 0 deletions mddatasetbuilder/dps.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
def dps(bonds: list[list[int]]) -> list[list[int]]: ...
Empty file added mddatasetbuilder/py.typed
Empty file.
15 changes: 12 additions & 3 deletions mddatasetbuilder/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import itertools
import pickle
from multiprocessing import Pool, Semaphore
from typing import BinaryIO, Union
from typing import BinaryIO, List, TypeVar, Union, overload

import lz4.frame
from tqdm.auto import tqdm
Expand Down Expand Up @@ -192,7 +192,7 @@ def bytestolist(x):
object
The decompressed object.
"""
return pickle.loads(decompress(x, isbytes=True))
return pickle.loads(decompress(x, isbytes=True)) # type: ignore


def run_mp(nproc, **arg):
Expand Down Expand Up @@ -231,7 +231,16 @@ def run_mp(nproc, **arg):
pool.join()


def must_be_list(obj):
T = TypeVar("T")


@overload
def must_be_list(obj: List[T]) -> List[T]: ...
@overload
def must_be_list(obj: T) -> List[T]: ...


def must_be_list(obj: Union[T, List[T]]) -> List[T]:
"""Convert a object to a list if the object is not a list.
Parameters
Expand Down
Loading

0 comments on commit e76b413

Please sign in to comment.