diff --git a/.github/workflows/pyright.yml b/.github/workflows/pyright.yml new file mode 100644 index 0000000..407fb4a --- /dev/null +++ b/.github/workflows/pyright.yml @@ -0,0 +1,27 @@ +on: + push: + branches: + - master + tags: + - v* + pull_request: + merge_group: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref || github.run_id }} + cancel-in-progress: true + +name: Type checker +jobs: + test: + name: pyright + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: '3.12' + - run: pipx run uv pip install --system -e . + - uses: jakebailey/pyright-action@v2 + with: + version: 1.1.364 \ No newline at end of file diff --git a/mddatasetbuilder/datasetbuilder.py b/mddatasetbuilder/datasetbuilder.py index 880e3d1..6f01544 100644 --- a/mddatasetbuilder/datasetbuilder.py +++ b/mddatasetbuilder/datasetbuilder.py @@ -21,6 +21,7 @@ import tempfile import time from collections import Counter, defaultdict +from typing import TYPE_CHECKING, List, Optional import numpy as np from ase.data import atomic_numbers @@ -29,7 +30,7 @@ from sklearn.cluster import MiniBatchKMeans from ._logger import logger -from .detect import Detect +from .detect import Detect, DetectDump from .utils import ( bytestolist, listtobytes, @@ -37,6 +38,7 @@ read_compressed_block, run_mp, ) +from ase.atoms import Atoms class DatasetBuilder: @@ -97,7 +99,7 @@ def __init__( nproc=None, pbc=True, fragment=False, - errorfilename=None, + errorfilename: Optional[List[str]]=None, errorlimit=0.0, atom_pref=False, ): @@ -302,6 +304,7 @@ def _writestepmatrix(self, item): step, lines = item results = [] if step in self.dstep: + assert isinstance(self.crddetector, DetectDump) step_atoms, _ = self.crddetector.readcrd(lines) for atoma in self.dstep[step]: # atom ID starts from 1 @@ -309,6 +312,7 @@ def _writestepmatrix(self, item): atoma - 1, range(len(step_atoms)), mic=True ) cutoffatoms = step_atoms[distances < self.cutoff] + assert isinstance(cutoffatoms, Atoms) symbols = cutoffatoms.get_chemical_symbols() results.append( ( @@ -364,7 +368,7 @@ def _clusterdatas(cls, X, n_clusters, n_each=1): min_max_scaler = preprocessing.MinMaxScaler() X = np.array(min_max_scaler.fit_transform(X)) clus = MiniBatchKMeans( - n_clusters=n_clusters, init_size=(min(3 * n_clusters, len(X))), n_init=3 + n_clusters=n_clusters, init_size=(min(3 * n_clusters, len(X))), n_init=3 # type: ignore ) labels = clus.fit_predict(X) choosedidx = [] @@ -477,6 +481,7 @@ def _convertgjf(self, gjffilename, takenatomidindex, atoms_whole): multiplicity_whole = sum(multiplicities) - len(takenatomidindex) + 1 multiplicity_whole_str = f"0 {multiplicity_whole}" title = "\nGenerated by MDDatasetMaker (Author: Jinzhe Zeng)\n" + connect = None if len(self.qmkeywords) > 1: connect = "\n--link1--\n" chk = [f"%chk={os.path.splitext(os.path.basename(gjffilename))[0]}.chk"] @@ -505,6 +510,7 @@ def _convertgjf(self, gjffilename, takenatomidindex, atoms_whole): ] ) for kw in itertools.islice(self.qmkeywords, 1, None): + assert connect is not None buff.extend((connect, *chk, kw, title, f"0 {multiplicity_whole}", "\n")) buff.append("\n") with open(gjffilename, "w") as f: @@ -530,10 +536,12 @@ def _writestepxyzfile(self, item): results = 0 if step in self.dstep: if len(lines) == 2: + assert isinstance(self.crddetector, DetectDump) step_atoms, _ = self.crddetector.readcrd(lines[0]) - molecules = self.bonddetector.readmolecule(lines[1]) + molecules, _ = self.bonddetector.readmolecule(lines[1]) else: molecules, step_atoms = self.bonddetector.readmolecule(lines) + assert step_atoms is not None for atoma, trajatomfilename, itype, itotal in self.dstep[step]: # update counter folder = str(itotal // 1000).zfill(self.foldermaxlength) @@ -555,9 +563,10 @@ def _writestepxyzfile(self, item): idsum += len(mol_atomid) idx = np.concatenate(takenatomids) cutoffatoms = step_atoms[idx] - cutoffatoms[np.nonzero(idx == atoma - 1)[0][0]].tag = 1 + assert isinstance(cutoffatoms, Atoms) + cutoffatoms[np.nonzero(idx == atoma - 1)[0][0]].tag = 1 # type: ignore cutoffatoms.wrap( - center=step_atoms[atoma - 1].position + center=step_atoms[atoma - 1].position # type: ignore / cutoffatoms.get_cell_lengths_and_angles()[0:3], pbc=cutoffatoms.get_pbc(), ) @@ -632,6 +641,7 @@ def erroriter(self): str the line of model deviation """ + assert self.errorfilename is not None fns = must_be_list(self.errorfilename) for fn in fns: with open(fn) as f: diff --git a/mddatasetbuilder/detect.py b/mddatasetbuilder/detect.py index 54624b8..75b86da 100644 --- a/mddatasetbuilder/detect.py +++ b/mddatasetbuilder/detect.py @@ -4,6 +4,7 @@ from abc import ABCMeta, abstractmethod from collections import defaultdict from enum import Enum, auto +from typing import List, Optional, Tuple, Union import numpy as np from ase import Atom, Atoms @@ -24,16 +25,16 @@ def __init__(self, filename, atomname, pbc, errorlimit=None, errorfilename=None) self.steplinenum = self._readN() @abstractmethod - def _readN(self): + def _readN(self) -> int: pass @abstractmethod - def readatombondtype(self, item): + def readatombondtype(self, item) -> Tuple[dict, int]: """Read bond types of atoms such as C1111.""" pass @abstractmethod - def readmolecule(self, lines): + def readmolecule(self, lines) -> Tuple[List[List[int]], Optional[Atoms]]: """Read molecules.""" pass @@ -54,6 +55,10 @@ class DetectBond(Detect): def _readN(self): """Read bondfile N, which should be at very beginning.""" + N = None + atomtype = None + stepaindex = None + stepbindex = None # copy from reacnetgenerator on 2018-12-15 with open( self.filename if isinstance(self.filename, str) else self.filename[0] @@ -72,7 +77,10 @@ def _readN(self): atomtype = np.zeros(N, dtype=int) else: s = line.split() + assert atomtype is not None atomtype[int(s[0]) - 1] = int(s[1]) + if stepaindex is None or stepbindex is None or N is None or atomtype is None: + raise RuntimeError("The bond file is not completed") steplinenum = stepbindex - stepaindex self._N = N self.atomtype = atomtype @@ -110,7 +118,7 @@ def readatombondtype(self, item): ) return d, step - def readmolecule(self, lines): + def readmolecule(self, lines) -> Tuple[List[List[int]], Optional[Atoms]]: """Return molecules from lines. Parameters @@ -122,16 +130,18 @@ def readmolecule(self, lines): ------- molecules: list Indexes of atoms in molecules. + None + None """ # copy from reacnetgenerator on 2018-12-15 - bond = [None] * self._N + bond: List[Optional[List[int]]] = [None] * self._N for line in lines: if line: if not line.startswith("#"): s = line.split() bond[int(s[0]) - 1] = [int(x) - 1 for x in s[3 : 3 + int(s[2])]] molecules = connectmolecule(bond) - return molecules + return molecules, None class DetectDump(Detect): @@ -140,9 +150,14 @@ class DetectDump(Detect): def _readN(self): # copy from reacnetgenerator on 2018-12-15 iscompleted = False + N = None + atomtype = None + stepaindex = None + stepbindex = None with open( self.filename if isinstance(self.filename, str) else self.filename[0] ) as f: + linecontent = None for index, line in enumerate(f): if line.startswith("ITEM:"): linecontent = self.LineType.linecontent(line) @@ -154,7 +169,9 @@ def _readN(self): self.yidx = keys.index("y") - 2 self.zidx = keys.index("z") - 2 else: - if linecontent == self.LineType.NUMBER: + if linecontent is None: + raise RuntimeError("No ITEM: in the dump file") + elif linecontent == self.LineType.NUMBER: if iscompleted: stepbindex = index break @@ -165,7 +182,10 @@ def _readN(self): atomtype = np.zeros(N, dtype=int) elif linecontent == self.LineType.ATOMS: s = line.split() + assert atomtype is not None atomtype[int(s[self.id_idx]) - 1] = int(s[self.tidx]) + if stepaindex is None or stepbindex is None or N is None or atomtype is None: + raise RuntimeError("The dump file is not completed") steplinenum = stepbindex - stepaindex self._N = N self.atomtype = atomtype @@ -188,21 +208,21 @@ def readatombondtype(self, item): the step index """ (step, lines), needlerror = item - if needlerror: - trajline, errorline = lines - lerror = np.fromstring(errorline, dtype=float, sep=" ")[7:] + lerror: Optional[Union[np.ndarray, List[float]]] = None d = defaultdict(list) step_atoms, ids = self.readcrd(lines) if needlerror: + trajline, errorline = lines + lerror = np.fromstring(errorline, dtype=float, sep=" ")[7:] lerror = [x for (y, x) in sorted(zip(ids, lerror))] level = self._crd2bond(step_atoms, readlevel=True) for i, (n, l) in enumerate(zip(self.atomnames, level)): - if not needlerror or lerror[i] > self.errorlimit: + if lerror is None or (self.errorlimit is not None and lerror[i] > self.errorlimit): # Note that atom id starts from 1 d[pickle.dumps((n, sorted(l)))].append(i + 1) return d, step - def readmolecule(self, lines): + def readmolecule(self, lines) -> Tuple[List[List[int]], Optional[Atoms]]: """Return molecules from lines. Parameters @@ -250,11 +270,11 @@ def _crd2bond(cls, step_atoms, readlevel): mol.CloneData(uc) mol.SetPeriodicMol() mol.ConnectTheDots() - if not readlevel: - bond = [[] for i in range(atomnumber)] - else: + # when readlevel is False, bond is used to store connected atoms + # otherwise, bondlevel is used to store bond orders + bond = [[] for i in range(atomnumber)] + if readlevel: mol.PerceiveBondOrders() - bondlevel = [[] for i in range(atomnumber)] mol.EndModify() for b in openbabel.OBMolBondIter(mol): s1 = b.GetBeginAtom().GetId() @@ -264,23 +284,26 @@ def _crd2bond(cls, step_atoms, readlevel): bond[s2].append(s1) else: level = b.GetBondOrder() - bondlevel[s1].append(level) - bondlevel[s2].append(level) - return bondlevel if readlevel else bond + bond[s1].append(level) + bond[s2].append(level) + return bond - def readcrd(self, item): + def readcrd(self, item) -> tuple[Atoms, List[int]]: """Only this function can read coordinates.""" lines = item # box information ss = [] step_atoms = [] ids = [] + linecontent = None for line in lines: if line: if line.startswith("ITEM:"): linecontent = self.LineType.linecontent(line) else: - if linecontent == self.LineType.ATOMS: + if linecontent is None: + raise RuntimeError("No ITEM: in the dump file") + elif linecontent == self.LineType.ATOMS: s = line.split() ids.append(int(s[self.id_idx])) step_atoms.append( @@ -314,8 +337,8 @@ def readcrd(self, item): [[xhi - xlo, 0.0, 0.0], [xy, yhi - ylo, 0.0], [xz, yz, zhi - zlo]] ) # sort by ID - step_atoms = [x for (y, x) in sorted(zip(ids, step_atoms))] - step_atoms = Atoms(step_atoms, cell=boxsize, pbc=self.pbc) + step_atoms_ = [x for (y, x) in sorted(zip(ids, step_atoms))] + step_atoms = Atoms(step_atoms_, cell=boxsize, pbc=self.pbc) return step_atoms, ids class LineType(Enum): diff --git a/mddatasetbuilder/utils.py b/mddatasetbuilder/utils.py index 42909d2..2b9d3f3 100644 --- a/mddatasetbuilder/utils.py +++ b/mddatasetbuilder/utils.py @@ -3,7 +3,7 @@ import itertools import pickle from multiprocessing import Pool, Semaphore -from typing import BinaryIO, Union +from typing import BinaryIO, List, TypeVar, Union, overload import lz4.frame from tqdm.auto import tqdm @@ -192,7 +192,7 @@ def bytestolist(x): object The decompressed object. """ - return pickle.loads(decompress(x, isbytes=True)) + return pickle.loads(decompress(x, isbytes=True)) # type: ignore def run_mp(nproc, **arg): @@ -231,7 +231,16 @@ def run_mp(nproc, **arg): pool.join() -def must_be_list(obj): + +T = TypeVar('T') +@overload +def must_be_list(obj: List[T]) -> List[T]: + ... +@overload +def must_be_list(obj: T) -> List[T]: + ... + +def must_be_list(obj: Union[T, List[T]]) -> List[T]: """Convert a object to a list if the object is not a list. Parameters