Skip to content

Commit

Permalink
add pyright checker
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed May 25, 2024
1 parent 0bdadc2 commit 5af6c4f
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 32 deletions.
27 changes: 27 additions & 0 deletions .github/workflows/pyright.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
on:
push:
branches:
- master
tags:
- v*
pull_request:
merge_group:

concurrency:
group: ${{ github.workflow }}-${{ github.ref || github.run_id }}
cancel-in-progress: true

name: Type checker
jobs:
test:
name: pyright
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.12'
- run: pipx run uv pip install --system -e .
- uses: jakebailey/pyright-action@v2
with:
version: 1.1.364
22 changes: 16 additions & 6 deletions mddatasetbuilder/datasetbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import tempfile
import time
from collections import Counter, defaultdict
from typing import TYPE_CHECKING, List, Optional

import numpy as np
from ase.data import atomic_numbers
Expand All @@ -29,14 +30,15 @@
from sklearn.cluster import MiniBatchKMeans

from ._logger import logger
from .detect import Detect
from .detect import Detect, DetectDump
from .utils import (
bytestolist,
listtobytes,
must_be_list,
read_compressed_block,
run_mp,
)
from ase.atoms import Atoms


class DatasetBuilder:
Expand Down Expand Up @@ -97,7 +99,7 @@ def __init__(
nproc=None,
pbc=True,
fragment=False,
errorfilename=None,
errorfilename: Optional[List[str]]=None,
errorlimit=0.0,
atom_pref=False,
):
Expand Down Expand Up @@ -302,13 +304,15 @@ def _writestepmatrix(self, item):
step, lines = item
results = []
if step in self.dstep:
assert isinstance(self.crddetector, DetectDump)
step_atoms, _ = self.crddetector.readcrd(lines)
for atoma in self.dstep[step]:
# atom ID starts from 1
distances = step_atoms.get_distances(
atoma - 1, range(len(step_atoms)), mic=True
)
cutoffatoms = step_atoms[distances < self.cutoff]
assert isinstance(cutoffatoms, Atoms)
symbols = cutoffatoms.get_chemical_symbols()
results.append(
(
Expand Down Expand Up @@ -364,7 +368,7 @@ def _clusterdatas(cls, X, n_clusters, n_each=1):
min_max_scaler = preprocessing.MinMaxScaler()
X = np.array(min_max_scaler.fit_transform(X))
clus = MiniBatchKMeans(
n_clusters=n_clusters, init_size=(min(3 * n_clusters, len(X))), n_init=3
n_clusters=n_clusters, init_size=(min(3 * n_clusters, len(X))), n_init=3 # type: ignore
)
labels = clus.fit_predict(X)
choosedidx = []
Expand Down Expand Up @@ -477,6 +481,7 @@ def _convertgjf(self, gjffilename, takenatomidindex, atoms_whole):
multiplicity_whole = sum(multiplicities) - len(takenatomidindex) + 1
multiplicity_whole_str = f"0 {multiplicity_whole}"
title = "\nGenerated by MDDatasetMaker (Author: Jinzhe Zeng)\n"
connect = None
if len(self.qmkeywords) > 1:
connect = "\n--link1--\n"
chk = [f"%chk={os.path.splitext(os.path.basename(gjffilename))[0]}.chk"]
Expand Down Expand Up @@ -505,6 +510,7 @@ def _convertgjf(self, gjffilename, takenatomidindex, atoms_whole):
]
)
for kw in itertools.islice(self.qmkeywords, 1, None):
assert connect is not None

Check warning on line 513 in mddatasetbuilder/datasetbuilder.py

View check run for this annotation

Codecov / codecov/patch

mddatasetbuilder/datasetbuilder.py#L513

Added line #L513 was not covered by tests
buff.extend((connect, *chk, kw, title, f"0 {multiplicity_whole}", "\n"))
buff.append("\n")
with open(gjffilename, "w") as f:
Expand All @@ -530,10 +536,12 @@ def _writestepxyzfile(self, item):
results = 0
if step in self.dstep:
if len(lines) == 2:
assert isinstance(self.crddetector, DetectDump)
step_atoms, _ = self.crddetector.readcrd(lines[0])
molecules = self.bonddetector.readmolecule(lines[1])
molecules, _ = self.bonddetector.readmolecule(lines[1])
else:
molecules, step_atoms = self.bonddetector.readmolecule(lines)
assert step_atoms is not None
for atoma, trajatomfilename, itype, itotal in self.dstep[step]:
# update counter
folder = str(itotal // 1000).zfill(self.foldermaxlength)
Expand All @@ -555,9 +563,10 @@ def _writestepxyzfile(self, item):
idsum += len(mol_atomid)
idx = np.concatenate(takenatomids)
cutoffatoms = step_atoms[idx]
cutoffatoms[np.nonzero(idx == atoma - 1)[0][0]].tag = 1
assert isinstance(cutoffatoms, Atoms)
cutoffatoms[np.nonzero(idx == atoma - 1)[0][0]].tag = 1 # type: ignore
cutoffatoms.wrap(
center=step_atoms[atoma - 1].position
center=step_atoms[atoma - 1].position # type: ignore
/ cutoffatoms.get_cell_lengths_and_angles()[0:3],
pbc=cutoffatoms.get_pbc(),
)
Expand Down Expand Up @@ -632,6 +641,7 @@ def erroriter(self):
str
the line of model deviation
"""
assert self.errorfilename is not None

Check warning on line 644 in mddatasetbuilder/datasetbuilder.py

View check run for this annotation

Codecov / codecov/patch

mddatasetbuilder/datasetbuilder.py#L644

Added line #L644 was not covered by tests
fns = must_be_list(self.errorfilename)
for fn in fns:
with open(fn) as f:
Expand Down
69 changes: 46 additions & 23 deletions mddatasetbuilder/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from abc import ABCMeta, abstractmethod
from collections import defaultdict
from enum import Enum, auto
from typing import List, Optional, Tuple, Union

import numpy as np
from ase import Atom, Atoms
Expand All @@ -24,16 +25,16 @@ def __init__(self, filename, atomname, pbc, errorlimit=None, errorfilename=None)
self.steplinenum = self._readN()

@abstractmethod
def _readN(self):
def _readN(self) -> int:
pass

@abstractmethod
def readatombondtype(self, item):
def readatombondtype(self, item) -> Tuple[dict, int]:
"""Read bond types of atoms such as C1111."""
pass

@abstractmethod
def readmolecule(self, lines):
def readmolecule(self, lines) -> Tuple[List[List[int]], Optional[Atoms]]:
"""Read molecules."""
pass

Expand All @@ -54,6 +55,10 @@ class DetectBond(Detect):

def _readN(self):
"""Read bondfile N, which should be at very beginning."""
N = None
atomtype = None
stepaindex = None
stepbindex = None
# copy from reacnetgenerator on 2018-12-15
with open(
self.filename if isinstance(self.filename, str) else self.filename[0]
Expand All @@ -72,7 +77,10 @@ def _readN(self):
atomtype = np.zeros(N, dtype=int)
else:
s = line.split()
assert atomtype is not None
atomtype[int(s[0]) - 1] = int(s[1])
if stepaindex is None or stepbindex is None or N is None or atomtype is None:
raise RuntimeError("The bond file is not completed")

Check warning on line 83 in mddatasetbuilder/detect.py

View check run for this annotation

Codecov / codecov/patch

mddatasetbuilder/detect.py#L83

Added line #L83 was not covered by tests
steplinenum = stepbindex - stepaindex
self._N = N
self.atomtype = atomtype
Expand Down Expand Up @@ -110,7 +118,7 @@ def readatombondtype(self, item):
)
return d, step

def readmolecule(self, lines):
def readmolecule(self, lines) -> Tuple[List[List[int]], Optional[Atoms]]:
"""Return molecules from lines.
Parameters
Expand All @@ -122,16 +130,18 @@ def readmolecule(self, lines):
-------
molecules: list
Indexes of atoms in molecules.
None
None
"""
# copy from reacnetgenerator on 2018-12-15
bond = [None] * self._N
bond: List[Optional[List[int]]] = [None] * self._N
for line in lines:
if line:
if not line.startswith("#"):
s = line.split()
bond[int(s[0]) - 1] = [int(x) - 1 for x in s[3 : 3 + int(s[2])]]
molecules = connectmolecule(bond)
return molecules
return molecules, None


class DetectDump(Detect):
Expand All @@ -140,9 +150,14 @@ class DetectDump(Detect):
def _readN(self):
# copy from reacnetgenerator on 2018-12-15
iscompleted = False
N = None
atomtype = None
stepaindex = None
stepbindex = None
with open(
self.filename if isinstance(self.filename, str) else self.filename[0]
) as f:
linecontent = None
for index, line in enumerate(f):
if line.startswith("ITEM:"):
linecontent = self.LineType.linecontent(line)
Expand All @@ -154,7 +169,9 @@ def _readN(self):
self.yidx = keys.index("y") - 2
self.zidx = keys.index("z") - 2
else:
if linecontent == self.LineType.NUMBER:
if linecontent is None:
raise RuntimeError("No ITEM: in the dump file")

Check warning on line 173 in mddatasetbuilder/detect.py

View check run for this annotation

Codecov / codecov/patch

mddatasetbuilder/detect.py#L173

Added line #L173 was not covered by tests
elif linecontent == self.LineType.NUMBER:
if iscompleted:
stepbindex = index
break
Expand All @@ -165,7 +182,10 @@ def _readN(self):
atomtype = np.zeros(N, dtype=int)
elif linecontent == self.LineType.ATOMS:
s = line.split()
assert atomtype is not None
atomtype[int(s[self.id_idx]) - 1] = int(s[self.tidx])
if stepaindex is None or stepbindex is None or N is None or atomtype is None:
raise RuntimeError("The dump file is not completed")

Check warning on line 188 in mddatasetbuilder/detect.py

View check run for this annotation

Codecov / codecov/patch

mddatasetbuilder/detect.py#L188

Added line #L188 was not covered by tests
steplinenum = stepbindex - stepaindex
self._N = N
self.atomtype = atomtype
Expand All @@ -188,21 +208,21 @@ def readatombondtype(self, item):
the step index
"""
(step, lines), needlerror = item
if needlerror:
trajline, errorline = lines
lerror = np.fromstring(errorline, dtype=float, sep=" ")[7:]
lerror: Optional[Union[np.ndarray, List[float]]] = None
d = defaultdict(list)
step_atoms, ids = self.readcrd(lines)
if needlerror:
trajline, errorline = lines
lerror = np.fromstring(errorline, dtype=float, sep=" ")[7:]

Check warning on line 216 in mddatasetbuilder/detect.py

View check run for this annotation

Codecov / codecov/patch

mddatasetbuilder/detect.py#L215-L216

Added lines #L215 - L216 were not covered by tests
lerror = [x for (y, x) in sorted(zip(ids, lerror))]
level = self._crd2bond(step_atoms, readlevel=True)
for i, (n, l) in enumerate(zip(self.atomnames, level)):
if not needlerror or lerror[i] > self.errorlimit:
if lerror is None or (self.errorlimit is not None and lerror[i] > self.errorlimit):
# Note that atom id starts from 1
d[pickle.dumps((n, sorted(l)))].append(i + 1)
return d, step

def readmolecule(self, lines):
def readmolecule(self, lines) -> Tuple[List[List[int]], Optional[Atoms]]:
"""Return molecules from lines.
Parameters
Expand Down Expand Up @@ -250,11 +270,11 @@ def _crd2bond(cls, step_atoms, readlevel):
mol.CloneData(uc)
mol.SetPeriodicMol()
mol.ConnectTheDots()
if not readlevel:
bond = [[] for i in range(atomnumber)]
else:
# when readlevel is False, bond is used to store connected atoms
# otherwise, bondlevel is used to store bond orders
bond = [[] for i in range(atomnumber)]
if readlevel:
mol.PerceiveBondOrders()
bondlevel = [[] for i in range(atomnumber)]
mol.EndModify()
for b in openbabel.OBMolBondIter(mol):
s1 = b.GetBeginAtom().GetId()
Expand All @@ -264,23 +284,26 @@ def _crd2bond(cls, step_atoms, readlevel):
bond[s2].append(s1)
else:
level = b.GetBondOrder()
bondlevel[s1].append(level)
bondlevel[s2].append(level)
return bondlevel if readlevel else bond
bond[s1].append(level)
bond[s2].append(level)
return bond

def readcrd(self, item):
def readcrd(self, item) -> tuple[Atoms, List[int]]:
"""Only this function can read coordinates."""
lines = item
# box information
ss = []
step_atoms = []
ids = []
linecontent = None
for line in lines:
if line:
if line.startswith("ITEM:"):
linecontent = self.LineType.linecontent(line)
else:
if linecontent == self.LineType.ATOMS:
if linecontent is None:
raise RuntimeError("No ITEM: in the dump file")

Check warning on line 305 in mddatasetbuilder/detect.py

View check run for this annotation

Codecov / codecov/patch

mddatasetbuilder/detect.py#L305

Added line #L305 was not covered by tests
elif linecontent == self.LineType.ATOMS:
s = line.split()
ids.append(int(s[self.id_idx]))
step_atoms.append(
Expand Down Expand Up @@ -314,8 +337,8 @@ def readcrd(self, item):
[[xhi - xlo, 0.0, 0.0], [xy, yhi - ylo, 0.0], [xz, yz, zhi - zlo]]
)
# sort by ID
step_atoms = [x for (y, x) in sorted(zip(ids, step_atoms))]
step_atoms = Atoms(step_atoms, cell=boxsize, pbc=self.pbc)
step_atoms_ = [x for (y, x) in sorted(zip(ids, step_atoms))]
step_atoms = Atoms(step_atoms_, cell=boxsize, pbc=self.pbc)
return step_atoms, ids

class LineType(Enum):
Expand Down
15 changes: 12 additions & 3 deletions mddatasetbuilder/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import itertools
import pickle
from multiprocessing import Pool, Semaphore
from typing import BinaryIO, Union
from typing import BinaryIO, List, TypeVar, Union, overload

import lz4.frame
from tqdm.auto import tqdm
Expand Down Expand Up @@ -192,7 +192,7 @@ def bytestolist(x):
object
The decompressed object.
"""
return pickle.loads(decompress(x, isbytes=True))
return pickle.loads(decompress(x, isbytes=True)) # type: ignore


def run_mp(nproc, **arg):
Expand Down Expand Up @@ -231,7 +231,16 @@ def run_mp(nproc, **arg):
pool.join()


def must_be_list(obj):

T = TypeVar('T')
@overload
def must_be_list(obj: List[T]) -> List[T]:
...

Check warning on line 238 in mddatasetbuilder/utils.py

View check run for this annotation

Codecov / codecov/patch

mddatasetbuilder/utils.py#L238

Added line #L238 was not covered by tests
@overload
def must_be_list(obj: T) -> List[T]:
...

Check warning on line 241 in mddatasetbuilder/utils.py

View check run for this annotation

Codecov / codecov/patch

mddatasetbuilder/utils.py#L241

Added line #L241 was not covered by tests

def must_be_list(obj: Union[T, List[T]]) -> List[T]:
"""Convert a object to a list if the object is not a list.
Parameters
Expand Down

0 comments on commit 5af6c4f

Please sign in to comment.