Skip to content

Commit

Permalink
typing to annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
htz1992213 committed Feb 4, 2024
1 parent 5540f89 commit 043fe22
Show file tree
Hide file tree
Showing 13 changed files with 124 additions and 115 deletions.
8 changes: 4 additions & 4 deletions mdgo/conductivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def calc_cond_msd(
anions: AtomGroup,
cations: AtomGroup,
run_start: int,
cation_charge: Union[int, float] = 1,
anion_charge: Union[int, float] = -1,
cation_charge: int | float = 1,
anion_charge: int | float = -1,
) -> np.ndarray:
"""Calculates the conductivity "mean square displacement" over time
Expand Down Expand Up @@ -128,11 +128,11 @@ def choose_msd_fitting_region(
def conductivity_calculator(
time_array: np.ndarray,
cond_array: np.ndarray,
v: Union[int, float],
v: int | float,
name: str,
start: int,
end: int,
T: Union[int, float],
T: int | float,
units: str = "real",
) -> float:
"""Calculates the overall conductivity of the system
Expand Down
105 changes: 53 additions & 52 deletions mdgo/coordination.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

from __future__ import annotations
from collections.abc import Callable

import numpy as np
from tqdm.auto import tqdm
Expand All @@ -30,9 +31,9 @@ def neighbor_distance(
run_start: int,
run_end: int,
species: str,
select_dict: Dict[str, str],
select_dict: dict[str, str],
distance: float,
) -> Dict[str, np.ndarray]:
) -> dict[str, np.ndarray]:
"""
Calculates a dictionary of distances between the ``center_atom`` and neighbor atoms.
Expand Down Expand Up @@ -74,12 +75,12 @@ def neighbor_distance(


def find_nearest(
trj: Dict[str, np.ndarray],
trj: dict[str, np.ndarray],
time_step: float,
binding_cutoff: float,
hopping_cutoff: float,
smooth: int = 51,
) -> Tuple[List[int], Union[float, np.floating], List[int]]:
) -> tuple[list[int], float | np.floating, list[int]]:
"""Using the dictionary of neighbor distance ``trj``, finds the nearest neighbor ``sites`` that the center atom
binds to, and calculates the ``frequency`` of hopping between each neighbor, and ``steps`` when each binding site
exhibits the closest distance to the center atom.
Expand All @@ -101,7 +102,7 @@ def find_nearest(
for kw in list(trj):
trj[kw] = savgol_filter(trj.get(kw), smooth, 2)
site_distance = [100 for _ in range(time_span)]
sites: List[Union[int, np.integer]] = [0 for _ in range(time_span)]
sites: list[int | np.integer] = [0 for _ in range(time_span)]
start_site = min(trj, key=lambda k: trj[k][0])
kw_start = trj.get(start_site)
assert kw_start is not None
Expand Down Expand Up @@ -133,7 +134,7 @@ def find_nearest(
sites = [int(i) for i in sites]
sites_and_distance_array = np.array([[sites[i], site_distance[i]] for i in range(len(sites))])
steps = []
closest_step: Optional[int] = 0
closest_step: int | None = 0
previous_site = sites_and_distance_array[0][0]
if previous_site == 0:
closest_step = None
Expand Down Expand Up @@ -162,12 +163,12 @@ def find_nearest(


def find_nearest_free_only(
trj: Dict[str, np.ndarray],
trj: dict[str, np.ndarray],
time_step: float,
binding_cutoff: float,
hopping_cutoff: float,
smooth: int = 51,
) -> Tuple[List[int], Union[float, np.floating], List[int]]:
) -> tuple[list[int], float | np.floating, list[int]]:
"""Using the dictionary of neighbor distance ``trj``, finds the nearest neighbor ``sites`` that the ``center_atom``
binds to, and calculates the ``frequency`` of hopping between each neighbor, and ``steps`` when each binding site
exhibits the closest distance to the center atom.
Expand All @@ -190,7 +191,7 @@ def find_nearest_free_only(
for kw in list(trj):
trj[kw] = savgol_filter(trj.get(kw), smooth, 2)
site_distance = [100 for _ in range(time_span)]
sites: List[Union[int, np.integer]] = [0 for _ in range(time_span)]
sites: list[int | np.integer] = [0 for _ in range(time_span)]
start_site = min(trj, key=lambda k: trj[k][0])
kw_start = trj.get(start_site)
assert kw_start is not None
Expand Down Expand Up @@ -222,7 +223,7 @@ def find_nearest_free_only(
sites = [int(i) for i in sites]
sites_and_distance_array = np.array([[sites[i], site_distance[i]] for i in range(len(sites))])
steps = []
closest_step: Optional[int] = 0
closest_step: int | None = 0
previous_site = sites_and_distance_array[0][0]
previous_zero = False
if previous_site == 0:
Expand Down Expand Up @@ -257,8 +258,8 @@ def find_nearest_free_only(


def find_in_n_out(
trj: Dict[str, np.ndarray], binding_cutoff: float, hopping_cutoff: float, smooth: int = 51, cool: int = 20
) -> Tuple[List[int], List[int]]:
trj: dict[str, np.ndarray], binding_cutoff: float, hopping_cutoff: float, smooth: int = 51, cool: int = 20
) -> tuple[list[int], list[int]]:
"""Finds the frames when the center atom binds with the neighbor (binding) or hopping out (hopping)
according to the dictionary of neighbor distance.
Expand Down Expand Up @@ -309,8 +310,8 @@ def find_in_n_out(
sites = [int(i) for i in sites]

last = sites[0]
steps_in: List[int] = []
steps_out: List[int] = []
steps_in: list[int] = []
steps_out: list[int] = []
in_cool = cool
out_cool = cool
for i, s in enumerate(sites):
Expand Down Expand Up @@ -339,13 +340,13 @@ def find_in_n_out(
def check_contiguous_steps(
nvt_run: Universe,
center_atom: Atom,
distance_dict: Dict[str, float],
select_dict: Dict[str, str],
distance_dict: dict[str, float],
select_dict: dict[str, str],
run_start: int,
run_end: int,
checkpoints: np.ndarray,
lag: int = 20,
) -> Dict[str, np.ndarray]:
) -> dict[str, np.ndarray]:
"""Calculates the distance between the center atom and the neighbor atom
in the checkpoint +/- lag time range.
Expand All @@ -364,7 +365,7 @@ def check_contiguous_steps(
An array of distance between the center atom and the neighbor atoms
in the checkpoint +/- lag time range.
"""
coord_num: Dict[str, Union[List[List[int]], np.ndarray]] = {
coord_num: dict[str, list[list[int]] | np.ndarray] = {
x: [[] for _ in range(lag * 2 + 1)] for x in distance_dict
}
trj_analysis = nvt_run.trajectory[run_start:run_end:]
Expand Down Expand Up @@ -393,8 +394,8 @@ def check_contiguous_steps(
def heat_map(
nvt_run: Universe,
floating_atom: Atom,
cluster_center_sites: List[int],
cluster_terminal: Union[str, List[str]],
cluster_center_sites: list[int],
cluster_terminal: str | list[str],
cartesian_by_ref: np.ndarray,
run_start: int,
run_end: int,
Expand Down Expand Up @@ -441,7 +442,7 @@ def heat_map(
for species in cluster_terminal
]
bind_atoms_xyz = [nvt_run.select_atoms(sel, periodic=True) for sel in selections]
vertex_atoms: List[Atom] = []
vertex_atoms: list[Atom] = []
for atoms in bind_atoms_xyz:
if len(atoms) == 1:
vertex_atoms.append(atoms[0])
Expand Down Expand Up @@ -501,10 +502,10 @@ def heat_map(

def process_evol(
nvt_run: Universe,
select_dict: Dict[str, str],
in_list: Dict[str, List[np.ndarray]],
out_list: Dict[str, List[np.ndarray]],
distance_dict: Dict[str, float],
select_dict: dict[str, str],
in_list: dict[str, list[np.ndarray]],
out_list: dict[str, list[np.ndarray]],
distance_dict: dict[str, float],
run_start: int,
run_end: int,
lag: int,
Expand Down Expand Up @@ -572,10 +573,10 @@ def process_evol(

def get_full_coords(
coords: np.ndarray,
reflection: Optional[List[np.ndarray]] = None,
rotation: Optional[List[np.ndarray]] = None,
inversion: Optional[List[np.ndarray]] = None,
sample: Optional[int] = None,
reflection: list[np.ndarray] | None = None,
rotation: list[np.ndarray] | None = None,
inversion: list[np.ndarray] | None = None,
sample: int | None = None,
dim: str = "xyz",
) -> np.ndarray:
"""
Expand Down Expand Up @@ -640,12 +641,12 @@ def get_full_coords(

def cluster_coordinates( # TODO: rewrite the method
nvt_run: Universe,
select_dict: Dict[str, str],
select_dict: dict[str, str],
run_start: int,
run_end: int,
species: List[str],
species: list[str],
distance: float,
basis_vectors: Optional[Union[List[np.ndarray], np.ndarray]] = None,
basis_vectors: list[np.ndarray] | np.ndarray | None = None,
cluster_center: str = "center",
) -> np.ndarray:
"""Calculates the average position of a cluster.
Expand Down Expand Up @@ -708,15 +709,15 @@ def cluster_coordinates( # TODO: rewrite the method
def num_of_neighbor(
nvt_run: Universe,
center_atom: Atom,
distance_dict: Dict[str, float],
select_dict: Dict[str, str],
distance_dict: dict[str, float],
select_dict: dict[str, str],
run_start,
run_end,
write=False,
structure_code=None,
write_freq=0,
write_path=None,
) -> Dict[str, np.ndarray]:
) -> dict[str, np.ndarray]:
"""Calculates the coordination number of each specified neighbor species and the total coordination number
in the specified frame range.
Expand Down Expand Up @@ -778,11 +779,11 @@ def num_of_neighbor(
def num_of_neighbor_simple(
nvt_run: Universe,
center_atom: Atom,
distance_dict: Dict[str, float],
select_dict: Dict[str, str],
distance_dict: dict[str, float],
select_dict: dict[str, str],
run_start: int,
run_end: int,
) -> Dict[str, np.ndarray]:
) -> dict[str, np.ndarray]:
"""Calculates solvation structure type (1 for SSIP, 2 for CIP and 3 for AGG) with respect to the ``enter_atom``
in the specified frame range.
Expand Down Expand Up @@ -830,12 +831,12 @@ def num_of_neighbor_simple(
def angular_dist_of_neighbor(
nvt_run: Universe,
center_atom: Atom,
distance_dict: Dict[str, float],
select_dict: Dict[str, str],
distance_dict: dict[str, float],
select_dict: dict[str, str],
run_start: int,
run_end: int,
cip: bool = True,
) -> Dict[str, np.ndarray]:
) -> dict[str, np.ndarray]:
"""
Calculates the angle of a-c-b of center atom c in the specified frames.
Expand Down Expand Up @@ -893,12 +894,12 @@ def angular_dist_of_neighbor(
def num_of_neighbor_specific(
nvt_run: Universe,
center_atom: Atom,
distance_dict: Dict[str, float],
select_dict: Dict[str, str],
distance_dict: dict[str, float],
select_dict: dict[str, str],
run_start: int,
run_end: int,
counter_atom: str = "anion",
) -> Dict[str, np.ndarray]:
) -> dict[str, np.ndarray]:
"""
Calculates the coordination number of each specific solvation structure type (SSIP, CIP, AGG).
Expand Down Expand Up @@ -962,7 +963,7 @@ def full_solvation_structure( # TODO: rewrite the method
center_atom: Atom,
center_species: str,
counter_species: str,
select_dict: Dict[str, str],
select_dict: dict[str, str],
distance: float,
run_start: int,
run_end: int,
Expand Down Expand Up @@ -1022,8 +1023,8 @@ def counter_shell(this_shell, this_layer, frame):
trj_analysis = nvt_run.trajectory[run_start:run_end:]
cn_values = np.zeros((int(len(trj_analysis)), depth))
for ts in trj_analysis:
center_ion_list: List[np.int_] = [center_atom.id]
counter_ion_list: List[np.int_] = []
center_ion_list: list[np.int_] = [center_atom.id]
counter_ion_list: list[np.int_] = []
first_shell = nvt_run.select_atoms(
select_counter_ion(counter_selection, distance, center_atom),
periodic=True,
Expand All @@ -1036,12 +1037,12 @@ def concat_coord_array(
nvt_run: Universe,
func: Callable,
center_atoms: AtomGroup,
distance_dict: Dict[str, float],
select_dict: Dict[str, str],
distance_dict: dict[str, float],
select_dict: dict[str, str],
run_start: int,
run_end: int,
**kwargs: Union[bool, str],
) -> Dict[str, np.ndarray]:
**kwargs: bool | str,
) -> dict[str, np.ndarray]:
"""
A helper function to analyze the coordination number/structure of every atoms in an ``AtomGroup`` using the
specified function.
Expand Down Expand Up @@ -1104,7 +1105,7 @@ def write_out(center_pos: np.ndarray, center_name: str, neighbors: AtomGroup, pa


def select_shell(
select: Union[Dict[str, str], str], distance: Union[Dict[str, float], str], center_atom: Atom, kw: str
select: dict[str, str] | str, distance: dict[str, float] | str, center_atom: Atom, kw: str
) -> str:
"""
Select a group of atoms that is within a distance of an ``center_atom``.
Expand Down
Loading

0 comments on commit 043fe22

Please sign in to comment.