diff --git a/docs/source/conf.py b/docs/source/conf.py index 950b8eb3..44522a11 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,4 +1,4 @@ -# Configuration file for the Sphinx documentation builder. +"""Configuration file for the Sphinx documentation builder.""" # # This file only contains a selection of the most common options. For a full # list see the documentation: @@ -10,10 +10,10 @@ # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # +from __future__ import annotations + import os import sys -import sphinx_rtd_theme, sphinx_autodoc_typehints -from typing import List sys.path.insert(0, os.path.abspath("../../mdgo")) @@ -55,7 +55,7 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns: List[str] = [] +exclude_patterns: list[str] = [] # -- Options for HTML output ------------------------------------------------- diff --git a/mdgo/__init__.py b/mdgo/__init__.py index fa4ec20d..689a18ae 100644 --- a/mdgo/__init__.py +++ b/mdgo/__init__.py @@ -1,10 +1,10 @@ -# coding: utf-8 # Copyright (c) Tingzheng Hou. # Distributed under the terms of the MIT License. -""" -This package contains core modules and classes molecular dynamics simulation setup and analysis. -""" +"""This package contains core modules and classes molecular dynamics simulation setup and analysis.""" + +from __future__ import annotations + __author__ = "Mdgo Development Team" __email__ = "tingzheng_hou@berkeley.edu" __maintainer__ = "Tingzheng Hou" diff --git a/mdgo/conductivity.py b/mdgo/conductivity.py index 7c05999c..0c76bdcc 100644 --- a/mdgo/conductivity.py +++ b/mdgo/conductivity.py @@ -1,19 +1,21 @@ -# coding: utf-8 # Copyright (c) Tingzheng Hou. # Distributed under the terms of the MIT License. -""" -This module implements functions to calculate the ionic conductivity. -""" -from typing import Union +"""This module implements functions to calculate the ionic conductivity.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING import numpy as np -from tqdm.auto import tqdm from scipy import stats -from MDAnalysis import Universe, AtomGroup +from tqdm.auto import tqdm from mdgo.msd import msd_fft +if TYPE_CHECKING: + from MDAnalysis import AtomGroup, Universe + __author__ = "Kara Fong, Tingzheng Hou" __version__ = "0.3.0" __maintainer__ = "Tingzheng Hou" @@ -26,10 +28,10 @@ def calc_cond_msd( anions: AtomGroup, cations: AtomGroup, run_start: int, - cation_charge: Union[int, float] = 1, - anion_charge: Union[int, float] = -1, + cation_charge: float = 1, + anion_charge: float = -1, ) -> np.ndarray: - """Calculates the conductivity "mean square displacement" over time + """Calculates the conductivity "mean square displacement" over time. Note: Coordinates must be unwrapped (in dcd file when creating MDAnalysis Universe) @@ -51,15 +53,14 @@ def calc_cond_msd( anion_list = anions.split("residue") # compute sum over all charges and positions qr = [] - for ts in tqdm(u.trajectory[run_start:]): + for _ts in tqdm(u.trajectory[run_start:]): qr_temp = np.zeros(3) for anion in anion_list: qr_temp += anion.center_of_mass() * anion_charge for cation in cation_list: qr_temp += cation.center_of_mass() * cation_charge qr.append(qr_temp) - msd = msd_fft(np.array(qr)) - return msd + return msd_fft(np.array(qr)) def get_beta( @@ -127,14 +128,14 @@ def choose_msd_fitting_region( def conductivity_calculator( time_array: np.ndarray, cond_array: np.ndarray, - v: Union[int, float], + v: float, name: str, start: int, end: int, - T: Union[int, float], + T: float, units: str = "real", ) -> float: - """Calculates the overall conductivity of the system + """Calculates the overall conductivity of the system. Args: time_array: times at which position data was collected in the simulation diff --git a/mdgo/coordination.py b/mdgo/coordination.py index 9c3004e8..f7c350eb 100644 --- a/mdgo/coordination.py +++ b/mdgo/coordination.py @@ -1,21 +1,24 @@ -# coding: utf-8 # Copyright (c) Tingzheng Hou. # Distributed under the terms of the MIT License. -""" -This module implements functions for coordination analysis. -""" +"""This module implements functions for coordination analysis.""" -from typing import Dict, List, Tuple, Union, Callable, Optional +from __future__ import annotations + +from typing import TYPE_CHECKING import numpy as np -from tqdm.auto import tqdm -from MDAnalysis import Universe, AtomGroup -from MDAnalysis.core.groups import Atom from MDAnalysis.analysis.distances import distance_array from scipy.signal import savgol_filter -from mdgo.util.coord import atom_vec, angle +from tqdm.auto import tqdm + +from mdgo.util.coord import angle, atom_vec +if TYPE_CHECKING: + from collections.abc import Callable + + from MDAnalysis import AtomGroup, Universe + from MDAnalysis.core.groups import Atom __author__ = "Tingzheng Hou" __version__ = "0.3.0" @@ -30,9 +33,9 @@ def neighbor_distance( run_start: int, run_end: int, species: str, - select_dict: Dict[str, str], + select_dict: dict[str, str], distance: float, -) -> Dict[str, np.ndarray]: +) -> dict[str, np.ndarray]: """ Calculates a dictionary of distances between the ``center_atom`` and neighbor atoms. @@ -50,12 +53,11 @@ def neighbor_distance( A dictionary of distance of neighbor atoms to the ``center_atom``. The keys are atom indexes in string type . """ dist_dict = {} - time_count = 0 trj_analysis = nvt_run.trajectory[run_start:run_end:] species_selection = select_dict.get(species) if species_selection is None: raise ValueError("Invalid species selection") - for ts in trj_analysis: + for _ts in trj_analysis: selection = ( "(" + species_selection + ") and (around " + str(distance) + " index " + str(center_atom.index) + ")" ) @@ -63,23 +65,20 @@ def neighbor_distance( for atom in shell.atoms: if str(atom.index) not in dist_dict: dist_dict[str(atom.index)] = np.full(run_end - run_start, 100.0) - time_count += 1 - time_count = 0 - for ts in trj_analysis: + for time_count, ts in enumerate(trj_analysis): for atom_index, val in dist_dict.items(): dist = distance_array(ts[center_atom.index], ts[int(atom_index)], ts.dimensions) val[time_count] = dist - time_count += 1 return dist_dict def find_nearest( - trj: Dict[str, np.ndarray], + trj: dict[str, np.ndarray], time_step: float, binding_cutoff: float, hopping_cutoff: float, smooth: int = 51, -) -> Tuple[List[int], Union[float, np.floating], List[int]]: +) -> tuple[list[int], float | np.floating, list[int]]: """Using the dictionary of neighbor distance ``trj``, finds the nearest neighbor ``sites`` that the center atom binds to, and calculates the ``frequency`` of hopping between each neighbor, and ``steps`` when each binding site exhibits the closest distance to the center atom. @@ -96,12 +95,12 @@ def find_nearest( the ``frequency`` of hopping between sites, and ``steps`` when each binding site exhibits the closest distance to the center atom. """ - time_span = len(list(trj.values())[0]) + time_span = len(next(iter(trj.values()))) if smooth > 0: for kw in list(trj): trj[kw] = savgol_filter(trj.get(kw), smooth, 2) site_distance = [100 for _ in range(time_span)] - sites: List[Union[int, np.integer]] = [0 for _ in range(time_span)] + sites: list[int | np.integer] = [0 for _ in range(time_span)] start_site = min(trj, key=lambda k: trj[k][0]) kw_start = trj.get(start_site) assert kw_start is not None @@ -133,7 +132,7 @@ def find_nearest( sites = [int(i) for i in sites] sites_and_distance_array = np.array([[sites[i], site_distance[i]] for i in range(len(sites))]) steps = [] - closest_step: Optional[int] = 0 + closest_step: int | None = 0 previous_site = sites_and_distance_array[0][0] if previous_site == 0: closest_step = None @@ -162,12 +161,12 @@ def find_nearest( def find_nearest_free_only( - trj: Dict[str, np.ndarray], + trj: dict[str, np.ndarray], time_step: float, binding_cutoff: float, hopping_cutoff: float, smooth: int = 51, -) -> Tuple[List[int], Union[float, np.floating], List[int]]: +) -> tuple[list[int], float | np.floating, list[int]]: """Using the dictionary of neighbor distance ``trj``, finds the nearest neighbor ``sites`` that the ``center_atom`` binds to, and calculates the ``frequency`` of hopping between each neighbor, and ``steps`` when each binding site exhibits the closest distance to the center atom. @@ -185,12 +184,12 @@ def find_nearest_free_only( the ``frequency`` of hopping between sites, and ``steps`` when each binding site exhibits the closest distance to the center atom. """ - time_span = len(list(trj.values())[0]) + time_span = len(next(iter(trj.values()))) if smooth > 0: for kw in list(trj): trj[kw] = savgol_filter(trj.get(kw), smooth, 2) site_distance = [100 for _ in range(time_span)] - sites: List[Union[int, np.integer]] = [0 for _ in range(time_span)] + sites: list[int | np.integer] = [0 for _ in range(time_span)] start_site = min(trj, key=lambda k: trj[k][0]) kw_start = trj.get(start_site) assert kw_start is not None @@ -222,7 +221,7 @@ def find_nearest_free_only( sites = [int(i) for i in sites] sites_and_distance_array = np.array([[sites[i], site_distance[i]] for i in range(len(sites))]) steps = [] - closest_step: Optional[int] = 0 + closest_step: int | None = 0 previous_site = sites_and_distance_array[0][0] previous_zero = False if previous_site == 0: @@ -257,8 +256,8 @@ def find_nearest_free_only( def find_in_n_out( - trj: Dict[str, np.ndarray], binding_cutoff: float, hopping_cutoff: float, smooth: int = 51, cool: int = 20 -) -> Tuple[List[int], List[int]]: + trj: dict[str, np.ndarray], binding_cutoff: float, hopping_cutoff: float, smooth: int = 51, cool: int = 20 +) -> tuple[list[int], list[int]]: """Finds the frames when the center atom binds with the neighbor (binding) or hopping out (hopping) according to the dictionary of neighbor distance. @@ -272,7 +271,7 @@ def find_in_n_out( Returns: Two arrays of numberings of frames with hopping in and hopping out event, respectively. """ - time_span = len(list(trj.values())[0]) + time_span = len(next(iter(trj.values()))) if smooth > 0: for kw in list(trj): trj[kw] = savgol_filter(trj.get(kw), smooth, 2) @@ -309,8 +308,8 @@ def find_in_n_out( sites = [int(i) for i in sites] last = sites[0] - steps_in: List[int] = [] - steps_out: List[int] = [] + steps_in: list[int] = [] + steps_out: list[int] = [] in_cool = cool out_cool = cool for i, s in enumerate(sites): @@ -339,13 +338,13 @@ def find_in_n_out( def check_contiguous_steps( nvt_run: Universe, center_atom: Atom, - distance_dict: Dict[str, float], - select_dict: Dict[str, str], + distance_dict: dict[str, float], + select_dict: dict[str, str], run_start: int, run_end: int, checkpoints: np.ndarray, lag: int = 20, -) -> Dict[str, np.ndarray]: +) -> dict[str, np.ndarray]: """Calculates the distance between the center atom and the neighbor atom in the checkpoint +/- lag time range. @@ -364,12 +363,10 @@ def check_contiguous_steps( An array of distance between the center atom and the neighbor atoms in the checkpoint +/- lag time range. """ - coord_num: Dict[str, Union[List[List[int]], np.ndarray]] = { - x: [[] for _ in range(lag * 2 + 1)] for x in distance_dict - } + coord_num: dict[str, list[list[int]] | np.ndarray] = {x: [[] for _ in range(lag * 2 + 1)] for x in distance_dict} trj_analysis = nvt_run.trajectory[run_start:run_end:] has = False - for i, ts in enumerate(trj_analysis): + for i, _ts in enumerate(trj_analysis): log = False checkpoint = -1 for j in checkpoints: @@ -393,8 +390,8 @@ def check_contiguous_steps( def heat_map( nvt_run: Universe, floating_atom: Atom, - cluster_center_sites: List[int], - cluster_terminal: Union[str, List[str]], + cluster_center_sites: list[int], + cluster_terminal: str | list[str], cartesian_by_ref: np.ndarray, run_start: int, run_end: int, @@ -441,7 +438,7 @@ def heat_map( for species in cluster_terminal ] bind_atoms_xyz = [nvt_run.select_atoms(sel, periodic=True) for sel in selections] - vertex_atoms: List[Atom] = [] + vertex_atoms: list[Atom] = [] for atoms in bind_atoms_xyz: if len(atoms) == 1: vertex_atoms.append(atoms[0]) @@ -451,8 +448,8 @@ def heat_map( vertex_atoms.append(atoms[idx[0]]) else: raise ValueError( - f"There should be at least 1 cluster_terminal atom in the {str(dim[i])} dimension." - f"Try broadening the selection at index {str(i + 1)} of the cluster_terminal " + f"There should be at least 1 cluster_terminal atom in the {dim[i]!s} dimension." + f"Try broadening the selection at index {i + 1!s} of the cluster_terminal " ) else: assert isinstance(cluster_terminal, str) @@ -501,10 +498,10 @@ def heat_map( def process_evol( nvt_run: Universe, - select_dict: Dict[str, str], - in_list: Dict[str, List[np.ndarray]], - out_list: Dict[str, List[np.ndarray]], - distance_dict: Dict[str, float], + select_dict: dict[str, str], + in_list: dict[str, list[np.ndarray]], + out_list: dict[str, list[np.ndarray]], + distance_dict: dict[str, float], run_start: int, run_end: int, lag: int, @@ -572,10 +569,10 @@ def process_evol( def get_full_coords( coords: np.ndarray, - reflection: Optional[List[np.ndarray]] = None, - rotation: Optional[List[np.ndarray]] = None, - inversion: Optional[List[np.ndarray]] = None, - sample: Optional[int] = None, + reflection: list[np.ndarray] | None = None, + rotation: list[np.ndarray] | None = None, + inversion: list[np.ndarray] | None = None, + sample: int | None = None, dim: str = "xyz", ) -> np.ndarray: """ @@ -640,12 +637,12 @@ def get_full_coords( def cluster_coordinates( # TODO: rewrite the method nvt_run: Universe, - select_dict: Dict[str, str], + select_dict: dict[str, str], run_start: int, run_end: int, - species: List[str], + species: list[str], distance: float, - basis_vectors: Optional[Union[List[np.ndarray], np.ndarray]] = None, + basis_vectors: list[np.ndarray] | np.ndarray | None = None, cluster_center: str = "center", ) -> np.ndarray: """Calculates the average position of a cluster. @@ -679,7 +676,7 @@ def cluster_coordinates( # TODO: rewrite the method cluster = [] for atom in shell: coord_list = [] - for ts in trj_analysis: + for _ts in trj_analysis: coord_list.append(atom.position) cluster.append(np.mean(np.array(coord_list), axis=0)) cluster_array = np.array(cluster) @@ -700,23 +697,22 @@ def cluster_coordinates( # TODO: rewrite the method vec3 = vec3 / np.linalg.norm(vec3) basis_xyz = np.transpose([vec1, vec2, vec3]) cluster_norm = np.linalg.solve(basis_xyz, cluster_array.T).T - cluster_norm = cluster_norm - np.mean(cluster_norm, axis=0) - return cluster_norm + return cluster_norm - np.mean(cluster_norm, axis=0) return cluster_array def num_of_neighbor( nvt_run: Universe, center_atom: Atom, - distance_dict: Dict[str, float], - select_dict: Dict[str, str], + distance_dict: dict[str, float], + select_dict: dict[str, str], run_start, run_end, write=False, structure_code=None, write_freq=0, write_path=None, -) -> Dict[str, np.ndarray]: +) -> dict[str, np.ndarray]: """Calculates the coordination number of each specified neighbor species and the total coordination number in the specified frame range. @@ -738,14 +734,13 @@ def num_of_neighbor( A diction containing the coordination number sequence of each specified neighbor species and the total coordination number sequence in the specified frame range . """ - time_count = 0 trj_analysis = nvt_run.trajectory[run_start:run_end:] cn_values = {} species = list(distance_dict.keys()) for kw in species: cn_values[kw] = np.zeros(int(len(trj_analysis))) cn_values["total"] = np.zeros(int(len(trj_analysis))) - for ts in trj_analysis: + for time_count, ts in enumerate(trj_analysis): digit_of_species = len(species) - 1 for kw in species: selection = select_shell(select_dict, distance_dict, center_atom, kw) @@ -771,18 +766,17 @@ def num_of_neighbor( center_name = center_atom.name path = write_path + str(center_atom.id) + "_" + str(int(ts.time)) + "_" + str(structure_code) + ".xyz" write_out(center_pos, center_name, structure, path) - time_count += 1 return cn_values def num_of_neighbor_simple( nvt_run: Universe, center_atom: Atom, - distance_dict: Dict[str, float], - select_dict: Dict[str, str], + distance_dict: dict[str, float], + select_dict: dict[str, str], run_start: int, run_end: int, -) -> Dict[str, np.ndarray]: +) -> dict[str, np.ndarray]: """Calculates solvation structure type (1 for SSIP, 2 for CIP and 3 for AGG) with respect to the ``enter_atom`` in the specified frame range. @@ -799,14 +793,12 @@ def num_of_neighbor_simple( A dict with "total" as the key and an array of the solvation structure type in the specified frame range as the value. """ - - time_count = 0 trj_analysis = nvt_run.trajectory[run_start:run_end:] center_selection = "same type as index " + str(center_atom.index) assert len(distance_dict) == 1, "Please only specify the counter-ion species in the distance_dict" - species = list(distance_dict.keys())[0] + species = next(iter(distance_dict.keys())) cn_values = np.zeros(int(len(trj_analysis))) - for ts in trj_analysis: + for time_count, _ts in enumerate(trj_analysis): selection = select_shell(select_dict, distance_dict, center_atom, species) shell = nvt_run.select_atoms(selection, periodic=True) shell_len = len(shell) @@ -822,20 +814,18 @@ def num_of_neighbor_simple( cn_values[time_count] = 3 else: cn_values[time_count] = 3 - time_count += 1 - cn_values = {"total": cn_values} - return cn_values + return {"total": cn_values} def angular_dist_of_neighbor( nvt_run: Universe, center_atom: Atom, - distance_dict: Dict[str, float], - select_dict: Dict[str, str], + distance_dict: dict[str, float], + select_dict: dict[str, str], run_start: int, run_end: int, cip: bool = True, -) -> Dict[str, np.ndarray]: +) -> dict[str, np.ndarray]: """ Calculates the angle of a-c-b of center atom c in the specified frames. @@ -860,7 +850,7 @@ def angular_dist_of_neighbor( neighbor_a, neighbor_b, center_c = tuple(names) acb_angle = [] trj_analysis = nvt_run.trajectory[run_start:run_end:] - for ts in trj_analysis: + for _ts in trj_analysis: a_selection = select_shell(select_dict, distance_dict, center_atom, neighbor_a) a_group = nvt_run.select_atoms(a_selection, periodic=True) a_num = len(a_group) @@ -870,10 +860,7 @@ def angular_dist_of_neighbor( c_selection = select_shell(select_dict, distance_dict, a_group.atoms[0], center_c) c_atoms = nvt_run.select_atoms(c_selection, periodic=True) shell_species_len = len(c_atoms) - 1 - if shell_species_len == 0: - shell_type = "cip" - else: - shell_type = "agg" + shell_type = "cip" if shell_species_len == 0 else "agg" else: shell_type = "agg" if shell_type == "agg" and cip: @@ -893,12 +880,12 @@ def angular_dist_of_neighbor( def num_of_neighbor_specific( nvt_run: Universe, center_atom: Atom, - distance_dict: Dict[str, float], - select_dict: Dict[str, str], + distance_dict: dict[str, float], + select_dict: dict[str, str], run_start: int, run_end: int, counter_atom: str = "anion", -) -> Dict[str, np.ndarray]: +) -> dict[str, np.ndarray]: """ Calculates the coordination number of each specific solvation structure type (SSIP, CIP, AGG). @@ -916,7 +903,6 @@ def num_of_neighbor_specific( A tuple containing three dictionary of the coordination number of each neighbor species and total coordination number for the three solvation structure type, respectively. """ - time_count = 0 trj_analysis = nvt_run.trajectory[run_start:run_end:] cip_step = [] ssip_step = [] @@ -925,7 +911,7 @@ def num_of_neighbor_specific( for kw in distance_dict: cn_values[kw] = np.zeros(int(len(trj_analysis))) cn_values["total"] = np.zeros(int(len(trj_analysis))) - for ts in trj_analysis: + for time_count, _ts in enumerate(trj_analysis): for kw in distance_dict: kw_selection = select_shell(select_dict, distance_dict, center_atom, kw) kw_shell = nvt_run.select_atoms(kw_selection, periodic=True) @@ -948,7 +934,6 @@ def num_of_neighbor_specific( agg_step.append(time_count) else: agg_step.append(time_count) - time_count += 1 cn_dict = {} for kw in distance_dict: cn_dict["ssip_" + kw] = cn_values[kw][ssip_step] @@ -962,7 +947,7 @@ def full_solvation_structure( # TODO: rewrite the method center_atom: Atom, center_species: str, counter_species: str, - select_dict: Dict[str, str], + select_dict: dict[str, str], distance: float, run_start: int, run_end: int, @@ -989,7 +974,8 @@ def full_solvation_structure( # TODO: rewrite the method """ center_selection = select_dict.get(center_species) counter_selection = select_dict.get(counter_species) - assert (center_selection is not None) and (counter_selection is not None) + assert center_selection is not None + assert counter_selection is not None def select_counter_ion(selection, dist, atom): return "(" + selection + " and around " + str(dist) + " same fragment as index " + str(atom.index) + ")" @@ -1021,9 +1007,9 @@ def counter_shell(this_shell, this_layer, frame): time_count = 0 trj_analysis = nvt_run.trajectory[run_start:run_end:] cn_values = np.zeros((int(len(trj_analysis)), depth)) - for ts in trj_analysis: - center_ion_list: List[np.int_] = [center_atom.id] - counter_ion_list: List[np.int_] = [] + for _ts in trj_analysis: + center_ion_list: list[np.int_] = [center_atom.id] + counter_ion_list: list[np.int_] = [] first_shell = nvt_run.select_atoms( select_counter_ion(counter_selection, distance, center_atom), periodic=True, @@ -1036,12 +1022,12 @@ def concat_coord_array( nvt_run: Universe, func: Callable, center_atoms: AtomGroup, - distance_dict: Dict[str, float], - select_dict: Dict[str, str], + distance_dict: dict[str, float], + select_dict: dict[str, str], run_start: int, run_end: int, - **kwargs: Union[bool, str], -) -> Dict[str, np.ndarray]: + **kwargs: bool | str, +) -> dict[str, np.ndarray]: """ A helper function to analyze the coordination number/structure of every atoms in an ``AtomGroup`` using the specified function. @@ -1055,6 +1041,7 @@ def concat_coord_array( and the corresponding values are the selection language. run_start: Start frame of analysis. run_end: End frame of analysis. + kwargs: Keyword arguments in the func. Returns: A diction containing the coordination number sequence of each specified neighbor species @@ -1103,9 +1090,7 @@ def write_out(center_pos: np.ndarray, center_name: str, neighbors: AtomGroup, pa xyz_file.write("\n".join(lines)) -def select_shell( - select: Union[Dict[str, str], str], distance: Union[Dict[str, float], str], center_atom: Atom, kw: str -) -> str: +def select_shell(select: dict[str, str] | str, distance: dict[str, float] | str, center_atom: Atom, kw: str) -> str: """ Select a group of atoms that is within a distance of an ``center_atom``. @@ -1132,5 +1117,4 @@ def select_shell( distance_str = str(distance_value) else: distance_str = distance - selection = "(" + species_selection + ") and (around " + distance_str + " index " + str(center_atom.index) + ")" - return selection + return "(" + species_selection + ") and (around " + distance_str + " index " + str(center_atom.index) + ")" diff --git a/mdgo/core/__init__.py b/mdgo/core/__init__.py index 1f80549c..a817d7b1 100644 --- a/mdgo/core/__init__.py +++ b/mdgo/core/__init__.py @@ -1,4 +1,3 @@ -# coding: utf-8 # Copyright (c) Tingzheng Hou. # Distributed under the terms of the MIT License. @@ -7,6 +6,8 @@ setup, run and analysis. """ +from __future__ import annotations + __author__ = "Tingzheng Hou" __version__ = "0.3.1" __maintainer__ = "Tingzheng Hou" diff --git a/mdgo/core/analysis.py b/mdgo/core/analysis.py index ea97b086..ed24532d 100644 --- a/mdgo/core/analysis.py +++ b/mdgo/core/analysis.py @@ -1,4 +1,3 @@ -# coding: utf-8 # Copyright (c) Tingzheng Hou. # Distributed under the terms of the MIT License. @@ -8,39 +7,40 @@ """ from __future__ import annotations -from typing import Union, Dict, Tuple, List, Optional + +import matplotlib.pyplot as plt import MDAnalysis import numpy as np import pandas as pd -import matplotlib.pyplot as plt from MDAnalysis import Universe from MDAnalysis.analysis.distances import distance_array from MDAnalysis.lib.distances import capped_distance from tqdm.auto import tqdm -from mdgo.util.dict_utils import ( - mass_to_name, - assign_name, - assign_resname, - res_dict_from_select_dict, - res_dict_from_datafile, - select_dict_from_resname, -) -from mdgo.conductivity import calc_cond_msd, conductivity_calculator, choose_msd_fitting_region, get_beta + +from mdgo.conductivity import calc_cond_msd, choose_msd_fitting_region, conductivity_calculator, get_beta from mdgo.coordination import ( + angular_dist_of_neighbor, concat_coord_array, + find_nearest, + find_nearest_free_only, + get_full_coords, + heat_map, + neighbor_distance, num_of_neighbor, num_of_neighbor_simple, num_of_neighbor_specific, - angular_dist_of_neighbor, - neighbor_distance, - find_nearest, - find_nearest_free_only, process_evol, - heat_map, - get_full_coords, ) -from mdgo.msd import total_msd, partial_msd, DIM +from mdgo.msd import DIM, partial_msd, total_msd from mdgo.residence_time import calc_neigh_corr, fit_residence_time +from mdgo.util.dict_utils import ( + assign_name, + assign_resname, + mass_to_name, + res_dict_from_datafile, + res_dict_from_select_dict, + select_dict_from_resname, +) class MdRun: @@ -75,8 +75,8 @@ def __init__( nvt_start: int, time_step: float, name: str, - select_dict: Optional[Dict[str, str]] = None, - res_dict: Optional[Dict[str, str]] = None, + select_dict: dict[str, str] | None = None, + res_dict: dict[str, str] | None = None, cation_name: str = "cation", anion_name: str = "anion", cation_charge: float = 1, @@ -90,7 +90,6 @@ def __init__( parsed data (``Universe``) or other bridging objects (``CombinedData``). Not recommended to use directly. """ - self.wrapped_run = wrapped_run self.unwrapped_run = unwrapped_run self.nvt_start = nvt_start @@ -149,8 +148,8 @@ def from_lammps( nvt_start: int, time_step: float, name: str, - select_dict: Optional[Dict[str, str]] = None, - res_dict: Optional[Dict[str, str]] = None, + select_dict: dict[str, str] | None = None, + res_dict: dict[str, str] | None = None, cation_name: str = "cation", anion_name: str = "anion", cation_charge: float = 1, @@ -202,9 +201,7 @@ def from_lammps( ) def get_init_dimension(self) -> np.ndarray: - """ - Returns the initial box dimension. - """ + """Returns the initial box dimension.""" return self.wrapped_run.trajectory[0].dimensions def get_equilibrium_dimension(self, npt_range: int, period: int = 200) -> np.ndarray: @@ -235,9 +232,7 @@ def get_equilibrium_dimension(self, npt_range: int, period: int = 200) -> np.nda return np.mean(np.array(d), axis=0) def get_nvt_dimension(self) -> np.ndarray: - """ - Returns the box dimension at the last frame. - """ + """Returns the box dimension at the last frame.""" return self.wrapped_run.trajectory[-1].dimensions def get_cond_array(self) -> np.ndarray: @@ -250,7 +245,7 @@ def get_cond_array(self) -> np.ndarray: nvt_run = self.unwrapped_run cations = nvt_run.select_atoms(self.select_dict.get("cation")) anions = nvt_run.select_atoms(self.select_dict.get("anion")) - cond_array = calc_cond_msd( + return calc_cond_msd( nvt_run, anions, cations, @@ -258,7 +253,6 @@ def get_cond_array(self) -> np.ndarray: self.cation_charge, self.anion_charge, ) - return cond_array def choose_cond_fit_region(self) -> tuple: """ @@ -359,10 +353,9 @@ def get_conductivity(self, start: int = -1, end: int = -1) -> float: print(f"Start of linear fitting regime: {start} ({self.time_array[start]} {time_units})") print(f"End of linear fitting regime: {end} ({self.time_array[end]} {time_units})") print(f"Beta value (fit to MSD = t^\u03B2): {beta} (\u03B2 = 1 in the diffusive regime)") - cond = conductivity_calculator( + return conductivity_calculator( self.time_array, self.cond_array, self.nvt_v, self.name, start, end, self.temp, self.units ) - return cond def coord_num_array_single_species( self, @@ -387,7 +380,7 @@ def coord_num_array_single_species( nvt_run = self.wrapped_run distance_dict = {species: distance} center_atoms = nvt_run.select_atoms(self.select_dict.get(center_atom)) - num_array = concat_coord_array( + return concat_coord_array( nvt_run, num_of_neighbor, center_atoms, @@ -396,15 +389,14 @@ def coord_num_array_single_species( run_start, run_end, )["total"] - return num_array def coord_num_array_multi_species( self, - distance_dict: Dict[str, float], + distance_dict: dict[str, float], run_start: int, run_end: int, center_atom: str = "cation", - ) -> Dict[str, np.ndarray]: + ) -> dict[str, np.ndarray]: """Calculates the coordination number array of multiple species around the interested ``center_atom``. Args: @@ -419,7 +411,7 @@ def coord_num_array_multi_species( """ nvt_run = self.wrapped_run center_atoms = nvt_run.select_atoms(self.select_dict.get(center_atom)) - num_array_dict = concat_coord_array( + return concat_coord_array( nvt_run, num_of_neighbor, center_atoms, @@ -428,16 +420,15 @@ def coord_num_array_multi_species( run_start, run_end, ) - return num_array_dict def coord_num_array_specific( self, - distance_dict: Dict[str, float], + distance_dict: dict[str, float], run_start: int, run_end: int, center_atom: str = "cation", counter_atom: str = "anion", - ) -> Dict[str, np.ndarray]: + ) -> dict[str, np.ndarray]: """Calculates the coordination number array of multiple species of specific coordination types (SSIP, CIP, AGG). @@ -454,7 +445,7 @@ def coord_num_array_specific( """ nvt_run = self.wrapped_run center_atoms = nvt_run.select_atoms(self.select_dict.get(center_atom)) - num_array_dict = concat_coord_array( + return concat_coord_array( nvt_run, num_of_neighbor_specific, center_atoms, @@ -464,11 +455,10 @@ def coord_num_array_specific( run_end, counter_atom=counter_atom, ) - return num_array_dict def write_solvation_structure( self, - distance_dict: Dict[str, float], + distance_dict: dict[str, float], run_start: int, run_end: int, structure_code: int, @@ -476,7 +466,7 @@ def write_solvation_structure( write_path: str, center_atom: str = "cation", ): - """Writes out a series of desired solvation structures as ``*.xyz`` files + """Writes out a series of desired solvation structures as ``*.xyz`` files. Args: distance_dict: A dict of coordination cutoff distance of the neighbor species. @@ -529,7 +519,7 @@ def coord_type_array( nvt_run = self.wrapped_run distance_dict = {counter_atom: distance} center_atoms = nvt_run.select_atoms(self.select_dict.get(center_atom)) - num_array = concat_coord_array( + return concat_coord_array( nvt_run, num_of_neighbor_simple, center_atoms, @@ -538,11 +528,10 @@ def coord_type_array( run_start, run_end, )["total"] - return num_array def angle_array( self, - distance_dict: Dict[str, float], + distance_dict: dict[str, float], run_start: int, run_end: int, center_atom: str = "cation", @@ -568,8 +557,8 @@ def angle_array( nvt_run = self.wrapped_run center_atoms = nvt_run.select_atoms(self.select_dict.get(center_atom)) assert len(distance_dict) == 2, "Only distance a->c, b->c shoud be specified in the distance_dict." - distance_dict[center_atom] = list(distance_dict.values())[0] - ang_array = concat_coord_array( + distance_dict[center_atom] = next(iter(distance_dict.values())) + return concat_coord_array( nvt_run, angular_dist_of_neighbor, center_atoms, @@ -579,7 +568,6 @@ def angle_array( run_end, cip=cip, )["total"] - return ang_array def coordination( self, @@ -613,12 +601,11 @@ def coordination( item_list.append(str(int(combined[i, 0]))) percent_list.append(f"{(combined[i, 1] / combined[:, 1].sum() * 100):.4f}%") df_dict = {item_name: item_list, "Percentage": percent_list} - df = pd.DataFrame(df_dict) - return df + return pd.DataFrame(df_dict) def rdf_integral( self, - distance_dict: Dict[str, float], + distance_dict: dict[str, float], run_start: int, run_end: int, center_atom: str = "cation", @@ -645,8 +632,7 @@ def rdf_integral( item_list.append(kw) cn_list.append(cn) df_dict = {item_name: item_list, "CN": cn_list} - df = pd.DataFrame(df_dict) - return df + return pd.DataFrame(df_dict) def coordination_type( self, @@ -656,7 +642,7 @@ def coordination_type( center_atom: str = "cation", counter_atom: str = "anion", ) -> pd.DataFrame: - """Tabulates the percentage of each solvation structures (CIP/SSIP/AGG) + """Tabulates the percentage of each solvation structures (CIP/SSIP/AGG). Args: distance: The coordination cutoff distance. @@ -684,19 +670,18 @@ def coordination_type( item_list.append(item_dict.get(item)) percent_list.append(f"{(combined[i, 1] / combined[:, 1].sum() * 100):.4f}%") df_dict = {item_name: item_list, "Percentage": percent_list} - df = pd.DataFrame(df_dict) - return df + return pd.DataFrame(df_dict) def coordination_specific( self, - distance_dict: Dict[str, float], + distance_dict: dict[str, float], run_start: int, run_end: int, center_atom: str = "cation", counter_atom: str = "anion", ) -> pd.DataFrame: """Calculates the integral of the coordiantion number of selected species - in each type of solvation structures (CIP/SSIP/AGG) + in each type of solvation structures (CIP/SSIP/AGG). Args: distance_dict: A dict of coordination cutoff distance of the neighbor species. @@ -728,8 +713,7 @@ def coordination_specific( else: agg_list.append(cn) df_dict = {item_name: item_list, "CN in SSIP": ssip_list, "CN in CIP": cip_list, "CN in AGG": agg_list} - df = pd.DataFrame(df_dict) - return df + return pd.DataFrame(df_dict) def get_msd_all( self, @@ -746,7 +730,7 @@ def get_msd_all( Args: start: Start time step. end: End time step. - msd_type: Desired dimensions to be included in the MSD. Defaults to ‘xyz’. + msd_type: Desired dimensions to be included in the MSD. Defaults to "xyz". fft: Whether to use FFT to accelerate the calculation. Default to True. built_in: Whether to use built in method to calculate msd or use function from mds. Default to True. @@ -759,7 +743,7 @@ def get_msd_all( """ selection = self.select_dict.get(species) assert selection is not None - msd_array = total_msd( + return total_msd( self.unwrapped_run, start=start, end=end, @@ -769,7 +753,6 @@ def get_msd_all( built_in=built_in, center_of_mass=center_of_mass, ) - return msd_array def get_msd_partial( self, @@ -779,7 +762,7 @@ def get_msd_partial( largest: int = 1000, center_atom: str = "cation", binding_site: str = "anion", - ) -> Tuple[Optional[List[np.ndarray]], Optional[List[np.ndarray]]]: + ) -> tuple[list[np.ndarray] | None, list[np.ndarray] | None]: """ Calculates the mean square displacement (MSD) of the ``center_atom`` according to coordination states. The returned ``free_array`` include the MSD when ``center_atom`` is not coordinated to ``binding_site``. @@ -832,11 +815,11 @@ def get_d(self, msd_array: np.ndarray, start: int, end: int, percentage: float = def get_neighbor_corr( self, - distance_dict: Dict[str, float], + distance_dict: dict[str, float], run_start: int, run_end: int, center_atom: str = "cation", - ) -> Tuple[np.ndarray, Dict[str, np.ndarray]]: + ) -> tuple[np.ndarray, dict[str, np.ndarray]]: """Calculates the neighbor auto-correlation function (ACF) of selected species around center_atom. @@ -860,9 +843,9 @@ def get_neighbor_corr( ) def get_residence_time( - self, times: np.ndarray, acf_avg_dict: Dict[str, np.ndarray], cutoff_time: int - ) -> Dict[str, np.floating]: - """Calculates the residence time of selected species around cation + self, times: np.ndarray, acf_avg_dict: dict[str, np.ndarray], cutoff_time: int + ) -> dict[str, np.floating]: + """Calculates the residence time of selected species around cation. Args: times: The time series. @@ -882,8 +865,8 @@ def get_neighbor_trj( neighbor_cutoff: float, center_atom: str = "cation", index: int = 0, - ) -> Dict[str, np.ndarray]: - """Returns the distance between one center atom and neighbors as a function of time + ) -> dict[str, np.ndarray]: + """Returns the distance between one center atom and neighbors as a function of time. Args: run_start: Start frame of analysis. @@ -917,7 +900,7 @@ def get_hopping_freq_dist( floating_atom: str = "cation", smooth: int = 51, mode: str = "full", - ) -> Tuple[np.floating, np.floating]: + ) -> tuple[np.floating, np.floating]: """Calculates the cation hopping rate and hopping distance. Args: @@ -967,7 +950,7 @@ def get_hopping_freq_dist( def shell_evolution( self, - distance_dict: Dict[str, float], + distance_dict: dict[str, float], run_start: int, run_end: int, lag_step: int, @@ -977,8 +960,8 @@ def shell_evolution( cool: int = 0, binding_site: str = "anion", center_atom: str = "cation", - duplicate_run: Optional[List[MdRun]] = None, - ) -> Dict[str, Dict[str, Union[int, np.ndarray]]]: + duplicate_run: list[MdRun] | None = None, + ) -> dict[str, dict[str, int | np.ndarray]]: """Calculates the coordination number evolution of species around ``center_atom`` as a function of time, the coordination numbers are averaged over all time steps around events when the center_atom hopping to and hopping out from the ``binding_site``. If ``duplicate_run`` is given, it is also averaged over @@ -1001,8 +984,8 @@ def shell_evolution( A dictionary containing the number of trj logged, the averaged coordination number and standard deviation for each species, and the corresponding time sequence. """ - in_list: Dict[str, List[np.ndarray]] = {} - out_list: Dict[str, List[np.ndarray]] = {} + in_list: dict[str, list[np.ndarray]] = {} + out_list: dict[str, list[np.ndarray]] = {} for k in list(distance_dict): in_list[k] = [] out_list[k] = [] @@ -1059,17 +1042,17 @@ def get_heat_map( run_start: int, run_end: int, cluster_center: str, - cluster_terminal: Union[str, List[str]], + cluster_terminal: str | list[str], binding_cutoff: float, hopping_cutoff: float, floating_atom: str = "cation", cartesian_by_ref: np.ndarray = None, - sym_dict: Dict[str, List[np.ndarray]] = None, - sample: Optional[int] = None, + sym_dict: dict[str, list[np.ndarray]] | None = None, + sample: int | None = None, smooth: int = 51, dim: str = "xyz", ) -> np.ndarray: - """Calculates the heatmap matrix of floating ion around a cluster + """Calculates the heatmap matrix of floating ion around a cluster. Args: run_start: Start frame of analysis. @@ -1095,7 +1078,7 @@ def get_heat_map( nvt_run = self.wrapped_run floating_atoms = nvt_run.select_atoms(self.select_dict.get(floating_atom)) if isinstance(cluster_terminal, str): - terminal_atom_type: Union[str, List[str]] = self.select_dict.get(cluster_terminal, "Not defined") + terminal_atom_type: str | list[str] = self.select_dict.get(cluster_terminal, "Not defined") assert terminal_atom_type != "Not defined", f"{cluster_terminal} not defined in select_dict" else: terminal_atom_type = [] @@ -1124,7 +1107,7 @@ def get_heat_map( run_start, run_end, ) - if not coords.size == 0: + if coords.size != 0: coord_list = np.concatenate((coord_list, coords), axis=0) coord_list = coord_list[1:] if sym_dict: @@ -1134,7 +1117,7 @@ def get_heat_map( def get_cluster_distance( self, run_start: int, run_end: int, neighbor_cutoff: float, cluster_center: str = "center" ) -> np.floating: - """Calculates the average distance of the center of clusters/molecules + """Calculates the average distance of the center of clusters/molecules. Args: run_start: Start frame of analysis. diff --git a/mdgo/core/run.py b/mdgo/core/run.py index ffe60cba..ed1abec3 100644 --- a/mdgo/core/run.py +++ b/mdgo/core/run.py @@ -1,21 +1,15 @@ -# coding: utf-8 # Copyright (c) Tingzheng Hou. # Distributed under the terms of the MIT License. -""" -This module implements a core class MdRun for molecular dynamics job setup. -""" +"""This module implements a core class MdRun for molecular dynamics job setup.""" +from __future__ import annotations class MdJob: - """ - A core class for MD results analysis. - """ + """A core class for MD results analysis.""" def __init__(self, name): - """ - Base constructor - """ + """Base constructor.""" self.name = name @classmethod @@ -24,6 +18,7 @@ def from_dict(cls): Constructor. Returns: + name: The name of the class """ return cls("name") @@ -34,6 +29,6 @@ def from_recipe(cls): Constructor. Returns: - + name: The name of the class """ return cls("name") diff --git a/mdgo/forcefield/__init__.py b/mdgo/forcefield/__init__.py index 69f445c9..64225456 100644 --- a/mdgo/forcefield/__init__.py +++ b/mdgo/forcefield/__init__.py @@ -1,4 +1,3 @@ -# coding: utf-8 # Copyright (c) Tingzheng Hou. # Distributed under the terms of the MIT License. @@ -7,6 +6,8 @@ modifying MD force filed data. """ +from __future__ import annotations + __author__ = "Tingzheng Hou, Ryan Kingsbury" __version__ = "0.3.1" __maintainer__ = "Tingzheng Hou, Ryan Kingsbury" @@ -14,7 +15,7 @@ __date__ = "Dec 19, 2023" -from .aqueous import IonLJData, Aqueous +from .aqueous import Aqueous, IonLJData from .charge import ChargeWriter from .crawler import FFcrawler from .maestro import MaestroRunner diff --git a/mdgo/forcefield/aqueous.py b/mdgo/forcefield/aqueous.py index 23a1ce5d..27e766ee 100644 --- a/mdgo/forcefield/aqueous.py +++ b/mdgo/forcefield/aqueous.py @@ -1,15 +1,14 @@ -# coding: utf-8 # Copyright (c) Tingzheng Hou. # Distributed under the terms of the MIT License. -""" -A class for retrieving water and ion force field parameters. -""" +"""A class for retrieving water and ion force field parameters.""" + +from __future__ import annotations import os import re from dataclasses import dataclass -from typing import Optional, Union, Final, Literal +from typing import Final, Literal from monty.json import MSONable from monty.serialization import loadfn @@ -17,7 +16,6 @@ from pymatgen.core.ion import Ion from pymatgen.io.lammps.data import ForceField, LammpsData, Topology, lattice_2_lmpbox - MODULE_DIR: Final[str] = os.path.dirname(os.path.abspath(__file__)) DATA_DIR: Final[str] = os.path.join(MODULE_DIR, "data") DATA_MODELS: Final[dict] = { @@ -125,6 +123,7 @@ def get_water(model: str = "spce") -> LammpsData: model: Water model to use. Valid choices are "spc", "spce", "opc3", "tip3pew", "tip3pfb", "tip4p2005", "tip4pew", "tip4pfb", and "opc". (Default: "spce") + Returns: LammpsData: Force field parameters for the chosen water model. If you specify an invalid water model, None is returned. @@ -136,10 +135,10 @@ def get_water(model: str = "spce") -> LammpsData: @staticmethod def get_ion( - ion: Union[Ion, str], + ion: Ion | str, parameter_set: str = "auto", water_model: str = "auto", - mixing_rule: Optional[str] = None, + mixing_rule: str | None = None, ) -> LammpsData: """ Retrieve force field parameters for an ion in water. @@ -178,7 +177,7 @@ def get_ion( Sachini et al., Systematic Comparison of the Structural and Dynamic Properties of Commonly Used Water Models for Molecular Dynamics Simulations. J. Chem. Inf. Model. - 2021, 61, 9, 4521–4536. https://doi.org/10.1021/acs.jcim.1c00794 + 2021, 61, 9, 4521-4536. https://doi.org/10.1021/acs.jcim.1c00794 mixing_rule: The mixing rule to use for the ion parameter. Default to None, which does not change the original mixing rule of the parameter set. Available choices are 'LB' @@ -226,10 +225,7 @@ def get_ion( parameter_set = alias.get(parameter_set, parameter_set) # Make the Ion object to get mass and charge - if isinstance(ion, Ion): - ion_obj = ion - else: - ion_obj = Ion.from_formula(ion.capitalize()) + ion_obj = ion if isinstance(ion, Ion) else Ion.from_formula(ion.capitalize()) # load ion data as a list of IonLJData objects ion_data = loadfn(os.path.join(DATA_DIR, "ion_lj_params.json")) diff --git a/mdgo/forcefield/charge.py b/mdgo/forcefield/charge.py index 3ab0d1b8..ae3b30d9 100644 --- a/mdgo/forcefield/charge.py +++ b/mdgo/forcefield/charge.py @@ -1,10 +1,9 @@ -# coding: utf-8 # Copyright (c) Tingzheng Hou. # Distributed under the terms of the MIT License. -""" -A class for writing, overwriting, scaling charges of a LammpsData object. -""" +"""A class for writing, overwriting, scaling charges of a LammpsData object.""" + +from __future__ import annotations import numpy as np from pymatgen.io.lammps.data import LammpsData @@ -15,7 +14,7 @@ class ChargeWriter: A class for write, overwrite, scale charges of a LammpsData object. TODO: Auto determine number of significant figures of charges TODO: write to obj or write separate charge file - TODO: Read LammpsData or path + TODO: Read LammpsData or path. Args: data: The provided LammpsData obj. @@ -29,7 +28,7 @@ def __init__(self, data: LammpsData, precision: int = 10): def scale(self, factor: float) -> LammpsData: """ - Scales the charge in of the in self.data and returns a new one. TODO: check if non-destructive + Scales the charge in of the in self.data and returns a new one. TODO: check if non-destructive. Args: factor: The charge scaling factor diff --git a/mdgo/forcefield/crawler.py b/mdgo/forcefield/crawler.py index 6047a768..a50b66ca 100644 --- a/mdgo/forcefield/crawler.py +++ b/mdgo/forcefield/crawler.py @@ -1,4 +1,3 @@ -# coding: utf-8 # Copyright (c) Tingzheng Hou. # Distributed under the terms of the MIT License. @@ -13,18 +12,15 @@ matches your Chrome version via https://chromedriver.chromium.org/downloads """ +from __future__ import annotations + import os import shutil import time -from typing import Optional - from pymatgen.io.lammps.data import LammpsData from selenium import webdriver -from selenium.common.exceptions import ( - TimeoutException, - WebDriverException, -) +from selenium.common.exceptions import TimeoutException, WebDriverException from selenium.webdriver.common.by import By from selenium.webdriver.support import expected_conditions as EC from selenium.webdriver.support.ui import WebDriverWait @@ -50,7 +46,6 @@ class FFcrawler: Default to False. Examples: - >>> lpg = FFcrawler('/path/to/work/dir', '/path/to/chromedriver') >>> lpg.data_from_pdb("/path/to/pdb") """ @@ -58,7 +53,7 @@ class FFcrawler: def __init__( self, write_dir: str, - chromedriver_dir: Optional[str] = None, + chromedriver_dir: str | None = None, headless: bool = True, xyz: bool = False, gromacs: bool = False, @@ -96,10 +91,7 @@ def __init__( print("LigParGen server connected.") def quit(self): - """ - Method for quiting ChromeDriver. - - """ + """Method for quiting ChromeDriver.""" self.web.quit() def data_from_pdb(self, pdb_dir: str): diff --git a/mdgo/forcefield/maestro.py b/mdgo/forcefield/maestro.py index 29702cb8..4ea465f2 100644 --- a/mdgo/forcefield/maestro.py +++ b/mdgo/forcefield/maestro.py @@ -1,4 +1,3 @@ -# coding: utf-8 # Copyright (c) Tingzheng Hou. # Distributed under the terms of the MIT License. @@ -17,12 +16,14 @@ """ +from __future__ import annotations + import os import signal import subprocess import time from string import Template -from typing import Optional, Final +from typing import Final from mdgo.util.reformat import ff_parser @@ -66,7 +67,6 @@ class MaestroRunner: $SCHRODINGER/mmshare-vversion/data/f14/ Examples: - >>> mr = MaestroRunner('/path/to/structure', '/path/to/working/dir') >>> mr.get_mae() >>> mr.get_ff() @@ -81,7 +81,7 @@ def __init__( structure_dir: str, working_dir: str, out: str = "lmp", - cmd_template: Optional[str] = None, + cmd_template: str | None = None, assign_bond: bool = False, ): """Base constructor.""" @@ -99,11 +99,11 @@ def __init__( self.cmd_template = cmd_template else: if assign_bond: - with open(self.template_assignbond, "r") as f: + with open(self.template_assignbond) as f: cmd_template = f.read() self.cmd_template = cmd_template else: - with open(self.template_noassignbond, "r") as f: + with open(self.template_noassignbond) as f: cmd_template = f.read() self.cmd_template = cmd_template @@ -124,7 +124,7 @@ def get_mae(self, wait: float = 30): shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - preexec_fn=os.setsid, + start_new_session=True, ) except subprocess.CalledProcessError as e: raise ValueError(f"Maestro failed with errorcode {e.returncode} and stderr: {e.stderr}") from e @@ -146,8 +146,7 @@ def get_ff(self): FFLD.format(self.mae + ".mae", self.ff), check=True, shell=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, + capture_output=True, ) except subprocess.CalledProcessError as e: raise ValueError(f"Maestro failed with errorcode {e.returncode} and stderr: {e.stderr}") from e diff --git a/mdgo/forcefield/pubchem.py b/mdgo/forcefield/pubchem.py index e488aa53..56cab4a6 100644 --- a/mdgo/forcefield/pubchem.py +++ b/mdgo/forcefield/pubchem.py @@ -1,4 +1,3 @@ -# coding: utf-8 # Copyright (c) Tingzheng Hou. # Distributed under the terms of the MIT License. @@ -7,24 +6,22 @@ can be used to retrieve compound structure and information. """ +from __future__ import annotations + import os import time -from typing import Optional, Final +from typing import Final from urllib.parse import quote import pubchempy as pcp from selenium import webdriver -from selenium.common.exceptions import ( - NoSuchElementException, - TimeoutException, -) +from selenium.common.exceptions import NoSuchElementException, TimeoutException from selenium.webdriver.common.by import By from selenium.webdriver.support import expected_conditions as EC from selenium.webdriver.support.ui import WebDriverWait from mdgo.util.reformat import sdf_to_pdb - MAESTRO: Final[str] = "$SCHRODINGER/maestro -console -nosplash" FFLD: Final[str] = "$SCHRODINGER/utilities/ffld_server -imae {} -version 14 -print_parameters -out_file {}" MolecularWeight: Final[str] = "https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/cid/{}/property/MolecularWeight/txt" @@ -87,14 +84,11 @@ def __init__( print("PubChem server connected.") def quit(self): - """ - Method for quiting ChromeDriver. - - """ + """Method for quiting ChromeDriver.""" if not self.api: self.web.quit() - def obtain_entry(self, search_text: str, name: str, output_format: str = "sdf") -> Optional[str]: + def obtain_entry(self, search_text: str, name: str, output_format: str = "sdf") -> str | None: """ Search the PubChem database with a text entry and save the structure in desired format. @@ -116,8 +110,6 @@ def smiles_to_pdb(self, smiles: str): Args: smiles: SMILES code. - Returns: - """ convertor_url = "https://cactus.nci.nih.gov/translate/" input_xpath = "/html/body/div/div[2]/div[1]/form/table[1]/tbody/tr[2]/td[1]/input[1]" @@ -138,7 +130,7 @@ def smiles_to_pdb(self, smiles: str): print(".", end="") print("\nStructure file saved.") - def _obtain_entry_web(self, search_text: str, name: str, output_format: str) -> Optional[str]: + def _obtain_entry_web(self, search_text: str, name: str, output_format: str) -> str | None: cid = None try: @@ -147,7 +139,7 @@ def _obtain_entry_web(self, search_text: str, name: str, output_format: str) -> self.web.get(url) loaded_element_path = '//*[@id="main-results"]/div[1]/div/ul' self.wait.until(EC.presence_of_element_located((By.XPATH, loaded_element_path))) - best_xpath = '//*[@id="featured-results"]/div/div[2]' "/div/div[1]/div[2]/div[1]/a/span/span" + best_xpath = '//*[@id="featured-results"]/div/div[2]/div/div[1]/div[2]/div[1]/a/span/span' relevant_xpath = ( '//*[@id="collection-results-container"]' "/div/div/div[2]/ul/li[1]/div/div/div[1]" @@ -195,7 +187,7 @@ def _obtain_entry_web(self, search_text: str, name: str, output_format: str) -> self.quit() return cid - def _obtain_entry_api(self, search_text, name, output_format) -> Optional[str]: + def _obtain_entry_api(self, search_text, name, output_format) -> str | None: cid = None cids = pcp.get_cids(search_text, "name", record_type="3d") if len(cids) == 0: diff --git a/mdgo/msd.py b/mdgo/msd.py index 0fd30f71..f8c6d6c9 100644 --- a/mdgo/msd.py +++ b/mdgo/msd.py @@ -1,4 +1,3 @@ -# coding: utf-8 # Copyright (c) Tingzheng Hou. # Distributed under the terms of the MIT License. @@ -10,7 +9,9 @@ http://stackoverflow.com/questions/34222272/computing-mean-square-displacement-using-python-and-fft#34222273 """ -from typing import List, Dict, Tuple, Union, Optional, Literal +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal try: import MDAnalysis.analysis.msd as mda_msd @@ -24,8 +25,9 @@ import numpy as np from tqdm.auto import trange -from MDAnalysis import Universe, AtomGroup -from MDAnalysis.core.groups import Atom +if TYPE_CHECKING: + from MDAnalysis import AtomGroup, Universe + from MDAnalysis.core.groups import Atom __author__ = "Tingzheng Hou" __version__ = "0.3.0" @@ -55,7 +57,7 @@ def total_msd( start: Start frame of analysis. end: End frame of analysis. select: A selection string. Defaults to “all” in which case all atoms are selected. - msd_type: Desired dimensions to be included in the MSD. Defaults to ‘xyz’. + msd_type: Desired dimensions to be included in the MSD. Defaults to "xyz". fft: Whether to use FFT to accelerate the calculation. Default to True. built_in: Whether to use built in method to calculate msd or use function from mds. Default to True. center_of_mass: Whether to subtract center of mass at each step for atom coordinates. Default to True. @@ -169,11 +171,11 @@ def create_position_arrays( atom_group = nvt_run.select_atoms(select) atom_positions = np.zeros((end - start, len(atom_group), 3)) if center_of_mass: - for ts in nvt_run.trajectory[start:end]: + for _ts in nvt_run.trajectory[start:end]: atom_positions[time, :, :] = atom_group.positions - nvt_run.atoms.center_of_mass() time += 1 else: - for ts in nvt_run.trajectory[start:end]: + for _ts in nvt_run.trajectory[start:end]: atom_positions[time, :, :] = atom_group.positions time += 1 return atom_positions @@ -197,7 +199,7 @@ def onsager_ii_self( start: Start frame of analysis. end: End frame of analysis. select: A selection string. Defaults to “all” in which case all atoms are selected. - msd_type: Desired dimensions to be included in the MSD. Defaults to ‘xyz’. + msd_type: Desired dimensions to be included in the MSD. Defaults to "xyz". center_of_mass: Whether to subtract center of mass at each step for atom coordinates. Default to True. fft: Whether to use FFT to accelerate the calculation. Default to True. @@ -223,8 +225,7 @@ def onsager_ii_self( r = atom_positions[:, atom_num, dim[0] : dim[1] : dim[2]] msd_temp = msd_straight_forward(np.array(r)) # [start:end] ii_self += msd_temp - msd = np.array(ii_self) / n_atoms - return msd + return np.array(ii_self) / n_atoms def mda_msd_wrapper( @@ -238,7 +239,7 @@ def mda_msd_wrapper( start: Start frame of analysis. end: End frame of analysis. select: A selection string. Defaults to “all” in which case all atoms are selected. - msd_type: Desired dimensions to be included in the MSD. Defaults to ‘xyz’. + msd_type: Desired dimensions to be included in the MSD. Defaults to "xyz". fft: Whether to use FFT to accelerate the calculation. Default to True. Warning: @@ -259,7 +260,7 @@ def mda_msd_wrapper( return total_array -def parse_msd_type(msd_type: DIM) -> List[int]: +def parse_msd_type(msd_type: DIM) -> list[int]: """ Sets up the desired dimensionality of the MSD. @@ -283,9 +284,7 @@ def parse_msd_type(msd_type: DIM) -> List[int]: try: dim = keys[msd_type_str] except KeyError: - raise ValueError( - f"invalid msd_type: {msd_type_str} specified, please specify one of xyz, " "xy, xz, yz, x, y, z" - ) + raise ValueError(f"invalid msd_type: {msd_type_str} specified, please specify one of xyz, xy, xz, yz, x, y, z") return dim @@ -311,11 +310,10 @@ def _total_msd(nvt_run: Universe, run_start: int, run_end: int, select: str = "a current_coord = ts[li_atom.id - 1] coords.append(current_coord) all_list.append(np.array(coords)) - total_array = msd_from_frags(all_list, run_end - run_start - 1) - return total_array + return msd_from_frags(all_list, run_end - run_start - 1) -def msd_from_frags(coord_list: List[np.ndarray], largest: int) -> np.ndarray: +def msd_from_frags(coord_list: list[np.ndarray], largest: int) -> np.ndarray: """ Calculates the MSD using a list of fragments of trajectory with the conventional algorithm. @@ -326,14 +324,14 @@ def msd_from_frags(coord_list: List[np.ndarray], largest: int) -> np.ndarray: Returns: The MSD series. """ - msd_dict: Dict[Union[int, np.integer], np.ndarray] = {} + msd_dict: dict[int | np.integer, np.ndarray] = {} for state in coord_list: n_frames = state.shape[0] lag_times = np.arange(1, min(n_frames, largest)) for lag in lag_times: disp = state[:-lag, :] - state[lag:, :] sqdist = np.square(disp).sum(axis=-1) - if lag in msd_dict.keys(): + if lag in msd_dict: msd_dict[lag] = np.concatenate((msd_dict[lag], sqdist), axis=0) else: msd_dict[lag] = sqdist @@ -345,23 +343,22 @@ def msd_from_frags(coord_list: List[np.ndarray], largest: int) -> np.ndarray: assert msds is not None msds_by_state[kw] = msds.mean() timeseries.append(msds_by_state[kw]) - timeseries = np.array(timeseries) - return timeseries + return np.array(timeseries) def states_coord_array( nvt_run: Universe, atom: Atom, - select_dict: Dict[str, str], + select_dict: dict[str, str], distance: float, run_start: int, run_end: int, binding_site: str = "anion", -) -> Tuple[List[np.ndarray], List[np.ndarray]]: +) -> tuple[list[np.ndarray], list[np.ndarray]]: """Cuts the trajectory of an atom into fragments. Each fragment contains consecutive timesteps of coordinates of the atom in either attached or free state. The Attached state is when the atom coordinates with the ``binding_site`` species (distance < ``distance``), and vice versa for the free state. - TODO: check if need wrapped trj + TODO: check if need wrapped trj. Args: nvt_run: An MDAnalysis ``Universe`` containing unwrapped trajectory. @@ -426,12 +423,12 @@ def partial_msd( nvt_run: Universe, atoms: AtomGroup, largest: int, - select_dict: Dict[str, str], + select_dict: dict[str, str], distance: float, run_start: int, run_end: int, binding_site: str = "anion", -) -> Tuple[Optional[List[np.ndarray]], Optional[List[np.ndarray]]]: +) -> tuple[list[np.ndarray] | None, list[np.ndarray] | None]: """Calculates the mean square displacement (MSD) of the ``atoms`` according to coordination states. The returned ``free_data`` include the MSD when ``atoms`` are not coordinated to ``binding_site``. The ``attach_data`` includes the MSD of ``atoms`` are not coordinated to ``binding_site``. diff --git a/mdgo/residence_time.py b/mdgo/residence_time.py index 288af091..5808ee6a 100644 --- a/mdgo/residence_time.py +++ b/mdgo/residence_time.py @@ -1,21 +1,22 @@ -# coding: utf-8 # Copyright (c) Tingzheng Hou. # Distributed under the terms of the MIT License. -""" -This module calculates species correlation lifetime (residence time). -""" +"""This module calculates species correlation lifetime (residence time).""" + +from __future__ import annotations + import os -from typing import List, Dict, Union, Tuple +from typing import TYPE_CHECKING -import numpy as np import matplotlib.pyplot as plt -from statsmodels.tsa.stattools import acovf +import numpy as np from scipy.optimize import curve_fit +from statsmodels.tsa.stattools import acovf from tqdm.auto import tqdm -from MDAnalysis import Universe -from MDAnalysis.core.groups import Atom +if TYPE_CHECKING: + from MDAnalysis import Universe + from MDAnalysis.core.groups import Atom __author__ = "Kara Fong, Tingzheng Hou" __version__ = "0.3.0" @@ -28,11 +29,11 @@ def neighbors_one_atom( nvt_run: Universe, center_atom: Atom, species: str, - select_dict: Dict[str, str], + select_dict: dict[str, str], distance: float, run_start: int, run_end: int, -) -> Dict[str, np.ndarray]: +) -> dict[str, np.ndarray]: """ Create adjacency matrix for one center atom. @@ -50,9 +51,8 @@ def neighbors_one_atom( A neighbor dict with neighbor atom id as keys and arrays of adjacent boolean (0/1) as values. """ bool_values = {} - time_count = 0 - for ts in nvt_run.trajectory[run_start:run_end:]: - if species in select_dict.keys(): + for time_count, _ts in enumerate(nvt_run.trajectory[run_start:run_end:]): + if species in select_dict: selection = ( "(" + select_dict[species] @@ -69,13 +69,12 @@ def neighbors_one_atom( if str(atom.id) not in bool_values: bool_values[str(atom.id)] = np.zeros(int((run_end - run_start) / 1)) bool_values[str(atom.id)][time_count] = 1 - time_count += 1 return bool_values -def calc_acf(a_values: Dict[str, np.ndarray]) -> List[np.ndarray]: +def calc_acf(a_values: dict[str, np.ndarray]) -> list[np.ndarray]: """ - Calculate auto-correlation function (ACF) + Calculate auto-correlation function (ACF). Args: a_values: A dict of adjacency matrix with neighbor atom id as keys and arrays @@ -85,20 +84,20 @@ def calc_acf(a_values: Dict[str, np.ndarray]) -> List[np.ndarray]: A list of auto-correlation functions for each neighbor species. """ acfs = [] - for atom_id, neighbors in a_values.items(): - # atom_id_numeric = int(re.search(r"\d+", atom_id).group()) + for neighbors in a_values.values(): # for _atom_id, neighbors in a_values.items(): + # atom_id_numeric = int(re.search(r"\d+", _atom_id).group()) acfs.append(acovf(neighbors, demean=False, unbiased=True, fft=True)) return acfs def exponential_func( - x: Union[float, np.floating, np.ndarray], - a: Union[float, np.floating, np.ndarray], - b: Union[float, np.floating, np.ndarray], - c: Union[float, np.floating, np.ndarray], -) -> Union[np.floating, np.ndarray]: + x: float | np.floating | np.ndarray, + a: float | np.floating | np.ndarray, + b: float | np.floating | np.ndarray, + c: float | np.floating | np.ndarray, +) -> np.floating | np.ndarray: """ - An exponential decay function + An exponential decay function. Args: x: Independent variable. @@ -114,21 +113,21 @@ def exponential_func( def calc_neigh_corr( nvt_run: Universe, - distance_dict: Dict[str, float], - select_dict: Dict[str, str], + distance_dict: dict[str, float], + select_dict: dict[str, str], time_step: float, run_start: int, run_end: int, center_atom: str = "cation", -) -> Tuple[np.ndarray, Dict[str, np.ndarray]]: +) -> tuple[np.ndarray, dict[str, np.ndarray]]: """Calculates the neighbor auto-correlation function (ACF) of selected species around center atom. Args: nvt_run: An MDAnalysis ``Universe``. - distance_dict: - select_dict: - time_step: + distance_dict: A dict of coordination cutoff distance of the neighbor species. + select_dict: A dictionary of atom species selection. + time_step: Timestep between each frame, in ps. run_start: Start frame of analysis. run_end: End frame of analysis. center_atom: The center atom to calculate the ACF for. Default to "cation". @@ -138,15 +137,13 @@ def calc_neigh_corr( """ # Set up times array times = [] - step = 0 center_atoms = nvt_run.select_atoms(select_dict[center_atom]) - for ts in nvt_run.trajectory[run_start:run_end]: + for step, _ts in enumerate(nvt_run.trajectory[run_start:run_end]): times.append(step * time_step) - step += 1 times = np.array(times) acf_avg = {} - for kw in distance_dict.keys(): + for kw in distance_dict: acf_all = [] for atom in tqdm(center_atoms[::]): distance = distance_dict.get(kw) @@ -161,22 +158,21 @@ def calc_neigh_corr( run_end, ) acfs = calc_acf(adjacency_matrix) - for acf in acfs: - acf_all.append(acf) + acf_all.extend(list(acfs)) acf_avg[kw] = np.mean(acf_all, axis=0) return times, acf_avg def fit_residence_time( times: np.ndarray, - acf_avg_dict: Dict[str, np.ndarray], + acf_avg_dict: dict[str, np.ndarray], cutoff_time: int, time_step: float, - save_curve: Union[str, bool] = False, -) -> Dict[str, np.floating]: + save_curve: str | bool = False, +) -> dict[str, np.floating]: """ Use the ACF to fit the residence time (Exponential decay constant). - TODO: allow defining the residence time according to a threshold value of the decay + TODO: allow defining the residence time according to a threshold value of the decay. Args: times: A time series. diff --git a/mdgo/util/__init__.py b/mdgo/util/__init__.py index eedbf9ee..88100714 100644 --- a/mdgo/util/__init__.py +++ b/mdgo/util/__init__.py @@ -1,18 +1,19 @@ -# coding: utf-8 # Copyright (c) Tingzheng Hou. # Distributed under the terms of the MIT License. """The util package implements various utilities that are commonly used by various packages.""" +from __future__ import annotations + __author__ = "Tingzheng Hou" __version__ = "0.3.0" __maintainer__ = "Tingzheng Hou" __email__ = "tingzheng_hou@berkeley.edu" __date__ = "Jul 19, 2021" -from typing import Final, Dict +from typing import Final -MM_of_Elements: Final[Dict[str, float]] = { +MM_of_Elements: Final[dict[str, float]] = { "H": 1.00794, "He": 4.002602, "Li": 6.941, diff --git a/mdgo/util/coord.py b/mdgo/util/coord.py index 2d536c7b..32570f0b 100644 --- a/mdgo/util/coord.py +++ b/mdgo/util/coord.py @@ -1,13 +1,16 @@ -# coding: utf-8 # Copyright (c) Tingzheng Hou. # Distributed under the terms of the MIT License. """Utilities for manipulating coordinates under periodic boundary conditions.""" -from typing import List, Union +from __future__ import annotations + +from typing import TYPE_CHECKING + import numpy as np -from MDAnalysis.core.groups import Atom +if TYPE_CHECKING: + from MDAnalysis.core.groups import Atom def atom_vec(atom1: Atom, atom2: Atom, dimension: np.ndarray) -> np.ndarray: @@ -35,9 +38,9 @@ def atom_vec(atom1: Atom, atom2: Atom, dimension: np.ndarray) -> np.ndarray: def position_vec( - pos1: Union[List[float], np.ndarray], - pos2: Union[List[float], np.ndarray], - dimension: Union[List[float], np.ndarray], + pos1: list[float] | np.ndarray, + pos2: list[float] | np.ndarray, + dimension: list[float] | np.ndarray, ) -> np.ndarray: """ Calculate the vector from pos2 to pos2. @@ -50,7 +53,7 @@ def position_vec( Return: The obtained vector. """ - vec: List[Union[int, float, np.floating]] = [0, 0, 0] + vec: list[int | float | np.floating] = [0, 0, 0] for i in range(3): diff = pos1[i] - pos2[i] if diff > dimension[i] / 2: diff --git a/mdgo/util/dict_utils.py b/mdgo/util/dict_utils.py index 945dbe5f..2dec52c1 100644 --- a/mdgo/util/dict_utils.py +++ b/mdgo/util/dict_utils.py @@ -1,23 +1,25 @@ -# coding: utf-8 # Copyright (c) Tingzheng Hou. # Distributed under the terms of the MIT License. """Utilities for manipulating dictionaries.""" -import string -import re +from __future__ import annotations + import math -from typing import Dict, Union -import numpy as np -import pandas as pd +import re +import string +from typing import TYPE_CHECKING +import numpy as np from pymatgen.io.lammps.data import CombinedData -from MDAnalysis import Universe -from MDAnalysis.core.groups import Residue, AtomGroup - from . import MM_of_Elements +if TYPE_CHECKING: + import pandas as pd + from MDAnalysis import Universe + from MDAnalysis.core.groups import AtomGroup, Residue + def mass_to_name(masses: np.ndarray) -> np.ndarray: """ @@ -38,9 +40,10 @@ def mass_to_name(masses: np.ndarray) -> np.ndarray: return np.array(names) -def lmp_mass_to_name(df: pd.DataFrame) -> Dict[int, str]: +def lmp_mass_to_name(df: pd.DataFrame) -> dict[int, str]: """ Create a dict for mapping atom type id to element from the mass information. + Args: df: The masses attribute from LammpsData object Return: @@ -65,7 +68,7 @@ def assign_name(u: Universe, names: np.ndarray): u.add_TopologyAttr("name", values=names) -def assign_resname(u: Universe, res_dict: Dict[str, str]): +def assign_resname(u: Universe, res_dict: dict[str, str]): """ Assign resnames to residues in a MDAnalysis.universe object. The function will not overwrite existing resnames. @@ -82,7 +85,7 @@ def assign_resname(u: Universe, res_dict: Dict[str, str]): res_group.residues.resnames = res_names -def res_dict_from_select_dict(u: Universe, select_dict: Dict[str, str]) -> Dict[str, str]: +def res_dict_from_select_dict(u: Universe, select_dict: dict[str, str]) -> dict[str, str]: """ Infer res_dict (residue selection) from select_dict (atom selection) in a MDAnalysis.universe object. @@ -91,7 +94,7 @@ def res_dict_from_select_dict(u: Universe, select_dict: Dict[str, str]) -> Dict[ select_dict: A dictionary of atom species, where each atom species name is a key and the corresponding values are the selection language. - return: + Return: A dictionary of resnames. """ saved_select = [] @@ -112,18 +115,18 @@ def res_dict_from_select_dict(u: Universe, select_dict: Dict[str, str]) -> Dict[ return res_dict -def res_dict_from_datafile(filename: str) -> Dict[str, str]: +def res_dict_from_datafile(filename: str) -> dict[str, str]: """ Infer res_dict (residue selection) from a LAMMPS data file. Args: filename: Path to the data file. The data file must be generated by a CombinedData object. - return: + Return: A dictionary of resnames. """ res_dict = {} - with open(filename, "r") as f: + with open(filename) as f: lines = f.readlines() if lines[0] == "Generated by pymatgen.io.lammps.data.LammpsData\n" and lines[1].startswith("#"): elyte_info = re.findall(r"\w+", lines[1]) @@ -148,14 +151,14 @@ def res_dict_from_datafile(filename: str) -> Dict[str, str]: raise ValueError("The LAMMPS data file should be generated by pymatgen.io.lammps.data.") -def res_dict_from_lammpsdata(lammps_data: CombinedData) -> Dict[str, str]: +def res_dict_from_lammpsdata(lammps_data: CombinedData) -> dict[str, str]: """ Infer res_dict (residue selection) from a LAMMPS data file. Args: lammps_data: A CombinedData object. - return: + Return: A dictionary of resnames. """ assert isinstance(lammps_data, CombinedData) @@ -184,7 +187,7 @@ def res_dict_from_lammpsdata(lammps_data: CombinedData) -> Dict[str, str]: return res_dict -def select_dict_from_resname(u: Universe) -> Dict[str, str]: +def select_dict_from_resname(u: Universe) -> dict[str, str]: """ Infer select_dict (possibly interested atom species selection) from resnames in a MDAnalysis.universe object. The resname must be pre-assigned already. @@ -192,10 +195,10 @@ def select_dict_from_resname(u: Universe) -> Dict[str, str]: Args: u: The universe object to work with. - return: + Return: A dictionary of atom species. """ - select_dict: Dict[str, str] = {} + select_dict: dict[str, str] = {} resnames = np.unique(u.residues.resnames) for resname in resnames: if resname == "": @@ -206,9 +209,9 @@ def select_dict_from_resname(u: Universe) -> Dict[str, str]: for i, frag in enumerate(residue.atoms.fragments): charge = np.sum(frag.charges) if charge > 0.001: - extract_atom_from_ion(True, frag, select_dict) + extract_atom_from_cation(frag, select_dict) elif charge < -0.001: - extract_atom_from_ion(False, frag, select_dict) + extract_atom_from_anion(frag, select_dict) else: extract_atom_from_molecule(resname, frag, select_dict, number=i + 1) elif len(residue.atoms.fragments) >= 2: @@ -218,10 +221,10 @@ def select_dict_from_resname(u: Universe) -> Dict[str, str]: for frag in residue.atoms.fragments: charge = np.sum(frag.charges) if charge > 0.001: - extract_atom_from_ion(True, frag, select_dict, cation_number) + extract_atom_from_cation(frag, select_dict, cation_number) cation_number += 1 elif charge < -0.001: - extract_atom_from_ion(False, frag, select_dict, anion_number) + extract_atom_from_anion(frag, select_dict, anion_number) anion_number += 1 else: extract_atom_from_molecule(resname, frag, select_dict, molecule_number) @@ -229,66 +232,69 @@ def select_dict_from_resname(u: Universe) -> Dict[str, str]: else: extract_atom_from_molecule(resname, residue, select_dict) elif residue.charge > 0: - extract_atom_from_ion(True, residue, select_dict) + extract_atom_from_cation(residue, select_dict) else: - extract_atom_from_ion(False, residue, select_dict) + extract_atom_from_anion(residue, select_dict) return select_dict -def extract_atom_from_ion(positive: bool, ion: Union[Residue, AtomGroup], select_dict: Dict[str, str], number: int = 0): +def extract_atom_from_cation(ion: Residue | AtomGroup, select_dict: dict[str, str], number: int = 0): """ - Assign the most most charged atom and/or one unique atom in the ion into select_dict. + Assign the most charged atom and/or one unique atom in the cation into select_dict. Args: - positive: Whether the charge of ion is positive. Otherwise negative. Default to True. ion: Residue or AtomGroup select_dict: A dictionary of atom species, where each atom species name is a key and the corresponding values are the selection language. - number: The serial number of the ion. + number: The serial number of the cation. """ - if positive: - if number == 0: - cation_name = "cation" - else: - cation_name = "cation_" + str(number) - if len(ion.atoms.types) == 1: - select_dict[cation_name] = "type " + ion.atoms.types[0] - else: - # The most positively charged atom in the cation - pos_center = ion.atoms[np.argmax(ion.atoms.charges)] - unique_types = np.unique(ion.atoms.types, return_counts=True) - # One unique atom in the cation - uni_center = unique_types[0][np.argmin(unique_types[1])] - if pos_center.type == uni_center: - select_dict[cation_name] = "type " + uni_center - else: - select_dict[cation_name + "_" + pos_center.name + pos_center.type] = "type " + pos_center.type - select_dict[cation_name] = "type " + uni_center + cation_name = "cation" if number == 0 else "cation_" + str(number) + if len(ion.atoms.types) == 1: + select_dict[cation_name] = "type " + ion.atoms.types[0] else: - if number == 0: - anion_name = "anion" + # The most positively charged atom in the cation + pos_center = ion.atoms[np.argmax(ion.atoms.charges)] + unique_types = np.unique(ion.atoms.types, return_counts=True) + # One unique atom in the cation + uni_center = unique_types[0][np.argmin(unique_types[1])] + if pos_center.type == uni_center: + select_dict[cation_name] = "type " + uni_center else: - anion_name = "anion_" + str(number) - if len(ion.atoms.types) == 1: - select_dict[anion_name] = "type " + ion.atoms.types[0] + select_dict[cation_name + "_" + pos_center.name + pos_center.type] = "type " + pos_center.type + select_dict[cation_name] = "type " + uni_center + + +def extract_atom_from_anion(ion: Residue | AtomGroup, select_dict: dict[str, str], number: int = 0): + """ + Assign the most charged atom and/or one unique atom in the anion into select_dict. + + Args: + ion: Residue or AtomGroup + select_dict: A dictionary of atom species, where each atom species name is a key + and the corresponding values are the selection language. + number: The serial number of the anion. + """ + anion_name = "anion" if number == 0 else "anion_" + str(number) + if len(ion.atoms.types) == 1: + select_dict[anion_name] = "type " + ion.atoms.types[0] + else: + # The most negatively charged atom in the anion + neg_center = ion.atoms[np.argmin(ion.atoms.charges)] + unique_types = np.unique(ion.atoms.types, return_counts=True) + # One unique atom in the anion + uni_center = unique_types[0][np.argmin(unique_types[1])] + if neg_center.type == uni_center: + select_dict[anion_name] = "type " + uni_center else: - # The most negatively charged atom in the anion - neg_center = ion.atoms[np.argmin(ion.atoms.charges)] - unique_types = np.unique(ion.atoms.types, return_counts=True) - # One unique atom in the anion - uni_center = unique_types[0][np.argmin(unique_types[1])] - if neg_center.type == uni_center: - select_dict[anion_name] = "type " + uni_center - else: - select_dict[anion_name + "_" + neg_center.name + neg_center.type] = "type " + neg_center.type - select_dict[anion_name] = "type " + uni_center + select_dict[anion_name + "_" + neg_center.name + neg_center.type] = "type " + neg_center.type + select_dict[anion_name] = "type " + uni_center def extract_atom_from_molecule( - resname: str, molecule: Union[Residue, AtomGroup], select_dict: Dict[str, str], number: int = 0 + resname: str, molecule: Residue | AtomGroup, select_dict: dict[str, str], number: int = 0 ): """ - Assign the most negatively charged atom in the molecule into select_dict + Assign the most negatively charged atom in the molecule into select_dict. Args: resname: The name of the molecule diff --git a/mdgo/util/num.py b/mdgo/util/num.py index ef68b57d..12089111 100644 --- a/mdgo/util/num.py +++ b/mdgo/util/num.py @@ -1,13 +1,12 @@ -# coding: utf-8 # Copyright (c) Tingzheng Hou. # Distributed under the terms of the MIT License. """Utilities for manipulating numbers in data structures.""" -from typing import List, Union, Optional +from __future__ import annotations -def strip_zeros(items: Union[List[Union[str, float, int]], str]) -> Optional[List[int]]: +def strip_zeros(items: list[str | float | int] | str) -> list[int] | None: """ Strip the trailing zeros of a sequence. diff --git a/mdgo/util/packmol.py b/mdgo/util/packmol.py index da5cabb1..67a38d6c 100644 --- a/mdgo/util/packmol.py +++ b/mdgo/util/packmol.py @@ -1,4 +1,3 @@ -# coding: utf-8 # Copyright (c) Tingzheng Hou. # Distributed under the terms of the MIT License. @@ -13,15 +12,16 @@ set the folder of the packmol executable to the PATH environment variable. """ +from __future__ import annotations + import os import subprocess from pathlib import Path -from typing import Dict, List, Optional, Union from shutil import which + from pymatgen.core import Molecule # from pymatgen.io.core import InputFile, InputSet, InputGenerator - from mdgo.util.volume import molecular_volume __author__ = "Tingzheng Hou, Ryan Kingsbury" @@ -39,7 +39,6 @@ class PackmolWrapper: molecules into a one single unit. Examples: - >>> molecules = [{"name": "EMC", "number": 2, "coords": "/Users/th/Downloads/test_selenium/EMC.lmp.xyz"}] @@ -54,13 +53,13 @@ class PackmolWrapper: def __init__( self, path: str, - molecules: List[Dict], - box: Optional[List[float]] = None, + molecules: list[dict], + box: list[float] | None = None, tolerance: float = 2.0, seed: int = 1, - control_params: Optional[Dict] = None, - inputfile: Union[str, Path] = "packmol.inp", - outputfile: Union[str, Path] = "packmol_out.xyz", + control_params: dict | None = None, + inputfile: str | Path = "packmol.inp", + outputfile: str | Path = "packmol_out.xyz", ): """ Args: @@ -72,16 +71,17 @@ def __init__( 2. "number" - the number of that molecule to pack into the box 3. "coords" - Coordinates in the form of either a Molecule object or a path to a file. - Example: - {"name": "water", - "number": 500, - "coords": "/path/to/input/file.xyz"} + For Example, + {"name": "water", + "number": 500, + "coords": "/path/to/input/file.xyz"} box: A list of box dimensions xlo, ylo, zlo, xhi, yhi, zhi, in Å. If set to None (default), mdgo will estimate the required box size based on the volumes of the provided molecules using mdgo.volume.molecular_volume() tolerance: Tolerance for packmol, in Å. seed: Random seed for packmol. Use a value of 1 (default) for deterministic output, or -1 to generate a new random seed from the current time. + control_params: Specify custom control parapeters, e,g, "maxit" and "nloop", in a dict inputfile: Path to the input file. Default to 'packmol.inp'. outputfile: Path to the output file. Default to 'output.xyz'. """ @@ -97,8 +97,10 @@ def __init__( def run_packmol(self, timeout=30): """Run packmol and write out the packed structure. + Args: timeout: Timeout in seconds. + Raises: ValueError if packmol does not succeed in packing the box. TimeoutExpiredError if packmold does not finish within the timeout. @@ -117,8 +119,7 @@ def run_packmol(self, timeout=30): check=True, shell=True, timeout=timeout, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, + capture_output=True, ) # this workaround is needed because packmol can fail to find # a solution but still return a zero exit code @@ -143,17 +144,13 @@ def run_packmol(self, timeout=30): # InputSet def make_packmol_input(self): """Make a Packmol usable input file.""" - if self.box: box_list = " ".join(str(i) for i in self.box) else: # estimate the total volume of all molecules net_volume = 0.0 - for idx, d in enumerate(self.molecules): - if not isinstance(d["coords"], Molecule): - mol = Molecule.from_file(d["coords"]) - else: - mol = d["coords"] + for _idx, d in enumerate(self.molecules): + mol = Molecule.from_file(d["coords"]) if not isinstance(d["coords"], Molecule) else d["coords"] # molecular volume in cubic Å vol = molecular_volume(mol, radii_type="pymatgen", molar_volume=False) # pad the calculated length by an amount related to the tolerance parameter @@ -172,7 +169,7 @@ def make_packmol_input(self): if isinstance(v, list): out.write(f'{k} {" ".join(str(x) for x in v)}\n') else: - out.write(f"{k} {str(v)}\n") + out.write(f"{k} {v!s}\n") out.write(f"seed {self.seed}\n") out.write(f"tolerance {self.tolerance}\n\n") @@ -182,7 +179,7 @@ def make_packmol_input(self): else: out.write(f"output {self.output}\n\n") - for i, d in enumerate(self.molecules): + for _i, d in enumerate(self.molecules): if isinstance(d["coords"], str): if " " in d["coords"]: out.write(f'structure "{d["coords"]}"\n') @@ -190,9 +187,9 @@ def make_packmol_input(self): out.write(f'structure {d["coords"]}\n') elif isinstance(d["coords"], Path): if " " in str(d["coords"]): - out.write(f'structure "{str(d["coords"])}"\n') + out.write(f'structure "{d["coords"]!s}"\n') else: - out.write(f'structure {str(d["coords"])}\n') + out.write(f'structure {d["coords"]!s}\n') elif isinstance(d["coords"], Molecule): fname = os.path.join(self.path, f'packmol_{d["name"]}.xyz') d["coords"].to(filename=fname) @@ -200,7 +197,7 @@ def make_packmol_input(self): out.write(f'structure "{fname}"\n') else: out.write(f"structure {fname}\n") - out.write(f' number {str(d["number"])}\n') + out.write(f' number {d["number"]!s}\n') out.write(f" inside box {box_list}\n") out.write("end structure\n\n") diff --git a/mdgo/util/reformat.py b/mdgo/util/reformat.py index 4b967967..ec5d4a68 100644 --- a/mdgo/util/reformat.py +++ b/mdgo/util/reformat.py @@ -1,21 +1,21 @@ -# coding: utf-8 # Copyright (c) Tingzheng Hou. # Distributed under the terms of the MIT License. -""" -Utilities for converting data file formats. -""" +"""Utilities for converting data file formats.""" + +from __future__ import annotations -from io import StringIO import re -from typing import List, Dict, Any, Final +from io import StringIO +from typing import Any, Final + import pandas as pd from mdgo.util.dict_utils import MM_of_Elements -from . import __author__ +from . import __author__ -SECTION_SORTER: Final[Dict[str, Dict[str, Any]]] = { +SECTION_SORTER: Final[dict[str, dict[str, Any]]] = { "atoms": { "in_kw": None, "in_header": ["atom", "charge", "sigma", "epsilon"], @@ -68,9 +68,7 @@ }, } -BOX: Final[ - str -] = """{0:6f} {1:6f} xlo xhi +BOX: Final[str] = """{0:6f} {1:6f} xlo xhi {0:6f} {1:6f} ylo yhi {0:6f} {1:6f} zlo zhi""" @@ -87,12 +85,12 @@ def ff_parser(ff_dir: str, xyz_dir: str) -> str: Return: The output LAMMPS data string. """ - with open(xyz_dir, "r") as f_xyz: + with open(xyz_dir) as f_xyz: molecule = pd.read_table(f_xyz, skiprows=2, delim_whitespace=True, names=["atom", "x", "y", "z"]) coordinates = molecule[["x", "y", "z"]] lo = coordinates.min().min() - 0.5 hi = coordinates.max().max() + 0.5 - with open(ff_dir, "r") as f: + with open(ff_dir) as f: lines_org = f.read() lines = lines_org.split("\n\n") atoms = "\n".join(lines[4].split("\n", 4)[4].split("\n")[:-1]) @@ -118,7 +116,7 @@ def ff_parser(ff_dir: str, xyz_dir: str) -> str: counts = {} counts["atoms"] = len(dfs["atoms"].index) mass_list = [] - for index, row in dfs["atoms"].iterrows(): + for _index, row in dfs["atoms"].iterrows(): mass_list.append(MM_of_Elements.get(re.split(r"(\d+)", row["atom"])[0])) mass_df = pd.DataFrame(mass_list) mass_df.index += 1 @@ -173,15 +171,14 @@ def ff_parser(ff_dir: str, xyz_dir: str) -> str: stats_template = "{:>" + str(max_stats) + "} {}" count_lines = [stats_template.format(v, k) for k, v in counts.items()] type_lines = [stats_template.format(v, k[:-1] + " types") for k, v in counts.items()] - stats = "\n".join(count_lines + [""] + type_lines) + stats = "\n".join([*count_lines, "", *type_lines]) header = [ f"LAMMPS data file created by mdgo (by {__author__})\n" "# OPLS force field: harmonic, harmonic, opls, cvff", stats, BOX.format(lo, hi), ] - data_string = "\n\n".join(header + masses + ff + topo) + "\n" - return data_string + return "\n\n".join(header + masses + ff + topo) + "\n" def sdf_to_pdb( @@ -203,16 +200,12 @@ def sdf_to_pdb( credit: Whether to credit line (remark 888) in the pdb file. Default to True. pubchem: Whether the sdf is downloaded from PubChem. Default to True. """ - # parse sdf file file - with open(sdf_file, "r") as inp: + with open(sdf_file) as inp: sdf_lines = inp.readlines() sdf = list(map(str.strip, sdf_lines)) - if pubchem: - title = "cid_" - else: - title = "" - pdb_atoms: List[Dict[str, Any]] = [] + title = "cid_" if pubchem else "" + pdb_atoms: list[dict[str, Any]] = [] # create pdb list of dictionaries atoms = 0 bonds = 0 @@ -243,12 +236,12 @@ def sdf_to_pdb( "z": float(line_split[2]), "occupancy": 1.00, "tempFactor": 0.00, - "altLoc": str(""), - "chainID": str(""), - "iCode": str(""), + "altLoc": "", + "chainID": "", + "iCode": "", "element": str(line_split[3]), - "charge": str(""), - "segment": str(""), + "charge": "", + "segment": "", } pdb_atoms.append(newline) elif i in list(range(4 + atoms, 4 + atoms + bonds)): @@ -266,13 +259,13 @@ def sdf_to_pdb( pass # write pdb file - with open(pdb_file, "wt") as outp: + with open(pdb_file, "w") as outp: if write_title: outp.write(f"TITLE {title:70s}\n") if version: outp.write("REMARK 4 COMPLIES WITH FORMAT V. 3.3, 21-NOV-2012\n") if credit: - outp.write("REMARK 888\n" "REMARK 888 WRITTEN BY MDGO (CREATED BY TINGZHENG HOU)\n") + outp.write("REMARK 888\nREMARK 888 WRITTEN BY MDGO (CREATED BY TINGZHENG HOU)\n") for n in range(atoms): line_dict = pdb_atoms[n].copy() if len(line_dict["name"]) == 3: @@ -308,7 +301,7 @@ def sdf_to_pdb( bond_lines[atom].append(atom2s[i]) for i, atom in enumerate(atom2s): bond_lines[atom].append(atom1s[i]) - for i, odr in enumerate(orders): + for _i, odr in enumerate(orders): for j, ln in enumerate(bond_lines): if ln[0] == odr[0]: bond_lines.insert(j + 1, odr) diff --git a/mdgo/util/volume.py b/mdgo/util/volume.py index 2a77bba7..923fd073 100644 --- a/mdgo/util/volume.py +++ b/mdgo/util/volume.py @@ -1,4 +1,3 @@ -# coding: utf-8 # Copyright (c) Tingzheng Hou. # Distributed under the terms of the MIT License. @@ -13,21 +12,21 @@ the cube is defined by the -xsize, -ysize and -zsize options. """ +from __future__ import annotations -import sys -import os import argparse -from typing import Optional, List, Dict, Union, Tuple, Final +import os +import sys +from typing import Final import numpy as np -from pymatgen.core import Molecule, Element - +from pymatgen.core import Element, Molecule DEFAULT_VDW = 1.5 # See Ev:130902 -MOLAR_VOLUME: Final[Dict[str, float]] = {"lipf6": 18, "litfsi": 100} # empirical value +MOLAR_VOLUME: Final[dict[str, float]] = {"lipf6": 18, "litfsi": 100} # empirical value -ALIAS: Final[Dict[str, str]] = { +ALIAS: Final[dict[str, str]] = { "ethylene carbonate": "ec", "ec": "ec", "propylene carbonate": "pc", @@ -68,7 +67,7 @@ } # From PubChem -MOLAR_MASS: Final[Dict[str, float]] = { +MOLAR_MASS: Final[dict[str, float]] = { "ec": 88.06, "pc": 102.09, "dec": 118.13, @@ -88,7 +87,7 @@ } # from Sigma-Aldrich -DENSITY: Final[Dict[str, float]] = { +DENSITY: Final[dict[str, float]] = { "ec": 1.321, "pc": 1.204, "dec": 0.975, @@ -284,9 +283,9 @@ def parse_command_line(): return args -def get_max_dimensions(mol: Molecule) -> Tuple[float, float, float, float, float, float]: +def get_max_dimensions(mol: Molecule) -> tuple[float, float, float, float, float, float]: """ - Calculates the dimension of a Molecule + Calculates the dimension of a Molecule. Args: mol: A Molecule object. @@ -294,7 +293,6 @@ def get_max_dimensions(mol: Molecule) -> Tuple[float, float, float, float, float Returns: xmin, xmax, ymin, ymax, zmin, zmax """ - xmin = 9999 ymin = 9999 zmin = 9999 @@ -319,7 +317,7 @@ def get_max_dimensions(mol: Molecule) -> Tuple[float, float, float, float, float def set_max_dimensions( x: float = 0.0, y: float = 0.0, z: float = 0.0, x_size: float = 10.0, y_size: float = 10.0, z_size: float = 10.0 -) -> Tuple[float, float, float, float, float, float]: +) -> tuple[float, float, float, float, float, float]: """ Set the max dimensions for calculating active site volume. @@ -345,7 +343,7 @@ def set_max_dimensions( def round_dimensions( x_min: float, x_max: float, y_min: float, y_max: float, z_min: float, z_max: float, mode: str = "lig" -) -> Tuple[float, float, float, float, float, float]: +) -> tuple[float, float, float, float, float, float]: """ Round dimensions to a larger box size (+ buffer). @@ -375,7 +373,7 @@ def round_dimensions( def dsq(a1: float, a2: float, a3: float, b1: float, b2: float, b3: float) -> float: """ - Squared distance between a and b + Squared distance between a and b. Args: a1: x coordinate of a @@ -388,13 +386,12 @@ def dsq(a1: float, a2: float, a3: float, b1: float, b2: float, b3: float) -> flo Returns: squared distance """ - d2 = (b1 - a1) ** 2 + (b2 - a2) ** 2 + (b3 - a3) ** 2 - return d2 + return (b1 - a1) ** 2 + (b2 - a2) ** 2 + (b3 - a3) ** 2 def get_dimensions( x0: float, x1: float, y0: float, y1: float, z0: float, z1: float, res: float = 0.1 -) -> Tuple[int, int, int]: +) -> tuple[int, int, int]: """ Mesh dimensions in unit of res. @@ -433,12 +430,10 @@ def make_matrix(x_num: int, y_num: int, z_num: int) -> np.ndarray: Returns: matrix """ - - matrix = np.array([[[None for _ in range(z_num)] for _ in range(y_num)] for _ in range(x_num)]) - return matrix + return np.array([[[None for _ in range(z_num)] for _ in range(y_num)] for _ in range(x_num)]) -def get_radii(radii_type: str = "Bondi") -> Dict[str, float]: +def get_radii(radii_type: str = "Bondi") -> dict[str, float]: """ Get a radii dict by type. @@ -533,15 +528,14 @@ def fill_volume_matrix( for a in mol.sites: element = str(a.species.elements[0]) - if exclude_h: - if element == "H": - continue + if exclude_h and element == "H": + continue radius = radii.get(element, DEFAULT_VDW) - for i in range(0, xsteps): + for i in range(xsteps): if abs(a.x - (x0 + 0.5 * res + i * res)) < radius: - for j in range(0, ysteps): + for j in range(ysteps): if abs(a.y - (y0 + 0.5 * res + j * res)) < radius: - for k in range(0, zsteps): + for k in range(zsteps): if matrix[i][j][k] != 1: if abs(a.z - (z0 + 0.5 * res + k * res)) < radius: if dsq( @@ -560,7 +554,7 @@ def fill_volume_matrix( return matrix -def get_occupied_volume(matrix: np.ndarray, res: float, name: Optional[str] = None, molar_volume=True) -> float: +def get_occupied_volume(matrix: np.ndarray, res: float, name: str | None = None, molar_volume=True) -> float: """ Get the occupied volume of the molecule in the box. @@ -581,7 +575,7 @@ def get_occupied_volume(matrix: np.ndarray, res: float, name: Optional[str] = No return v # Å^3 -def get_unoccupied_volume(matrix: np.ndarray, res: float, name: Optional[str] = None, molar_volume=True) -> float: +def get_unoccupied_volume(matrix: np.ndarray, res: float, name: str | None = None, molar_volume=True) -> float: """ Get the unoccupied volume of the molecule in the box. @@ -603,8 +597,8 @@ def get_unoccupied_volume(matrix: np.ndarray, res: float, name: Optional[str] = def molecular_volume( - mol: Union[str, Molecule], - name: Optional[str] = None, + mol: str | Molecule, + name: str | None = None, res: float = 0.1, radii_type: str = "Bondi", molar_volume: bool = True, @@ -618,7 +612,7 @@ def molecular_volume( z_size: float = 10.0, ) -> float: """ - Estimate the molar volume in cm^3/mol or volume in Å^3 + Estimate the molar volume in cm^3/mol or volume in Å^3. Args: mol: Molecule object or path to .xyz or other file that can be read @@ -647,10 +641,7 @@ def molecular_volume( Returns: The molar volume in cm^3/mol or volume in Å^3. """ - if isinstance(mol, str): - molecule = Molecule.from_file(mol) - else: - molecule = mol + molecule = Molecule.from_file(mol) if isinstance(mol, str) else mol if mode == "lig": print("Calculating occupied volume...") x_min, x_max, y_min, y_max, z_min, z_max = get_max_dimensions(molecule) @@ -674,18 +665,18 @@ def molecular_volume( def concentration_matcher( concentration: float, - salt: Union[float, int, str, Molecule], - solvents: List[Union[str, Dict[str, float]]], - solv_ratio: List[float], + salt: float | str | Molecule, + solvents: list[str | dict[str, float]], + solv_ratio: list[float], num_salt: int = 100, mode: str = "v", radii_type: str = "Bondi", -) -> Tuple[List, float]: +) -> tuple[list, float]: """ Estimate the number of molecules of each species in a box, given the salt concentration, salt type, solvent molecular weight, solvent density, solvent ratio and total number of salt. - TODO: Auto box size according to Debye screening length + TODO: Auto box size according to Debye screening length. Args: concentration: Salt concentration in mol/L. @@ -745,7 +736,7 @@ def concentration_matcher( sys.exit(1) name = os.path.splitext(os.path.split(salt)[-1])[0] ext = os.path.splitext(os.path.split(salt)[-1])[1] - if not ext == ".xyz": + if ext != ".xyz": print("Error: Wrong file format, please use a .xyz file.\n") sys.exit(1) salt_molar_volume = molecular_volume(salt, name, radii_type=radii_type) diff --git a/pyproject.toml b/pyproject.toml index e1480f13..348fcae9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,6 @@ [build-system] requires = [ - "setuptools>=42", - "wheel" + "setuptools>=65.0.0", ] build-backend = "setuptools.build_meta" @@ -75,3 +74,45 @@ lint.isort.split-on-trailing-comma = false "pymatgen/vis/*" = ["D"] "pymatgen/io/*" = ["D"] "dev_scripts/*" = ["D"] + +[tool.pytest.ini_options] +addopts = "--durations=30 --quiet -r xXs --color=yes -p no:warnings --import-mode=importlib" + +[tool.coverage.run] +parallel = true + +[tool.coverage.report] +exclude_also = [ + "@deprecated", + "@np.deprecate", + "def __repr__", + "except ImportError:", + "if 0:", + "if TYPE_CHECKING:", + "if __name__ == .__main__.:", + "if self.debug:", + "if settings.DEBUG", + "if typing.TYPE_CHECKING:", + "pragma: no cover", + "raise AssertionError", + "raise NotImplementedError", + "show_plot", +] + +[tool.mypy] +ignore_missing_imports = true +namespace_packages = true +explicit_package_bases = true +no_implicit_optional = false +disable_error_code = "annotation-unchecked" + +[[tool.mypy.overrides]] +module = ["requests.*", "tabulate.*"] +ignore_missing_imports = true + +[tool.codespell] +ignore-words-list = """ +titel,alls,ans,nd,mater,nwo,te,hart,ontop,ist,ot,fo,nax,coo,coul,ser,leary,thre,fase, +rute,reson,titels,ges,scalr,strat,struc,hda,nin,ons,pres,kno,loos,lamda,lew,atomate +""" +check-filenames = true \ No newline at end of file diff --git a/setup.py b/setup.py index d2dc1d78..fa75a4ea 100644 --- a/setup.py +++ b/setup.py @@ -1,16 +1,17 @@ -# coding: utf-8 # Copyright (c) Tingzheng Hou. # Distributed under the terms of the MIT License. """Setup.py for MDGO.""" +from __future__ import annotations + import os -from setuptools import setup, find_packages +from setuptools import find_packages, setup module_dir = os.path.dirname(os.path.abspath(__file__)) -with open(os.path.join(module_dir, "README.md"), "r") as f: +with open(os.path.join(module_dir, "README.md")) as f: readme = f.read() INSTALL_REQUIRES = [ @@ -70,7 +71,7 @@ "Topic :: Scientific/Engineering :: Physics", "Topic :: Scientific/Engineering :: Chemistry", "Topic :: Software Development :: Libraries :: Python Modules", - ], + ], packages=find_packages(), install_requires=INSTALL_REQUIRES, extras_require={ diff --git a/tasks.py b/tasks.py index 825cde6a..0f46e223 100644 --- a/tasks.py +++ b/tasks.py @@ -1,7 +1,11 @@ """ Pyinvoke tasks.py file for automating releases and admin stuff. + +To cut a new mdgo release, use `invoke update-changelog` followed by `invoke release`. + Author: Tingzheng Hou """ +from __future__ import annotations import glob import json @@ -109,35 +113,35 @@ def set_ver(ctx, version): contents = f.read() contents = re.sub(r"__version__ = .*\n", '__version__ = "%s"\n' % version, contents) - with open("mdgo/__init__.py", "wt") as f: + with open("mdgo/__init__.py", "w") as f: f.write(contents) with open("mdgo/core/__init__.py") as f: contents = f.read() contents = re.sub(r"__version__ = .*\n", '__version__ = "%s"\n' % version, contents) - with open("mdgo/core/__init__.py", "wt") as f: + with open("mdgo/core/__init__.py", "w") as f: f.write(contents) with open("mdgo/forcefield/__init__.py") as f: contents = f.read() contents = re.sub(r"__version__ = .*\n", '__version__ = "%s"\n' % version, contents) - with open("mdgo/forcefield/__init__.py", "wt") as f: + with open("mdgo/forcefield/__init__.py", "w") as f: f.write(contents) with open("mdgo/util/__init__.py") as f: contents = f.read() contents = re.sub(r"__version__ = .*\n", '__version__ = "%s"\n' % version, contents) - with open("mdgo/util/__init__.py", "wt") as f: + with open("mdgo/util/__init__.py", "w") as f: f.write(contents) with open("setup.py") as f: contents = f.read() contents = re.sub(r"version=([^,]+),", 'version="%s",' % version, contents) - with open("setup.py", "wt") as f: + with open("setup.py", "w") as f: f.write(contents) @@ -178,8 +182,8 @@ def update_changelog(ctx, version, sim=False): output = subprocess.check_output(["git", "log", "--pretty=format:%s", "v%s..HEAD" % CURRENT_VER]) lines = [] misc = [] - for l in output.decode("utf-8").strip().split("\n"): - m = re.match(r"Merge pull request \#(\d+) from (.*)", l) + for line in output.decode("utf-8").strip().split("\n"): + m = re.match(r"Merge pull request \#(\d+) from (.*)", line) if m: pr_number = m.group(1) contrib, pr_name = m.group(2).split("/", 1) @@ -190,22 +194,22 @@ def update_changelog(ctx, version, sim=False): ll = ll.strip() if ll in ["", "## Summary"]: continue - elif ll.startswith("## Checklist") or ll.startswith("## TODO"): + if ll.startswith(("## Checklist", "## TODO")): break lines.append(f" {ll}") - misc.append(l) + misc.append(line) with open("CHANGES.rst") as f: contents = f.read() - l = "==========" - toks = contents.split(l) + line = "==========" + toks = contents.split(line) head = "\n\nv%s\n" % version + "-" * (len(version) + 1) + "\n" toks.insert(-1, head + "\n".join(lines)) if not sim: with open("CHANGES.rst", "w") as f: - f.write(toks[0] + l + "".join(toks[1:])) + f.write(toks[0] + line + "".join(toks[1:])) ctx.run("open CHANGES.rst") else: - print(toks[0] + l + "".join(toks[1:])) + print(toks[0] + line + "".join(toks[1:])) print("The following commit messages were not included...") print("\n".join(misc)) diff --git a/tests/test_conductivity.py b/tests/test_conductivity.py index 9e1d3066..b5f864da 100644 --- a/tests/test_conductivity.py +++ b/tests/test_conductivity.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import os -import sys -import tempfile -from io import StringIO import unittest + import MDAnalysis -from mdgo.conductivity import * +import numpy as np + +from mdgo.conductivity import calc_cond_msd, get_beta test_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_files") @@ -23,27 +25,15 @@ def setUpClass(cls) -> None: cls.time_array = np.array([i * 10 for i in range(cls.gen2.trajectory.n_frames - 100)]) def test_calc_cond_msd(self): - self.assertEqual(-2.9103830456733704e-11, self.cond_array[0]) - self.assertEqual(112.66080481783138, self.cond_array[1]) - self.assertEqual(236007.76624833583, self.cond_array[-1]) + assert self.cond_array[0] == -2.9103830456733704e-11 + assert self.cond_array[1] == 112.66080481783138 + assert self.cond_array[-1] == 236007.76624833583 def test_get_beta(self): - self.assertEqual( - (0.8188201425517928, 0.2535110576154693), - get_beta(self.cond_array, self.time_array, 10, 100), - ) - self.assertEqual( - (1.2525648107674503, 1.0120346984003845), - get_beta(self.cond_array, self.time_array, 1000, 2000), - ) - self.assertEqual( - (1.4075552564189142, 1.3748981878979976), - get_beta(self.cond_array, self.time_array, 1500, 2500), - ) - self.assertEqual( - (1.5021915651236932, 51.79451695748163), - get_beta(self.cond_array, self.time_array, 2000, 4000), - ) + assert get_beta(self.cond_array, self.time_array, 10, 100) == (0.8188201425517928, 0.2535110576154693) + assert get_beta(self.cond_array, self.time_array, 1000, 2000) == (1.2525648107674503, 1.0120346984003845) + assert get_beta(self.cond_array, self.time_array, 1500, 2500) == (1.4075552564189142, 1.3748981878979976) + assert get_beta(self.cond_array, self.time_array, 2000, 4000) == (1.5021915651236932, 51.79451695748163) if __name__ == "__main__": diff --git a/tests/test_coordination.py b/tests/test_coordination.py index d6cfbef8..cf1cc95f 100644 --- a/tests/test_coordination.py +++ b/tests/test_coordination.py @@ -1,10 +1,11 @@ -import os +from __future__ import annotations + import unittest class MyTestCase(unittest.TestCase): def test_something(self): - self.assertEqual(True, True) + assert True is True if __name__ == "__main__": diff --git a/tests/test_core.py b/tests/test_core.py index d6cfbef8..cf1cc95f 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,10 +1,11 @@ -import os +from __future__ import annotations + import unittest class MyTestCase(unittest.TestCase): def test_something(self): - self.assertEqual(True, True) + assert True is True if __name__ == "__main__": diff --git a/tests/test_forcefield.py b/tests/test_forcefield.py index 464d7d5f..2a5703e0 100644 --- a/tests/test_forcefield.py +++ b/tests/test_forcefield.py @@ -1,4 +1,7 @@ -import io +from __future__ import annotations + +import os +import shutil import sys import tempfile import unittest @@ -6,9 +9,10 @@ import numpy as np import pytest +from pymatgen.io.lammps.data import LammpsData -from mdgo.forcefield.crawler import * -from mdgo.forcefield.aqueous import * +from mdgo.forcefield.aqueous import Aqueous, Ion +from mdgo.forcefield.crawler import FFcrawler test_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_files") @@ -34,43 +38,34 @@ def test_chrome(self) -> None: lpg = FFcrawler(download_dir, xyz=True, gromacs=True) lpg.data_from_pdb(os.path.join(test_dir, "EMC.pdb")) - self.assertIn( - "LigParGen server connected.\n" - "Structure info uploaded. Rendering force field...\n", - out.getvalue(), - ) - self.assertIn( - "Force field file downloaded.\n" - ".xyz file saved.\n" - "Force field file saved.\n", - out.getvalue(), - ) - self.assertTrue(os.path.exists(os.path.join(download_dir, "EMC.lmp"))) - self.assertTrue(os.path.exists(os.path.join(download_dir, "EMC.lmp.xyz"))) - self.assertTrue(os.path.exists(os.path.join(download_dir, "EMC.gro"))) - self.assertTrue(os.path.exists(os.path.join(download_dir, "EMC.itp"))) + assert "LigParGen server connected.\nStructure info uploaded. Rendering force field...\n" in out.getvalue() + assert "Force field file downloaded.\n.xyz file saved.\nForce field file saved.\n" in out.getvalue() + assert os.path.exists(os.path.join(download_dir, "EMC.lmp")) + assert os.path.exists(os.path.join(download_dir, "EMC.lmp.xyz")) + assert os.path.exists(os.path.join(download_dir, "EMC.gro")) + assert os.path.exists(os.path.join(download_dir, "EMC.itp")) with open(os.path.join(download_dir, "EMC.lmp")) as f: pdf_actual = f.readlines() - self.assertListEqual(pdf, pdf_actual) + assert pdf == pdf_actual with open(os.path.join(download_dir, "EMC.lmp.xyz")) as f: xyz_actual = f.readlines() - self.assertListEqual(xyz, xyz_actual) + assert xyz == xyz_actual with open(os.path.join(download_dir, "EMC.gro")) as f: gro_actual = f.readlines() - self.assertListEqual(gro, gro_actual) + assert gro == gro_actual with open(os.path.join(download_dir, "EMC.itp")) as f: itp_actual = f.readlines() - self.assertListEqual(itp, itp_actual) + assert itp == itp_actual lpg = FFcrawler(download_dir) lpg.data_from_smiles("CCOC(=O)OC") with open(os.path.join(download_dir, "CCOC(=O)OC.lmp")) as f: smiles_actual = f.readlines() - self.assertListEqual(smiles[:13], smiles_actual[:13]) - self.assertListEqual(smiles[18:131], smiles_actual[18:131]) - self.assertEqual(" 1 1 1 -0.28", smiles_actual[131][:26]) - self.assertEqual(" 2 1 2 0.01", smiles_actual[132][:25]) - self.assertEqual(" 15 1 15 0.10", smiles_actual[145][:25]) - self.assertListEqual(smiles_actual[146:], smiles[146:]) + assert smiles_actual[:13] == smiles[:13] + assert smiles_actual[18:131] == smiles[18:131] + assert smiles_actual[131][:26] == " 1 1 1 -0.28" + assert smiles_actual[132][:25] == " 2 1 2 0.01" + assert smiles_actual[145][:25] == " 15 1 15 0.10" + assert smiles_actual[146:] == smiles[146:] finally: sys.stdout = saved_stdout shutil.rmtree(download_dir) diff --git a/tests/test_mdgopackmol.py b/tests/test_mdgopackmol.py index 9ffa91b3..9e21d7e9 100644 --- a/tests/test_mdgopackmol.py +++ b/tests/test_mdgopackmol.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import os import tempfile from pathlib import Path from subprocess import TimeoutExpired -import pytest import numpy as np +import pytest from pymatgen.core import Molecule from mdgo.util.packmol import PackmolWrapper @@ -12,7 +14,7 @@ test_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_files") -@pytest.fixture +@pytest.fixture() def ethanol(): """ Returns a Molecule of ethanol @@ -33,7 +35,7 @@ def ethanol(): return Molecule(ethanol_atoms, ethanol_coords) -@pytest.fixture +@pytest.fixture() def water(): """ Returns a Molecule of water @@ -120,11 +122,11 @@ def test_control_params(self, water, ethanol): control_params={"maxit": 0, "nloop": 0}, ) pw.make_packmol_input() - with open(os.path.join(scratch_dir, "packmol.inp"), "r") as f: + with open(os.path.join(scratch_dir, "packmol.inp")) as f: input_string = f.read() assert "maxit 0" in input_string assert "nloop 0" in input_string - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Packmol failed with 1"): pw.run_packmol() def test_timeout(self, water, ethanol): @@ -158,10 +160,10 @@ def test_no_return_and_box(self, water, ethanol): box=[0, 0, 0, 2, 2, 2], ) pw.make_packmol_input() - with open(os.path.join(scratch_dir, "packmol.inp"), "r") as f: + with open(os.path.join(scratch_dir, "packmol.inp")) as f: input_string = f.read() assert "inside box 0 0 0 2 2 2" in input_string - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Packmol failed with"): pw.run_packmol() def test_random_seed(self, water, ethanol): diff --git a/tests/test_msd.py b/tests/test_msd.py index e3b70506..bc7dab97 100644 --- a/tests/test_msd.py +++ b/tests/test_msd.py @@ -1,15 +1,28 @@ +from __future__ import annotations + import os import unittest import MDAnalysis +import numpy as np +from numpy.testing import assert_allclose try: import tidynamics as td except ImportError: td = None -from mdgo.msd import * +import pytest +from mdgo.msd import ( + create_position_arrays, + mda_msd_wrapper, + msd_fft, + msd_straight_forward, + onsager_ii_self, + parse_msd_type, + total_msd, +) test_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_files") @@ -29,105 +42,123 @@ def setUpClass(cls) -> None: cls.dims = ["x", "y", "z"] def test_msd_straight_forward(self): - assert np.allclose(self.fft, msd_straight_forward(self.arr1)) + assert_allclose( + self.fft, + msd_straight_forward(self.arr1), + atol=1e-12, + ) def test_msd_fft(self): - assert np.allclose(self.fft, msd_fft(self.arr1)) - assert np.allclose(msd_straight_forward(self.arr2), msd_fft(self.arr2)) + assert_allclose(self.fft, msd_fft(self.arr1), atol=1e-12) + assert_allclose(msd_straight_forward(self.arr2), msd_fft(self.arr2), atol=1e-12) def test_create_position_arrays(self): - assert np.allclose( + assert_allclose( np.array([21.53381769, 14.97501839, -3.87998785]), - create_position_arrays(self.gen2, 0, 100, select="type 3")[50][2] + create_position_arrays(self.gen2, 0, 100, select="type 3")[50][2], ) - assert np.allclose( + assert_allclose( np.array([-2.78550047, -11.85487624, -17.1221954]), - create_position_arrays(self.gen2, 0, 100, select="type 3")[99][10] + create_position_arrays(self.gen2, 0, 100, select="type 3")[99][10], ) - assert np.allclose( - np.array([41.1079216 , 34.95127106, 18.00482368]), - create_position_arrays(self.gen2, 0, 100, select="type 3", center_of_mass=False)[50][2] + assert_allclose( + np.array([41.1079216, 34.95127106, 18.00482368]), + create_position_arrays(self.gen2, 0, 100, select="type 3", center_of_mass=False)[50][2], ) - assert np.allclose( - np.array([16.98478317, 8.27190208, 5.07116079]), - create_position_arrays(self.gen2, 0, 100, select="type 3", center_of_mass=False)[99][10] + assert_allclose( + np.array([16.98478317, 8.27190208, 5.07116079]), + create_position_arrays(self.gen2, 0, 100, select="type 3", center_of_mass=False)[99][10], ) def test_parse_msd_type(self): xyz = parse_msd_type("xyz") - self.assertEqual(["x", "y", "z"], self.dims[xyz[0]:xyz[1]:xyz[2]]) + assert ["x", "y", "z"] == self.dims[xyz[0] : xyz[1] : xyz[2]] xy = parse_msd_type("xy") - self.assertEqual(["x", "y"], self.dims[xy[0]:xy[1]:xy[2]]) + assert ["x", "y"] == self.dims[xy[0] : xy[1] : xy[2]] yz = parse_msd_type("yz") - self.assertEqual(["y", "z"], self.dims[yz[0]:yz[1]:yz[2]]) + assert ["y", "z"] == self.dims[yz[0] : yz[1] : yz[2]] xz = parse_msd_type("xz") - self.assertEqual(["x", "z"], self.dims[xz[0]:xz[1]:xz[2]]) + assert ["x", "z"] == self.dims[xz[0] : xz[1] : xz[2]] x = parse_msd_type("x") - self.assertEqual(["x"], self.dims[x[0]:x[1]:x[2]]) + assert ["x"] == self.dims[x[0] : x[1] : x[2]] y = parse_msd_type("y") - self.assertEqual(["y"], self.dims[y[0]:y[1]:y[2]]) + assert ["y"] == self.dims[y[0] : y[1] : y[2]] z = parse_msd_type("z") - self.assertEqual(["z"], self.dims[z[0]:z[1]:z[2]]) + assert ["z"] == self.dims[z[0] : z[1] : z[2]] def test_onsager_ii_self(self): onsager_ii_self_fft = onsager_ii_self(self.gen2, 0, 100, select="type 3") onsager_ii_self_nocom = onsager_ii_self(self.gen2, 0, 100, select="type 3", center_of_mass=False) onsager_ii_self_nofft = onsager_ii_self(self.gen2, 0, 100, select="type 3", fft=False) - self.assertAlmostEqual(32.14254152556588, onsager_ii_self_fft[50]) - self.assertAlmostEqual(63.62190983, onsager_ii_self_fft[98]) - self.assertAlmostEqual(67.29990019, onsager_ii_self_fft[99]) - self.assertAlmostEqual(32.14254152556588, onsager_ii_self_nofft[50]) - self.assertAlmostEqual(63.62190983, onsager_ii_self_nofft[98]) - self.assertAlmostEqual(67.29990019, onsager_ii_self_nofft[99]) - self.assertAlmostEqual(32.338364098424634, onsager_ii_self_nocom[50]) - self.assertAlmostEqual(63.52915984813752, onsager_ii_self_nocom[98]) - self.assertAlmostEqual(67.29599346166411, onsager_ii_self_nocom[99]) + assert_allclose(32.14254152556588, onsager_ii_self_fft[50]) + assert_allclose(63.62190983, onsager_ii_self_fft[98]) + assert_allclose(67.29990019, onsager_ii_self_fft[99]) + assert_allclose(32.14254152556588, onsager_ii_self_nofft[50]) + assert_allclose(63.62190983, onsager_ii_self_nofft[98]) + assert_allclose(67.29990019, onsager_ii_self_nofft[99]) + assert_allclose(32.338364098424634, onsager_ii_self_nocom[50]) + assert_allclose(63.52915984813752, onsager_ii_self_nocom[98]) + assert_allclose(67.29599346166411, onsager_ii_self_nocom[99]) def test_mda_msd_wrapper(self): mda_msd_cation = mda_msd_wrapper(self.gen2, 0, 100, select="type 3", fft=False) mda_msd_anion = mda_msd_wrapper(self.gen2, 0, 100, select="type 1", fft=False) - self.assertAlmostEqual(32.338364098424634, mda_msd_cation[50]) - self.assertAlmostEqual(63.52915984813752, mda_msd_cation[98]) - self.assertAlmostEqual(67.29599346166411, mda_msd_cation[99]) - self.assertAlmostEqual(42.69200176568008, mda_msd_anion[50]) - self.assertAlmostEqual(86.9209518, mda_msd_anion[98]) - self.assertAlmostEqual(89.84668178, mda_msd_anion[99]) - assert np.allclose( + assert_allclose(32.338364098424634, mda_msd_cation[50]) + assert_allclose(63.52915984813752, mda_msd_cation[98]) + assert_allclose(67.29599346166411, mda_msd_cation[99]) + assert_allclose(42.69200176568008, mda_msd_anion[50]) + assert_allclose(86.9209518, mda_msd_anion[98]) + assert_allclose(89.84668178, mda_msd_anion[99]) + assert_allclose( onsager_ii_self(self.gen2, 0, 10, select="type 3", msd_type="x", center_of_mass=False), mda_msd_wrapper(self.gen2, 0, 10, select="type 3", msd_type="x", fft=False), + atol=1e-12, ) - assert np.allclose( + assert_allclose( onsager_ii_self(self.gen2, 0, 10, select="type 3", msd_type="y", center_of_mass=False), mda_msd_wrapper(self.gen2, 0, 10, select="type 3", msd_type="y", fft=False), + atol=1e-12, ) - assert np.allclose( + assert_allclose( onsager_ii_self(self.gen2, 0, 10, select="type 3", msd_type="z", center_of_mass=False), mda_msd_wrapper(self.gen2, 0, 10, select="type 3", msd_type="z", fft=False), + atol=1e-12, ) - assert np.allclose( + assert_allclose( onsager_ii_self(self.gen2, 0, 100, select="type 3", msd_type="xy", center_of_mass=False), mda_msd_wrapper(self.gen2, 0, 100, select="type 3", msd_type="xy", fft=False), + atol=1e-12, ) - assert np.allclose( + assert_allclose( onsager_ii_self(self.gen2, 0, 100, select="type 3", msd_type="yz", center_of_mass=False), mda_msd_wrapper(self.gen2, 0, 100, select="type 3", msd_type="yz", fft=False), + atol=1e-12, ) - assert np.allclose( + assert_allclose( onsager_ii_self(self.gen2, 0, 100, select="type 3", msd_type="xz", center_of_mass=False), mda_msd_wrapper(self.gen2, 0, 100, select="type 3", msd_type="xz", fft=False), + atol=1e-12, ) if td is not None: - assert np.allclose(mda_msd_cation, mda_msd_wrapper(self.gen2, 0, 100, select="type 3")) + assert_allclose( + mda_msd_cation, + mda_msd_wrapper(self.gen2, 0, 100, select="type 3"), + atol=1e-12, + ) def test_total_msd(self): total_builtin_cation = total_msd(self.gen2, 0, 100, select="type 3", fft=True, built_in=True) total_mda_cation = total_msd( self.gen2, 0, 100, select="type 3", fft=False, built_in=False, center_of_mass=False ) - self.assertAlmostEqual(32.14254152556588, total_builtin_cation[50]) - self.assertAlmostEqual(32.338364098424634, total_mda_cation[50]) - with self.assertRaises(ValueError): + assert_allclose(total_builtin_cation[50], 32.14254152556588) + assert_allclose(total_mda_cation[50], 32.338364098424634) + with pytest.raises( + ValueError, + match="Warning! MDAnalysis does not support subtracting center " + "of mass. Calculating without subtracting...", + ): total_msd(self.gen2, 0, 100, select="type 3", fft=True, built_in=False, center_of_mass=True) diff --git a/tests/test_residence_time.py b/tests/test_residence_time.py index d6cfbef8..cf1cc95f 100644 --- a/tests/test_residence_time.py +++ b/tests/test_residence_time.py @@ -1,10 +1,11 @@ -import os +from __future__ import annotations + import unittest class MyTestCase(unittest.TestCase): def test_something(self): - self.assertEqual(True, True) + assert True is True if __name__ == "__main__": diff --git a/tests/test_util.py b/tests/test_util.py index d6cfbef8..cf1cc95f 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,10 +1,11 @@ -import os +from __future__ import annotations + import unittest class MyTestCase(unittest.TestCase): def test_something(self): - self.assertEqual(True, True) + assert True is True if __name__ == "__main__": diff --git a/tests/test_volume.py b/tests/test_volume.py index b1bdc356..9a2f893e 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -1,6 +1,11 @@ +from __future__ import annotations + import os import unittest + +from numpy.testing import assert_allclose from pymatgen.core import Molecule + from mdgo.util.volume import molecular_volume test_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_files") @@ -23,23 +28,23 @@ def test_molecular_volume(self) -> None: lipf6_volume_3 = molecular_volume(self.lipf6, radii_type="Lange") lipf6_volume_4 = molecular_volume(self.lipf6, radii_type="pymatgen") lipf6_volume_5 = molecular_volume(self.lipf6, molar_volume=False) - self.assertAlmostEqual(lipf6_volume_1, 47.62, places=2) - self.assertAlmostEqual(lipf6_volume_2, 43.36, places=2) - self.assertAlmostEqual(lipf6_volume_3, 41.49, places=2) - self.assertAlmostEqual(lipf6_volume_4, 51.94, places=2) - self.assertAlmostEqual(lipf6_volume_5, 79.08, places=2) + assert_allclose(lipf6_volume_1, 47.62, atol=0.01) + assert_allclose(lipf6_volume_2, 43.36, atol=0.01) + assert_allclose(lipf6_volume_3, 41.49, atol=0.01) + assert_allclose(lipf6_volume_4, 51.94, atol=0.01) + assert_allclose(lipf6_volume_5, 79.08, atol=0.01) ec_volume_1 = molecular_volume(self.ec) ec_volume_2 = molecular_volume(self.ec, exclude_h=False) ec_volume_3 = molecular_volume(self.ec, res=1.0) ec_volume_4 = molecular_volume(self.ec, radii_type="Lange") ec_volume_5 = molecular_volume(self.ec, radii_type="pymatgen") ec_volume_6 = molecular_volume(self.ec, molar_volume=False) - self.assertAlmostEqual(ec_volume_1, 38.44, places=2) - self.assertAlmostEqual(ec_volume_2, 43.17, places=2) - self.assertAlmostEqual(ec_volume_3, 40.95, places=2) - self.assertAlmostEqual(ec_volume_4, 41.07, places=2) - self.assertAlmostEqual(ec_volume_5, 38.44, places=2) - self.assertAlmostEqual(ec_volume_6, 63.83, places=2) + assert_allclose(ec_volume_1, 38.44, atol=0.01) + assert_allclose(ec_volume_2, 43.17, atol=0.01) + assert_allclose(ec_volume_3, 40.95, atol=0.01) + assert_allclose(ec_volume_4, 41.07, atol=0.01) + assert_allclose(ec_volume_5, 38.44, atol=0.01) + assert_allclose(ec_volume_6, 63.83, atol=0.01) litfsi_volume_1 = molecular_volume(self.litfsi) litfsi_volume_2 = molecular_volume(self.litfsi, exclude_h=False) litfsi_volume_3 = molecular_volume(self.litfsi, res=1.0) @@ -47,13 +52,13 @@ def test_molecular_volume(self) -> None: litfsi_volume_5 = molecular_volume(self.litfsi, radii_type="pymatgen") litfsi_volume_6 = molecular_volume(self.litfsi, molar_volume=False) litfsi_volume_7 = molecular_volume(self.litfsi, mode="act", x_size=8, y_size=8, z_size=8) - self.assertAlmostEqual(litfsi_volume_1, 100.16, places=2) - self.assertAlmostEqual(litfsi_volume_2, 100.16, places=2) - self.assertAlmostEqual(litfsi_volume_3, 99.37, places=2) - self.assertAlmostEqual(litfsi_volume_4, 90.78, places=2) - self.assertAlmostEqual(litfsi_volume_5, 105.31, places=2) - self.assertAlmostEqual(litfsi_volume_6, 166.32, places=2) - self.assertAlmostEqual(litfsi_volume_7, 124.66, places=2) + assert_allclose(litfsi_volume_1, 100.16, atol=0.01) + assert_allclose(litfsi_volume_2, 100.16, atol=0.01) + assert_allclose(litfsi_volume_3, 99.37, atol=0.01) + assert_allclose(litfsi_volume_4, 90.78, atol=0.01) + assert_allclose(litfsi_volume_5, 105.31, atol=0.01) + assert_allclose(litfsi_volume_6, 166.32, atol=0.01) + assert_allclose(litfsi_volume_7, 124.66, atol=0.01) if __name__ == "__main__":