From 306aa9f7af7abf731a79e72f70f8e31c13a44dc9 Mon Sep 17 00:00:00 2001 From: s2123329 Date: Tue, 23 Jan 2024 18:59:17 +0000 Subject: [PATCH] Update analyses, uc arrangement and D11 ucs --- package/ClayCode/analysis/__init__.py | 299 ++ package/ClayCode/analysis/analysisbase.py | 205 +- package/ClayCode/analysis/config/__init__.py | 0 .../ClayCode/analysis/config/defaults.yaml | 9 + package/ClayCode/analysis/consts.py | 6 + package/ClayCode/analysis/coordination.py | 672 +++ package/ClayCode/analysis/data/__init__.py | 0 package/ClayCode/analysis/dataclasses.py | 4055 +++++++++++++++++ package/ClayCode/analysis/lib.py | 1845 ++++++++ package/ClayCode/analysis/utils.py | 581 +++ package/ClayCode/analysis/zdist.py | 490 ++ package/ClayCode/builder/assembly.py | 313 +- package/ClayCode/builder/claycomp.py | 193 +- package/ClayCode/builder/utils.py | 6 +- package/ClayCode/data/data/UCS/D11/D1001.gro | 37 + package/ClayCode/data/data/UCS/D11/D1001.itp | 54 + package/ClayCode/data/data/UCS/D11/D1002.gro | 37 + package/ClayCode/data/data/UCS/D11/D1002.itp | 54 + package/ClayCode/data/data/UCS/D11/D1003.gro | 37 + package/ClayCode/data/data/UCS/D11/D1003.itp | 54 + package/ClayCode/data/data/UCS/charge_occ.csv | 3 +- pyproject.toml | 1 + 22 files changed, 8739 insertions(+), 212 deletions(-) create mode 100644 package/ClayCode/analysis/__init__.py create mode 100644 package/ClayCode/analysis/config/__init__.py create mode 100644 package/ClayCode/analysis/config/defaults.yaml create mode 100644 package/ClayCode/analysis/consts.py create mode 100644 package/ClayCode/analysis/coordination.py create mode 100644 package/ClayCode/analysis/data/__init__.py create mode 100644 package/ClayCode/analysis/dataclasses.py create mode 100644 package/ClayCode/analysis/lib.py create mode 100755 package/ClayCode/analysis/utils.py create mode 100755 package/ClayCode/analysis/zdist.py create mode 100644 package/ClayCode/data/data/UCS/D11/D1001.gro create mode 100644 package/ClayCode/data/data/UCS/D11/D1001.itp create mode 100644 package/ClayCode/data/data/UCS/D11/D1002.gro create mode 100644 package/ClayCode/data/data/UCS/D11/D1002.itp create mode 100644 package/ClayCode/data/data/UCS/D11/D1003.gro create mode 100644 package/ClayCode/data/data/UCS/D11/D1003.itp diff --git a/package/ClayCode/analysis/__init__.py b/package/ClayCode/analysis/__init__.py new file mode 100644 index 00000000..b21dc965 --- /dev/null +++ b/package/ClayCode/analysis/__init__.py @@ -0,0 +1,299 @@ +# from functools import wraps +# +import logging +from datetime import datetime, timezone +from pathlib import Path + +__all__ = [ + "plots", + "multidist", + "veldist", + "analysisbase", + "lib", + "zdist", + "ph", + "setup", + "utils", + "exec_time", + "exec_date", + "AA", + "FF", + "MDP", + "CLAYS", + "IONS", + "SOL", + "SOL_DENSITY", + "setup", + "utils", +] + +import MDAnalysis + +tpr_logger = logging.getLogger("MDAnalysis.topology.TPRparser").setLevel( + level=logging.WARNING +) + +PATH = Path(__file__) +DATA = (PATH.parent / "../data").resolve() +AA = (DATA / "AA").resolve() +FF = (DATA / "FF").resolve() +MDP = (DATA / "MDP").resolve() +CLAYS = (DATA / "CLAYS").resolve() +UCS = (DATA / "UCS").resolve() + +IONS = ["Cl", "Na", "Ca", "K", "Mg", "Cs"] +SOL_DENSITY = 1000 # g L-1 +SOL = "SOL" + +shandler = logging.StreamHandler() +shandler.setLevel(logging.INFO) + +logging.basicConfig( + level=logging.INFO, + format="%(name)-7s - %(message)s", # %(levelname)s - + datefmt="%Y/%m/%d", + handlers=[shandler], +) + +logger = logging.getLogger("ClayAnalysis") +logger.debug(f"Using MDAnalysis {MDAnalysis.__version__}") + +exec_time = datetime.now(timezone.utc).strftime("%y%m%d-%H%M") +exec_date = datetime.now(timezone.utc).strftime("%y%m%d") + +FILE_SEARCHSTR_LIST = ["_7", "_06", "_n", "_neutral"] + +# AA = Path('/storage/aa_test/aa.csv') # pl.Path(__file__).parent / "AA/aa.csv" +# IONS = ['Cl', 'Na', 'Ca', 'K', 'Mg', 'Cs'] + +# import MDAnalysis +# import numpy as np +# import re +# import warnings +# +# warnings.filterwarnings('ignore', category=DeprecationWarning) +# +# class ClayFFAtomAttributes: +# Masses = { +# 'Fe': 55.850, +# 'Si': 28.090, +# 'O': 16.000, +# 'H': 1.008, +# 'Al': 26.980, +# 'Mg': 24.310, +# 'Na':23, +# 'Cl':35, +# 'K': 1, +# 'Ca': 1, +# 'Mg': 2 +# +# } +# Types = { +# 'feo': ['feo', 'Fe', 3], +# 'fet': ['fet', 'Fe', 2], +# 'fe2': ['fe2', 'Fe', 2], +# 'st': ['st', 'Si', 4], +# 'at': ['at', 'Al', 3], +# 'ao': ['ao', 'Al', 3], +# 'mgo': ['mgo', 'Mg', 2], +# 'obts': ['obts', 'O', -2], +# 'obos': ['obos', 'O', -2], +# 'obt': ['obts', 'O', -2], +# 'obo': ['obos', 'O', -2], +# 'ob': ['ob', 'O', -2], +# 'obs': ['obss', 'O', -2], +# 'ohs': ['ohs', 'O', -2], +# 'ho': ['ho', 'H', 1], +# 'oh': ['oh', 'O', -2], +# 'hw': ['HW', 'H', 1], +# 'ow': ['OW', 'O', -2], +# 'na': ['Na', 'Na', 1], +# 'cl': ['Cl', 'Cl', -1], +# 'ca': ['Ca', 'Ca', 2], +# 'mg': ['Mg', 'Mg', 2], +# 'k': ['K', 'K', 1] +# } +# +# def __init__(self, name): +# self.name = name +# self.__type = None +# +# +# @property +# def type(self): +# if self.__type is None: +# try: +# type_match = re.search(r'|'.join(ClayFFAtomAttributes.Types.keys()), +# self.name.lower()).group(0) +# self.__type = ClayFFAtomAttributes.Types[type_match][0] +# except: +# pass +# return self.__type +# +# +# @property +# def element(self): +# return ClayFFAtomAttributes.Types[self.type.lower()][1] +# +# +# @property +# def mass(self): +# return ClayFFAtomAttributes.Masses[self.element] +# +# @property +# def charge(self): +# return ClayFFAtomAttributes.Types[self.type.lower()][2] +# +# def add_method(cls): +# def decorator(func): +# @wraps(func) +# def wrapper(self, *args, **kwargs): +# return func(self, *args, **kwargs) +# +# setattr(cls, func.__name__, wrapper) +# +# return decorator +# +# @add_method(MDAnalysis.Universe) +# def elements(self): +# elements_dict = dict({k: v} for (k, v, _) in ClayFFAtomAttributes.Types.values()) +# elements = np.full_like(crdin.atoms.names, None) +# +# for el, element in enumerate(self.atoms): +# print(element) +# try: +# elements[el] = elements_dict[element.type.lower()] +# print(elements[el]) +# except: +# pass +# return elements +# + +from typing import Dict + +ITP_KWDS = { + "defaults": ["nbfunc", "comb-rule", "gen-pairs", "fudgeLJ", "fudgeQQ"], + "atomtypes": [ + "at-type", + "at-number", + "mass", + "charge", + "ptype", + "sigma", + "epsilon", + ], + "bondtypes": ["ai", "aj", "b0", "kb"], + "pairtypes": ["ai", "aj", "V", "W"], + "angletypes": ["ai", "aj", "ak", "theta0", "ktheta"], + "dihedraltypes": ["ai", "aj", "ak", "al", "phi0", "phitheta"], + "constrainttypes": ["ai", "aj", "b0"], + "nonbond_params": ["ai", "aj", "V", "W"], + "moleculetype": ["res-name", "n-excl"], + "atoms": [ + "id", + "at-type", + "res-number", + "res-name", + "at-name", + "charge-nr", + "charge", + "mass", + ], + "bonds": ["ai", "aj", "funct", "b0", "kb"], + "pairs": ["ai", "aj", "funct", "theta0", "ktheta"], + "angles": ["ai", "aj", "ak"], + "dihedrals": ["ai", "aj", "ak", "al"], + "system": ["sys-name"], + "molecules": ["res-name", "mol-number"], + "settles": ["at-type", "func", "doh", "dhh"], + "exclusions": ["ai", "aj", "ak"], + "nonbond_params": ["ai", "aj", "V", "W"], +} +DTYPES = { + "at-type": "str", + "at-number": "int32", + "ptype": "str", + "sigma": "float64", + "epsilon": "float64", + "id": "int32", + "res-number": "int32", + "res-name": "str", + "at-name": "str", + "charge-nr": "int32", + "charge": "float64", + "mass": "float64", + "FF": "str", + "itp": "str", + "ai": "int16", + "aj": "int16", + "ak": "int16", + "al": "int16", + "k0": "float64", + "b0": "float64", + "kb": "float64", + "theta0": "float64", + "ktheta": "float64", + "phi0": "float64", + "phitheta": "float64", + "V": "str", + "W": "str", + "nbfunc": "int16", + "func": "int16", + "comb-rule": "int16", + "gen-pairs": "str", + "fudgeLJ": "float32", + "fudgeQQ": "float32", + "n-excl": "int16", + "doh": "float32", + "dhh": "float32", + "funct": "int16", + "sys-name": "str", + "mol-number": "int32", +} + +# GRO_KWDS = {"titel": ["sys-name"], +# "n-atoms": ["n-atoms"], +# "coordinates": +# ["res-number", +# "res-name", "at-name", "at-number", +# "x", "y", "y", "vx", "vy", "vz", +# "box"]} + +GRO_KWDS = {} +MDP_KWDS = {} +TOP_KWDS = ITP_KWDS + + +def set_globals() -> Dict[str, Dict[str, str]]: + """ + Combine '*._KWD' dictionaries and add datatype mapping + :return: Combined keyword dictionary + :rtype: Dict[str, Dict[str, str]] + """ + import re + + combined_dict = {} + global_dict = lambda key: globals()[key] + + # set_global = lambda key, value: globals().__setitem__(key, value) + + del_global = lambda key: globals().__delitem__(key) + # set_global('KWD_DICT', {}) + kwds = sorted( + re.findall(r"[A-Z]+_KWDS", " ".join(globals().keys())), reverse=True + ) + for kwd_dict in kwds: + kwd = kwd_dict.split("_")[0] + # assert len(dicts) % 2 == 0, ValueError(f'Expected even number of KWD and DTYPE dictionaries.') + new_dict = {} + for key, vals in global_dict(kwd_dict).items(): + new_dict[key] = {} + for val in vals: + new_dict[key][val] = global_dict("DTYPES")[val] + combined_dict[f".{kwd.lower()}"] = new_dict + del_global(kwd_dict) + del_global("DTYPES") + return combined_dict + + +KWD_DICT = set_globals() diff --git a/package/ClayCode/analysis/analysisbase.py b/package/ClayCode/analysis/analysisbase.py index c5ca7657..d8aae996 100644 --- a/package/ClayCode/analysis/analysisbase.py +++ b/package/ClayCode/analysis/analysisbase.py @@ -133,12 +133,19 @@ import numpy as np import pandas as pd +import zarr +from ClayCode.analysis.lib import Bins, Cutoff +from ClayCode.core.utils import ( + SubprocessProgressBar, + get_header, + get_subheader, +) from MDAnalysis import coordinates from MDAnalysis.core.groups import AtomGroup from MDAnalysis.lib.log import ProgressBar from numpy.typing import NDArray -logger = logging.getLogger(Path(__file__).stem) +logger = logging.getLogger(__name__) analysis_class = TypeVar("analysis_class") analysis_data = TypeVar("analysis_data") @@ -282,10 +289,16 @@ class AnalysisData(UserDict): _default_bin_step = 0.1 _default_cutoff = 20 - # __slots__ = ['name', 'bins', 'timeseries', 'hist', 'edges', '_n_bins', 'n_bins', 'cutoff', 'bin_step'] + # __slots__ = ['name', 'bins', 'timeseries', 'hist', 'ads_edges', '_n_bins', 'n_bins', 'cutoff', 'bin_step'] def __init__( - self, name: str, cutoff=None, bin_step=None, n_bins=None, min=0.0 + self, + name: str, + cutoff=None, + bin_step=None, + n_bins=None, + min=0.0, + verbose=True, ): self.name = name assert ( @@ -310,10 +323,11 @@ def __init__( elif self.n_bins is None: # print('n_bins None') self.n_bins = np.rint(self.cutoff / self.bin_step) - logger.info(f"{name!r}:") - logger.info(f"cutoff: {self.cutoff}") - logger.info(f"n_bins: {self.n_bins}") - logger.info(f"bin_step: {self.bin_step}") + if verbose: + logger.info(get_subheader(f" Initialising {name!r} analysis")) + logger.finfo(f"{self.cutoff}", kwd_str=f"cutoff: ") + logger.finfo(f"{self.n_bins}", kwd_str=f"n_bins: ") + logger.finfo(f"{self.bin_step}", kwd_str=f"bin_step: ") hist, edges = np.histogram( [-1], bins=self.n_bins, range=(self._min, self.cutoff) ) @@ -326,42 +340,59 @@ def __init__( self.df = pd.DataFrame(index=self.bins) self.df.index.name = "bins" self.hist2d = {} + self._verbose = verbose @property - def n_bins(self): + def n_bins(self) -> int: + """Number of bins in data histogram""" return self._n_bins @n_bins.setter - def n_bins(self, n_bins): + def n_bins(self, n_bins: Union[int, str]): + """Number of bins in data histogram""" if n_bins is not None: self._n_bins = int(n_bins) if self.cutoff / self.n_bins != self.bin_step: self.bin_step = self.cutoff / self.n_bins @property - def cutoff(self): + def cutoff(self) -> float: + """Maximum value for included perpendicular distance from a surface""" return self._cutoff @cutoff.setter - def cutoff(self, cutoff): + def cutoff(self, cutoff: Union[float, int, str]) -> None: + """Maximum value for included perpendicular distance from a surface""" if cutoff is not None: self._cutoff = float(cutoff) # else: # self._cutoff = self._default_cutoff @property - def bin_step(self): + def bin_step(self) -> float: + """Bin size in data histogram""" return self._bin_step @bin_step.setter - def bin_step(self, bin_step): + def bin_step(self, bin_step: Union[float, int, str]): + """Bin size in data histogram""" if bin_step is not None: self._bin_step = float(bin_step) - def get_hist_data(self, use_abs=True, guess_min=True): - data = np.ravel(self.timeseries) + def get_hist_data( + self, use_abs: bool = True + ) -> None: # , guess_min=True): + r"""Create histogram data from timeseries. + :param use_abs: use absolute values of timeseries + :type use_abs: bool + """ + data = zarr.array( + np.ravel(self.timeseries), + chunks=(1000000,), + write_empty_chunks=True, + ) if use_abs == True: - data = np.abs(data) + data[:] = np.abs(data) # if guess_min == True: # ll = np.min(data) # else: @@ -372,36 +403,43 @@ def get_hist_data(self, use_abs=True, guess_min=True): hist = hist / len(self.timeseries) self.hist[:] = hist - def get_rel_data(self, other: analysis_data, use_abs=True, **kwargs): + def get_rel_data( + self, other: analysis_data, use_abs=True, **kwargs + ) -> analysis_data: r"""Create new instance with modified data values. - :param other: - :type other: - :param use_abs: - :type use_abs: - :param kwargs: - :type kwargs: - :return: - :rtype: + :param other: other data to compare + :type other: analysis_data + :param use_abs: use absolute values of timeseries + :type use_abs: bool + :param kwargs: additional arguments + :type kwargs: dict + :return: new instance with modified data values + :rtype: analysis_data """ - data = np.concatenate( - [[np.ravel(self.timeseries)], [np.ravel(other.timeseries)]], axis=0 + data = zarr.array( + np.concatenate( + [[np.ravel(self.timeseries)], [np.ravel(other.timeseries)]], + axis=0, + ), + chunks=(1000000,), + write_empty_chunks=True, ) if use_abs == False: pass elif use_abs == True: - data = np.abs(data) + data[:] = np.abs(data) elif len(use_abs) == 2: use_abs = np.array(use_abs) mask = np.broadcast_to(use_abs[:, np.newaxis], data.shape) data[mask] = np.abs(data[mask]) - data = np.divide(data[0], data[1], where=data != 0)[0] + data[:] = np.divide(data[0], data[1], where=data != 0)[0] if "cutoff" in kwargs.keys(): - cutoff = kwargs["cutoff"] + cutoff = float(kwargs["cutoff"]) else: cutoff = self.cutoff if "bin_step" in kwargs.keys(): - bin_step = kwargs["bin_step"] + bin_step = float(kwargs["bin_step"]) else: bin_step = self.bin_step new_data = self.__class__( @@ -472,7 +510,7 @@ def get_df(self): def __repr__(self): return ( f"AnalysisData({self.name!r}, " - f"edges = ({self._min}, {self.cutoff}), " + f"ads_edges = ({self._min}, {self.cutoff}), " f"bin_step = {self.bin_step}, " f"n_bins = {self.n_bins}, " f"has_data = {self.has_hist})" @@ -615,7 +653,7 @@ def _conclude(self): """ # histogram attributes format: # -------------------------- - # name: [name, bins, timeseries, hist, hist2d, edges, n_bins, cutoff, bin_step] + # name: [name, bins, timeseries, hist, hist2d, ads_edges, n_bins, cutoff, bin_step] _attrs = [] @@ -625,6 +663,7 @@ def __init__(self, trajectory, verbose=False, **kwargs): self.results = Results() self.sel_n_atoms = None self._abs = True + self._get_new_data = True def _init_data(self, **kwargs): data = self.__class__._attrs @@ -632,10 +671,13 @@ def _init_data(self, **kwargs): data = [self.__class__.__name__.lower()] self.data = {} for item in data: - if item in kwargs.keys(): - args = kwargs[item] - else: - args = kwargs + args = {} + for key, val in kwargs.items(): + if isinstance(val, dict): + if item in val.keys(): + args[key] = val[item] + else: + args[key] = val self.data[item] = AnalysisData(item, **args) def _setup_frames( @@ -723,8 +765,14 @@ def _conclude(self, n_atoms: Optional[int] = None, **kwargs): # n_atoms = int(n_atoms) if type(self._abs) == bool: self._abs = [self._abs for a in range(len(self.data))] + logger.finfo( + f"Using absolute " + ", ".join(self.data.keys()) + " values.\n" + ) # print(self._abs, len(self.data)) for vi, v in enumerate(self.data.values()): + logger.finfo( + f"Getting {v.name!r} histogram", initial_linebreak=True + ) v.get_hist_data(use_abs=self._abs[vi]) v.get_norm(self.sel_n_atoms) v.get_df() @@ -738,6 +786,7 @@ def _get_results(self): for key, val in self.__dict__.items(): if not key.startswith("_") and key != "results": self.results[key] = val + # logger.info(f"{val}") logger.info(f"{self.save}") if self.save is False: pass @@ -745,7 +794,7 @@ def _get_results(self): outdir = Path(self.save).parent logger.info(f"Saving results in {str(outdir.absolute())!r}") if not outdir.is_dir(): - os.mkdir(outdir) + os.makedirs(outdir, exist_ok=True) logger.info(f"Created {outdir}") with open(f"{self.save}.p", "wb") as outfile: pkl.dump(self.results, outfile) @@ -778,42 +827,63 @@ def run(self, start=None, stop=None, step=None, frames=None, verbose=True): frame indices in the `frames` keyword argument. """ - logger.info("Choosing frames to analyze") - # if verbose unchanged, use class default - verbose = ( - getattr(self, "_verbose", False) if verbose is None else verbose - ) + if self._get_new_data is False: + logger.finfo( + "Not running new analysis. Output file exists and overwrite not selected." + ) + else: + logger.info(get_subheader("Starting analysis run")) + logger.finfo("Choosing frames to analyze") + # if verbose unchanged, use class default + verbose = ( + getattr(self, "_verbose", False) + if verbose is None + else verbose + ) - self._setup_frames( - self._trajectory, start=start, stop=stop, step=step, frames=frames - ) - logger.info("Starting preparation") - self._prepare() - logger.info( - "Starting analysis loop over %d trajectory frames", self.n_frames - ) - for i, ts in enumerate( - ProgressBar(self._sliced_trajectory, verbose=verbose) - ): - self._frame_index = i - self._ts = ts - self.frames[i] = ts.frame - self.times[i] = ts.time - self._single_frame() - logger.info("Finishing up") - self._conclude(self.sel_n_atoms) - logger.info(f"Getting results") - self._get_results() - self._save() + self._setup_frames( + self._trajectory, + start=start, + stop=stop, + step=step, + frames=frames, + ) + logger.info("Starting preparation") + self._prepare() + logger.finfo( + f"Starting analysis loop over {self.n_frames} trajectory frames" + ) + for i, ts in enumerate( + ProgressBar(self._sliced_trajectory, verbose=verbose) + ): + self._frame_index = i + self._ts = ts + self.frames[i] = ts.frame + self.times[i] = ts.time + self._single_frame() + self._post_process() + logger.info(get_subheader("Finishing up")) + progress_bar = SubprocessProgressBar(label="Getting results") + progress_bar.run_with_progress( + lambda: self._conclude(self.sel_n_atoms) + ) + logger.info(get_header(f"Writing results")) + self._get_results() + self._save() + logger.info("Done!\n") return self + def _post_process(self): + """Post processing of data.""" + pass + def _save(self): pass def __repr__(self): return ( - f"{self.__class__.__name__}(data: {self._attrs}, " - f"loaded frames: {self.n_frames})" + f"{self.__class__.__name__}(data: " + ", ".join(self._attrs) + ")" + f"loaded frames: {self._universe.trajectory.n_frames})" ) @@ -869,6 +939,9 @@ def rotation_matrix(mobile, ref): """ def __init__(self, function, trajectory=None, *args, **kwargs): + logger.info( + get_header(f"Selected {self.__class__.__name__} analysis.") + ) if (trajectory is not None) and ( not isinstance(trajectory, coordinates.base.ProtoReader) ): diff --git a/package/ClayCode/analysis/config/__init__.py b/package/ClayCode/analysis/config/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/package/ClayCode/analysis/config/defaults.yaml b/package/ClayCode/analysis/config/defaults.yaml new file mode 100644 index 00000000..049cf63f --- /dev/null +++ b/package/ClayCode/analysis/config/defaults.yaml @@ -0,0 +1,9 @@ + +# ============================================================================= +# General specifications for clay analysis runs +# ============================================================================= + +# ============================================================================= +# Required Parameters +# ============================================================================= + diff --git a/package/ClayCode/analysis/consts.py b/package/ClayCode/analysis/consts.py new file mode 100644 index 00000000..bd396898 --- /dev/null +++ b/package/ClayCode/analysis/consts.py @@ -0,0 +1,6 @@ +from pathlib import Path + +from importlib_resources import files + +ANALYSIS_DATA = files("ClayCode.analysis.data") +PE_DATA: Path = ANALYSIS_DATA / "peaks_edges" diff --git a/package/ClayCode/analysis/coordination.py b/package/ClayCode/analysis/coordination.py new file mode 100644 index 00000000..07ebbd5d --- /dev/null +++ b/package/ClayCode/analysis/coordination.py @@ -0,0 +1,672 @@ +#!/usr/bin/env python3 +import logging +import pickle as pkl +import sys +import warnings +from argparse import ArgumentParser +from pathlib import Path +from typing import Any, List, Literal, NoReturn, Optional, Union + +import MDAnalysis as mda +import numpy as np +from ClayCode.analysis.analysisbase import ClayAnalysisBase +from ClayCode.analysis.consts import PE_DATA +from ClayCode.analysis.lib import ( + check_traj, + exclude_xyz_cutoff, + exclude_z_cutoff, + get_dist, + get_edge_fname, + get_paths, + get_selections, + process_box, + read_edge_file, + run_analysis, +) +from ClayCode.analysis.zdist import ZDens +from ClayCode.builder.utils import get_checked_input, select_input_option +from ClayCode.core.utils import get_subheader +from MDAnalysis import Universe +from MDAnalysis.core.groups import AtomGroup +from MDAnalysis.lib.distances import ( + apply_PBC, + distance_array, + self_distance_array, +) + +warnings.filterwarnings("ignore", category=DeprecationWarning) + +__all__ = ["CrdDist"] + +logger = logging.getLogger(Path(__file__).stem) + + +class CrdDist(ClayAnalysisBase): + _attrs = ["rdf", "groups"] + _abs = [True, True] + + def __init__( + self, + sysname: str, + sel: AtomGroup, + clay: AtomGroup, + other: Optional[AtomGroup] = None, + n_bins: Optional[int] = None, + bin_step: Optional[Union[int, float]] = None, + cutoff: Optional[Union[float, int]] = None, + edges: Optional[Union[str, Path]] = None, + zdist: Optional[Union[str, Path]] = None, + save: Union[bool, str] = True, + check_traj_len: Union[Literal[False], int] = False, + guess_steps: bool = False, + **basekwargs: Any, + ) -> NoReturn: + super(CrdDist, self).__init__(sel.universe.trajectory, **basekwargs) + self.sysname = sysname + self._ags = [sel] + self._universe = self._ags[0].universe + self.sel = sel + self.sel_n_atoms = sel.n_atoms + self.clay = clay + if other is None: + other = sel + # assert isinstance(sel, AtomGroup) + # assert isinstance(other, AtomGroup) + if other == sel: + self.other = self.sel + self.self = True + else: + self.other = other + self.self = False + self.other_n_atoms = other.n_atoms + if edges is None: + edge_file = get_edge_fname( + atom_type=self.sel.resnames[0], + cutoff=cutoff, + bins=bin_step, + name="pe", + ) + if type(edges) == str: + edge_file = Path(edges) + elif type(edges) == Path: + edge_file = edges + if not edge_file.is_file(): + edge_file = PE_DATA / edge_file.with_suffix(".p").name + if edge_file.is_file(): + logger.info( + f"Using edge file {str(edge_file.name)!r} from database" + ) + if not edge_file.is_file(): + logger.info(f"Edge file {str(edge_file.name)!r} does not exist.\n") + edge_selection = select_input_option( + instance_or_manual_setup=True, + query="Use one of the edge files in database? [y]es/[n]o (default yes)\n", + options=["y", "n", ""], + result_map={"y": True, "n": False, "": True}, + ) + if edge_selection: + options = sorted( + [f for f in PE_DATA.glob(f"{self.sel.resnames[0]}_*.p")] + ) + logger.info("Available edge files:") + for i, f in enumerate(options): + logger.finfo(f"{f.name}", kwd_str=f"{i}: ", indent="\t") + edge_file = get_checked_input( + result_type=int, + result=edge_file, + exit_val="e", + check_value="|".join(list(map(str, range(len(options))))), + query="Select edge file: (exit with e)\n", + ) + if edge_file == "e": + logger.info("Exiting.") + sys.exit(0) + else: + edge_file = options[edge_file] + + assert edge_file.is_file(), f"edge file {edge_file} does not exist" + self._edges_file = edge_file + self._edges = read_edge_file(self._edges_file, cutoff, skip=False) + self._edges = self._edges[self._edges > 0] + while self._edges[-1] > cutoff: + self._edges = self._edges[:-1] + if self._edges[-1] < cutoff: + self._edges = np.append(self._edges, cutoff) + self._attrs.extend( + [f"group_{edge}" for edge in range(len(self._edges))] + ) + self.zdist = zdist + bin_step = dict( + map( + lambda x: (x, bin_step) if x != "groups" else (x, 1), + self._attrs, + ) + ) + cutoff = dict( + map( + lambda x: (x, cutoff) + if x != "groups" + else (x, len(self._edges)), + self._attrs, + ) + ) + verbose = dict( + map( + lambda x: (x, True) if x in ["groups", "rdf"] else (x, False), + self._attrs, + ) + ) + self._init_data( + n_bins=n_bins, + bin_step=bin_step, + cutoff=cutoff, + min=0, + verbose=verbose, + ) + self.save = save + if self.save is False: + pass + else: + if type(self.save) == bool: + self.save = ( + f"{self.__class__.__name__.lower()}_" + f"{self.sysname}_{self.sel.resnames[0]}_{self.other.resnames[0]}" + ) + check_traj(self, check_traj_len) + self._guess_steps = guess_steps + # print(datargs.files, data['edge_file'].shape) + # cutoff = np.ravel(self.edge_file), + # cutoff = np.rint(np.max(np.ravel(self.edge_file))) + + # if r_cutoff is None: + # self.r_cutoff = np.array([*np.max(self.sel.universe.dimensions[:2]), 5.0]) + # elif type(r_cutoff) in [int, float]: + # self.r_cutoff = np.array([float(r_cutoff) for c in range(3)]) + # elif type(r_cutoff) == list and len(r_cutoff) == 3: + # self.r_cutoff = np.array(r_cutoff) + # else: + # raise ValueError('Wrong type or length for cutoff!') + # self.r_cutoff = self.r_cutoff.astype(np.float64) + # print('r cutoff', self.r_cutoff) + # self.save = save + # if self.save is False: + # pass + # else: + # if type(self.save) == bool: + # self.save = ( + # f"{self.__class__.__name__.lower()}_" + # f"{self.sysname}_{self.sel.resnames[0]}" + # ) + # self._other_dist_f = distance_array + # self._provide_args = lambda: self.sel.positions, self.other.positions + + # check_traj(self, check_traj_len) + # self._guess_steps = guess_steps + + def _prepare(self) -> NoReturn: + logger.info( + f"Starting run:\n" + f"Frames start: {self.start}, " + f"stop: {self.stop}, " + f"step: {self.step}\n" + ) + zdist = None + overwrite = False + while zdist is None: + zdist = self.zdist + if type(zdist) == str: + zdist = Path(zdist) + if zdist is None or not Path(zdist).is_file(): + zdist = ZDens( + sysname=sysname, + sel=sel, + clay=clay, + n_bins=self.data["rdf"].n_bins, + bin_step=self.data["rdf"].bin_step, + cutoff=self.data["rdf"].cutoff, + save=False, + write=self.zdist, + overwrite=overwrite, + ) + run_analysis( + zdist, start=args.start, stop=args.stop, step=args.step + ) + zdist = Path(zdist.write) + zdata = np.load(zdist) + self.mask = zdata["mask"] + self.zdata = np.ma.masked_array(zdata["zdist"], mask=self.mask) + start, stop, step = zdata["run_prms"] + if len( + np.arange(start=self.start, stop=self.stop, step=self.step) + ) != len(self.zdata): + logger.info( + "Selected Trajectory slicing does not match zdens data!" + ) + if self._guess_steps == True: + logger.info( + "Using slicing from zdens:\n" + f"start: {start: d}, " + f"stop: {stop:d}, " + f"step: {step:d}.\n" + ) + self._setup_frames(self._trajectory, start, stop, step) + else: + logger.finfo( + "Slicing error!\n" + f"Expected start: {start:d}, " + f"stop: {stop:d}, " + f"step: {step:d}.\n" + f"Found start: {self.start:d}, " + f"stop: {self.stop:d}, " + f"step: {self.step:d}\n" + f"Will overwrite {zdist.name!r} with new zdens data.\n" + ) + self.zdist = zdist = None + overwrite = True + continue + if self.sel_n_atoms != zdata["sel_n_atoms"]: + raise ValueError( + f"Atom number mismatch between z-data ({zdata['sel_n_atoms']}) and selection atoms ({self.sel.n_atoms})!" + ) + self._z_cutoff = np.rint(zdata["cutoff"]) + process_box(self) + + self._dist = np.empty( + (self.sel.n_atoms, self.other.n_atoms, 3), + dtype=np.float64, + ) + + self._dist_m = np.ma.array( + self._dist, + dtype=np.float64, + fill_value=np.nan, + ) + + self._rad = np.empty( + (self.sel.n_atoms, self.other.n_atoms), + dtype=np.float64, + ) + self._rad_m = np.ma.array( + self._rad, + fill_value=np.nan, + dtype=np.float64, + ) + # self._z_dist = np.margs.empty( + # (self.sel.n_atoms, self.other.n_atoms), fill_value=np.nan, dtype=np.float64 + # ) + + # _attrs absolute + # self._abs = [True, True, False] + + self._sel_pos = np.empty((self.sel_n_atoms, 3), dtype=np.float64) + self._sel_pos_m = np.ma.array( + self._sel_pos, fill_value=np.nan, dtype=np.float64 + ) + # self._other_pos = np.margs.empty((self._other_pos, 3), + # dtype=np.float64, + # fill_value=np.nan) + if self.self is False: + self._other_pos = np.empty( + (self.other_n_atoms, 3), dtype=np.float64 + ) + self.diag_mask = False + else: + self.other = None + dist_slice = self._dist_m[..., 0] + diag_idx = np.diag_indices_from(dist_slice) + diag_mask = np.ma.getmaskarray(dist_slice) + diag_mask[diag_idx] = True + diag_mask = np.broadcast_to( + diag_mask[..., np.newaxis], self._dist.shape + ) + self.diag_mask = np.bitwise_or.accumulate(diag_mask, axis=2).copy() + + def _single_frame(self) -> NoReturn: + self._edge_numbers = np.digitize( + self.zdata[self._frame_index], self._edges + ) + self._rad.fill(0) + self._rad_m.mask = False + self._dist.fill(0) + self._dist_m.fill(0) + self._dist_m.mask = False # [..., 2] = self.mask[self._frame_index] + self._sel_pos.fill(0) + self._sel_pos_m.fill(0) + self._sel_pos_m.soften_mask() + self._sel_pos_m.mask = np.broadcast_to( + self.mask[self._frame_index, :, np.newaxis], self._sel_pos_m.shape + ) + self._sel_pos_m.harden_mask() + self._dist_m.mask = np.broadcast_to( + self._sel_pos_m.mask[:, np.newaxis, :], self._dist_m.shape + ) + # self._dist.fill(0) + # self._dist.mask = self.mask[self._frame_index] + + self._sel_pos_m[:] = self.sel.positions + self._sel_pos_m[:] = apply_PBC(self._sel_pos_m, self._ts.dimensions) + + if self.self is False: + self._other_pos.fill(0) + self._other_pos[:] = apply_PBC( + self.other.positions, self._ts.dimensions + ) + else: + self._dist_m.mask += self.diag_mask + self._other_pos = self._sel_pos + # if self.self: + + # if self.self is True: + # sel = self._dist == 0.0 + # sel = np.bitwise_or.accumulate(sel, axis=2) + # self._dist.mask += sel + + get_dist( + self._sel_pos_m, self._other_pos, self._dist_m, self._ts.dimensions + ) + self._process_distances(self._dist_m, self._ts.dimensions) + new_mask = np.any( + np.greater(np.abs(self._dist_m), self.data["rdf"].cutoff), axis=2 + ) + self._dist_m.mask += new_mask[:, :, np.newaxis] + self._rad[:] = np.linalg.norm(self._dist_m, axis=2) + self._rad[:] = np.where( + self._rad <= self.data["rdf"].cutoff, self._rad, 0 + ) + self._rad[:] = np.where(self._rad == 0, np.NaN, self._rad) + # if self.self is True: + # self._rad.mask += self._rad == 0.0 + self.data["groups"].timeseries.append(self._edge_numbers) + self.data["rdf"].timeseries.append(self._rad) + # # categorise by edge + # # self._rad_m.mask = np.broadcast_to( + # self._dist_m.mask += ( + # np.abs(self._dist) >= self.data["zmtd"].cutoff + # ) + # self._dist_m.mask = np.broadcast_to( + # self._dist.mask[..., 2, np.newaxis], self._dist.shape + # ) + # + # exclude_z_cutoff(self._dist, self.data["zmtdist"].cutoff) + # + # self.data["zmtdist"].timeseries.append(self._dist[:, :, 2]) + # self._rad[:] = np.linalg.norm(self._dist, axis=2) + # self._rad.mask = self._rad > self.data["rdf"].cutoff + # self.data["rdf"].timeseries.append(self._rad) + # # print(self._rad.shape, self._dist.shape) + # rdist = np.min(self._rad, axis=1) + # self.data["rmtdist"].timeseries.append(rdist) + # print(np.min(self._rad, axis=1).shape) + + def _post_process(self) -> NoReturn: + logger.finfo( + f"Grouping RDF data by adsorption shell on clay surface.", + initial_linebreak=True, + ) + prev = 0 + for i, edge in enumerate(self._edges): + self.data[f"group_{i}"].timeseries = np.where( + np.array(self.data["groups"].timeseries)[:, :, np.newaxis] + != i, + np.array(self.data["rdf"].timeseries), + np.NaN, + ) + logger.finfo( + f'"group_{i}"', + kwd_str=f"{prev:.1f} <= z < {edge:.2f}: ", + indent="\t", + ) + + def _save(self): + if self.save is False: + pass + else: + for v in self.data.values(): + v.save( + self.save, + sel_names=np.unique(self.sel.names), + n_atoms=self.sel.n_atoms, + n_frames=self.n_frames, + other=np.unique(self.other.names), + n_other_atoms=self.other.n_atoms, + ) + print("Done!") + + +if __name__ == "__main__": + parser = ArgumentParser( + prog="coordiantion", + description="Compute radial distributions between 2 atom types.", + add_help=True, + allow_abbrev=False, + ) + parser.add_argument( + "-name", type=str, help="System name", dest="sysname", required=True + ) + + parser.add_argument( + "-inp", + type=str, + help="Input file names", + nargs=2, + metavar=("coordinates", "trajectory"), + dest="infiles", + required=False, + ) + parser.add_argument( + "-inpname", + type=str, + help="Input file names", + metavar="name_stem", + dest="inpname", + required=False, + ) + parser.add_argument( + "-zdist", + type=str, + help="z-dist data filename", + dest="zdist", + required=False, + ) + parser.add_argument( + "-uc", + type=str, + help="Clay unit cell type", + dest="clay_type", + required=True, + ) + parser.add_argument( + "-sel", + type=str, + nargs="+", + help="Atom type selection", + dest="sel", + required=True, + ) + parser.add_argument( + "-other", + type=str, + nargs="+", + help="Other atomtype for distance selection", + dest="other", + required=False, + default=None, + ) + parser.add_argument( + "-edges", + type=str, + help="Adsorption shell upper limits", + required=False, + dest="edges", + default=None, + ) + parser.add_argument( + "-n_bins", + default=None, + type=int, + help="Number of bins in histogram", + dest="n_bins", + ) + parser.add_argument( + "-bin_step", + type=float, + default=None, + help="bin size in histogram", + dest="bin_step", + ) + + parser.add_argument( + "-check_traj", + type=int, + default=False, + help="Expected trajectory length.", + dest="check_traj_len", + ) + + parser.add_argument( + "-cutoff", + type=float, + default=None, + help="cutoff in x,x2,z-direction", + dest="cutoff", + ) + + # parser.add_argument('-cutoff', + # type=float, + # default=None, + # help='radial cutoff', + # dest='cutoff') + + parser.add_argument( + "-start", + type=int, + default=None, + help="First frame for analysis.", + dest="start", + ) + parser.add_argument( + "-step", + type=int, + default=None, + help="Frame steps for analysis.", + dest="step", + ) + parser.add_argument( + "-stop", + type=int, + default=None, + help="Last frame for analysis.", + dest="stop", + ) + parser.add_argument( + "-out", + type=str, + help="Filename for results pickle.", + dest="save", + default=False, + ) + parser.add_argument( + "-path", + default=False, + help="File with analysis data paths.", + dest="path", + ) + parser.add_argument( + "--in_mem", + default=False, + action="store_true", + help="Read trajectory in memory.", + dest="in_mem", + ) + + +if __name__ == "__main__": + args = parser.parse_args(sys.argv[1:]) + + sysname = args.sysname + + gro, trr, path = get_paths( + infiles=args.infiles, inpname=args.inpname, path=args.path + ) + + logger.finfo(f"{sysname!r}", kwd_str=f"System name: ") + + if args.save is None: + outpath = path + else: + outpath = Path(args.save) + if outpath.is_dir(): + outname = f'{gro}_{args.sel[-1].strip("*")}' + outname = (path / outname).resolve() + else: + outname = Path(args.save).resolve() + logger.finfo(f"{str(outname.resolve())!r}", kwd_str=f"Output path: ") + + # if args.r_cutoff is None: + # r_cutoff = args.cutoff + # if len(args.r_cutoff) == 1: + # r_cutoff = [args.r_cutoff[0] for c in range(3)] + # elif len(args.r_cutoff) == 3: + # r_cutoff = args.r_cutoff + # else: + # raise ValueError('Expected either 1 or 3 arguments for r_cutoff!') + # + # print(r_cutoff) + + coords = gro + traj = trr + logger.debug(f"Using {coords.name} and {traj.name}.") + try: + u = Universe(str(coords), str(traj)) + new = False + if not args.check_traj_len: + logger.finfo( + "Skipping trajectory length check.", initial_linebreak=True + ) + else: + if not u.trajectory.n_frames == args.check_traj_len: + logger.finfo( + f"Wrong frame number, found {u.trajectory.n_frames}, expected {args.check_traj_len}!", + initial_linebreak=True, + ) + new = True + else: + logger.finfo( + f"Trajectory has correct frame number of {args.check_traj_len}.", + initial_linebreak=True, + ) + except: + logger.info("Could not construct universe!", initial_linebreak=True) + new = True + logger.info(get_subheader("Getting atom groups")) + sel, clay, other = get_selections( + infiles=(coords, traj), + sel=args.sel, + clay_type=args.clay_type, + other=args.other, + in_memory=args.in_mem, + ) + + if args.save == "True": + args.save = True + elif args.save == "False": + args.save = False + # if args.write == "True": + # args.write = True + # elif args.write == "False": + # args.write = False + + dist = CrdDist( + sysname=args.sysname, + sel=sel, + clay=clay, + other=other, + n_bins=args.n_bins, + bin_step=args.bin_step, + cutoff=args.cutoff, + edges=args.edges, + zdist=args.zdist, + save=args.save, + check_traj_len=args.check_traj_len, + ) + run_analysis(dist, start=args.start, stop=args.stop, step=args.step) diff --git a/package/ClayCode/analysis/data/__init__.py b/package/ClayCode/analysis/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/package/ClayCode/analysis/dataclasses.py b/package/ClayCode/analysis/dataclasses.py new file mode 100644 index 00000000..836ab84e --- /dev/null +++ b/package/ClayCode/analysis/dataclasses.py @@ -0,0 +1,4055 @@ +import os +import pickle as pkl +from functools import cached_property +from pathlib import Path +from typing import List, Literal, Optional, Tuple, Union + +import numpy as np +import pandas as pd +from ClayCode.analysis.consts import PE_DATA +from ClayCode.analysis.lib import Bins, Cutoff, get_edge_fname, read_edge_file +from ClayCode.analysis.plots import HistData, Timeseries, logger +from ClayCode.analysis.utils import make_1d, redirect_tqdm +from matplotlib import colormaps +from matplotlib import colors as mpc +from matplotlib import pyplot as plt +from scipy.signal import find_peaks +from tqdm import tqdm + + +class Data: + """ + Class for histogram analysis data processing and plotting. + Reads files in `indir` that match the naming pattern + "`namestem`*_`cutoff`_`bins`.dat" + (or "`namestem`*_`cutoff`_`bins`_`analysis`.dat" if `analysis` != `None`. + The data is stored in a :class:`pandas.DataFrame` + :param indir: Data directory + :type indir: Union[str, Path] + :param cutoff: Maximum value in the histogram bins + :type cutoff: Union[int, float] + :param bins: Histogram bin size + :type bins: float + :param ions: List of ion types in solvent + :type ions: List[Literal['Na', 'K', 'Ca', 'Mg']], optional + :param atoms: List of atom types in selection, defaults to `None` + :type atoms: List[Literal['ions', 'OT', 'N', 'CA', 'OW']], optional + :param other: Optional list of atom types in a second selection, + defaults to `None` + :type other: List[Literal['ions', 'OT', 'N', 'CA', 'OW']], optional + :param clays: List of clay types, + defaults to `None` + :type clays: List[Literal['NAu-1', 'NAu-2']], optional + :param aas: List of amino acid types in lower case 3-letter code, + defaults to `None` + :type aas: Optional[List[Literal['ala', 'arg', 'asn', 'asp', + 'ctl', 'cys', 'gln', 'glu', + 'gly', 'his', 'ile', 'leu', + 'lys', 'met', 'phe', 'pro', + 'ser', 'thr', 'trp', 'tyr', + 'val']]] + :param load: Load, defaults to False + :type load: Union[str, Literal[False], Path], optional + :param odir: Output directory, defaults to `None` + :type odir: str, optional + :param nameparts: number of `_`-separated partes in `namestem` + :type nameparts: int, defaults to 1 + :param namestem: leading string in naming pattern, optional + :type namestem: str, defaults to '' + :param analysis: trailing string in naming pattern, optional + defaults to `None` + :type analysis: str, optional + :param df: :class: `pandas.DataFrame` + """ + + aas = [ + "ala", + "arg", + "asn", + "asp", + "ctl", + "cys", + "gln", + "glu", + "gly", + "his", + "ile", + "leu", + "lys", + "met", + "phe", + "pro", + "ser", + "thr", + "trp", + "tyr", + "val", + ] + + ions = ["Na", "K", "Ca", "Mg"] + atoms = ["ions", "OT", "N", "CA"] + clays = ["NAu-1", "NAu-2"] # , 'L31'] + + def __init__( + self, + indir: Union[str, Path], + cutoff: Union[int, float], + bins: float, + ions: List[Literal["Na", "K", "Ca", "Mg"]] = None, + atoms: List[Literal["ions", "OT", "N", "CA", "OW"]] = None, + other: List[Literal["ions", "OT", "N", "CA", "OW"]] = None, + clays: List[Literal["NAu-1", "NAu-2"]] = None, + aas: List[ + Literal[ + "ala", + "arg", + "asn", + "asp", + "ctl", + "cys", + "gln", + "glu", + "gly", + "his", + "ile", + "leu", + "lys", + "met", + "phe", + "pro", + "ser", + "thr", + "trp", + "tyr", + "val", + ] + ] = None, + load: Union[str, Literal[False], Path] = False, + odir: Optional[str] = None, + nameparts: int = 1, + namestem: str = "", + analysis: Optional[str] = None, + ): + """Constructor method""" + logger.info(f"Initialising {self.__class__.__name__}") + self.filelist: list = [] + self.bins: Bins = Bins(bins) + self.cutoff: float = Cutoff(cutoff) + self.analysis: Union[str, None] = analysis + + if type(indir) != Path: + indir = Path(indir) + + self._indir = indir + + if self.analysis is None: + logger.info( + rf"Getting {namestem}*_" + rf"{self.cutoff}_" + rf"{self.bins}.dat from {str(indir.resolve())!r}" + ) + self.filelist: List[Path] = sorted( + list( + indir.glob( + rf"{namestem}*_" rf"{self.cutoff}_" rf"{self.bins}.dat" + ) + ) + ) + else: + logger.info( + rf"Getting {namestem}*_" + rf"{self.cutoff}_" + rf"{self.bins}_" + rf"{analysis}.dat from {str(indir.resolve())!r}" + ) + self.filelist: List[Path] = sorted( + list( + indir.glob( + rf"{namestem}*_" + rf"{self.cutoff}_" + rf"{self.bins}_" + rf"{self.analysis}.dat" + ) + ) + ) + logger.info(f"Found {len(self.filelist)} files.") + + if load != False: + load = Path(load.resolve()) + self.df: pd.DataFrame = pkl.load(load) + logger.info(f"Using data from {load!r}") + else: + if ions == None: + ions = self.__class__.ions + logger.info( + f"ions not specified, using default {self.__class__.ions}" + ) + else: + logger.info(f"Using custom {ions} for ions") + if atoms == None: + atoms = self.__class__.atoms + logger.info( + f"atoms not specified, using default {self.__class__.atoms}" + ) + else: + logger.info(f"Using custom {atoms} for atoms") + if aas == None: + aas = self.__class__.aas + logger.info( + f"aas not specified, using default {self.__class__.aas}" + ) + else: + logger.info(f"Using custom {aas} for aas") + if clays == None: + clays = self.__class__.clays + logger.info( + f"clays not specified, using default {self.__class__.clays}" + ) + else: + logger.info(f"Using custom {clays} for clays") + + f = self.filelist[0] + # print(f) + + x = pd.read_csv(f, delimiter="\s+", comment="#").to_numpy() + x = x[:, 0] + + cols = pd.Index(["NAu-1", "NAu-2"], name="clays") + + if other != None: + if other is True: + other = atoms + other.append("OW") + idx = pd.MultiIndex.from_product( + [ions, aas, atoms, other, x], + names=["ions", "aas", "atoms", "other", "x"], + ) + self.other: List[str] = other + logger.info(f"Setting second atom selection to {self.other}") + else: + idx = pd.MultiIndex.from_product( + [ions, aas, atoms, x], names=["ions", "aas", "atoms", "x"] + ) + self.other: None = None + self.df: pd.DataFrame = pd.DataFrame(index=idx, columns=cols) + + self._get_data(nameparts) + + self.df.dropna(inplace=True, how="all", axis=0) + + setattr(self, self.df.columns.name, list(self.df.columns)) + + self.df.reset_index(level=["ions", "atoms"], inplace=True) + self.df["_atoms"] = self.df["atoms"].where( + self.df["atoms"] != "ions", self.df["ions"], axis=0 + ) + self.df.set_index( + ["ions", "atoms", "_atoms"], inplace=True, append=True + ) + self.df.index = self.df.index.reorder_levels([*idx.names, "_atoms"]) + self._atoms = self.df.index.get_level_values("_atoms").tolist() + self.df["x_bins"] = np.NaN + self.df.set_index(["x_bins"], inplace=True, append=True) + + for iid, i in enumerate(self.df.index.names): + value: List[Union[str, float]] = ( + self.df.index._get_level_values(level=iid).unique().tolist() + ) + logger.info(f"Setting {i} to {value}") + setattr(self, i, value) + + if odir != None: + self.odir: Path = Path(odir) + else: + self.odir: Path = Path(".").cwd() + + logger.info(f"Output directory set to {str(self.odir.resolve())!r}\n") + self.__bin_df = pd.DataFrame(columns=self.df.columns) + + self.__edges = {} + self.__peaks = {} + + def _get_data(self, nameparts): + idsl = pd.IndexSlice + for f in self.filelist: + namesplit = f.stem.split("_") + if self.analysis is not None: + namesplit.pop(-1) + else: + self.analysis = "zdist" + name = namesplit[:nameparts] + namesplit = namesplit[nameparts:] + if self.other != None: + # other = namesplit[5] + other = namesplit.pop(5) + if other in self.ions: + other = "ions" + try: + clay, ion, aa, pH, atom, cutoff, bins = namesplit + assert cutoff == self.cutoff + assert bins == self.bins + array = pd.read_csv(f, delimiter="\s+", comment="#").to_numpy() + try: + self.df.loc[idsl[ion, aa, atom, :], clay] = array[:, 2] + except ValueError: + self.df.loc[idsl[ion, aa, atom, other, :], clay] = array[ + :, 2 + ] + except IndexError: + try: + self.df.loc[idsl[ion, aa, atom, :], clay] = array[:, 1] + except ValueError: + self.df.loc[ + idsl[ion, aa, atom, other, :], clay + ] = array[:, 1] + except KeyError: + pass + except IndexError: + logger.info(f"Encountered IndexError while getting data") + except ValueError: + print(namesplit) + logger.info(f"Encountered ValueError while getting data") + self.name = "_".join(name) + + def __repr__(self): + return self.df[self.clays].dropna().__repr__() + + @property + def densdiff(self): + try: + return self.df["diff"].dropna() + except KeyError: + self._get_densdiff() + return self.df["diff"].dropna() + + def _get_densdiff(self): + self.df["diff"] = -self.df.diff(axis=1)[self.df.columns[-1]] + + def plot( + self, + x: Literal["clays", "aas", "ions", "atoms", "other"], + y: Literal["clays", "ions", "aas", "atoms", "other"], + select: Literal["clays", "ions", "aas", "atoms", "other"], + rowlabel: str = "y", + columnlabel: str = "x", + figsize=None, + dpi=None, + diff=False, + xmax=50, + ymax=50, + save=False, + xlim=None, + ylim=None, + odir=".", + plot_table=None, + ): + aas_classes = [ + ["arg", "lys", "his"], + ["glu", "gln"], + ["cys"], + ["gly"], + ["pro"], + ["ala", "val", "ile", "leu", "met"], + ["phe", "tyr", "trp"], + ["ser", "thr", "asp", "gln"], + ] + ions_classes = [["Na", "Ca"], ["Ca", "Mg"]] + atoms_classes = [["ions"], ["N"], ["OT"], ["CA"]] + clays_classes = [["NAu-1"], ["NAu-2"]] + cmaps_seq = ["Purples", "Blues", "Greens", "Oranges", "Reds"] + cmaps_single = ["Dark2"] + sel_list = ("clays", "ions", "aas", "atoms") + # for color, attr in zip([''], sel_list): + # cmaps_dict[attr] = {} + # cm.get_cmap() + cmap_dict = {"clays": []} + + title_dict = { + "clays": "Clay type", + "ions": "Ion type", + "aas": "Amino acid", + "atoms": "Atom type", + "other": "Other atom type", + } + + sel_list = ["clays", "ions", "aas", "atoms"] + + if self.other != None: + sel_list.append("other") + + separate = [s for s in sel_list if (s != x and s != y and s != select)] + idx = pd.Index([s for s in sel_list if (s != x and s not in separate)]) + + sep = pd.Index(separate) + + vx = getattr(self, x) + # print(vx) + + if diff == True: + vx = "/".join(vx) + lx = 1 + else: + lx = len(vx) + + vy = getattr(self, y) + + ly = len(vy) + # print(ly) + + yid = np.ravel(np.where(np.array(idx) == y))[0] + # print(yid) + + label_key = idx.difference( + pd.Index([x, y, *separate]), sort=False + ).values[0] + + label_id = idx.get_loc(key=label_key) + + # label_classes = locals()[f'{label_key}_classes'] + # cmap_dict = {} + # single_id = 0 + # seq_id = 0 + # for category in label_classes: + # if len(category) == 1: + # cmap = matplotlib.cycler('color', cm.Dark2.colors) + # single_id += 1 + # else: + # cmap = getattr(cm, cmaps_seq[seq_id])(np.linspace(0, 1, len(category))) + # + # cmap = matplotlib.cycler('color', cmap) + # # cmap = cmap(np.linspace(0, 1, len(category))).colors + # # viridis(np.linspace(0,1,N))) + # # cm.get_cmap(cmaps_seq[seq_id], len(category)) + # seq_id += 1 + # for item_id, item in enumerate(category): + # cmap_dict[item] = cmap.__getitem__(item_id) + # + n_plots = len(sep) + + x_dict = dict(zip(vx, np.arange(lx))) + + if diff == True: + diffstr = "diff" + sel = "diff" + self._get_densdiff() + else: + sel = self.clays + diffstr = "" + + plot_df = self.df[sel].copy() + plot_df.index = plot_df.index.droplevel(["_atoms", "x_bins"]) + plot_df.reset_index().set_index([*idx, "x"]) + # print(plot_df.head(5)) + + if figsize == None: + figsize = tuple( + [ + 5 * lx if (10 * lx) < xmax else xmax, + 5 * ly if (5 * ly) < ymax else ymax, + ] + ) + + if dpi == None: + dpi = 300 + + iters = np.array( + np.meshgrid(*[getattr(self, idxit) for idxit in idx]) + ).T.reshape(-1, len(idx)) + + logger.info(f"Printing plots for {sep}\nColumns: {vx}\nRows: {vy}") + + label_mod = lambda l: ", ".join( + [li.upper() if namei == "aas" else li for li, namei in l] + ) + + sep_it = np.array( + np.meshgrid(*[getattr(self, idxit) for idxit in sep]) + ).T.reshape(-1, len(sep)) + # print(vx, vy, lx, ly) + + for pl in sep_it: + # print(pl) + # try: + # fig.clear() + # except: + # pass + y_dict = dict(zip(vy, np.arange(ly))) + if separate == "atoms" and pl != "": + ... + + legends_list = [(a, b) for a in range(ly) for b in range(lx)] + + legends = dict( + zip(legends_list, [[] for a in range(len(legends_list))]) + ) + + # if type(pl) in [list, tuple, np.ndarray]: + # viewlist = [] + # for p in pl: + # viewlist.append(plot_df.xs((p), level=separate, axis=0)) + # + # sepview = pd.concat(viewlist) + # plsave = 'ions' + # + # else: + sepview = plot_df.xs((pl), level=separate, axis=0) + plsave = pl + + fig, ax = plt.subplots( + nrows=ly, + ncols=lx, + figsize=figsize, + sharey=True, + dpi=dpi, + constrained_layout=True, + ) + + fig.suptitle( + ( + ", ".join([title_dict[s].upper() for s in separate]) + + f": {label_mod(list(tuple(zip(pl, separate))))}" + ), + size=16, + weight="bold", + ) + pi = 0 + for col in vx: + try: + view = sepview.xs(col, axis=1) + pi = 1 + except ValueError: + view = sepview + col = vx + pi += 1 + for it in iters: + try: + values = view.xs( + tuple(it), level=idx.tolist() + ).reset_index(drop=False) + values = values.values + if np.all(values) >= 0: + try: + x_id, y_id = x_dict[col], y_dict[it[yid]] + ax[y_id, x_id].plot( + values[:, 0], + values[:, 1], + label=it[label_id], + ) + except: + x_id, y_id = 0, y_dict[it[yid]] + ax[y_id].plot( + values[:, 0], + values[:, 1], + label=it[label_id], + ) + if pi == 1: + legends[y_id, x_id].append(it[label_id]) + else: + logger.info("NaN values") + except KeyError: + logger.info(f"No data for {pl}, {vx}, {it}") + for i in range(ly): + try: + ax[i, 0].set_ylabel( + f"{label_mod([(vy[i], y)])}\n" + rowlabel + ) + for j in range(lx): + ax[i, j].legend( + [ + label_mod([(leg, label_key)]) + for leg in legends[i, j] + ], + ncol=3, + ) + if xlim != None: + ax[i, j].set_xlim((0.0, float(xlim))) + if ylim != None: + ax[i, j].set_ylim((0.0, float(ylim))) + ax[ly - 1, j].set_xlabel( + columnlabel + f"\n{label_mod([(vx[j], x)])}" + ) + except IndexError: + ax[i].set_ylabel(f"{label_mod([(vy[i], y)])}\n" + rowlabel) + ax[i].legend( + [ + label_mod([(leg, label_key)]) + for leg in legends[i, 0] + ], + ncol=3, + ) + ax[ly - 1].set_xlabel( + columnlabel + f"\n{label_mod([(vx[0], x)])}" + ) + + fig.supxlabel(f"{title_dict[x]}s", size=14) + fig.supylabel(f"{title_dict[y]}s", size=14) + if save != False: + odir = Path(odir) + if not odir.is_dir(): + os.makedirs(odir) + if type(save) == str: + fig.savefig(odir / f"{save}.png") + else: + logger.info( + f"Saving to {self.name}_{diffstr}_{x}_{y}_{plsave}_{self.cutoff}_{self.bins}.png" + ) + fig.savefig( + odir + / f"{self.name}_{diffstr}_{x}_{y}_{plsave}_{self.cutoff}_{self.bins}.png" + ) + else: + plt.show() + self.fig.clear() + + def _get_edge_fname( + self, + atom_type: str, + other: Optional[str], + name: Union[Literal["pe"], Literal["edge"]] = "pe", + ): + return get_edge_fname(atom_type, name, other, PE_DATA) + + def _get_edges( + self, + height: Union[float, int] = 0.01, + distance: Union[float, int] = 2, + width: int = 1, + wlen: int = 11, + peak_cutoff: Union[int, float] = 10, + prominence: float = 0.005, + atom_type="all", + other=None, + **kwargs, + ) -> List[float]: + """Identifies ads_edges and maxima of position density peaks. + Peak and edge identification based on the ``scipy.signal`` :func:`find_peaks` + :param height: Required height of peaks. + :type height: Union[float, int] + :param distance: Required minimal horizontal distance (>= 1) in samples between + neighbouring peaks. + :type distance: Union[float, int] + :param width: Required width of peaks in samples. + :type width: int + :param wlen: A window length in samples that optionally limits the evaluated area + for each peak to a subset of the evaluated sequence. + :type wlen: int + :return: list of peak ads_edges + :rtype: List[float]""" + from ClayAnalysis.peaks import Peaks + from sklearn.neighbors import KernelDensity + + p = Peaks(self) + edge_df: pd.DataFrame = self.df.copy() + if other is None: + logger.info( + f'Found atom types {edge_df.index.unique("_atoms").tolist()}' + ) + edge_df.index = edge_df.index.droplevel(["ions", "atoms"]) + # logger.info(edge_df.groupby(["_atoms", "x"]).count()) + edge_df = edge_df.groupby(["_atoms", "x"]).sum() + # Take sum for all columns (clay types) + edge_df = edge_df.aggregate("sum", axis=1) + if atom_type != "all": + atom_types = [atom_type] + else: + atom_types = edge_df.index.unique(level="_atoms").tolist() + else: + logger.info( + f'Found atom types {edge_df.index.unique("_atoms").tolist()}' + ) + logger.info( + f'Found atom types {edge_df.index.unique("other").tolist()}' + ) + edge_df.index = edge_df.index.droplevel(["ions", "atoms"]) + edge_df = edge_df.groupby(["other", "_atoms", "x"]).sum() + + logger.info(f"Getting peaks ads_edges for {atom_types}") + for atom_type in atom_types: + outname = self._get_edge_fname( + atom_type=atom_type, other=other, name="pe" + ) + if not outname.is_file(): + if atom_type == "OT": + atom_peak_cutoff = peak_cutoff + 5 + else: + atom_peak_cutoff = peak_cutoff + df_slice = edge_df.xs( + atom_type, level="_atoms", drop_level=False + ) + expanded_x = np.expand_dims( + df_slice.index.get_level_values("x"), axis=1 + ) + kde: KernelDensity = KernelDensity( + kernel="tophat", bandwidth=0.01 + ).fit(X=expanded_x, sample_weight=df_slice.to_numpy()) + score = kde.score_samples(expanded_x) + score -= np.amin(score, where=~np.isinf(score), initial=0) + score = np.where( + np.logical_or( + np.isinf(score), + df_slice.index.get_level_values("x") + >= atom_peak_cutoff, + ), + 0, + score, + ) + cut = np.argwhere(score != 0) + plt.plot( + df_slice.index.get_level_values("x")[cut], + 10 * score[cut] / np.sum(score[cut]), + linestyle="dotted", + ) + # score = np.where(df_slice.index.get_level_values('x') > peak_cutoff, 0, score) + peak_prominence = prominence * score + # plt.plot(x_vals, score) + + peaks, peak_dict = find_peaks( + score, + # df_slice.to_numpy(), + # height=height, + distance=distance, + width=width, + wlen=wlen, + prominence=peak_prominence, + ) + plt.plot( + df_slice.index.get_level_values("x"), + df_slice.to_numpy() / np.sum(df_slice.to_numpy()), + ) + # plt.plot(df_slice.index.get_level_values("x"), df_slice.to_numpy()) + x_vals = df_slice.index.get_level_values("x").to_numpy() + logger.info(f"Found {len(peaks)} peaks: {x_vals[peaks]}") + edges = [0] + for id in range(0, len(peaks) - 1): + window = np.s_[peaks[id] : peaks[id + 1]] + edge_id = np.argwhere( + score[window] == np.min(score[window]) + )[0] + edges.append(x_vals[window.start + edge_id][0]) + # edge_id = np.argwhere( + # df_slice.values == np.min(df_slice.values[window]) + # )[0] + # ads_edges.append(x_vals[edge_id][0]) + # logger.info(f"ads_edges: {ads_edges}") + # ads_edges.append(x_vals[peak_dict["right_bases"][-1]]) + # logger.info(f"ads_edges: {ads_edges}") + + # logger.info(f"ads_edges: {ads_edges}") + try: + right_base = peak_dict["right_bases"][-1] + except IndexError: + right_base = None + # logger.info(f'{right_base}') + # print(right_base) + if len(peaks) <= 1: + logger.info("l1") + right_edge = right_base + else: + logger.info("ln1") + final_slice = np.s_[ + window.stop : window.stop + + np.min([edge_id[0], window.stop - window.start]) + ] + if score[right_base] <= np.min(score[final_slice]): + logger.info("a") + right_edge = right_base + else: + right_edge = ( + window.stop + + np.argwhere( + score[final_slice] + == np.min(score[final_slice]) + )[-1][0] + ) + logger.info(f"{right_base}, {right_edge}") + if right_edge is not None: + edges.append(x_vals[right_edge]) # + edges.append(x_vals[-1]) + # print(ads_edges) + plt.scatter( + x_vals[peaks], + df_slice[peaks] / np.sum(df_slice.to_numpy()), + color="red", + ) + edge_dict = { + "ads_edges": edges, + "cutoff": self.cutoff, + "peak": x_vals[peaks], + } + # for peak in edge_dict["peak"]: + # plt.axvline(peak, 0, 1, color="green") + # for edge in edge_dict["ads_edges"]: + # plt.axvline(edge, 0, 0.5, color="orange") + # plt.suptitle(atom_type) + for p in peaks: + # print(peak) + # print((peak, df_slice[peak_i])) + plt.annotate( + rf"{np.round(x_vals[p], 1):2.1f} \AA", + xy=( + x_vals[p], + df_slice[p] / np.sum(df_slice.to_numpy()), + ), + textcoords="offset points", + verticalalignment="bottom", + ) + logger.info(edge_dict) + for ei, edge in enumerate(edge_dict["ads_edges"]): + plt.axvline( + edge, 0, 2, color="orange" + ) # , label = f'edge {ei}: {edge:2.1f}') + # plt.annotate(fr'{edge:2.1f} \AA', xy=(edge, 0.8)) + plt.suptitle(f"Edges: {atom_type}") + plt.xlabel(r"z-distance (\AA)") + plt.ylabel(r"density") + plt.xticks(np.arange(0, 21, 2)) + # plt.ylim(0, 0.5) + # plt.show() + # plt.savefig( + # Path(self._get_edge_fname(atom_type=atom_type)).with_suffix(".png") + # ) + # plt.close() + with open(outname, "wb") as edge_file: + pkl.dump(edge_dict, edge_file) + logger.info(f"Wrote {atom_type} ads_edges to {outname}.") + # plt.show() + # p.get_peaks(atom_type=atom_type) + # ads_edges = self._read_edge_file(atom_type=atom_type, skip=False) + # self.__edges[atom_type] = ads_edges + # self.__peaks[atom_type] = x_vals[peaks] + # + + def _read_edge_file(self, atom_type: str, skip=True, other=None): + fname = self._get_edge_fname(atom_type, name="ads_edges", other=other) + return read_edge_file(fname, self.cutoff, skip) + + # # def _read_peak_file(self, atom_type): + # # fname = self._get_edge_fname(atom_type) + # # if not fname.exists(): + # # logger.info("does not exist") + # # os.mkdir(fname.parent) + # # from ClayAnalysis.peaks import Peaks + # # pks = Peaks(self) + # # pks.get_peaks(atom_type=atom_type) + # # with open(fname, "rb") as edges_file: + # # logger.info(f"Reading peaks {edges_file.name}") + # # p = pkl.load(edges_file)["peaks"] + # # print(p) + # # return p + # + # # @property + # # def peaks(self): + # # if len(self.__peaks) == len(self._atoms): + # # pass + # # else: + # # for atom_type in self._atoms: + # # # try: + # # + # # self.__edges[atom_type] = self._read_peak_file(atom_type) + # # logger.info(f"Reading peaks") + # # # except FileNotFoundError: + # # # logger.info(f"Getting new peaks") + # # # self._get_edges(atom_type) + # # return self.__peaks + + @property + def edges(self): + if len(self.__edges) == len(self._atoms): + pass + else: + for atom_type in self._atoms: + # try: + self.__edges[atom_type] = self._read_edge_file(atom_type) + # logger.info(f"Reading peaks") + # except FileNotFoundError: + # logger.info(f"Getting new ads_edges") + # self._get_edges(atom_type) + return self.__edges + + def get_bin_df(self): + idx = self.df.index.names + bin_df = self.df.copy() + atom_types = bin_df.index.get_level_values("_atoms").unique().tolist() + bin_df.reset_index(["x_bins", "x", "_atoms"], drop=False, inplace=True) + for atom_type in atom_types: + # logger.info(f"{atom_type}") + try: + edges = self.__edges[atom_type] + except KeyError: + # edge_fname = self._get_edge_fname(atom_type) + edges = self._read_edge_file( + atom_type=atom_type, other=self.other + ) + # if edge_fname.is_file(): + # self.__edges[atom_type] = self._read_edge_file(atom_type) + # else: + # raise + # self._get_edges(atom_type=atom_type) + # ads_edges = self.__edges[atom_type] + # print(ads_edges, bin_df['x_bins'].where(bin_df['_atoms'] == atom_type)) + bin_df["x_bins"].where( + bin_df["_atoms"] != atom_type, + pd.cut(bin_df["x"], [*edges]), + inplace=True, + ) + bin_df.reset_index(drop=False, inplace=True) + + bin_df.set_index(idx, inplace=True) + self.df = bin_df.copy() + + @property + def bin_df(self): + if not self.df.index.get_level_values("x_bins").is_interval(): + logger.info("No Interval") + self.get_bin_df() + else: + logger.info("Interval") + return self.df + + # area_df = self.df.copy() + # atom_col = edge_df.loc['atoms'] + # edge_df['atoms'].where(edge_df['atoms'] == 'ions') + # data_slices = edge_df.groupby(['ions', 'atoms', 'x']).sum() + # data_slices = data_slices.aggregate('sum', axis=1) + # ion_slices = data_slices.xs('ions', level='atoms') + # # other_slices = + # + # peaks = find_peaks(data_slices.to_numpy(), + # height=height, + # distance=distance, + # width=width, + # wlen=wlen) + # check_logger.info(f'Found peaks {peaks[0]}') + # + # colours = ['blue', 'orange'] + # fig, ax = plt.subplots(len(data_slices.index.unique('atoms'))) + # y = [] + # fig = plt.figure(figsize=(16, 9)) + # for atom_type in data_slices.index.unique('atoms'): + # data_slice = data_slices.xs(atom_type, level='atoms') + # plt_slice = data_slice + # if atom_type == 'ions': + # for ion_type in data_slice.index.unique('ions'): + # plt_slice = data_slice.xs(ion_type, level='ions') + # y.append((plt_slice.reset_index()['x'].to_numpy(), plt_slice.to_numpy())) + # else: + # y.append((plt_slice.reset_index()['x'].to_numpy(), plt_slice.to_numpy())) + # + # for y_data in y: + # # y = plt_slice.to_numpy() + # # x = plt_slice.reset_index()['x'].to_numpy()#atom_type) + # plt.plot(*y_data) + # plt.vlines(data_slice.reset_index()['x'].to_numpy()[peaks[0]], -1, 1, color='red') + # plt.xlim(0, 7) + + # + # group = data.index.droplevel('x') + # + # # new_idx = pd.MultiIndex.from_product(group = data.index.droplevel('x').get_level_values) + # + # ads_edges = np.array(ads_edges, dtype=np.float32) + # if ads_edges[0] != min: + # np.insert(ads_edges, 0, min) + # if ads_edges[-1] < self.cutoff: + # ads_edges.append(self.cutoff) + # # intervals = pd.IntervalIndex.from_breaks(ads_edges) + # + # data = data.reset_index(drop=False).set_index(group.names) + # print(data.index.names) + # print(data.columns) + # data['bins'] = pd.cut(data['x'], [min, *ads_edges, self.cutoff]) + # print(data['bins'].head(5)) + # data.set_index(['bins'], append=True, inplace=True) + # data = data.loc[:, self.clays] + # grouped = data.groupby([*group.names, 'bins']).sum() + # + # + # + # # data.set_index('bins', append=True, inplace=True) + # # data = data.reset_index(level='x').set_index('bins', append=True) + # + # + # # if type(sel_level) == str: + # # sel_level = [sel_level] + # # # group = [g for g in group if g not in sel_level] + # # # group.append('area_bins') + # # x = data.groupby(by=[*group.names, 'bins']).cumsum() + # + # return grouped + # # def _get_areas(self, sel, sel_level, ads_edges, min = 0.0): + # # idsl = pd.IndexSlice + # # data = self.df.xs(sel, + # # level=sel_level, + # # drop_level=False).copy() + # # group = data.index.droplevel('x') + # # + # # # new_idx = pd.MultiIndex.from_product(group = data.index.droplevel('x').get_level_values) + # # + # # ads_edges = np.array(ads_edges, dtype=np.float32) + # # if ads_edges[0] != min: + # # np.insert(ads_edges, 0, min) + # # if ads_edges[-1] < self.cutoff: + # # ads_edges.append(self.cutoff) + # # # intervals = pd.IntervalIndex.from_breaks(ads_edges) + # # + # # data = data.reset_index(drop=False).set_index(group.names) + # # print(data.index.names) + # # print(data.columns) + # # data['bins'] = pd.cut(data['x'], [min, *ads_edges, self.cutoff]) + # # print(data['bins'].head(5)) + # # data.set_index(['bins'], append=True, inplace=True) + # # data = data.loc[:, self.clays] + # # grouped = data.groupby([*group.names, 'bins']).sum() + # # + # # + # # + # # # data.set_index('bins', append=True, inplace=True) + # # # data = data.reset_index(level='x').set_index('bins', append=True) + # # + # # + # # # if type(sel_level) == str: + # # # sel_level = [sel_level] + # # # # group = [g for g in group if g not in sel_level] + # # # # group.append('area_bins') + # # # x = data.groupby(by=[*group.names, 'bins']).cumsum() + # # + # # return grouped + # + + # def _get_bin_label(self, x_bin): + # if x_bin.right < np.max(self.x): + # label = f'${x_bin.left} - {x_bin.right}$ \AA' + # # barwidth = x_bin.right - x_bin.left + # + # else: + # label = f'$ > {x_bin.left}$ \AA' + # return label + + # def plot_bars(self, + # bars: Literal['clays', 'aas', 'ions', 'atoms', 'other'], + # x: Literal['clays', 'aas', 'ions', 'atoms', 'other'], + # y: Literal['clays', 'aas', 'ions', 'atoms', 'other'], + # # y: Literal['clays', 'ions', 'aas', 'atoms', 'other'], + # # select: Literal['clays', 'ions', 'aas'], + # rowlabel: str = 'y', + # columnlabel: str = 'x', + # figsize=None, + # dpi=None, + # # diff=False, + # xmax=50, + # ymax=50, + # save=False, + # ylim=None, + # odir='.', + # barwidth=0.75, + # xpad=0.25, + # cmap='winter' + # ): + # """Create stacked Histogram adsorption shell populations. + # + # """ + # aas_classes = [['arg', 'lys', 'his'], + # ['glu', 'gln'], + # ['cys'], + # ['gly'], + # ['pro'], + # ['ala', 'val', 'ile', 'leu', 'met'], + # ['phe', 'tyr', 'trp'], + # ['ser', 'thr', 'asp', 'gln']] + # ions_classes = [['Na', 'Ca'], + # ['Ca', 'Mg']] + # atoms_classes = [['ions'], + # ['N'], + # ['OT'], + # ['CA']] + # clays_classes = [['NAu-1'], + # ['NAu-2']] + # cmaps_seq = ['Purples', 'Blues', 'Greens', 'Oranges', 'Reds'] + # cmaps_single = ['Dark2'] + # sel_list = ('clays', 'ions', 'aas', 'atoms') + # # for color, attr in zip([''], sel_list): + # # cmaps_dict[attr] = {} + # # cm.get_cmap() + # cmap_dict = {'clays': []} + # + # title_dict = {'clays': 'Clay type', + # 'ions': 'Ion type', + # 'aas': 'Amino acid', + # 'atoms': 'Atom type', + # 'other': 'Other atom type'} + # + # sel_list = ['clays', 'ions', 'aas', 'atoms'] + # + # # if self.other != None: + # # sel_list.append('other') + # + # separate = [s for s in sel_list if s not in [x, y, bars]] # (s != x and s != y and s != bars and s != groups)] + # + # idx = pd.Index([s for s in sel_list if (s != x and s != bars and s not in separate)]) + # + # + # sep = pd.Index(separate) + # + # vx = getattr(self, x) + # logger.info(f'x = {x}: {vx}') + # lx = len(vx) + # + # vy = getattr(self, y) + # logger.info(f'y = {y}: {vy}') + # ly = len(vy) + # + # + # vbars = getattr(self, bars) + # lbars = len(vbars) + # logger.info(f'bars = {bars}: {vbars}') + # + # bar_dict = dict(zip(vbars, np.arange(lbars))) + # + # yid = np.ravel(np.where(np.array(idx) == y))[0] + # + # + # # label_key = idx.difference(pd.Index([x, y, *separate]), sort=False).values[0] + # + # # label_id = idx.get_loc(key=label_key) + # n_plots = len(sep) + # + # x_dict = dict(zip(vx, np.arange(lx))) + # + # + # sel = self.clays + # + # # get data for plotting + # plot_df = self.bin_df[sel].copy() + # + # # move clays category from columns to index + # idx_names = ['clays', *plot_df.index.droplevel(['x', '_atoms']).names] + # # DataFrame -> Series + # plot_df = plot_df.stack() + # + # # get values for atom types (including separate ions) + # atoms = plot_df.index.get_level_values('_atoms') + # + # # make new DataFrame from atom_type index level and values + # plot_df.index = plot_df.index.droplevel(['x', '_atoms']) + # plot_df = pd.DataFrame({'values': plot_df, + # '_atoms': atoms}) + # + # # list of matplotlib sequential cmaps + # cmaps = [ + # 'spring', 'summer', 'autumn', 'winter', 'cool', 'Wistia', + # 'hot', 'afmhot', 'gist_heat', 'copper'] + # + # # map unique atom types to colour map + # atom_types = atoms.unique() + # colour_dict = dict(zip(atom_types, cmaps[:len(atom_types)])) + # plot_df['_atoms'] = plot_df['_atoms'].transform(lambda x: colour_dict[x]) + # + # # reorder index for grouping + # plot_df = plot_df.reorder_levels(idx_names) + # + # # group and sum densities within adsorption shell bins + # plot_df = plot_df.groupby( + # plot_df.index.names).agg(values=pd.NamedAgg('values', 'sum'), + # colours=pd.NamedAgg('_atoms', 'first') + # ) + # + # # separate colour column from plot_df -> yields 2 Series + # colour_df = plot_df['colours'] + # plot_df = plot_df['values'] + # + # # add missing atom probabilities from bulk to the largest bin + # # (bin.left, cutoff] -> (bin.left, all bulk]) + # inner_sum = plot_df.groupby(plot_df.index.droplevel('x_bins').names).sum() + # extra = 1 - inner_sum + # plot_df.where(np.rint(plot_df.index.get_level_values('x_bins').right + # ) != int(self.cutoff), + # lambda x: x + extra, + # inplace=True + # ) + # + # # determine largest shell bin limit + # max_edge = list(map(lambda x: np.max(x[:-1]), self.ads_edges.values())) + # max_edge = np.max(max_edge) + # + # # normalise colour map from 0 to max_edge + # cnorm = mpc.Normalize(vmin=0, vmax=max_edge, clip=False) + # + # # set figure size + # if figsize == None: + # figsize = tuple([5 * lx if (10 * lx) < xmax else xmax, + # 5 * ly if (5 * ly) < ymax else ymax]) + # + # # set resultion + # if dpi == None: + # dpi = 100 + # + # # get plotting iter from index + # iters = np.array(np.meshgrid(*[getattr(self, idxit) for idxit in idx]) + # ).T.reshape(-1, len(idx)) + # + # logger.info(f'Printing bar plots for {sep}\nColumns: {vx}\nRows: {vy}') + # + # # set label modifier function + # label_mod = lambda l: ', '.join([li.upper() if namei == 'aas' + # else li for li, namei in l]) + # + # + # try: + # # iterator for more than one plot + # sep_it = np.array(np.meshgrid(*[getattr(self, idxit) for idxit in sep]) + # ).T.reshape(-1, len(sep)) + # except ValueError: + # # only one plot + # sep_it = [None] + # + # # iterate over separate plots + # for pl in sep_it: + # # index map for y values + # y_dict: dict = dict(zip(vy, np.arange(ly))) + # print(y_dict) + # + # # initialise legends + # legends_list: list = [(a, b) for a in range(ly) for b in range(lx)] + # legends: dict = dict(zip(legends_list, [[] for a in range(len(legends_list))])) + # + # # generate figure and axes array + # fig, ax = plt.subplots(nrows=ly, + # ncols=lx, + # figsize=figsize, + # sharey=True, + # dpi=dpi, + # constrained_layout=True, + # # sharex=True + # ) + # # only one plot + # if pl is None: + # logger.info(f'Generating plot') + # sepview = plot_df.view() + # plsave = '' + # + # # multiple plots + # else: + # logger.info(f'Generating {pl} plot') + # sepview = plot_df.xs((pl), + # level=separate, + # axis=0, + # drop_level=False) + # plsave = pl + # fig.suptitle((', '.join([title_dict[s].upper() for s in separate]) + + # f': {label_mod(list(tuple(zip(pl, separate))))}'), size=16, + # weight='bold') + # + # # set plot index + # pi = 0 + # + # #iterate over subplot columns + # for col in vx: + # logger.info(col) + # try: + # + # view = sepview.xs(col, + # level=x, + # axis=0, + # drop_level=False) + # + # pi = 1 + # except ValueError: + # view = sepview + # col = vx + # pi += 1 + # + # table_text = [] + # table_col = [] + # table_cmap = None + # table_rows = [] + # for it in iters: + # try: + # + # + # values = view.xs(tuple(it), + # level=idx.tolist(), + # drop_level=False) + # + # # x_grouplabels = [] + # x_labels = [] + # x_ticks = [] + # + # + # + # + # + # for vbar in vbars: + # # bulk_pad = 2 + # # bulk_edge = np.rint((max_edge + bulk_pad)) + # x_ticks.append(bar_dict[vbar] * (barwidth + xpad)) + # x_labels.append(vbar) + # + # + # bottom = 0.0 + # + # + # # x_tick = bar_dict[vbar] * (barwidth + xpad) + # + # + # bar_vals = values.xs(vbar, + # level=bars, + # drop_level=False) + # + # cmap = colormaps[colour_df.loc[bar_vals.index].values[0]] + # if table_cmap is None: + # table_cmap = cmap + # + # if len(self.__peaks) != len(self.df.index.get_level_values('_atoms').unique()): + # self._get_edges() + # try: + # peaks = self.__peaks[bar_vals.index.get_level_values('atoms').unique().tolist()[0]] + # except: + # peaks = self.__peaks[bar_vals.index.get_level_values('ions').unique().tolist()[0]] + # + # # bar_vals.values >= 0) + # if np.all(bar_vals.values) >= 0: + # table_text.append([f'${v * 100:3.1f} %$' for v in bar_vals.values]) + # + # + # x_id, y_id = x_dict[col], y_dict[it[yid]] + # + # bar_val_view = bar_vals + # bar_val_view.index = bar_val_view.index.get_level_values('x_bins') + # + # + # x_tick = x_ticks[-1] + # # x_ticks.append(x_tick) + # # bar_plots = [] + # for bar_id, bar_val in enumerate(bar_val_view.items()): + # + # + # x_bin, y_val = bar_val + # + # try: + # peak = peaks[bar_id] + # except IndexError: + # peak = x_bin.right + # colour = cmap(cnorm(peak)) + # if colour not in table_col and cmap == table_cmap: + # print('colour', colour) + # table_col.append(colour) + # if x_bin.right < np.max(self.x): + # label = f'${x_bin.left} - {x_bin.right}$ \AA' + # # barwidth = x_bin.right - x_bin.left + # else: + # label = f'$ > {x_bin.left}$ \AA' + # if label not in table_rows and cmap == table_cmap: + # table_rows.append(label) + # if y_val >= 0.010: + # + # # barwidth = bulk_edge - x_bin.left + # # try: + # # x_tick = x_ticks[-1] + barwidth + # # x_ticks.append(x_tick) + # # except IndexError: + # # x_tick = x_bin.left + # + # + # + # + # try: + # p = ax[y_id, x_id].bar(x_tick, + # y_val, + # label=label, + # bottom=bottom, + # width=barwidth, + # align='edge', + # color=colour + # ) + # ax[y_id, x_id].bar_label(p, labels=[label], + # fmt='%s', + # label_type='center') + # except IndexError: + # p = ax[y_id].bar(x_tick, + # y_val, + # label=label, + # bottom=bottom, + # width=barwidth, + # align='edge', + # color=colour + # ) + # ax[y_id, x_id].bar_label(p, labels=[label], + # fmt='%s', + # label_type='center') + # # finally: + # bottom += y_val + # print(table_text, table_col, table_rows) + # table = ax[y_id, x_id].table(cellText=table_text, + # rowColours=table_col, + # rowLabels=table_rows, + # # colLables=..., + # loc='bottom') + # # x_ticks = x_ticks[:-1] + # + # # x_ticks.append(bar_dict[vbar] * (barwidth + xpad)) + # + # + # + # # values = values + # # print('try 1 done') + # # + # # # for shell in values: + # # # view for group and bars + # # label = f'${lims.left} - {lims.right}$ \AA' + # # + # # try: + # # print('try 2') + # # print(x_dict[col], y_dict[it[yid]]) + # + # # except: + # # # raise ValueError + # # x_id, y_id = 0, y_dict[it[yid]] + # # label = f'${x_bin.left} - {x_bin.right}$ \AA' + # + # # if pi == 1: + # # legends[y_id, x_id].append(it[label_id]) + # # else: + # # check_logger.info('NaN values') + # + # except KeyError: + # logger.info(f'No data for {pl}, {vx}, {it}') + # + # # x_ticks = [np.linspace(n_bar * bulk_edge + xpad, + # # n_bar * bulk_edge + bulk_edge, int(bulk_edge)) for n_bar in range(lbars)] + # # x_ticks = np.ravel(x_ticks) + # # x_labels = np.tile(np.arange(0, bulk_edge, 1), lbars) + # + # for i in range(ly): + # try: + # ax[i, 0].set_ylabel(f'{label_mod([(vy[i], y)])}\n' + rowlabel) + # ax[i, 0].set_yticks(np.arange(0.0, 1.1, 0.2)) + # for j in range(lx): + # ax[i, j].spines[['top', 'right']].set_visible(False) + # ax[i, j].hlines(1.0, -xpad, lbars * (barwidth + xpad) + xpad, linestyle='--') + # # ax[i, j].legend(ncol=2, loc='lower center')#[leg for leg in legends[i, j]], ncol=3) + # # if xlim != None: + # ax[i, j].set_xlim((-xpad, lbars * (barwidth + xpad))) + # ax[i, j].set_xticks([], []) + # # if ylim != None: + # ax[i, j].set_ylim((0.0, 1.25)) + # + # ax[ly - 1, j].set_xticks(np.array(x_ticks) + 0.5 * barwidth, x_labels) + # ax[ly - 1, j].set_xlabel(bars + f'\n{label_mod([(vx[j], x)])}') + # except IndexError: + # ... + # # ax[i].set_ylabel(f'{label_mod([(vy[i], y)])}\n' + rowlabel) + # # ax[i].legend([label_mod([(leg, label_key)]) for leg in legends[i, 0]], ncol=3) + # # ax[ly - 1].set_xlabel(columnlabel + f'\n{label_mod([(vx[0], x)])}') + # # # + # fig.supxlabel(f'{title_dict[x]}s', size=14) + # fig.supylabel(f'{title_dict[y]}s', size=14) + # # # if save != False: + # # # odir = Path(odir) + # # # if not odir.is_dir(): + # # # os.makedirs(odir) + # # # if type(save) == str: + # # # fig.savefig(odir / f'{save}.png') + # # # else: + # # # fig.savefig(odir / f'{self.name}_{diffstr}_{x}_{y}_{plsave}_{self.cutoff}_{self.bins}.png') + # # # fig.clear() + # fig.show() + # return fig + + @cached_property + def binned_plot_colour_dfs_1d(self) -> Tuple[pd.DataFrame, pd.DataFrame]: + logger.info(f"Getting binned plot and colour dfs") + sel = self.clays + + # get data for plotting + plot_df = self.bin_df[sel].copy() + + # move clays category from columns to index + idx_names = ["clays", *plot_df.index.droplevel(["x"]).names] + # DataFrame -> Series + plot_df = plot_df.stack() + + # get values for atom types (including separate ions) + atoms = plot_df.index.get_level_values("_atoms") + + # make new DataFrame from atom_type index level and values + plot_df.index = plot_df.index.droplevel(["x"]) + # plot_df.index.names = [name.strip('_') for name in plot_df.index.names] + + plot_df = pd.DataFrame({"values": plot_df, "colours": atoms}) + + # list of matplotlib sequential cmaps + cmaps = [ + "spring", + "summer", + "autumn", + "winter", + "cool", + "Wistia", + "hot", + "afmhot", + "gist_heat", + "copper", + ] + + # map unique atom types to colour map + atom_types = atoms.unique() + colour_dict = dict(zip(atom_types, cmaps[: len(atom_types)])) + plot_df["colours"] = plot_df["colours"].transform( + lambda x: colour_dict[x] + ) + + # reorder index for grouping + plot_df = plot_df.reorder_levels(idx_names) + + # group and sum densities within adsorption shell bins + plot_df = plot_df.groupby(plot_df.index.names).agg( + values=pd.NamedAgg("values", "sum"), + colours=pd.NamedAgg("colours", "first"), + ) + + # separate colour column from plot_df -> yields 2 Series + colour_df = plot_df["colours"] + plot_df = plot_df["values"] + + # add missing atom probabilities from bulk to the largest bin + # (bin.left, cutoff] -> (bin.left, all bulk]) + inner_sum = plot_df.groupby( + plot_df.index.droplevel("x_bins").names + ).sum() + extra = 1 - inner_sum + plot_df.where( + np.rint(plot_df.index.get_level_values("x_bins").right) + != int(self.cutoff), + lambda x: x + extra, + inplace=True, + ) + return plot_df, colour_df + + @cached_property + def binned_df(self): + return self.binned_plot_colour_dfs_1d[0] + + @cached_property + def colour_df(self): + return self.binned_plot_colour_dfs_1d[1] + + @property + def max_shell_edge(self) -> float: + # determine largest shell bin limit + max_edge = list(map(lambda x: np.max(x[:-1]), self.edges.values())) + max_edge = np.max(max_edge) + return max_edge + + def plot_columns(self, sepview, col, x, vx, pi): + # logger.info(col) + try: + view = sepview.xs(col, level=x, axis=0, drop_level=False) + pi = 1 + except ValueError: + view = sepview + col = vx + pi += 1 + return view, col, pi + + def get_bar_peaks(self, atom_type, other=None): + # if len(self.__peaks) != len(self.df.index.get_level_values("_atoms").unique()): + peaks = self._read_edge_file( + atom_type=atom_type, other=other + ) # ['peaks'] + return peaks + # print(peaks) + # logger.info(f"Found peaks {peaks}") + # try: + # print(peaks) + # print(bar_vals.index.get_level_values('atoms')) + # bar_peaks = peaks[ + # bar_vals.index.get_level_values("atoms").unique().tolist()#[0] + # ] + # print(bar_peaks) + # bar_peaks=bar_peaks[0] + # except: + # bar_peaks = peaks[ + # bar_vals.index.get_level_values("ions").unique().tolist()#[0] + # ] + # print(bar_peaks) + # bar_peaks = bar_peaks[0] + # return bar_peaks + + # def plot_bars(self, + # bars: Literal['clays', 'aas', 'ions', 'other'], + # x: Literal['clays', 'aas', 'ions', 'other'], + # y: Literal['clays', 'aas', 'ions', 'other'], + # rowlabel: str = 'y', + # columnlabel: str = 'x', + # figsize=None, + # dpi=None, + # # diff=False, + # xmax=50, + # ymax=50, + # save=False, + # ylim=None, + # odir='.', + # barwidth=0.75, + # xpad=0.25, + # cmap='winter' + # ): + # """Create stacked Histogram adsorption shell populations. + # + # """ + # aas_classes = [['arg', 'lys', 'his'], + # ['glu', 'gln'], + # ['cys'], + # ['gly'], + # ['pro'], + # ['ala', 'val', 'ile', 'leu', 'met'], + # ['phe', 'tyr', 'trp'], + # ['ser', 'thr', 'asp', 'gln']] + # ions_classes = [['Na', 'Ca'], + # ['Ca', 'Mg']] + # atoms_classes = [['ions'], + # ['N'], + # ['OT'], + # ['CA']] + # clays_classes = [['NAu-1'], + # ['NAu-2']] + # cmaps_seq = ['Purples', 'Blues', 'Greens', 'Oranges', 'Reds'] + # cmaps_single = ['Dark2'] + # sel_list = ('clays', 'ions', 'aas', 'atoms') + # # for color, attr in zip([''], sel_list): + # # cmaps_dict[attr] = {} + # # cm.get_cmap() + # cmap_dict = {'clays': []} + # + # title_dict = {'clays': 'Clay type', + # 'ions': 'Ion type', + # 'aas': 'Amino acid', + # '_atoms': 'Atom type', + # 'other': 'Other atom type'} + # + # sel_list = ['clays', 'ions', 'aas', '_atoms'] + # + # if self.other != None: + # sel_list.append('other') + # + # assert x in sel_list and x != '_atoms' + # assert y in sel_list and y != '_atoms' + # assert bars in sel_list and bars != '_atoms' + # + # plot_df, colour_df = self._get_binned_plot_df_1d() + # + # cnorm = self._get_cnorm() + # + # # get data for plotting + # + # bins = 'x_bins' + # group = '_atoms' + # + # separate = [s for s in plot_df.index.names if s not in [x, y, bars, bins]] # (s != x and s != y and s != bars and s != groups)] + # logger.info(f'Separate plots: {separate}') + # idx = pd.Index([s for s in plot_df.index.names if (s != x and s != bars and s not in [*separate, bins])]) + # logger.info(f'Iteration index: {idx}') + # + # sep = pd.Index(separate) + # + # vx = getattr(self, x) + # logger.info(f'x = {x}: {vx}') + # lx = len(vx) + # + # vy = getattr(self, y) + # logger.info(f'y = {y}: {vy}') + # ly = len(vy) + # + # + # vbars = getattr(self, bars) + # lbars = len(vbars) + # logger.info(f'bars = {bars}: {vbars}') + # + # bar_dict = dict(zip(vbars, np.arange(lbars))) + # + # yid = np.ravel(np.where(np.array(idx) == y))[0] + # + # sys.exit(2) + # + # # label_key = idx.difference(pd.Index([x, y, *separate]), sort=False).values[0] + # + # # label_id = idx.get_loc(key=label_key) + # n_plots = len(sep) + # + # x_dict = dict(zip(vx, np.arange(lx))) + # + # + # sel = self.clays + # + # # set figure size + # if figsize == None: + # figsize = self.get_figsize(lx=lx, + # ly=ly, + # xmax=xmax, + # ymax=ymax) + # + # # set resultion + # if dpi == None: + # dpi = 100 + # + # # get plotting iter from index + # iters = self._get_idx_iter(idx=idx) + # + # logger.info(f'Printing bar plots for {sep}\nColumns: {vx}\nRows: {vy}') + # + # # set label modifier function + # label_mod = self.modify_plot_labels + # + # try: + # # iterator for more than one plot + # sep_it = self._get_idx_iter(idx=sep) + # except ValueError: + # # only one plot + # sep_it = [None] + # + # # iterate over separate plots + # for pl in sep_it: + # # index map for y values + # y_dict: dict = dict(zip(vy, np.arange(ly))) + # print(y_dict) + # + # legends = self.init_legend(ly=ly, + # lx=lx) + # print(legends) + # + # # generate figure and axes array + # fig, ax = plt.subplots(nrows=ly, + # ncols=lx, + # figsize=figsize, + # sharey=True, + # dpi=dpi, + # constrained_layout=True, + # # sharex=True + # ) + # # only one plot + # if pl is None: + # logger.info(f'Generating plot') + # sepview = plot_df.view() + # plsave = '' + # + # # multiple plots + # else: + # logger.info(f'Generating {pl} plot') + # print(plot_df.head(20),'\n', separate, pl) + # sepview = plot_df.xs((pl), + # level=separate, + # drop_level=False) + # plsave = pl + # print(pl) + # print(separate) + # fig.suptitle((', '.join([title_dict[s].upper() for s in separate]) + + # f': {label_mod(list(tuple(zip(pl, separate))))}'), size=16, + # weight='bold') + # + # # set plot index + # pi = 0 + # + # #iterate over subplot columns + # for col in vx: + # # logger.info(col) + # # view, col, pi = self.plot_columns(sepview=sepview, + # # col=col, + # # x=x, + # # vx=vx, + # # pi=pi) + # logger.info(col) + # try: + # view = sepview.xs(col, + # level=x, + # axis=0, + # drop_level=False) + # pi = 1 + # except ValueError: + # view = sepview + # col = vx + # pi += 1 + # + # table_text = [] + # table_col = [] + # table_cmap = None + # table_rows = [] + # for it in iters: + # try: + # values = view.xs(tuple(it), + # level=idx.tolist(), + # drop_level=False) + # + # x_labels = [] + # x_ticks = [] + # for vbar in vbars: + # print(vbar) + # x_ticks.append(bar_dict[vbar] * (barwidth + xpad)) + # x_labels.append(vbar) + # bottom = 0.0 + # bar_vals = values.xs(vbar, + # level=bars, + # drop_level=False) + # + # cmap = colormaps[colour_df.loc[bar_vals.index].values[0]] + # if table_cmap is None: + # table_cmap = cmap + # + # peaks = self.get_bar_peaks(bar_vals=bar_vals) + # + # if np.all(bar_vals.values) >= 0: + # table_text.append([f'${v * 100:3.1f} %$' for v in bar_vals.values]) + # x_id, y_id = x_dict[col], y_dict[it[yid]] + # + # bar_val_view = bar_vals + # bar_val_view.index = bar_val_view.index.get_level_values('x_bins') + # + # x_tick = x_ticks[-1] + # + # for bar_id, bar_val in enumerate(bar_val_view.items()): + # + # x_bin, y_val = bar_val + # + # try: + # peak = peaks[bar_id] + # except IndexError: + # peak = x_bin.right + # colour = cmap(cnorm(peak)) + # if colour not in table_col and cmap == table_cmap: + # print('colour', colour) + # table_col.append(colour) + # + # label = self._get_bin_label(x_bin) + # + # # if x_bin.right < np.max(self.x): + # # label = f'${x_bin.left} - {x_bin.right}$ \AA' + # # else: + # # label = f'$ > {x_bin.left}$ \AA' + # if label not in table_rows and cmap == table_cmap: + # table_rows.append(label) + # if y_val >= 0.010: + # + # # barwidth = bulk_edge - x_bin.left + # # try: + # # x_tick = x_ticks[-1] + barwidth + # # x_ticks.append(x_tick) + # # except IndexError: + # # x_tick = x_bin.left + # + # + # + # + # try: + # p = ax[y_id, x_id].bar(x_tick, + # y_val, + # label=label, + # bottom=bottom, + # width=barwidth, + # align='edge', + # color=colour + # ) + # ax[y_id, x_id].bar_label(p, labels=[label], + # fmt='%s', + # label_type='center') + # except IndexError: + # p = ax[y_id].bar(x_tick, + # y_val, + # label=label, + # bottom=bottom, + # width=barwidth, + # align='edge', + # color=colour + # ) + # ax[y_id, x_id].bar_label(p, labels=[label], + # fmt='%s', + # label_type='center') + # # finally: + # bottom += y_val + # print(table_text, table_col, table_rows) + # # table = ax[y_id, x_id].table(cellText=table_text, + # # rowColours=table_col, + # # rowLabels=table_rows, + # # # colLables=..., + # # loc='bottom') + # # x_ticks = x_ticks[:-1] + # + # # x_ticks.append(bar_dict[vbar] * (barwidth + xpad)) + # + # + # + # # values = values + # # print('try 1 done') + # # + # # # for shell in values: + # # # view for group and bars + # # label = f'${lims.left} - {lims.right}$ \AA' + # # + # # try: + # # print('try 2') + # # print(x_dict[col], y_dict[it[yid]]) + # + # # except: + # # # raise ValueError + # # x_id, y_id = 0, y_dict[it[yid]] + # # label = f'${x_bin.left} - {x_bin.right}$ \AA' + # + # # if pi == 1: + # # legends[y_id, x_id].append(it[label_id]) + # # else: + # # check_logger.info('NaN values') + # + # except KeyError: + # logger.info(f'No data for {pl}, {vx}, {it}') + # + # # x_ticks = [np.linspace(n_bar * bulk_edge + xpad, + # # n_bar * bulk_edge + bulk_edge, int(bulk_edge)) for n_bar in range(lbars)] + # # x_ticks = np.ravel(x_ticks) + # # x_labels = np.tile(np.arange(0, bulk_edge, 1), lbars) + # + # for i in range(ly): + # try: + # ax[i, 0].set_ylabel(f'{label_mod([(vy[i], y)])}\n' + rowlabel) + # ax[i, 0].set_yticks(np.arange(0.0, 1.1, 0.2)) + # for j in range(lx): + # ax[i, j].spines[['top', 'right']].set_visible(False) + # ax[i, j].hlines(1.0, -xpad, lbars * (barwidth + xpad) + xpad, linestyle='--') + # # ax[i, j].legend(ncol=2, loc='lower center')#[leg for leg in legends[i, j]], ncol=3) + # # if xlim != None: + # ax[i, j].set_xlim((-xpad, lbars * (barwidth + xpad))) + # ax[i, j].set_xticks([], []) + # # if ylim != None: + # ax[i, j].set_ylim((0.0, 1.25)) + # + # ax[ly - 1, j].set_xticks(np.array(x_ticks) + 0.5 * barwidth, x_labels) + # ax[ly - 1, j].set_xlabel(bars + f'\n{label_mod([(vx[j], x)])}') + # except IndexError: + # ... + # # ax[i].set_ylabel(f'{label_mod([(vy[i], y)])}\n' + rowlabel) + # # ax[i].legend([label_mod([(leg, label_key)]) for leg in legends[i, 0]], ncol=3) + # # ax[ly - 1].set_xlabel(columnlabel + f'\n{label_mod([(vx[0], x)])}') + # # # + # fig.supxlabel(f'{title_dict[x]}s', size=14) + # fig.supylabel(f'{title_dict[y]}s', size=14) + # # # if save != False: + # # # odir = Path(odir) + # # # if not odir.is_dir(): + # # # os.makedirs(odir) + # # # if type(save) == str: + # # # fig.savefig(odir / f'{save}.png') + # # # else: + # # # fig.savefig(odir / f'{self.name}_{diffstr}_{x}_{y}_{plsave}_{self.cutoff}_{self.bins}.png') + # # # fig.clear() + # fig.show() + # return fig + + def plot_bars_shifted( + self, + bars: Literal["clays", "aas", "ions", "atoms", "other"], + x: Literal["clays", "aas", "ions", "atoms", "other"], + y: Literal["clays", "aas", "ions", "atoms", "other"], + # y: Literal['clays', 'ions', 'aas', 'atoms', 'other'], + # select: Literal['clays', 'ions', 'aas'], + rowlabel: str = "y", + columnlabel: str = "x", + figsize=None, + dpi=None, + # diff=False, + xmax=50, + ymax=50, + save=False, + ylim=None, + odir=".", + # barwidth = 0.75, + xpad=0.25, + cmap="winter", + ): + """Create stacked Histogram adsorption shell populations.""" + aas_classes = [ + ["arg", "lys", "his"], + ["glu", "gln"], + ["cys"], + ["gly"], + ["pro"], + ["ala", "val", "ile", "leu", "met"], + ["phe", "tyr", "trp"], + ["ser", "thr", "asp", "gln"], + ] + ions_classes = [["Na", "Ca"], ["Ca", "Mg"]] + atoms_classes = [["ions"], ["N"], ["OT"], ["CA"]] + clays_classes = [["NAu-1"], ["NAu-2"]] + cmaps_seq = ["Purples", "Blues", "Greens", "Oranges", "Reds"] + cmaps_single = ["Dark2"] + sel_list = ("clays", "ions", "aas", "atoms") + # for color, attr in zip([''], sel_list): + # cmaps_dict[attr] = {} + # cm.get_cmap() + cmap_dict = {"clays": []} + + title_dict = { + "clays": "Clay type", + "ions": "Ion type", + "aas": "Amino acid", + "atoms": "Atom type", + "other": "Other atom type", + } + + sel_list = ["clays", "ions", "aas", "atoms"] + + # if self.other != None: + # sel_list.append('other') + cmap = colormaps[cmap] + separate = [ + s for s in sel_list if s not in [x, y, bars] + ] # (s != x and s != y and s != bars and s != groups)] + # print(separate) + idx = pd.Index( + [ + s + for s in sel_list + if (s != x and s != bars and s not in separate) + ] + ) + # print(idx) + + sep = pd.Index(separate) + + vx = getattr(self, x) + logger.info(f"x = {x}: {vx}") + lx = len(vx) + + vy = getattr(self, y) + logger.info(f"y = {y}: {vy}") + ly = len(vy) + # print(ly) + + vbars = getattr(self, bars) + lbars = len(vbars) + logger.info(f"bars = {bars}: {vbars}") + + bar_dict = dict(zip(vbars, np.arange(lbars))) + + yid = np.ravel(np.where(np.array(idx) == y))[0] + # print(yid) + + # label_key = idx.difference(pd.Index([x, y, *separate]), sort=False).values[0] + + # label_id = idx.get_loc(key=label_key) + n_plots = len(sep) + + x_dict = dict(zip(vx, np.arange(lx))) + # print(x_dict) + + sel = self.clays + + plot_df = self.bin_df[sel].copy() + idx_names = ["clays", *plot_df.index.droplevel(["x", "_atoms"]).names] + # print("idx", idx_names) + plot_df = plot_df.stack() + # atoms = plot_df.index.get_level_values('_atoms') + plot_df.index = plot_df.index.droplevel(["x", "_atoms"]) + # plot_df = pd.DataFrame({'values': plot_df, + # '_atoms': atoms}) + # idx_names.remove('_atoms') + plot_df = plot_df.reorder_levels(idx_names) + # + plot_df.name = "values" + # print(plot_df.head(3)) + plot_df = plot_df.groupby(plot_df.index.names).sum() + # print(plot_df.head(3)) + inner_sum = plot_df.groupby( + plot_df.index.droplevel("x_bins").names + ).sum() + extra = 1 - inner_sum + plot_df.where( + np.rint(plot_df.index.get_level_values("x_bins").right) + != int(self.cutoff), + lambda x: x + extra, + inplace=True, + ) + # print(type(self.ads_edges)) + max_edge = list(map(lambda x: np.max(x[:-1]), self.edges.values())) + max_edge = np.max(max_edge) + # print("max edge", max_edge) + # max_edge = np.ravel(np.array(*self.ads_edges.values())) + # print(max_edge) + cnorm = mpc.Normalize(vmin=0, vmax=max_edge, clip=False) + + if figsize == None: + figsize = tuple( + [ + 5 * lx if (10 * lx) < xmax else xmax, + 5 * ly if (5 * ly) < ymax else ymax, + ] + ) + # + if dpi == None: + dpi = 100 + # + iters = np.array( + np.meshgrid(*[getattr(self, idxit) for idxit in idx]) + ).T.reshape(-1, len(idx)) + # + logger.info(f"Printing bar plots for {sep}\nColumns: {vx}\nRows: {vy}") + # + label_mod = lambda l: ", ".join( + [li.upper() if namei == "aas" else li for li, namei in l] + ) + # + try: + sep_it = np.array( + np.meshgrid(*[getattr(self, idxit) for idxit in sep]) + ).T.reshape(-1, len(sep)) + except ValueError: + sep_it = [None] + # check_logger.info(vx, vy, lx, ly) + # + for pl in sep_it: + # print(pl) + # try: + # fig.clear() + # except: + # pass + y_dict = dict(zip(vy, np.arange(ly))) + # print(y_dict) + # if separate == 'atoms' and pl != '': + # ... + # + legends_list = [(a, b) for a in range(ly) for b in range(lx)] + # + legends = dict( + zip(legends_list, [[] for a in range(len(legends_list))]) + ) + # + # if type(pl) in [list, tuple, np.ndarray]: + # # viewlist = [] + # # for p in pl: + # # viewlist.append(plot_df.xs((p), level=separate, axis=0)) + # # + # # sepview = pd.concat(viewlist) + # # plsave = 'ions' + # # + # # else: + fig, ax = plt.subplots( + nrows=ly, + ncols=lx, + figsize=figsize, + sharey=True, + dpi=dpi, + constrained_layout=True, + # sharex=True + ) + if pl is None: + sepview = plot_df.view() + plsave = "" + else: + sepview = plot_df.xs( + (pl), level=separate, axis=0, drop_level=False + ) + plsave = pl + # + # + + # + fig.suptitle( + ( + ", ".join([title_dict[s].upper() for s in separate]) + + f": {label_mod(list(tuple(zip(pl, separate))))}" + ), + size=16, + weight="bold", + ) + pi = 0 + for col in vx: + try: + view = sepview.xs(col, level=x, axis=0, drop_level=False) + pi = 1 + except ValueError: + view = sepview + col = vx + pi += 1 + # print("column", col) + for it in iters: + try: + values = view.xs( + tuple(it), level=idx.tolist(), drop_level=False + ) + x_grouplabels = [] + x_labels = [] + x_ticks = [] + for vbar in vbars: + bulk_pad = 2 + bulk_edge = np.rint((max_edge + bulk_pad)) + x_ticks.append(bar_dict[vbar] * bulk_edge + xpad) + # print(values) + bottom = 0.0 + x_grouplabels.append(vbar) + # x_tick = bar_dict[vbar] * (barwidth + xpad) + bar_vals = values.xs( + vbar, level=bars, drop_level=False + ) + + if len(self.__peaks) != len( + self.df.index.get_level_values( + "_atoms" + ).unique() + ): + self._get_edges() + try: + peaks = self.__peaks[ + bar_vals.index.get_level_values("atoms") + .unique() + .tolist()[0] + ] + except: + peaks = self.__peaks[ + bar_vals.index.get_level_values("ions") + .unique() + .tolist()[0] + ] + # bar_vals.values >= 0) + if np.all(bar_vals.values) >= 0: + # print("All > 0") + x_id, y_id = x_dict[col], y_dict[it[yid]] + + bar_val_view = bar_vals + bar_val_view.index = ( + bar_val_view.index.get_level_values( + "x_bins" + ) + ) + # x_ticks.append(x_tick + 0.5 * barwidth) + # bar_plots = [] + for bar_id, bar_val in enumerate( + bar_val_view.items() + ): + x_bin, y_val = bar_val + + try: + peak = peaks[bar_id] + except IndexError: + peak = x_bin.right + colour = cmap(cnorm(peak)) + + label = self._get_bin_label(x_bin) + if x_bin.right < np.max(self.x): + label = f"${x_bin.left:3.1f} - {x_bin.right:3.1f}$ \AA" + barwidth = x_bin.right - x_bin.left + + else: + label = f"$ > {x_bin.left:3.1f}$ \AA" + # print(bar_val_view.index[-1].left) + barwidth = bulk_edge - x_bin.left + # try: + x_tick = x_ticks[-1] + barwidth + x_ticks.append(x_tick) + # except IndexError: + # x_tick = x_bin.left + x_labels.append(x_bin.left) + # print(x_ticks, x_tick, x_bin) + # print(peaks, bar_id, "label", "peak", label, peak) + # print(label) + try: + p = ax[y_id, x_id].bar( + x_tick, + y_val, + label=label, + left=bottom, + height=-barwidth, + align="edge", + color=colour, + ) + ax[y_id, x_id].bar_label( + p, + labels=[label], + fmt="%s", + label_type="center", + ) + except IndexError: + p = ax[y_id].bar( + x_tick, + y_val, + label=label, + left=bottom, + height=-barwidth, + align="edge", + color=colour, + ) + ax[y_id, x_id].bar_label( + p, + labels=[label], + fmt="%s", + label_type="center", + ) + # finally: + bottom += y_val + x_ticks = x_ticks[:-1] + # values = values + # print('try 1 done') + # + # # for shell in values: + # # view for group and bars + # label = f'${lims.left} - {lims.right}$ \AA' + # + # try: + # print('try 2') + # print(x_dict[col], y_dict[it[yid]]) + + # except: + # # raise ValueError + # x_id, y_id = 0, y_dict[it[yid]] + # label = f'${x_bin.left} - {x_bin.right}$ \AA' + + # if pi == 1: + # legends[y_id, x_id].append(it[label_id]) + # else: + # check_logger.info('NaN values') + + except KeyError: + logger.info(f"No data for {pl}, {vx}, {it}") + + x_ticks = [ + np.linspace( + n_bar * bulk_edge + xpad, + n_bar * bulk_edge + bulk_edge, + int(bulk_edge), + ) + for n_bar in range(lbars) + ] + x_ticks = np.ravel(x_ticks) + x_labels = np.tile(np.arange(0, bulk_edge, 1), lbars) + # print(x_ticks, x_labels) + for i in range(ly): + try: + ax[i, 0].set_ylabel( + f"{label_mod([(vy[i], y)])}\n" + rowlabel + ) + ax[i, 0].set_yticks(np.arange(0.0, 1.1, 0.2)) + for j in range(lx): + ax[i, j].spines[["top", "right"]].set_visible(False) + ax[i, j].hlines(1.0, -xpad, lbars, linestyle="--") + # ax[i, j].legend(ncol=2, loc='lower center')#[leg for leg in legends[i, j]], ncol=3) + # if xlim != None: + # ax[i, j].set_xlim((-xpad, lbars)) + ax[i, j].set_xticks([], []) + # if ylim != None: + ax[i, j].set_ylim((0.0, 1.25)) + # print(x_ticks, x_labels) + ax[ly - 1, j].set_xticks(x_ticks, x_labels) + # ax[ly - 1, j].set_xlabel(columnlabel + f'\n{label_mod([(vx[j], x)])}') + except IndexError: + ... + # ax[i].set_ylabel(f'{label_mod([(vy[i], y)])}\n' + rowlabel) + # ax[i].legend([label_mod([(leg, label_key)]) for leg in legends[i, 0]], ncol=3) + # ax[ly - 1].set_xlabel(columnlabel + f'\n{label_mod([(vx[0], x)])}') + # # + # # fig.supxlabel(f'{title_dict[x]}s', size=14) + # # fig.supylabel(f'{title_dict[y]}s', size=14) + # # if save != False: + # # odir = Path(odir) + # # if not odir.is_dir(): + # # os.makedirs(odir) + # # if type(save) == str: + # # fig.savefig(odir / f'{save}.png') + # # else: + # # fig.savefig(odir / f'{self.name}_{diffstr}_{x}_{y}_{plsave}_{self.cutoff}_{self.bins}.png') + # # fig.clear() + + def plot_hbars( + self, + bars: Literal["clays", "aas", "ions", "atoms", "other"], + x: Literal["clays", "aas", "ions", "atoms", "other"], + y: Literal["clays", "aas", "ions", "atoms", "other"], + # y: Literal['clays', 'ions', 'aas', 'atoms', 'other'], + # select: Literal['clays', 'ions', 'aas'], + rowlabel: str = "y", + columnlabel: str = "x", + figsize=None, + dpi=None, + # diff=False, + xmax=50, + ymax=50, + save=False, + ylim=None, + odir=".", + # barwidth = 0.75, + xpad=0.25, + cmap="winter", + ): + """Create stacked Histogram adsorption shell populations.""" + aas_classes = [ + ["arg", "lys", "his"], + ["glu", "gln"], + ["cys"], + ["gly"], + ["pro"], + ["ala", "val", "ile", "leu", "met"], + ["phe", "tyr", "trp"], + ["ser", "thr", "asp", "gln"], + ] + ions_classes = [["Na", "Ca"], ["Ca", "Mg"]] + atoms_classes = [["ions"], ["N"], ["OT"], ["CA"]] + clays_classes = [["NAu-1"], ["NAu-2"]] + cmaps_seq = ["Purples", "Blues", "Greens", "Oranges", "Reds"] + cmaps_single = ["Dark2"] + sel_list = ("clays", "ions", "aas", "atoms") + # for color, attr in zip([''], sel_list): + # cmaps_dict[attr] = {} + # cm.get_cmap() + cmap_dict = {"clays": []} + + title_dict = { + "clays": "Clay type", + "ions": "Ion type", + "aas": "Amino acid", + "atoms": "Atom type", + "other": "Other atom type", + } + + sel_list = ["clays", "ions", "aas", "atoms"] + + # if self.other != None: + # sel_list.append('other') + cmap = colormaps[cmap] + separate = [ + s for s in sel_list if s not in [x, y, bars] + ] # (s != x and s != y and s != bars and s != groups)] + # print(separate) + idx = pd.Index( + [ + s + for s in sel_list + if (s != x and s != bars and s not in separate) + ] + ) + # print(idx) + + sep = pd.Index(separate) + + vx = getattr(self, x) + logger.info(f"x = {x}: {vx}") + lx = len(vx) + + vy = getattr(self, y) + logger.info(f"y = {y}: {vy}") + ly = len(vy) + # print(ly) + + vbars = getattr(self, bars) + lbars = len(vbars) + logger.info(f"bars = {bars}: {vbars}") + + bar_dict = dict(zip(vbars, np.arange(lbars))) + + yid = np.ravel(np.where(np.array(idx) == y))[0] + # print(yid) + + # label_key = idx.difference(pd.Index([x, y, *separate]), sort=False).values[0] + + # label_id = idx.get_loc(key=label_key) + n_plots = len(sep) + + x_dict = dict(zip(vx, np.arange(lx))) + # print(x_dict) + + sel = self.clays + + plot_df = self.bin_df[sel].copy() + idx_names = ["clays", *plot_df.index.droplevel(["x", "_atoms"]).names] + # print("idx", idx_names) + plot_df = plot_df.stack() + # atoms = plot_df.index.get_level_values('_atoms') + plot_df.index = plot_df.index.droplevel(["x", "_atoms"]) + # plot_df = pd.DataFrame({'values': plot_df, + # '_atoms': atoms}) + # idx_names.remove('_atoms') + plot_df = plot_df.reorder_levels(idx_names) + # + plot_df.name = "values" + # print(plot_df.head(3)) + plot_df = plot_df.groupby(plot_df.index.names).sum() + # print(plot_df.head(3)) + inner_sum = plot_df.groupby( + plot_df.index.droplevel("x_bins").names + ).sum() + extra = 1 - inner_sum + plot_df.where( + np.rint(plot_df.index.get_level_values("x_bins").right) + != int(self.cutoff), + lambda x: x + extra, + inplace=True, + ) + # print(type(self.ads_edges)) + max_edge = list(map(lambda x: np.max(x[:-1]), self.edges.values())) + max_edge = np.max(max_edge) + # print("max edge", max_edge) + # max_edge = np.ravel(np.array(*self.ads_edges.values())) + # print(max_edge) + cnorm = mpc.Normalize(vmin=0, vmax=max_edge, clip=False) + + if figsize == None: + figsize = tuple( + [ + 5 * lx if (10 * lx) < xmax else xmax, + 5 * ly if (5 * ly) < ymax else ymax, + ] + ) + # + if dpi == None: + dpi = 100 + # + iters = np.array( + np.meshgrid(*[getattr(self, idxit) for idxit in idx]) + ).T.reshape(-1, len(idx)) + # + logger.info(f"Printing bar plots for {sep}\nColumns: {vx}\nRows: {vy}") + # + label_mod = lambda l: ", ".join( + [li.upper() if namei == "aas" else li for li, namei in l] + ) + # + try: + sep_it = np.array( + np.meshgrid(*[getattr(self, idxit) for idxit in sep]) + ).T.reshape(-1, len(sep)) + except ValueError: + sep_it = [None] + # check_logger.info(vx, vy, lx, ly) + # + for pl in sep_it: + # print(pl) + # try: + # fig.clear() + # except: + # pass + y_dict = dict(zip(vy, np.arange(ly))) + # print(y_dict) + # if separate == 'atoms' and pl != '': + # ... + # + legends_list = [(a, b) for a in range(ly) for b in range(lx)] + # + legends = dict( + zip(legends_list, [[] for a in range(len(legends_list))]) + ) + # + # if type(pl) in [list, tuple, np.ndarray]: + # # viewlist = [] + # # for p in pl: + # # viewlist.append(plot_df.xs((p), level=separate, axis=0)) + # # + # # sepview = pd.concat(viewlist) + # # plsave = 'ions' + # # + # # else: + fig, ax = plt.subplots( + nrows=ly, + ncols=lx, + figsize=figsize, + sharey=True, + dpi=dpi, + constrained_layout=True, + # sharex=True + ) + if pl is None: + sepview = plot_df.view() + plsave = "" + else: + sepview = plot_df.xs( + (pl), level=separate, axis=0, drop_level=False + ) + plsave = pl + # + # + + # + fig.suptitle( + ( + ", ".join([title_dict[s].upper() for s in separate]) + + f": {label_mod(list(tuple(zip(pl, separate))))}" + ), + size=16, + weight="bold", + ) + pi = 0 + for col in vx: + try: + view = sepview.xs(col, level=x, axis=0, drop_level=False) + pi = 1 + except ValueError: + view = sepview + col = vx + pi += 1 + # print("column", col) + for it in iters: + try: + values = view.xs( + tuple(it), level=idx.tolist(), drop_level=False + ) + x_grouplabels = [] + x_labels = [] + x_ticks = [] + for vbar in vbars: + bulk_pad = 2 + bulk_edge = np.rint((max_edge + bulk_pad)) + x_ticks.append(bar_dict[vbar] * bulk_edge + xpad) + # print(values) + bottom = 0.0 + x_grouplabels.append(vbar) + # x_tick = bar_dict[vbar] * (barwidth + xpad) + bar_vals = values.xs( + vbar, level=bars, drop_level=False + ) + + if len(self.__peaks) != len( + self.df.index.get_level_values( + "_atoms" + ).unique() + ): + self._get_edges() + try: + peaks = self.__peaks[ + bar_vals.index.get_level_values("atoms") + .unique() + .tolist()[0] + ] + except: + peaks = self.__peaks[ + bar_vals.index.get_level_values("ions") + .unique() + .tolist()[0] + ] + # bar_vals.values >= 0) + if np.all(bar_vals.values) >= 0: + # print("All > 0") + x_id, y_id = x_dict[col], y_dict[it[yid]] + + bar_val_view = bar_vals + bar_val_view.index = ( + bar_val_view.index.get_level_values( + "x_bins" + ) + ) + # x_ticks.append(x_tick + 0.5 * barwidth) + # bar_plots = [] + for bar_id, bar_val in enumerate( + bar_val_view.items() + ): + x_bin, y_val = bar_val + + try: + peak = peaks[bar_id] + except IndexError: + peak = x_bin.right + colour = cmap(cnorm(peak)) + + if x_bin.right < np.max(self.x): + label = f"${x_bin.left} - {x_bin.right}$ \AA" + barwidth = x_bin.right - x_bin.left + + else: + label = f"$ > {x_bin.left}$ \AA" + # print(bar_val_view.index[-1].left) + barwidth = bulk_edge - x_bin.left + # try: + x_tick = x_ticks[-1] + barwidth + x_ticks.append(x_tick) + # except IndexError: + # x_tick = x_bin.left + x_labels.append(x_bin.left) + # print(x_ticks, x_tick, x_bin) + # print(peaks, bar_id, "label", "peak", label, peak) + # print(label) + try: + p = ax[y_id, x_id].barh( + x_tick, + y_val, + label=label, + left=bottom, + height=-barwidth, + align="edge", + color=colour, + ) + ax[y_id, x_id].bar_label( + p, + labels=[label], + fmt="%s", + label_type="center", + ) + except IndexError: + p = ax[y_id].barh( + x_tick, + y_val, + label=label, + left=bottom, + height=-barwidth, + align="edge", + color=colour, + ) + ax[y_id, x_id].bar_label( + p, + labels=[label], + fmt="%s", + label_type="center", + ) + # finally: + bottom += y_val + x_ticks = x_ticks[:-1] + # values = values + # print('try 1 done') + # + # # for shell in values: + # # view for group and bars + # label = f'${lims.left} - {lims.right}$ \AA' + # + # try: + # print('try 2') + # print(x_dict[col], y_dict[it[yid]]) + + # except: + # # raise ValueError + # x_id, y_id = 0, y_dict[it[yid]] + # label = f'${x_bin.left} - {x_bin.right}$ \AA' + + # if pi == 1: + # legends[y_id, x_id].append(it[label_id]) + # else: + # check_logger.info('NaN values') + + except KeyError: + logger.info(f"No data for {pl}, {vx}, {it}") + + x_ticks = [ + np.linspace( + n_bar * bulk_edge + xpad, + n_bar * bulk_edge + bulk_edge, + int(bulk_edge), + ) + for n_bar in range(lbars) + ] + x_ticks = np.ravel(x_ticks) + x_labels = np.tile(np.arange(0, bulk_edge, 1), lbars) + # print(x_ticks, x_labels) + for i in range(ly): + try: + ax[i, 0].set_ylabel( + f"{label_mod([(vy[i], y)])}\n" + rowlabel + ) + ax[i, 0].set_xticks(np.arange(0.0, 1.1, 0.2)) + for j in range(lx): + ax[i, j].spines[["top", "right"]].set_visible(False) + ax[i, j].hlines(1.0, -xpad, lbars, linestyle="--") + # ax[i, j].legend(ncol=2, loc='lower center')#[leg for leg in legends[i, j]], ncol=3) + # if xlim != None: + # ax[i, j].set_xlim((-xpad, lbars)) + ax[i, j].set_yticks([], []) + # if ylim != None: + ax[i, j].set_ylim((0.0, 1.25)) + # print(x_ticks, x_labels) + ax[ly - 1, j].set_yticks(x_ticks, x_labels) + # ax[ly - 1, j].set_xlabel(columnlabel + f'\n{label_mod([(vx[j], x)])}') + except IndexError: + ... + # ax[i].set_ylabel(f'{label_mod([(vy[i], y)])}\n' + rowlabel) + # ax[i].legend([label_mod([(leg, label_key)]) for leg in legends[i, 0]], ncol=3) + # ax[ly - 1].set_xlabel(columnlabel + f'\n{label_mod([(vx[0], x)])}') + # # + # # fig.supxlabel(f'{title_dict[x]}s', size=14) + # # fig.supylabel(f'{title_dict[y]}s', size=14) + # # if save != False: + # # odir = Path(odir) + # # if not odir.is_dir(): + # # os.makedirs(odir) + # # if type(save) == str: + # # fig.savefig(odir / f'{save}.png') + # # else: + # # fig.savefig(odir / f'{self.name}_{diffstr}_{x}_{y}_{plsave}_{self.cutoff}_{self.bins}.png') + # # fig.clear() + + +class Data2D: + """ + Class for histogram analysis data processing and plotting. + Reads files in `indir` that match the naming pattern + "`namestem`*_`cutoff`_`bins`.dat" + (or "`namestem`*_`cutoff`_`bins`_`analysis`.dat" if `analysis` != `None`. + The data is stored in a :class:`pandas.DataFrame` + :param indir: Data directory + :type indir: Union[str, Path] + :param cutoff: Maximum value in the histogram bins + :type cutoff: Union[int, float] + :param bins: Histogram bin size + :type bins: float + :param ions: List of ion types in solvent + :type ions: List[Literal['Na', 'K', 'Ca', 'Mg']], optional + :param atoms: List of atom types in selection, defaults to `None` + :type atoms: List[Literal['ions', 'OT', 'N', 'CA', 'OW']], optional + :param other: Optional list of atom types in a second selection, + defaults to `None` + :type other: List[Literal['ions', 'OT', 'N', 'CA', 'OW']], optional + :param clays: List of clay types, + defaults to `None` + :type clays: List[Literal['NAu-1', 'NAu-2']], optional + :param aas: List of amino acid types in lower case 3-letter code, + defaults to `None` + :type aas: Optional[List[Literal['ala', 'arg', 'asn', 'asp', + 'ctl', 'cys', 'gln', 'glu', + 'gly', 'his', 'ile', 'leu', + 'lys', 'met', 'phe', 'pro', + 'ser', 'thr', 'trp', 'tyr', + 'val']]] + :param load: Load, defaults to False + :type load: Union[str, Literal[False], Path], optional + :param odir: Output directory, defaults to `None` + :type odir: str, optional + :param nameparts: number of `_`-separated partes in `namestem` + :type nameparts: int, defaults to 1 + :param namestem: leading string in naming pattern, optional + :type namestem: str, defaults to '' + :param analysis: trailing string in naming pattern, optional + defaults to `None` + :type analysis: str, optional + :param df: :class: `pandas.DataFrame` + """ + + aas = [ + "ala", + "arg", + "asn", + "asp", + "ctl", + "cys", + "gln", + "glu", + "gly", + "his", + "ile", + "leu", + "lys", + "met", + "phe", + "pro", + "ser", + "thr", + "trp", + "tyr", + "val", + ] + + ions = ["Na", "K", "Ca", "Mg"] + atoms = ["ions", "OT", "N", "CA"] + clays = ["NAu-1", "NAu-2"] # , 'L31'] + analyses = ["zvel", "surfzvel", "xyvel", "totvel"] + + @redirect_tqdm + def __init__( + self, + indir: Union[str, Path], + zdir: Union[str, Path], + cutoff: Union[int, float], + bins: float, + ions: List[Literal["Na", "K", "Ca", "Mg"]] = None, + atoms: List[Literal["ions", "OT", "N", "CA", "OW"]] = None, + other: List[Literal["ions", "OT", "N", "CA", "OW"]] = None, + clays: List[Literal["NAu-1", "NAu-2"]] = None, + aas: List[ + Literal[ + "ala", + "arg", + "asn", + "asp", + "ctl", + "cys", + "gln", + "glu", + "gly", + "his", + "ile", + "leu", + "lys", + "met", + "phe", + "pro", + "ser", + "thr", + "trp", + "tyr", + "val", + ] + ] = None, + load: Union[str, Literal[False], Path] = False, + odir: Optional[str] = None, + nameparts: int = 1, + namestem: str = "veldist", + zstem: str = "zdist", + analyses: Optional[List[str]] = None, + zname: Optional[str] = "zdens", + vel_bins: float = 0.05, + velrange: Tuple[float, float] = (0.0, 1.0), + ): + """Constructor method""" + logger.info(f"Initialising {self.__class__.__name__}") + self.filelist: dict = {} + self.bins: Bins = Bins(bins) + self.cutoff: float = Cutoff(cutoff) + x = None + + self.zname = zname + + if type(indir) != Path: + indir = Path(indir) + + if type(zdir) != Path: + zdir = Path(zdir) + + self._indir = indir + self._zdir = zdir + + # if self.analysis is None: + # logger.info( + # rf"Getting {namestem}*_" + # rf"{self.cutoff}_" + # rf"{self.bins}.p from {str(indir.resolve())!r}" + # ) + # self.filelist: List[Path] = sorted( + # list( + # indir.glob(rf"{namestem}*_" rf"{self.cutoff}_" rf"{self.bins}.dat") + # ) + # ) + # else: + # logger.info( + # rf"Getting {namestem}*_" + # rf"{self.cutoff}_" + # rf"{self.bins}_" + # rf"{analysis}.dat from {str(indir.resolve())!r}" + # ) + # self.filelist: List[Path] = sorted( + # list( + # indir.glob( + # rf"{namestem}*_" + # rf"{self.cutoff}_" + # rf"{self.bins}_" + # rf"{self.analysis}.dat" + # ) + # ) + # ) + # logger.info(f"Found {len(self.filelist)} files.") + + if load != False: + load = Path(load).resolve() + print(load) + self.df: pd.DataFrame = pkl.load(load) + logger.info(f"Using data from {load!r}") + else: + if ions == None: + ions = self.__class__.ions + logger.info( + f"ions not specified, using default {self.__class__.ions}" + ) + else: + logger.info(f"Using custom {ions} for ions") + if atoms == None: + atoms = self.__class__.atoms + logger.info( + f"atoms not specified, using default {self.__class__.atoms}" + ) + else: + logger.info(f"Using custom {atoms} for atoms") + if aas == None: + aas = self.__class__.aas + logger.info( + f"aas not specified, using default {self.__class__.aas}" + ) + else: + logger.info(f"Using custom {aas} for aas") + if clays == None: + clays = self.__class__.clays + logger.info( + f"clays not specified, using default {self.__class__.clays}" + ) + else: + logger.info(f"Using custom {clays} for clays") + if analyses == None: + analyses = self.__class__.analyses + logger.info( + f"clays not specified, using default {self.__class__.analyses}" + ) + else: + logger.info(f"Using custom {analyses} for analyses") + logger.info( + rf"Getting {zstem}*_" + rf"{self.cutoff}_" + rf"{self.bins}_" + rf"{self.zname}.p from {str(zdir.resolve())!r}" + ) + self.filelist[self.zname]: List[Path] = sorted( + list( + zdir.glob( + rf"{zstem}*_" + rf"{self.cutoff}_" + rf"{self.bins}_" + rf"{self.zname}.p" + ) + ) + ) + logger.info(f"Found {len(self.filelist[self.zname])} files.") + + for analysis in analyses: + logger.info( + rf"Getting {namestem}*_" + rf"{self.cutoff}_" + rf"{self.bins}_" + rf"{analysis}.p from {str(indir.resolve())!r}" + ) + self.filelist[analysis]: List[Path] = sorted( + list( + indir.glob( + rf"{namestem}*_" + rf"{self.cutoff}_" + rf"{self.bins}_" + rf"{analysis}.p" + ) + ) + ) + + logger.info(f"Found {len(self.filelist[analysis])} files.") + + cols = pd.Index([*analyses], name="analyses") + + if other is not None: + if other is True: + other = atoms + other.append("OW") + idx = pd.MultiIndex.from_product( + [clays, ions, aas, atoms, other], + names=["clays", "ions", "aas", "atoms", "other"], + ) + self.other: List[str] = other + logger.info(f"Setting second atom selection to {self.other}") + else: + idx = pd.MultiIndex.from_product( + [clays, ions, aas, atoms], + names=["clays", "ions", "aas", "atoms"], + ) + self.other: None = None + + self.df: pd.DataFrame = pd.DataFrame(index=idx, columns=cols) + self.zf = pd.Series( + index=self.df.index, name=self.zname, dtype="object" + ) + if other is not None: + self.zf.reset_index("other", drop=True, inplace=True) + # logger.info(f'Getting {self.__class__.__name__} data') + self._get_data(nameparts, analyses) + # logger.info('Finished getting data') + + setattr(self, f"{self.df}.columns.name", list(self.df.columns)) + self.df.reset_index(level=["ions", "atoms"], inplace=True) + self.df["_atoms"] = self.df["atoms"].where( + self.df["atoms"] != "ions", self.df["ions"], axis=0 + ) + self.df.set_index(["ions", "atoms"], inplace=True, append=True) + atoms_col = self.df["_atoms"].copy() + atoms_col = atoms_col.reset_index().set_index(self.zf.index.names) + atoms_col[self.zf.name] = np.nan + atoms_col[self.zf.name].update(self.zf) + self.zf = atoms_col + for df in [self.zf, self.df]: + df.set_index(["_atoms"], inplace=True, append=True) + # df.sort_index(inplace=True) + df.dropna(inplace=True, how="all", axis=0) + self.df.index = self.df.index.reorder_levels([*idx.names, "_atoms"]) + # self.zf.sort_index(inplace=True, sort_remaining=True) + # self.df.sort_index(inplace=True, sort_remaining=True) + _atoms = self.df.index.get_level_values("_atoms").tolist() + # + # print(self.zf, "\n", self.zf.index.names) + # print(self.df, "\n", self.df.index.names) + + for iid, i in enumerate(self.df.index.names): + value: List[Union[str, float]] = ( + self.df.index._get_level_values(level=iid).unique().tolist() + ) + logger.info(f"Setting {i} to {value}") + setattr(self, i, value) + + if odir != None: + self.odir: Path = Path(odir) + else: + self.odir: Path = Path(".").cwd() + + logger.info(f"Output directory set to {str(self.odir.resolve())!r}\n") + self.__bin_df = pd.DataFrame(columns=self.df.columns) + + self.__edges = {} + self.__peaks = {} + self.__z_bins = None + self.__other_bins = {} + + def get_fname(self, fname): + return Path(fname).parent / f"{Path(fname).stem}.tar.xz" + + def save(self, savename): + savename = self.get_fname(savename) + logger.info(f"Writing dataframe to {savename}") + pd.to_pickle(self.df, savename) + + @property + def z_bins(self): + if self.__z_bins is None: + with open(self.filelist[self.zname][0], "rb") as file: + data = pkl.load(file) + self.__z_bins = HistData(edges=data.edges) + return self.__z_bins + + @property + def other_bins(self): + for analysis in self.analyses: + if analysis not in self.__other_bins.keys(): + with open(self.filelist[analysis][0], "rb") as file: + data = pkl.load(file) + self.__other_bins[analysis] = HistData(edges=data.edges) + return self.__other_bins + + def _get_data(self, nameparts, analyses): + idsl = pd.IndexSlice + for f in self.filelist[self.zname]: + namesplit = f.stem.split("_") + namesplit.pop(-1) + namesplit = namesplit[nameparts:] + try: + clay, ion, aa, pH, atom, cutoff, bins = namesplit + assert cutoff == self.cutoff + assert bins == self.bins + try: + self.zf.loc[idsl[clay, ion, aa, atom]] = f + except KeyError: + print("KeyError") + except IndexError: + logger.info(f"Encountered IndexError while getting data") + except ValueError: + print(f, namesplit) + logger.info(f"Encountered ValueError while getting data") + + logger.info(f"Getting {self.__class__.__name__} data") + for analysis in tqdm( + analyses, leave=False, position=0, desc="analyses" + ): + logger.info(f"\tanalysis: {analysis}") + for f in tqdm( + self.filelist[analysis], leave=False, position=1, desc="files" + ): + namesplit = f.stem.split("_") + namesplit.pop(-1) + namesplit = namesplit[nameparts:] + if self.other is not None: + other = namesplit.pop(5) + if other in self.ions: + other = "ions" + try: + clay, ion, aa, pH, atom, cutoff, bins = namesplit + assert cutoff == self.cutoff + assert bins == self.bins + try: + zdata = self.zf.loc[idsl[clay, ion, aa, atom]] + timeseries = Timeseries(f) + timeseries.zdata = zdata + if self.other is None: + self.df.loc[ + idsl[clay, ion, aa, atom], analysis + ] = timeseries + else: + self.df.loc[ + idsl[clay, ion, aa, atom, other], analysis + ] = timeseries + except ValueError: + print("KeyError") + except IndexError: + logger.info(f"Encountered IndexError while getting data") + except ValueError: + print(f, namesplit) + logger.info(f"Encountered ValueError while getting data") + logger.info("Finished getting data") + logger.info(f'{self.df.dropna(how="all", axis=0)}') + logger.debug(self.zf[self.zf.isna()]) + logger.debug(self.df[self.df.isna()]) + + def __repr__(self): + return self.df.dropna(how="all", axis=0).__repr__() + + def _get_edge_fname( + self, + atom_type: str, + name: Union[Literal["pe"], Literal["edge"]] = "pe", + ): + # fname = Path.cwd() / f"edge_data/edges_{atom_type}_{self.cutoff}_{self.bins}.p" + fname = ( + Path(__file__).parent + / f"pe_data/{atom_type}_{name}_data_{self.cutoff}_{self.bins}.p" + ) + logger.info(f"Peak/edge Filename: {fname}") + return fname + + def _read_edge_file(self, atom_type: str, skip=True): + fname = self._get_edge_fname(atom_type, name="ads_edges") + if not fname.exists(): + logger.info("No edge file found.") + os.makedirs(fname.parent, exist_ok=True) + logger.info(f"{fname.parent}") + if skip is True: + logger.info(f"Continuing without ads_edges") + p = [0, self.cutoff] + else: + # self._get_edges(atom_type=atom_type) + raise FileNotFoundError(f"No edge file found {fname}.") + + else: + with open(fname, "rb") as edges_file: + logger.info(f"Reading ads_edges {edges_file.name}") + p = pkl.load(edges_file)["edges"] + logger.info(f"ads_edges:{p}") + return p + + # def get_bin_df(self): + # idx = self.df.index.names + # bin_df = self.df.copy() + # atom_types = bin_df.index.get_level_values("_atoms").unique().tolist() + # bin_df.reset_index(["x_bins", "x", "_atoms"], drop=False, inplace=True) + # for atom_type in atom_types: + # # logger.info(f"{atom_type}") + # try: + # ads_edges = self.__edges[atom_type] + # except KeyError: + # # edge_fname = self._get_edge_fname(atom_type) + # ads_edges = self._read_edge_file(atom_type=atom_type) + # # if edge_fname.is_file(): + # # self.__edges[atom_type] = self._read_edge_file(atom_type) + # # else: + # # raise + # # self._get_edges(atom_type=atom_type) + # # ads_edges = self.__edges[atom_type] + # # print(ads_edges, bin_df['x_bins'].where(bin_df['_atoms'] == atom_type)) + # bin_df["x_bins"].where( + # bin_df["_atoms"] != atom_type, + # pd.cut(bin_df["x"], [*ads_edges]), + # inplace=True, + # ) + # bin_df.reset_index(drop=False, inplace=True) + # + # bin_df.set_index( + # idx, + # inplace=True, + # ) + # self.df = bin_df.copy() + # + # @property + # def bin_df(self): + # if not self.df.index.get_level_values("x_bins").is_interval(): + # logger.info("No Interval") + # self.get_bin_df() + # else: + # logger.info("Interval") + # return self.df + + # with open(self.zdist, "rb") as zdist_file: + # zdist_data = pkl.load(zdist_file) + # self.zdist_timeseries = zdist_data.timeseries + # self.zdist_edges = zdist_data.ads_edges + # self.zdist_bins = zdist_data.bins + # timeseries_arrays = {} + # hist_arrays = {} + # for analysis in self.analyses: + # for f in self.filelist[analysis]: + # namesplit = f.stem.split("_") + # # if self.analysis is not None: + # namesplit.pop(-1) + # # else: + # # self.analysis = "zdist" + # name = namesplit[:nameparts] + # namesplit = namesplit[nameparts:] + # # if self.other != None: + # # other = namesplit[4] + # # namesplit.pop(4) + # # if other in self.ions: + # # other = "ions" + # # try: + # clay, ion, aa, pH, atom, cutoff, bins = namesplit + # assert cutoff == self.cutoff + # assert bins == self.bins + # with open(f, "rb") as file: + # array = pkl.load(f) + # timeseries_arrays[analysis] = (array.ads_edges, array.timeseries) + # hist_arrays[analysis] = np.histogramdd( + # [ + # np.ravel(self.zdist_timeseries), + # np.ravel(timeseries_arrays[analysis]), + # ], + # bins=[self.zdist_edges, self.vel_edges], + # ) + # analysis_cols = [a for a in self.analyses if self.analyses != "totvel"] + # for analysis_col in analysis_cols: + # hist_arrays[f"{analysis_col}/totvel"] = np.histogramdd( + # [ + # np.ravel(self.zdist_timeseries), + # np.divide( + # np.ravel(timeseries_arrays[analysis_col]), + # np.ravel(timeseries_arrays["totvel"]), + # where=np.ravel(timeseries_arrays["totvel"]) != 0, + # ), + # ], + # bins=[self.zdist_edges, self.vel_edges], + # ) + + # array = pd.read_csv(f, delimiter="\s+", comment="#").to_numpy() + # try: + # self.df.loc[idsl[ion, aa, atom, :], clay] = array[:, 2] + # except ValueError: + # logger.info('Using second atom species') + # self.df.loc[idsl[ion, aa, atom, other, :], clay] = array[:, 2] + # except KeyError: + # logger.info('Using 1st column') + # self.df.loc[idsl[ion, aa, atom, :], clay] = array[:, 1] + # except IndexError: + # logger.info(f"Encountered IndexError while getting data") + # except ValueError: + # logger.info(f"Encountered ValueError while getting data") + # self.name = "_".join(name) + + # def __repr__(self): + # return self.df[self.clays].dropna().__repr__() + + +class RawData: + """ + Class for raw simulation data checking and setup. + Reads files in `rootdir` that follow a directory structure + "`clay_type/ion_type/aa_type/`". + The data is stored in a :class:`pandas.DataFrame` + :param root_dir: Data directory + :type root_dir: Union[str, Path] + :param alt_dir: Alternative data directory, defaults to `None` + :type alt_dir: Union[str, Path], optional + :param ions: List of ion types in solvent + :type ions: List[Literal['Na', 'K', 'Ca', 'Mg']], optional + :param atoms: List of atom types in selection, defaults to `None` + :type atoms: List[Literal['ions', 'OT', 'N', 'CA', 'OW']], optional + :param other: Optional list of atom types in a second selection, + defaults to `None` + :type other: List[Literal['ions', 'OT', 'N', 'CA', 'OW']], optional + :param clays: List of clay types, + defaults to `None` + :type clays: List[Literal['NAu-1', 'NAu-2']], optional + :param aas: List of amino acid types in lower case 3-letter code, + defaults to `None` + :type aas: Optional[List[Literal['ala', 'arg', 'asn', 'asp', + 'ctl', 'cys', 'gln', 'glu', + 'gly', 'his', 'ile', 'leu', + 'lys', 'met', 'phe', 'pro', + 'ser', 'thr', 'trp', 'tyr', + 'val']]] + :param load: Load, defaults to False + :type load: Union[str, Literal[False], Path], optional + :param odir: Output directory, defaults to `None` + :type odir: str, optional + :param nameparts: number of `_`-separated partes in `namestem` + :type nameparts: int, defaults to 1 + :param namestem: leading string in naming pattern, optional + :type namestem: str, defaults to '' + :param analysis: trailing string in naming pattern, optional + defaults to `None` + :type analysis: str, optional + :param df: :class: `pandas.DataFrame` + """ + + aas = [ + "ala", + "arg", + "asn", + "asp", + "ctl", + "cys", + "gln", + "glu", + "gly", + "his", + "ile", + "leu", + "lys", + "met", + "phe", + "pro", + "ser", + "thr", + "trp", + "tyr", + "val", + ] + + ions = ["Na", "K", "Ca", "Mg"] + clays = ["NAu-1", "NAu-2"] + new_dirs = ["neutral", "setup"] + idx_names = ["root", "clays", "ions", "aas"] + + def __init__( + self, + root_dir: Union[str, Path], + alt_root: Optional[Union[str, Path]] = None, + ions: List[Literal["Na", "K", "Ca", "Mg"]] = None, + clays: List[Literal["NAu-1", "NAu-2"]] = None, + aas: List[ + Literal[ + "ala", + "arg", + "asn", + "asp", + "ctl", + "cys", + "gln", + "glu", + "gly", + "his", + "ile", + "leu", + "lys", + "met", + "phe", + "pro", + "ser", + "thr", + "trp", + "tyr", + "val", + ] + ] = None, + load: Union[str, Literal[False], Path] = False, + odir: Optional[str] = None, + new_dirs: List[str] = None, + ): + self.filelist: list = [] + if new_dirs is None: + self.new_dirs = self.__class__.new_dirs + else: + self.new_dirs = [Path(dir).name for dir in new_dirs] + + if type(root_dir) != Path: + root_dir = Path(root_dir) + if not root_dir.is_dir(): + logger.error(f"No directory found for {root_dir!r}") + + self.root = root_dir.resolve() + self.root_idx = [self.root] + + if alt_root is not None: + indir = root_dir / alt_root + if not indir.is_dir(): + indir = Path(alt_root).resolve() + if not indir.is_dir(): + logger.error( + f"No alternative directory found for {indir!r}" + ) + if indir.is_dir(): + self.alt = indir + self.root_idx.append(self.alt) + logger.info(f"Alternative root {self.alt.resolve()} specified") + else: + self.alt = None + # print(self.root_idx) + + # self.root = SimDir(indir.resolve()) + # else: + # self._alt_dir = None + + # self.filelist: List[Path] = sorted( + # list( + # root_dir.glob(rf"{namestem}*_" rf"{self.cutoff}_" rf"{self.bins}.dat") + # ) + # ) + # + # logger.info(f"Found {len(self.filelist)} files.") + + if load is not False: + load = Path(load.resolve()) + self.df: pd.DataFrame = pkl.load(load) + logger.info(f"Using data from {load!r}") + else: + if ions is None: + ions = self.__class__.ions + logger.info( + f"ions not specified, using default {self.__class__.ions}" + ) + else: + logger.info(f"Using custom {ions} for ions") + if aas is None: + aas = self.__class__.aas + logger.info( + f"aas not specified, using default {self.__class__.aas}" + ) + else: + logger.info(f"Using custom {aas} for aas") + if clays is None: + clays = self.__class__.clays + logger.info( + f"clays not specified, using default {self.__class__.clays}" + ) + else: + logger.info(f"Using custom {clays} for clays") + + self.clays = clays = self.modify_list_str(clays, suffix="-fe") + self.ions = ions = self.modify_list_str(ions) + self.aas = aas = self.modify_list_str(aas, suffix="_7") + + cols = pd.Index(["orig", *self.new_dirs], name="paths") + # cols = pd.Index(self.new_dirs, name="paths") + + idx = pd.MultiIndex.from_product( + [self.root_idx, clays, ions, aas], + names=self.__class__.idx_names, + ) + # df = pd.DataFrame(index=idx, columns=cols, dtype=object) + # for col in cols: + # df[col] = np.NaN + # df.columns.name = 'paths' + # self.df = df + self.df: pd.DataFrame = pd.DataFrame( + index=idx, columns=cols, dtype=object + ) + # self.idx_iter = self._get_idx_iter(self.df.index) + + for iid, i in enumerate(self.df.index.names): + value: List[Union[str, float]] = ( + self.df.index._get_level_values(level=iid) + .unique() + .tolist() + ) + logger.info(f"Setting {i} to {value}") + setattr(self, i, value) + + # logger.info(f'Getting data for {self.__class__.__name__}') + self._get_data() + # logger.info('Finished getting data') + + # self.df.dropna(inplace=True, how="all", axis=0) + + setattr(self, self.df.columns.name, list(self.df.columns)) + + for iid, i in enumerate(self.df.index.names): + value: List[Union[str, float]] = ( + self.df.index._get_level_values(level=iid).unique().tolist() + ) + logger.info(f"Setting {i} to {value}") + setattr(self, i, value) + + # print(self.root) + + if odir is not None: + self.odir: Path = Path(odir) + else: + self.odir: Path = Path(".").cwd() + print(self.odir) + + logger.info(f"Output directory set to {str(self.odir.resolve())!r}\n") + + def _get_data(self): + idsl = pd.IndexSlice + for dir_tree in self.idx_iter: + root, clay, ion, aa = dir_tree + p_str = np.apply_along_axis(lambda x: "/".join(x), 0, dir_tree[1:]) + path = root / p_str + print(root, p_str) + if path.is_dir(): + self.df.loc[idsl[root, clay, ion, aa], "orig"] = True + for new_dir in self.new_dirs: + if (path / new_dir).is_dir(): + self.df.loc[idsl[root, clay, ion, aa], new_dir] = True + + # def write_json(self, outpath=None): + # json_file = open_outfile(outpath=outpath, suffix='json', default='rawdata_paths') + # df = self.df.copy() + # df = df.stack(dropna=True) + # get_pd_idx_iter(self.df[self.df].index.unique()) + + @staticmethod + def regex_join(match_list, suffix="", prefix="") -> str: + if not isinstance(match_list[0], str) or suffix != "" or prefix != "": + match_list = list( + map(lambda x: f"{str(x.rstrip(suffix))}", match_list) + ) + match_list = list( + map(lambda x: f"{str(x.lstrip(prefix))}", match_list) + ) + match_list = list( + map(lambda x: f"{prefix}{str(x)}{suffix}", match_list) + ) + return "|".join(match_list) + + @staticmethod + def modify_list_str(match_list, suffix="", prefix="") -> List[str]: + # print(match_list) + if not isinstance(match_list[0], str) or suffix != "" or prefix != "": + match_list = list( + map(lambda x: f"{str(x.rstrip(suffix))}", match_list) + ) + match_list = list( + map(lambda x: f"{str(x.lstrip(prefix))}", match_list) + ) + match_list = list( + map(lambda x: f"{prefix}{str(x)}{suffix}", match_list) + ) + return match_list + + def save(self, savename=None, overwrite=True): + if savename is None: + savename = self.odir / f"{self.__class__.__name__}.p" + if not savename.is_file() or overwrite is True: + with open(savename, "wb") as outfile: + pkl.dump(self.df, outfile) + + def update_attrs(self): + for iid, i in enumerate(self.df.index.names): + value: List[Union[str, float]] = ( + self.df.index._get_level_values(level=iid).unique().tolist() + ) + logger.info(f"Setting {i} to {value}") + setattr(self, i, value) + + # @cached_property + @property + def idx_iter(self): + idx_values = [getattr(self, idxit) for idxit in self.df.index.names] + idx_product = np.array( + np.meshgrid(*[idx_value for idx_value in idx_values]) + ).T.reshape(-1, len(idx_values)) + # idx_product = np.apply_along_axis(lambda x: '/'.join(x), 1, idx_product) + return idx_product + + +class ArrayData2D: + aas = np.array( + [ + "ala", + "arg", + "asn", + "asp", + "ctl", + "cys", + "gln", + "glu", + "gly", + "his", + "ile", + "leu", + "lys", + "met", + "phe", + "pro", + "ser", + "thr", + "trp", + "tyr", + "val", + ] + ) + ions = ["Na", "K", "Ca", "Mg"] + atoms = ["ions", "OT", "N", "CA"] + clays = ["NAu-1", "NAu-2"] + + def __init__( + self, + indir: Union[str, Path], + namestem, + cutoff, + bins, + ions=None, + atoms=None, + other=None, + aas=None, + load=False, + odir=None, + nameparts: int = 1, + ): + self.filelist = [] + self.bins = Bins(bins) + self.cutoff = Cutoff(cutoff) + + if type(indir) != Path: + indir = Path(indir) + + # self.filelist = sorted(list(indir.glob(rf'*_{self.cutoff}_{self.bins}.dat'))) + # + + if load != False: + load = self.get_fname(load) + self.df = pd.read_pickle(load) + else: + if ions == None: + ions = self.__class__.ions + if atoms == None: + atoms = self.__class__.atoms + if aas == None: + aas = self.__class__.aas + self.filelist = sorted( + list(indir.glob(rf"{namestem}*_{self.cutoff}_{self.bins}.p")) + ) + + with open(self.filelist[0], "rb") as f: + data = pkl.load(f) + array, self.ybins, self.xbins = data.values() + + cols = pd.Index(["NAu-1", "NAu-2"], name="clays") + + if other != None: + # print("other") + other = atoms + other.append("OW") + idx = pd.MultiIndex.from_product( + [ions, aas, atoms, other], + names=["ions", "aas", "atoms", "other"], + ) + self.other = other + # print(self.other) + else: + idx = pd.MultiIndex.from_product( + [ions, aas, atoms], names=["ions", "aas", "atoms"] + ) + self.other = None + + logger.info(f"Getting DataFrame") + self.df = pd.DataFrame(index=idx, columns=cols) + + logger.info(f"Setting bins") + self.xbin_step = Bins( + np.round(np.abs(np.subtract.reduce(self.xbins[:2])), 2) + ) + self.ybin_step = Bins( + np.round(np.abs(np.subtract.reduce(self.ybins[:2])), 2) + ) + logger.info(f"Setting cutoff") + self.xcutoff = Cutoff(np.rint(self.xbins[-1])) + self.ycutoff = Cutoff(np.rint(self.ybins[-1])) + + self._get_data(nameparts) + # self.df.dropna(inplace=True, how='all', axis=0) + + setattr(self, self.df.columns.name, list(self.df.columns)) + + for iid, i in enumerate(self.df.index.names): + value = ( + self.df.index._get_level_values(level=iid).unique().tolist() + ) + setattr(self, i, value) + + if odir != None: + self.odir = Path(odir) + else: + self.odir = Path(".").cwd() + + def __repr__(self): + return self.df.__repr__() + + def get_fname(self, fname): + return Path(fname).with_suffix("tar.xz") + + def save(self, savename): + savename = self.get_fname(savename) + logger.info(f"Writing dataframe to {savename}") + pd.to_pickle(savename) + + def _get_data(self, nameparts: int): + idsl = pd.IndexSlice + for f in self.filelist: + namesplit = f.stem.split("_") + name = namesplit[:nameparts] + namesplit = namesplit[nameparts:] + if self.other != None: + other = namesplit[4] + namesplit.pop(4) + if other in self.ions: + other = "ions" + try: + clay, ion, aa, pH, atom, cutoff, bins = namesplit + + # clay, ion, aa, pH, cutoff, bins = f.stem.split(sep='_') + # atom = 'ions' + # assert cutoff == self.cutoff + # assert bins == self.xbin_step + + with open(f, "rb") as file: + data = pkl.load(file) + if type(data) == dict: + data = list(data.values()) + data = np.squeeze(data) + assert ( + type(data) == np.ndarray + ), f"Expected array type, found {data.__class__.__name__!r}" + try: + self.df.loc[idsl[ion, aa, atom], clay] = data + except ValueError: + self.df.loc[idsl[ion, aa, atom, other], clay] = data + except IndexError: + pass + except ValueError: + pass + except: + pass + self.name = "_".join(name) + + def plot( + self, + x: Literal["clays", "aas", "ions", "atoms", "other"], + y: Literal["clays", "ions", "aas", "atoms", "other"], + rowlabel: str = "y", + columnlabel: str = "x", + figsize=None, + dpi=None, + diff=False, + xmax=50, + ymax=50, + save=False, + xlim=None, + ylim=None, + cmap="magma", + odir=".", + plot_table=False, + ): + aas_classes = [ + ["arg", "lys", "his"], + ["glu", "gln"], + ["cys"], + ["pro"], + ["gly"], + ["pro"], + ["ala", "val", "ile", "leu", "met"], + ["phe", "tyr", "trp"], + ["ser", "thr", "asp", "gln"], + ] + ions_classes = [["Na", "Ca"], ["Ca", "Mg"]] + atoms_classes = [["ions"], ["N"], ["OT"], ["CA"]] + clays_classes = [["NAu-1"], ["NAu-2"]] + cmaps_seq = ["Purples", "Blues", "Greens", "Oranges", "Reds"] + cmaps_single = ["Dark2"] + sel_list = ("clays", "ions", "aas", "atoms") + # for color, attr in zip([''], sel_list): + # cmaps_dict[attr] = {} + # cm.get_cmap() + cmap_dict = {"clays": []} + + title_dict = { + "clays": "Clay type", + "ions": "Ion type", + "aas": "Amino acid", + "atoms": "Atom type", + "other": "Other atom type", + } + + sel_list = ["clays", "ions", "aas", "atoms"] + if self.other != None: + sel_list.append("other") + + separate = [s for s in sel_list if (s != x and s != y)] + idx = pd.Index([s for s in sel_list if (s != x and s not in separate)]) + + sep = pd.Index(separate) + + vx = getattr(self, x) + + if diff == True: + vx = "/".join(vx) + lx = 1 + else: + lx = len(vx) + + vy = getattr(self, y) + + ly = len(vy) + + yid = np.ravel(np.where(np.array(idx) == y))[0] + + # label_key = idx.difference(pd.Index([x, y, sep]), sort=False).values[0] + # label_id = idx.get_loc(key=label_key) + # label_classes = locals()[f'{label_key}_classes'] + # cmap_dict = {} + # single_id = 0 + # seq_id = 0 + # for category in label_classes: + # if len(category) == 1: + # cmap = matplotlib.cycler('color', cm.Dark2.colors) + # single_id += 1 + # else: + # cmap = getattr(cm, cmaps_seq[seq_id])(np.linspace(0, 1, len(category))) + # + # cmap = matplotlib.cycler('color', cmap) + # # cmap = cmap(np.linspace(0, 1, len(category))).colors + # # viridis(np.linspace(0,1,N))) + # # cm.get_cmap(cmaps_seq[seq_id], len(category)) + # seq_id += 1 + # for item_id, item in enumerate(category): + # cmap_dict[item] = cmap.__getitem__(item_id) + # + + x_dict = dict(zip(vx, np.arange(lx))) + + if diff == True: + diffstr = "diff" + sel = "diff" + self._get_densdiff() + else: + sel = self.clays + diffstr = "" + + plot_df = self.df[sel].copy() + plot_df.reset_index().set_index([*idx]) + + if figsize == None: + figsize = tuple( + [ + 5 * lx if (10 * lx) < xmax else xmax, + 5 * ly if (5 * ly) < ymax else ymax, + ] + ) + + if dpi == None: + dpi = 100 + + iters = np.array( + np.meshgrid(*[getattr(self, idxit) for idxit in idx]) + ).T.reshape(-1, len(idx)) + + logger.info(f"Printing plots for {sep}\nColumns: {vx}\nRows: {vy}") + + label_mod = lambda l: ", ".join( + [li.upper() if namei == "aas" else li for li, namei in l] + ) + + sep_it = np.array( + np.meshgrid(*[getattr(self, idxit) for idxit in sep]) + ).T.reshape(-1, len(sep)) + + for pl in sep_it: + # try: + # fig.clear() + # except: + # pass + y_dict = dict(zip(vy, np.arange(ly))) + # if separate == 'atoms' and pl != '': + # ... + + legends_list = [(a, b) for a in range(ly) for b in range(lx)] + + legends = dict( + zip(legends_list, [[] for a in range(len(legends_list))]) + ) + + # if type(pl) in [list, tuple, np.ndarray]: + # viewlist = [] + # for p in pl: + # viewlist.append(plot_df.xs((p), level=separate, axis=0)) + # + # sepview = pd.concat(viewlist) + # plsave = 'ions' + # + # else: + sepview = plot_df.xs(tuple(pl), level=sep.tolist(), axis=0) + plsave = "_".join(pl) + + fig, ax = plt.subplots( + nrows=ly, + ncols=lx, + figsize=figsize, + sharey=True, + dpi=dpi, + constrained_layout=True, + ) + + fig.suptitle( + ( + ", ".join([title_dict[s].upper() for s in separate]) + + f": {label_mod(list(tuple(zip(pl, separate))))}" + ), + size=16, + weight="bold", + ) + pi = 0 + for col in vx: + try: + view = sepview.xs(col, axis=1) + pi = 1 + except ValueError: + view = sepview + col = vx + pi += 1 + for it in iters: + try: + values = view.xs(it[0]) + # values = view.xs(tuple(it), level=idx.tolist()[0])#.reset_index(drop=False) + if type(values) == list: + values = np.squeeze(values) + if type(values) == np.ndarray: + values_array = values + data, xbins, ybins = values_array + if np.all(np.ravel(data)) >= 0: + levels = np.linspace( + np.min(data), np.max(data), 50 + ) + + try: + x_id, y_id = x_dict[col], y_dict[it[yid]] + ax[y_id, x_id].contourf( + xbins, ybins, data, cmap=cmap + ) + except: + x_id, y_id = 0, y_dict[it[yid]] + ax[y_id].contourf( + xbins, ybins, data, cmap=cmap + ) + if pi == 1: + legends[y_id, x_id].append( + it + ) # [label_id]) + else: + logger.info(f"Found {type(values)}: NaN values") + except KeyError: + logger.info(f"No data for {pl}, {vx}, {it}") + + for i in range(ly): + try: + ax[i, 0].set_ylabel( + f"{label_mod([(vy[i], y)])}\n" + rowlabel + ) + for j in range(lx): + # ax[i, j].legend([label_mod(leg, label_key) for leg in legends[i, j]], ncol=3) + if xlim != None: + ax[i, j].set_xlim((0.0, float(xlim))) + if ylim != None: + ax[i, j].set_ylim((0.0, float(ylim))) + ax[ly - 1, j].set_xlabel( + columnlabel + f"\n{label_mod([(vx[j], x)])}" + ) + except: + ax[i].set_ylabel(f"{label_mod([(vy[i], y)])}\n" + rowlabel) + # ax[i].legend([label_mod(leg, label_key) for leg in legends[i, 0]], ncol=3) + ax[ly - 1].set_xlabel( + columnlabel + f"\n{label_mod([(vx, x)])}" + ) + + fig.supxlabel(f"{title_dict[x]}s", size=14) + fig.supylabel(f"{title_dict[y]}s", size=14) + if save != False: + odir = Path(odir) + print(odir.absolute()) + if not odir.is_dir(): + os.makedirs(odir) + if type(save) == str: + fig.savefig(str(odir) / f"{save}.png") + else: + fig.savefig( + str(odir) + / f"{self.name}_{diffstr}_{x}_{y}_{plsave}_{self.cutoff}_{self.bins}.png" + ) + else: + fig.show + fig.clear() + + def make_df(self): + new_df = self.df.copy() + transformation = ( + lambda x: (make_1d(x[0]), x[1:]) if type(x) != float else x + ) + new_df = new_df.applymap(transformation) + return new_df diff --git a/package/ClayCode/analysis/lib.py b/package/ClayCode/analysis/lib.py new file mode 100644 index 00000000..e3fd6225 --- /dev/null +++ b/package/ClayCode/analysis/lib.py @@ -0,0 +1,1845 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import logging +import pathlib as pl +import pickle as pkl +import re +import shutil +import tempfile +from functools import partial, update_wrapper +from pathlib import Path, PosixPath +from typing import ( + Callable, + Dict, + List, + Literal, + NoReturn, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, + cast, + overload, +) + +import numpy as np +import pandas as pd +from ClayCode import Dir, File, GROFile, PathType +from ClayCode.analysis.consts import PE_DATA +from ClayCode.analysis.utils import select_file +from ClayCode.core import gmx +from ClayCode.core.consts import IONS, SOL, SOL_DENSITY +from ClayCode.core.utils import get_header +from ClayCode.data.consts import AA, DATA, FF, MDP, UCS +from MDAnalysis import AtomGroup, Universe +from MDAnalysis.lib.distances import minimize_vectors +from MDAnalysis.lib.mdamath import triclinic_vectors +from MDAnalysis.transformations.translate import center_in_box +from MDAnalysis.transformations.wrap import wrap +from numpy.typing import NDArray + +tpr_logger = logging.getLogger("MDAnalysis.topology.TPRparser").setLevel( + level=logging.WARNING +) + +__name__ = "lib" +__all__ = [ + "process_orthogonal", + "process_triclinic", + "select_cyzone", + "get_dist", + "select_solvent", + "process_box", + "exclude_xyz_cutoff", + "check_traj", + "run_analysis", + "get_selections", + "get_mol_prms", + "get_system_charges", + "process_orthogonal_axes", + "process_triclinic_axes", +] + +logger = logging.getLogger(__name__) + + +# def init_log( +# logname, +# level: Union[ +# Literal[20], +# Literal[10], +# Dict[ +# Union[Literal["file"], Literal["stream"]], +# Union[Literal[20], Literal[10]], +# ], +# ] = None, +# handlers: Union[ +# Literal["file"], +# Literal["stream"], +# List[Union[Literal["file"], Literal["stream"]]], +# ] = ["file", "stream"], +# runname=None, +# fpath=None, +# ): +# logger = logging.getLogger(logname) +# format = logging.Formatter("%(name)s - %(levelname)s - %(message)s") +# if isinstance(handlers, str): +# handlers = [handlers] +# if isinstance(level, str): +# level = [level * len(handlers)] +# level = dict(zip(handlers, level)) +# elif level is None: +# level = {"file": 20, "stream": 10} +# assert isinstance(level, dict), f"Unexpected type for level: {type(level)}" +# for key in level.keys(): +# assert key in handlers +# logger.setLevel(np.min(list(level.values()))) +# if "stream" in handlers: +# logger.debug("stream", level["stream"]) +# shandler = logging.StreamHandler() +# shandler.setLevel(level["stream"]) +# shandler.setFormatter(format) +# logger.addHandler(shandler) +# +# if "file" in handlers: +# logger.debug("file") +# if runname is None: +# pass +# elif runname == "__main__": +# runname = None +# fhandler = logging.FileHandler( +# get_logfname(logname, runname, time=exec_date, logpath=fpath), mode="w" +# ) +# fhandler.setLevel(level["file"]) +# fhandler.setFormatter(format) +# logger.addHandler(fhandler) +# +# return logger + + +def init_temp_inout( + inf: Union[str, Path], + outf: Union[str, Path], + new_tmp_dict: Dict[Union[Literal["crd"], Literal["top"]], bool], + which=Union[Literal["crd"], Literal["top"]], +) -> Tuple[ + Union[str, Path], + Union[str, Path], + Dict[Union[Literal["crd"], Literal["top"]], bool], +]: + inp = Path(inf) + outp = Path(outf) + if inp == outp: + temp_outp = tempfile.NamedTemporaryFile( + suffix=outp.suffix, prefix=outp.stem, dir=outp.parent + ) + # outp = inp.with_stem(inp.stem + "_tmp") + outp = Path(temp_outp.name) + new_tmp_dict[which] = True + logger.info(f"Creating temporary output file {outp.name!r}") + else: + temp_outp = None + if type(inf) == str: + inp = str(inp.resolve()) + if type(outf) == str: + temp_outp = str(outp.resolve()) + return inp, outp, new_tmp_dict, temp_outp + + +def fix_gro_residues(crdin: Union[Path, str], crdout: Union[Path, str]): + u = Universe(crdin) + if np.unique(u.residues.resnums).tolist() == [1]: + u = add_resnum(crdin=crdin, crdout=crdout) + if "iSL" not in u.residues.resnames: + u = rename_il_solvent(crdin=crdout, crdout=crdout) + return u + + +def temp_file_wrapper(f: Callable): + def wrapper(**kwargs): + kwargs_dict = locals()["kwargs"] + fargs_dict = {} + new_tmp = {} + for ftype in ["crd", "top"]: + if f"{ftype}in" in kwargs_dict.keys(): + fargs_dict[f"{ftype}in"] = kwargs_dict[f"{ftype}in"] + if f"{ftype}out" in kwargs_dict.keys(): + ( + fargs_dict[f"{ftype}in"], + fargs_dict[f"{ftype}out"], + new_tmp, + temp_outp, + ) = init_temp_inout( + kwargs_dict[f"{ftype}in"], + kwargs_dict[f"{ftype}out"], + new_tmp_dict=new_tmp, + which=ftype, + ) + elif f"{ftype}out" in kwargs_dict.keys(): + fargs_dict[f"{ftype}out"] = kwargs_dict[f"{ftype}out"] + for k, v in fargs_dict.items(): + locals()["kwargs"][k] = v + r = f(**kwargs_dict) + for ftype, new in new_tmp.items(): + if new is True: + infile = Path(fargs_dict[f"{ftype}in"]) + outfile = Path(fargs_dict[f"{ftype}out"]) + assert outfile.exists(), f"No file generated!" + shutil.copy(outfile, infile) + logger.info(f"Renaming {outfile.name!r} to {infile.name!r}") + return r + + return wrapper + + +def run_analysis( + instance: Type["ClayAnalysisBase"], start: int, stop: int, step: int +): + """Run MDAnalysis analysis. + :param start: First frame, defaults to None + :type start: int, optional + :param stop: Last frame, defaults to None + :type stop: int, optional + :param step: Iteration step, defaults to None + :type step: int, optional + """ + kwarg_dict = {} + for k, v in {"start": start, "step": step, "stop": stop}.items(): + if v is not None: + kwarg_dict[k] = v + instance.run(**kwarg_dict) + + +@overload +def get_selections( + infiles: Sequence[Union[str, Path, PosixPath]], + sel: Sequence[str], + clay_type: str, + other: Sequence[str], + in_memory: bool, +) -> Tuple[AtomGroup, AtomGroup, AtomGroup]: + ... + + +@overload +def get_selections( + infiles: Sequence[Union[str, Path, PosixPath]], + sel: Sequence[str], + clay_type: str, + other: None, + in_memory: bool, +) -> Tuple[AtomGroup, AtomGroup]: + ... + + +def get_selections( + infiles, sel, clay_type, other=None, in_memory=False +): # , save_new=True): + """Get MDAnalysis atom groups for clay, first and optional second selection. + :param in_memory: store trajectory to memory + :type in_memory: bool + :param clay_type: Unit cell type + :type clay_type: str + :param infiles: Coordinate and trajectory files + :type infiles: Sequence[Union[str, Path, PosixPath]] + :param sel: selection keywords as [resname] or [resname, atom type] or 'not SOL' + :type sel: Sequence[str] + :param other: selection keywords as [resname] or [resname, atom type], defaults to None + :type other: Sequence[str], optional + # :raises ValueError: lengths of sel or other != in [1, 2] + # :return sel: atom group for sel + # :rtype sel: AtomGroup + # :return clay: atom group for clay + # :rtype clay: AtomGroup + # :return other: atom group for other, optional + # :rtype other: AtomGroup + """ + infiles = [str(Path(infile).absolute()) for infile in infiles] + for file in infiles: + logger.info(f"Reading: {file!r}") + u = Universe(*infiles, in_memory=in_memory) + # only resname specified + if len(sel) == 1: + sel = u.select_atoms(f"resname {sel[0]}") + # rename and atom type specified + elif len(sel) == 2: + # expand search string for terminal O atom types + if sel[1] == "OT*": + sel[1] = "OT* O OXT" + sel = u.select_atoms(f"resname {sel[0]}* and name {sel[1]}") + else: + raise ValueError('Expected 1 or 2 arguments for "sel"') + if other is None: + pass + elif len(other) == 1: + logger.debug(f"other: {other}") + other = u.select_atoms(f"resname {other[0]}") + elif len(other) == 2: + logger.debug(f"other: {other}") + if other[1] == "OT*": + other[1] = "OT* O OXT" + other = u.select_atoms(f"resname {other[0]}* and name {other[1]}") + else: + raise ValueError('Expected 1 or 2 arguments for "other"') + clay = u.select_atoms(f"resname {clay_type}* and name OB* o*") + logger.info( + f"'clay': Selected {clay.n_atoms} atoms of " + f"{clay.n_residues} {clay_type!r} unit cells" + ) + + sel = select_outside_clay_stack(sel, clay) + + # Clay + two other atom groups selected + if other is not None: + other = select_outside_clay_stack(other, clay) + return sel, clay, other + + # Only clay + one other atom group selected + else: + return sel, clay + + +def select_outside_clay_stack(atom_group: AtomGroup, clay: AtomGroup): + atom_group = atom_group.select_atoms( + f" prop z >= {np.max(clay.positions[:, 2] - 1)} or" + f" prop z <= {np.min(clay.positions[:, 2] + 1)}" + ) + logger.info( + f"'atom_group': Selected {atom_group.n_atoms} atoms of names: {np.unique(atom_group.names)} " + f"(residues: {np.unique(atom_group.resnames)})" + ) + return atom_group + + +def search_ndx_group(ndx_str: str, sel_name: str): + try: + sel_match = re.search( + f"\[ {sel_name} \]", ndx_str, flags=re.MULTILINE | re.DOTALL + ).group(0) + match = True + except AttributeError: + match = False + return match + + +def save_selection( + outname: Union[str, Path], + atom_groups: List[AtomGroup], + ndx=False, + traj=".trr", + pdbqt=False, +): + ocoords = Path(outname).with_suffix("._gro").resolve() + opdb = Path(outname).with_suffix(".pdbqt").resolve() + logger.info(f"Writing coordinates and trajectory ({traj!r})") + if ndx is True: + ondx = Path(outname).with_suffix(".ndx").resolve() + if ondx.is_file(): + with open(ondx, "r") as ndx_file: + ndx_str = ndx_file.read() + else: + ndx_str = "" + outsel = atom_groups[0] + if ndx is True: + group_name = "clay" + atom_name = np.unique(atom_groups[0].atoms.names)[0][:2] + group_name += f"_{atom_name}" + if not search_ndx_group(ndx_str=ndx_str, sel_name=group_name): + atom_groups[0].write(ondx, name=group_name, mode="a") + # logger.info(f'ag : {outsel.n_atoms} {outsel.atoms}') + for ag in atom_groups[1:]: + if ndx is True: + group_name = np.unique(ag.residues.resnames)[0] + group_name = re.match("[a-zA-Z]*", group_name).group(0) + atom_name = np.unique(ag.atoms.names)[0][:2] + group_name += f"_{atom_name}" + if not search_ndx_group(ndx_str=ndx_str, sel_name=group_name): + ag.write(ondx, name=group_name, mode="a") + outsel += ag + logger.info( + f"New trajectory from {len(atom_groups)} groups with {outsel.n_atoms} total atoms" + ) + logger.info(f"1. {ocoords!r}") + outsel.write(str(ocoords)) + if pdbqt is True: + logger.info(f"2. {opdb!r}") + outsel.write(str(opdb), frames=outsel.universe.trajectory[-1::1]) + if type(traj) != list: + traj = [traj] + for t in traj: + otraj = outname.with_suffix(t).resolve() + logger.info(f"3. {otraj!r}") + outsel.write(str(otraj), frames="all") + # logger.info(f'4. {otrr}') + # outsel.write(str(otrr), frames='all') + + +def check_traj( + instance: Type["ClayAnalysisBase"], check_len: Union[int, Literal[False]] +) -> None: + """Check length of trajectory in analysis class instance. + :param instance: analysis class instance + :type instance: Type['ClayAnalysisBase'] + :param check_len: expected number of trajectory frames, defaults to False + :type check_len: Union[int, Literal[False]] + :raises SystemExit: Error if trajectory length != check_len + """ + logger.debug(f"Checking trajectory length: {check_len}") + if type(check_len) == int: + if instance._universe.trajectory.n_frames != check_len: + raise SystemExit( + "Wrong number of frames: " + f"{instance._universe.trajectory.n_frames}" + ) + + +def process_box(instance: Type["ClayAnalysisBase"]) -> None: + """Assign distance minimisation function in orthogonal or triclinic box. + + Correct x, x2, z interatomic distances for periodic boundary conditions + in orthogonal box inplace + O* + +--------------+ +---------/----+ + | S | | S | + | \ | --> | | + | \ | | | + | O | | | + +--------------+ +--------------+ + + :param instance: analysis class instance + :type instance: Type['ClayAnalysisBase'] + """ + box = instance._universe.dimensions + if np.all(box[3:] == 90.0): + instance._process_distances = process_orthogonal + instance._process_axes = process_orthogonal_axes + else: + instance._process_distances = process_triclinic + instance._process_axes = process_triclinic_axes + + +def process_orthogonal_axes( + distances: NDArray[np.float64], + dimensions: NDArray[np.float64], + axes: List[int], +) -> None: + """ + Correct x, x2, z interatomic distances for periodic boundary conditions + in orthogonal box inplace + + :param axes: + :type axes: + :param distances: interatomic distance array of shape (n, m, 3) + :type distances: NDArray[np.float64] + :param dimensions: simulation box dimension array of shape (6,) + :type dimensions: NDArray[np.float64] + :return: no return + :rtype: NoReturn + """ + assert ( + distances.shape[-1] == len(axes) or distances.ndim == 2 + ), f"Shape of distance array ({distances.shape[-1]}) does not match selected axes {axes}" + # logger.info(distances / dimensions[:3][axes], np.rint(distances / dimensions[:3][axes])) + for idx, dist in np.ma.ndenumerate(distances): + distances[idx] -= dimensions[:3][axes] * np.rint( + dist / dimensions[:3][axes] + ) + + +def process_orthogonal( + distances: NDArray[np.float64], dimensions: NDArray[np.float64] +) -> None: + """ + Correct x, x2, z interatomic distances for periodic boundary conditions + in orthogonal box inplace + + :param distances: interatomic distance array of shape (n, m, 3) + :type distances: NDArray[np.float64] + :param dimensions: simulation box dimension array of shape (6,) + :type dimensions: NDArray[np.float64] + :return: no return + :rtype: NoReturn + """ + # logger.info(distances[:3] / dimensions[:3], np.rint(distances / dimensions[:3])) + # for idx, dist in np.ma.ndenumerate(distances[:, :, 0]): + # old_dist = distances.copy() + distances -= dimensions[:3] * np.rint(distances / dimensions[:3]) + # print(distances[distances != old_dist]) + + +def process_triclinic_axes( + distances: NDArray[np.float64], + dimensions: NDArray[np.float64], + axes: List[int], +) -> None: + """ + Correct x, x2, z interatomic distances for periodic boundary conditions + in triclinic box inplace + + :param axes: + :type axes: + :param distances: interatomic distance array of shape (n, m, 3) + :type distances: NDArray[np.float64] + :param dimensions: simulation box dimension array of shape (6,) + :type dimensions: NDArray[np.float64] + :return: no return + :rtype: NoReturn + """ + box = triclinic_vectors(dimensions) + assert distances.shape[-1] >= len( + axes + ), f"Shape of distance array ({distances.shape[-1]}) does not match selected axes {axes}" + logger.info( + distances / np.diagonal(box)[..., axes], + np.rint(distances / np.diagonal(box)[..., axes]), + ) + distances -= np.diagonal(box)[..., axes] * np.rint( + distances / np.diagonal(box)[..., axes] + ) + + +def process_triclinic( + distances: NDArray[np.float64], dimensions: NDArray[np.float64] +) -> None: + """ + Correct x, x2, z interatomic distances for periodic boundary conditions + in triclinic box inplace + + :param distances: interatomic distance array of shape (n, m, 3) + :type distances: NDArray[np.float64] + :param dimensions: simulation box dimension array of shape (6,) + :type dimensions: NDArray[np.float64] + :return: no return + :rtype: NoReturn + """ + box = triclinic_vectors(dimensions) + # logger.info(distances / np.diagonal(box), np.rint(distances / np.diagonal(box))) + distances -= np.diagonal(box) * np.rint(distances / np.diagonal(box)) + + +def select_cyzone( + distances: MaskedArray, + z_dist: float, + xy_rad: float, + mask_array: MaskedArray, +) -> None: + """ + Select all distances corresponding to atoms within a cylindrical volume + of dimensions +- z_dist and radius xy_rad + :param distances: masked interatomic distance array of shape (n, m, 3) + :type distances: MaskedArray[np.float64] + :param z_dist: absolute value for cutoff in z direction + :type z_dist: float + :param xy_rad: absolute value for radius in xy plane + :type xy_rad: float + :param mask_array: array for temporary mask storage of shape (n, m) + :type mask_array: MaskedArray[np.float64] + :return: no return + :rtype: NoReturn + """ + z_col = distances[:, :, 2] + z_col.mask = np.abs(distances[:, :, 2]) > z_dist + distances.mask = np.broadcast_to( + z_col.mask[:, :, np.newaxis], distances.shape + ) + np.ma.sum(distances[:, :, [0, 1]].__pow__(2), axis=2, out=mask_array) + # print(np.min(mask_array)) + # print(xy_rad.__pow__(2)) + mask_array.harden_mask() + mask_array.mask = mask_array > xy_rad.__pow__(2) + np.copyto(distances.mask, mask_array.mask[:, :, np.newaxis]) + ... + + +def exclude_xyz_cutoff(distances: NDArray[np.int64], cutoff: float) -> None: + """ + Select all distances corresponding to atoms within a box + with length 2* cutoff + :param distances: masked interatomic distance array of shape (n, m, 3) + :type distances: NDArray[np.float64] + :param cutoff: absolute value for maximum distance + :type cutoff: float + :return: no return + :rtype: NoReturn + """ + mask = np.any(np.abs(distances) >= cutoff, axis=2) + np.copyto(distances.mask, mask[:, :, np.newaxis]) + # distances.mask += np.abs(distances) > cutoff + # distances.mask = np.bitwise_and.accumulate(distances.mask, axis=2) + + +def exclude_z_cutoff(distances: NDArray[np.int64], cutoff: float) -> None: + """ + Select all distances corresponding to atoms within a box + with length 2* cutoff + :param distances: masked interatomic distance array of shape (n, m, 3) + :type distances: NDArray[np.float64] + :param cutoff: absolute value for maximum distance + :type cutoff: float + :return: no return + :rtype: NoReturn + """ + mask = np.abs(distances[..., 2]) > cutoff + distances.mask += np.broadcast_to(mask[..., np.newaxis], distances.shape) + + +def get_dist( + ag_pos: NDArray[np.float64], + ref_pos: NDArray[np.float64], + distances: NDArray[np.float64], + box: NDArray[np.float64], +) -> NoReturn: + """Calculate minimum elementwise x, x2, z distances + of selection atom positions to reference atom positions in box. + Output array shape(len(ag_pos), len(ref_pos), 3) + :param ag_pos: atom group positions of shape (n_atoms, 3) + :type ag_pos: NDArray[np.float64] + :param ref_pos: atom group positions of shape (n_atoms, 3) + :type ref_pos: NDArray[np.float64] + distances: result array of shape (len(ag_pos), len(ref_pos), 3) + :type distances: NDArray[np.float64] + :param box: Timestep dimensions array of shape (6, ) + :type box: NDArray[np.float64] + """ + for atom_id, atom_pos in enumerate(ag_pos): + distances[atom_id, :, :] = minimize_vectors(atom_pos - ref_pos, box) + + +def get_self_dist( + ag_pos: NDArray[np.float64], distances: NDArray[np.float64] +) -> NoReturn: + """Calculate minimum elementwise x, x2, z distances + of selection atom positions to reference atom positions in box. + Output array shape(len(ag_pos), len(ref_pos), 3) + :param ag_pos: atom group positions of shape (n_atoms, 3) + :type ag_pos: NDArray[np.float64] + distances: result array of shape (len(ag_pos), len(ref_pos), 3) + :type distances: NDArray[np.float64] + :param box: Timestep dimensions array of shape (6, ) + :type box: NDArray[np.float64] + """ + for atom_id, atom_pos in enumerate(ag_pos): + distances[atom_id, ...] = np.where( + np.ix_(ag_pos[..., 0]) != atom_id, atom_pos - ag_pos, 0 + ) + + +# @cython.boundscheck(False) +# @cython.wraparound(False) +# @cython.cdivision(True) +# def _minimize_vectors_ortho(cython.floating[:, :] vectors not None, cython.floating[:] box not None, +# cython.floating[:, :] output not None): +# cdef int i, n +# cdef cython.floating box_inverse[3] +# cdef cython.floating[:] box_inverse_view +# +# box_inverse[0] = 1.0 / box[0] +# box_inverse[1] = 1.0 / box[1] +# box_inverse[2] = 1.0 / box[2] +# +# box_inverse_view = box_inverse +# +# n = len(vectors) +# with nogil: +# for i in range(n): +# output[i, 0] = vectors[i, 0] +# output[i, 1] = vectors[i, 1] +# output[i, 2] = vectors[i, 2] +# _minimum_image_orthogonal(output[i, :], box, box_inverse_view) + +# https://github.com/MDAnalysis/mdanalysis/blob/develop/package/MDAnalysis/lib/c_distances.pyx + +# @cython.boundscheck(False) +# @cython.wraparound(False) +# cdef inline void _minimum_image_orthogonal(cython.floating[:] dx, +# cython.floating[:] box, +# cython.floating[:] inverse_box) nogil: +# """Minimize dx to be the shortest vector +# Parameters +# ---------- +# dx : numpy.array, shape (3,) +# vector to minimize +# box : numpy.array, shape (3,) +# box length in each dimension +# inverse_box : numpy.array, shape (3,) +# inverse of box +# Operates in-place on dx! +# """ +# cdef int i +# cdef cython.floating s +# +# for i in range(3): +# if box[i] > 0: +# s = inverse_box[i] * dx[i] +# dx[i] = box[i] * (s - cround(s)) + + +def select_solvent( + center_ag: str, solvent_ag: AtomGroup, radius: float +) -> AtomGroup: + """Select solvent OW* atoms within sphere of + specified radius around atom group + :param center_ag: solvated atom group + :type center_ag: AtomGroup + :param solvent_ag: solvent atom group + :type center_ag: AtomGroup + :param radius: sphere radius + :type radius: float + :return: subsection of solvent_ag + :rtype: AtomGroup + :""" + return solvent_ag.select_atoms( + f"name OW* and {radius} around global center_ag", updating=True + ) + + +def update_universe(f): + def wrapper(crdname: str, crdout: Union[str, Path], **kwargs) -> Universe: + u = Universe(str(crdname)) + f(u=u, crdout=crdout, **kwargs) + u = Universe(str(crdout)) + return u + + return wrapper + + +def get_n_mols( + conc: Union[float, int], + u: Universe, + solvent: str = SOL, + density: float = SOL_DENSITY, # g/dm^3 +): + sol = u.select_atoms(f"resname {solvent}") + m = np.sum(sol.masses) # g + V = m / density # L + n_mols = conc * V + n_mols = np.round(n_mols).astype(int) + logger.info( + "Calculating molecule numbers:\n" + f"Target concentration = {conc:.3f} mol L-1\n" + f"Bulk volume = {V:.2f} A3\n" + f"Density = {density:.2f} g L-1\n" + f"Molecules to add = {n_mols}\n" + ) + return n_mols + + +def write_insert_dat( + n_mols: Union[int, float], save: Union[str, Literal[False]] +): + pos = np.zeros((int(n_mols), 3), dtype=np.float16) + if save: + save = pl.Path(save) + if save.suffix != ".dat": + save = str(save.resolve()) + ".dat" + logger.debug(f"Saving {n_mols} insert positions to {save}") + np.savetxt(save, pos, fmt="%4.3f", delimiter=" ", newline="\n") + with open(save, "r") as file: + r = file.read() + + +@update_universe +def center_clay(u: Universe, crdout: Union[Path, str], uc_name: Optional[str]): + if uc_name is None: + clay = u.select_atoms("not resname SOL iSL" + " ".join(IONS)) + else: + clay = u.select_atoms(f"resname {uc_name}*") + for ts in u.trajectory: + ts = center_in_box(clay, wrap=True)(ts) + ts = wrap(u.atoms)(ts) + u.atoms.write(crdout) + + +@temp_file_wrapper +def add_mol_list_to_top( + topin: Union[str, pl.Path], + topout: Union[str, pl.Path], + insert_list: List[str], + ff_path: Union[pl.Path, str] = FF, +): + logger.debug(insert_list) + with open(topin, "r") as topfile: + topstr = topfile.read().rstrip() + topmatch = re.search( + r"\[ system \].*", topstr, flags=re.MULTILINE | re.DOTALL + ).group(0) + ff_path = pl.Path(ff_path) + if not ff_path.is_dir(): + raise FileNotFoundError( + f"Specified force field path: {ff_path.resolve()!r} does not exist!" + ) + with open(ff_path / f"new_tophead.itp", "r") as tophead: + tophead = tophead.read() + if len(insert_list) != 0: + topstr = "\n".join([tophead, topmatch, *insert_list]) + else: + topstr = "\n".join([tophead, topmatch]) + with open(topout, "w") as topfile: + topfile.write(topstr) + assert Path(topout).exists() + + +@temp_file_wrapper +def neutralise_system( + odir: Path, crdin: Path, topin: Path, topout: Path, nion: str, pion: str +): + logger.debug("neutralise_system") + mdp = MDP / "genion.mdp" + assert mdp.exists(), f"{mdp.resolve()} does not exist" + odir = Path(odir).resolve() + assert odir.is_dir() + # make_opath = lambda p: odir / f"{p.stem}.{p.suffix}" + tpr = odir / "neutral.tpr" + ndx = odir / "neutral.ndx" + # isl = grep_file(crdin, 'iSL') + gmx.run_gmx_make_ndx(f=crdin, o=ndx) + if ndx.is_file(): + # if topin.resolve() == topout.resolve(): + # topout = topout.parent / f"{topout.stem}_n.top" + # otop_copy = True + # else: + # otop_copy = False + _, out = gmx.run_gmx_grompp( + f=MDP / "genion.mdp", + c=crdin, + p=topin, + o=tpr, + pp=topout, + po=tpr.with_suffix(".mdp"), + v="", + maxwarn=1, + # renum="", + ) + # isl = grep_file(crdin, 'iSL') + err = re.search(r"error", out) + if err is not None: + logger.error(f"gmx grompp raised an error!") + replaced = [] + else: + logger.debug(f"gmx grompp completed successfully.") + out = gmx.run_gmx_genion_neutralise( + s=tpr, + p=topout, + o=crdin, + n=ndx, + pname=pion, + pq=int(get_ion_charges()[pion]), + nname=nion, + nq=int(get_ion_charges()[nion]), + ) + if not topout.is_file(): + logger.error(f"gmx genion raised an error!") + else: + logger.info(f"gmx genion completed successfully.") + # add_resnum(crdname=crdin, crdout=crdin) + # rename_il_solvent(crdname=crdin, crdout=crdin) + # isl = grep_file(crdin, 'iSL') + # if otop_copy is True: + # shutil.move(topout, topin) + replaced = re.findall( + "Replacing solvent molecule", out.stderr, flags=re.MULTILINE + ) + logger.info(f"{crdin.name!r} add numbers, rename il solv") + add_resnum(crdin=crdin, crdout=crdin) + rename_il_solvent(crdin=crdin, crdout=crdin) + else: + logger.error(f"No index file {ndx.name} created!") + replaced = [] + return len(replaced) + + +@update_universe +def _remove_excess_gro_ions( + u: Universe, + crdout: Union[Path, str], + n_ions: int, + ion_type: str, +) -> None: + last_sol_id = u.select_atoms("resname SOL")[-1].index + ions = u.select_atoms(f"resname {ion_type}") + remove_ions = u.atoms.select_atoms( + f"index {last_sol_id + 1} - " f"{ions.indices[-1] - n_ions}" + ) + u.atoms -= remove_ions + logger.debug(f"Removing {remove_ions.n_atoms} " f"{ion_type} atoms") + u.atoms.write(str(crdout)) + logger.debug(f"Writing new coordinates to {Path(crdout).resolve()!r}") + return u + + +@temp_file_wrapper +def _remove_excess_top_ions( + topin: Union[Path, str], + topout: Union[Path, str], + n_ions: int, + ion_type: str, +) -> None: + with open(topin, "r") as topfile: + topstr = topfile.read() + ion_matches = re.search( + rf".*system.*({ion_type}\s+\d)*.*({ion_type}\s+{n_ions}.*)", + topstr, + flags=re.MULTILINE | re.DOTALL, + ) + sol_matches = re.search( + rf"(.*system.*SOL\s+\d+\n).*", topstr, flags=re.MULTILINE | re.DOTALL + ) + topstr = sol_matches.group(1) + ion_matches.group(2) + with open(topout, "w") as topfile: + logger.debug(f"Writing new topology to {Path(topout).resolve()!r}.") + topfile.write(topstr) + + +@temp_file_wrapper +def remove_excess_ions(crdin, topin, crdout, topout, n_ions, ion_type) -> None: + _remove_excess_top_ions( + topin=topin, topout=topout, n_ions=n_ions, ion_type=ion_type + ) + _remove_excess_gro_ions( + crdname=crdin, crdout=crdout, n_ions=n_ions, ion_type=ion_type + ) + + +@temp_file_wrapper +def rename_il_solvent(crdin: Universe, crdout: Union[Path, str]) -> None: + u = Universe(str(crdin)) + if "isl" not in list( + map(lambda n: n.lower(), np.unique(u.residues.resnames)) + ): + logger.info(f"Renaming interlayer SOL to iSL") + isl: AtomGroup = u.select_atoms("resname SOL").residues + idx: int = isl[np.ediff1d(isl.resnums, to_end=1) != 1][-1].resnum + isl: AtomGroup = isl.atoms.select_atoms(f"resnum 0 - {idx}") + isl.residues.resnames = "iSL" + if type(crdout) != Path: + crdout = Path(crdout) + crdout = str(crdout.resolve()) + u.atoms.write(crdout) + else: + logger.info(f"No interlayer SOL to rename") + if str(Path(crdin).resolve()) != str(Path(crdout.resolve())): + logger.info(f"Overwriting {crdin.name!r}") + shutil.move(crdin, crdout) + return u + + +@temp_file_wrapper +def add_resnum(crdin: Union[Path, str], crdout: Union[Path, str]) -> Universe: + u = Universe(str(crdin)) + # print(str(crdin)) + logger.info(f"Adding residue numbers to:\n{crdin.resolve()!r}") + res_n_atoms = get_system_n_atoms(crds=u, write=False) + atoms: AtomGroup = u.atoms + for i in np.unique(atoms.residues.resnames): + logger.info(f"Found residues: {i} - {res_n_atoms[i]} atoms") + res_idx = 1 + first_idx = 0 + last_idx = 0 + resids = [] + while last_idx < atoms.n_atoms: + resname = atoms[last_idx].residue.resname + n_atoms = res_n_atoms[resname] + last_idx = first_idx + n_atoms + first_idx = last_idx + resids.extend(np.full(n_atoms, res_idx).tolist()) + res_idx += 1 + logger.info(f"added {len(resids)} residues") + resids = list(map(lambda resid: f"{resid:5d}", resids)) + if type(crdout) != Path: + crdout = Path(crdout) + crdout = str(crdout.resolve()) + pattern = re.compile(r"^\s*\d+") + with open(crdin, "r") as crdfile: + crdlines = crdfile.readlines() + crdlines = [line for line in crdlines if re.match(r"\s*\n", line) is None] + new_lines = crdlines[:2] + for linenum, line in enumerate(crdlines[2:-1]): + line = re.sub(pattern, resids[linenum], line) + new_lines.append(line) + new_lines.append(crdlines[-1]) + with open(crdout, "w") as crdfile: + logger.debug(f"Writing coordinates to {str(crdout)!r}") + for line in new_lines: + crdfile.write(line) + logger.info(f"{crdfile.name!r} written") + u = Universe(str(crdout)) + return u + + # new_resnums = [] + # for res_idx, res in enumerate(atoms): + # new_resnums.append(np.full(res_n_atoms[res.resname], res_idx+1, dtype=int)) + # print(res, res_n_atoms[res.resname]) + + +# class PrmInfo: +# def __init__(self, +# name: Literal[Union['charges', 'n_atoms']], +# include_dir: Union[str, Path] = FF, +# write=False, +# force_update=False, +# ): +# self.name = + + +PRM_INFO_DICT = { + "n_atoms": cast( + Callable[[Universe], Dict[str, int]], + lambda u: dict( + [ + (r, u.select_atoms(f"moltype {r}").n_atoms) + for r in u.atoms.moltypes + ] + ), + ), + "charges": cast( + Callable[[Universe], Dict[str, float]], + lambda u: dict( + zip(u.atoms.moltypes, np.round(u.atoms.residues.charges, 4)) + ), + ), +} + + +def get_mol_prms( + prm_str: str, + itp_file: Union[str, pl.Path], + include_dir: Union[str, pl.Path] = FF, + write=False, + force_update=False, +) -> dict: + dict_func = PRM_INFO_DICT[prm_str] + residue_itp = Path(itp_file) + prop_file = DATA / f"{residue_itp.stem}_{prm_str}.p" + if (force_update is True) or (not prop_file.is_file()): + atom_u = Universe( + str(residue_itp), + topology_format="ITP", + include_dir=str(include_dir), + infer_system=True, + ) + prop_dict = dict_func(atom_u) + if write is True: + with open(prop_file, "wb") as prop_file: + pkl.dump(prop_dict, prop_file) + else: + with open(prop_file, "rb") as prop_file: + prop_dict = pkl.read(prop_file) + return prop_dict + + +get_mol_n_atoms = partial(get_mol_prms, prm_str="n_atoms") +update_wrapper(get_mol_n_atoms, "n_atoms") + +get_mol_charges = partial(get_mol_prms, prm_str="charges") +update_wrapper(get_mol_charges, "charges") + +PRM_METHODS = {"charges": get_mol_charges, "n_atoms": get_mol_n_atoms} +# def get_residue_n_atoms( +# residue_itp: Union[str, pl.Path], +# include_dir: Union[str, pl.Path] = FF, +# write=False, +# force_update=False, +# ): +# residue_itp = Path(residue_itp) +# n_atoms_file = DATA / f"{residue_itp.stem}_n_atoms.p" +# if not n_atoms_file.is_file() or force_update is True: +# atom_u = Universe( +# str(residue_itp), +# topology_format="ITP", +# include_dir=str(include_dir), +# infer_system=True, +# ) +# n_atoms_dict = dict([(r, atom_u.select_atoms(f'moltype {r}').n_atoms) for r in atom_u.atoms.moltypes]) +# if write is True: +# with open(n_atoms_file, "wb") as n_atoms_file: +# pkl.dump(n_atoms_dict, n_atoms_file) +# else: +# with open(n_atoms_file, "rb") as n_atoms_file: +# n_atoms_dict = pkl.read(n_atoms_file) +# return n_atoms_dict + +# def get_atom_type_charges( +# atom_itp: Union[str, pl.Path], +# include_dir: Union[str, pl.Path] = FF, +# write=False, +# force_update=False, +# ): +# atom_itp = Path(atom_itp) +# charge_file = DATA / f"{atom_itp.stem}_charges.p" +# if not charge_file.is_file() or force_update is True: +# atom_u = Universe( +# str(atom_itp), +# topology_format="ITP", +# include_dir=str(include_dir), +# infer_system=True, +# ) +# charge_dict = dict( +# zip(atom_u.atoms.moltypes, np.round(atom_u.atoms.residues.charges, 4)) +# ) +# if write == True: +# with open(charge_file, "wb") as charge_file: +# pkl.dump(charge_dict, charge_file) +# else: +# with open(charge_file, "rb") as charge_file: +# charge_dict = pkl.read(charge_file) +# return charge_dict + +ion_itp = FF / "Ion_Test.ff/ions.itp" +get_ion_charges = partial(get_mol_charges, itp_file=ion_itp) +update_wrapper(get_ion_charges, ion_itp) + +get_ion_n_atoms = partial(get_mol_n_atoms, itp_file=ion_itp) +update_wrapper(get_ion_charges, ion_itp) + + +def get_ion_prms(prm_str: str, **kwargs): + if prm_str == "charges": + prm_dict = get_ion_charges(**kwargs) + elif prm_str == "n_atoms": + prm_dict = get_ion_n_atoms(**kwargs) + else: + raise KeyError(f"Unexpected parameter: {prm_str!r}") + return prm_dict + + +# get_clay_charges = partial(get_atom_type_charges, atom_itp=FF/"ClayFF_Fe.ff/ffnonbonded.itp") + + +def get_clay_prms(prm_str: str, uc_name: str, uc_path=UCS, force_update=False): + prm_func = PRM_METHODS[prm_str] + prm_file = DATA / f"{uc_name.upper()}_{prm_str}.pkl" + if not prm_file.is_file() or force_update is True: + charge_dict = {} + uc_files = uc_path.glob(rf"{uc_name}/*[0-9].itp") + for uc_file in uc_files: + uc_charge = prm_func( + itp_file=uc_file, write=False, force_update=force_update + ) + charge_dict.update(uc_charge) + else: + with open(prm_file, "rb") as prm_file: + charge_dict = pkl.read(prm_file) + return charge_dict + + +get_clay_charges = partial(get_clay_prms, prm_str="charges") +update_wrapper(get_clay_charges, "charges") + +get_clay_n_atoms = partial(get_clay_prms, prm_str="n_atoms") +update_wrapper(get_clay_n_atoms, "n_atoms") + + +def get_sol_prms( + prm_str: str, + sol_path=FF / "ClayFF_Fe.ff", + include_dir: Union[str, pl.Path] = FF, + force_update=False, +): + prm_func = PRM_METHODS[prm_str] + charge_file = DATA / f"SOL_{prm_str}.pkl" + if not charge_file.is_file() or (force_update is True): + charge_dict = {} + sol_fnames = ["interlayer_spc", "spc"] + for file in sol_fnames: + itp = f"{sol_path}/{file}.itp" + sol_charge = prm_func( + itp_file=itp, + write=False, + include_dir=include_dir, + force_update=force_update, + ) + charge_dict.update(sol_charge) + else: + with open(charge_file, "rb") as charge_file: + charge_dict = pkl.read(charge_file) + return charge_dict + + +get_sol_charges = partial(get_sol_prms, prm_str="charge") +update_wrapper(get_sol_charges, "charges") + +get_sol_n_atoms = partial(get_sol_prms, prm_str="n_atoms") +update_wrapper(get_sol_n_atoms, "n_atoms") + + +def get_aa_prms(prm_str: str, aa_name: str, aa_path=AA, force_update=False): + prm_func = PRM_METHODS[prm_str] + charge_file = DATA / f"{aa_name.upper()}_{prm_str}.pkl" + if not charge_file.is_file() or force_update is True: + charge_dict = {} + aa_dirs = aa_path.glob(rf"pK[1-9]/{aa_name.upper()}[1-9].itp") + for aa_file in aa_dirs: + aa_charge = prm_func( + itp_file=aa_file, write=False, force_update=force_update + ) + charge_dict.update(aa_charge) + else: + with open(charge_file, "rb") as charge_file: + charge_dict = pkl.read(charge_file) + return charge_dict + + +get_aa_charges = partial(get_aa_prms, prm_str="charges") +update_wrapper(get_aa_charges, "charges") + +get_aa_n_atoms = partial(get_aa_prms, prm_str="n_atoms") +update_wrapper(get_aa_n_atoms, "n_atoms") + + +def get_all_prms(prm_str, force_update=True, write=True, name=None): + if name is not None: + namestr = f"{name}_" + else: + namestr = "" + charge_file = DATA / f"{namestr}{prm_str}.pkl" + if not charge_file.is_file() or force_update is True: + ion_dict = get_ion_prms(prm_str=prm_str, force_update=force_update) + aa_types = [ + "ala", + "arg", + "asn", + "asp", + "ctl", + "cys", + "gln", + "glu", + "gly", + "his", + "ile", + "leu", + "lys", + "met", + "phe", + "pro", + "ser", + "thr", + "trp", + "tyr", + "val", + ] + aa_dict = {} + for aa in aa_types: + aa_dict.update( + get_aa_prms( + prm_str=prm_str, aa_name=aa, force_update=force_update + ) + ) + clay_types = ["D21"] + clay_dict = {} + for uc in clay_types: + clay_dict.update( + get_clay_prms( + prm_str=prm_str, uc_name=uc, force_update=force_update + ) + ) + sol_dict = get_sol_prms(prm_str=prm_str, force_update=force_update) + charge_dict = {**ion_dict, **clay_dict, **sol_dict, **aa_dict} + if write is True: + with open(charge_file, "wb") as file: + pkl.dump(charge_dict, file) + else: + with open(charge_file, "rb") as file: + charge_dict = pkl.load(file) + return charge_dict + + +get_all_charges = partial(get_all_prms, prm_str="charges") +update_wrapper(get_all_charges, "charges") + +get_all_n_atoms = partial(get_all_prms, prm_str="n_atoms") +update_wrapper(get_all_n_atoms, "n_atoms") + + +def get_system_prms( + prm_str, crds: Union[str, Path, Universe], write=True, force_update=True +) -> Union[str, pd.Series, None]: + if type(crds) == Universe: + u = crds + name = "universe" + else: + try: + u = Universe(str(crds)) + name = Path(crds).stem + except ValueError: + logger.error(f"Could not create Universe from {crds}") + return None + prm_df = pd.Series( + get_all_prms(prm_str, write=write, force_update=force_update), + name=name, + ) + if prm_str == "charges": + residue_df = pd.Series( + u.residues.resnames, name="residues", dtype="str" + ) + residue_df = residue_df.aggregate("value_counts") + prm_df = pd.concat( + {prm_str: prm_df, "counts": residue_df}, axis=1, join="inner" + ) + prm_df["sum"] = prm_df.apply("product", axis=1).astype(int) + sys_prms = prm_df["sum"].sum().astype(int) + elif prm_str == "n_atoms": + sys_prms = prm_df + return sys_prms + + +get_system_charges = partial(get_system_prms, prm_str="charges") +update_wrapper(get_system_charges, "charges") + +get_system_n_atoms = partial(get_system_prms, prm_str="n_atoms") +update_wrapper(get_system_n_atoms, "n_atoms") + + +# def neutralise_charge( +# charge, +# crdin: Union[str, pl.Path], +# crdout: Union[str, pl.Path], +# topin: Union[str, pl.Path], +# topout: Union[str, pl.Path], +# ion_itp: Union[str, pl.Path], +# include_dir: Union[str, pl.Path], +# pion: Optional[str] = "Na", +# nion: Optional[str] = "Cl", +# ): +# if charge == 0: +# return +# else: +# logger.debug("\n# NEUTRALISING EXCESS CHARGE:\n") +# ion_u = Universe( +# str(ion_itp), +# topology_format="ITP", +# include_dir=str(include_dir), +# infer_system=True, +# ) +# bulk_ions = [pion, nion] +# ion_sel = [ +# ion_u.select_atoms(f"resname {pion}"), +# ion_u.select_atoms(f"resname {nion}"), +# ] +# ion_charges = list(np.sum(ion[0].charge) for ion in ion_sel) +# pion = {k: v for k, v in zip(bulk_ions, ion_charges) if v > 0} +# nion = {k: v for k, v in zip(bulk_ions, ion_charges) if v < 0} +# n_ions = [0, 0] +# if len(pion) > 1 or len(nion) > 1: +# raise KeyError( +# f"expected one value for positive/negative bulk ion, found {len(pion)}/{len(nion)}" +# ) +# if charge < 0: +# ion_id = 0 +# else: +# ion_id = 1 +# with open(topin, "r") as topfile: +# topstr = topfile.read() +# substr = rf"({bulk_ions[ion_id]}\s*)([0-9]*)" +# pattern = rf"{substr}(?!.*{substr})" +# topmatch = re.search(pattern, topstr, flags=re.MULTILINE | re.DOTALL).group(2) +# logger.debug(f"Found {topmatch} {bulk_ions[ion_id]} ions") +# add_ions = int(abs(charge // ion_charges[ion_id])) +# n_ions[ion_id] = int(add_ions) +# logger.debug(f"Adding {add_ions} {bulk_ions[ion_id]} to topology.") +# if ion_charges[ion_id] != 1: +# remainder = int((charge % ion_charges[ion_id])) +# if remainder != 0: +# n_ions[ion_id] += ion_id +# logger.debug("Remaining non-zero charge.") +# other_id = np.abs(ion_id - 1) +# if np.abs(ion_charges[other_id]) != 1: +# bulk_ions[other_id] = "Cl" +# ion_charges[other_id] = -1 +# n_ions[other_id] = np.abs(remainder) +# logger.debug( +# f"Adding {n_ions[other_id]} atoms of {bulk_ions[other_id]}." +# ) +# ndx = "replace.ndx" +# tpr = "genion.tpr" +# if topin is not topout: +# shutil.copyfile(topin, topout) +# gmx.run_gmx_select(s=crdin, f=crdin, select="'SOL'", on=ndx) +# _, grompp = gmx.run_gmx_grompp( +# f="genion.mdp", c=crdin, p=topout, o=tpr, maxwarn=1 +# ) +# err = re.search(r"error", grompp) +# if err is None: +# crdout, err = gmx.run_gmx_genion( +# n=ndx, +# s=tpr, +# p=topout, +# o=crdout, +# pname=bulk_ions[0], +# np=int(n_ions[0]), +# pq=int(ion_charges[0]), +# nname=bulk_ions[1], +# nn=int(n_ions[1]), +# nq=int(ion_charges[1]), +# ) +# replaced = re.findall("Replacing solvent molecule", err, flags=re.MULTILINE) +# if len(replaced) == np.sum(n_ions): +# logger.debug( +# f"Inserted {n_ions[0]} {bulk_ions[0]} and {n_ions[1]} {bulk_ions[1]} ions." +# ) +# else: +# logger.debug(grompp) +# shutil.rmtree(crdin.parent) +# logger.debug("\ngmx grompp raised error!") + + +def add_mols_to_top( + topin: Union[str, pl.Path], + topout: Union[str, pl.Path], + insert: Union[None, str, pl.Path], + n_mols: int, + include_dir, +): + if n_mols != 0: + itp_path = pl.Path(insert).parent.resolve() + itp = pl.Path(insert).stem + # insertgro = Universe(str(itp_path / f'{itp}.crdin')) + insert = Universe( + str(itp_path / f"{itp}.itp"), + topology_format="ITP", + include_dir=include_dir, + ) + else: + itp_path = pl.Path(topin).parent + topin = pl.Path(topin) + with open(topin, "r") as topfile: + topstr = topfile.read().rstrip() + topmatch = re.search( + r"\[ system \].*", topstr, flags=re.MULTILINE | re.DOTALL + ).group(0) + with open(itp_path.parent.parent / "FF/tophead.itp", "r") as tophead: + tophead = tophead.read() + # topstr_iter = topstr.splitlines() + + # for line in topstr_iter: + # match =re.search(r'^\s*[#]\s*include\s*["'+r"']"+r'([A-Za-z0-9./_])["'+r"']" , line) + # if match is not None: + # top_list.append(match.group(1)) + # for file in top_list: + # with open(file, 'r') as topfile: + # topstr=topfile.read().rstrip() + # topstr = topstr.splitlines() + # topstr_iter.extend(topstr) + # prm_dict=OrderedDict() + # prms=['defaults', 'atomtypes','bondtypes','pairtypes','angletypes', 'dihedraltypes', + # 'constrainttypes', 'nonbond_params', 'moleculetype', 'atoms', 'system', 'molecules'] + # for prm in prms: + # prm_dict[prm] = [] + # prm_str=get_search_str(prm_dict) + # + # match_list=[] + # kwd=None + # topstr_iter=iter(topstr_iter) + # while True: + # try: + # line=next(topstr_iter) + # + # match=re.search(rf'{prm_str}', line) + # if match is None and kwd is None: + # pass + # elif match is not None: + # + # kwd = match.group(0) + # + # if len(match_list) != 0: + # if last_kwd == 'defaults': + # for item in match_list: + # if item.strip().startswith(';'): + # match_list.remove(item) + # prm_dict[last_kwd].extend([match_list]) + # match_list = [] + # last_kwd=kwd + # line=next(topstr_iter) + # elif match is None and re.search(r'[[][a-zA-Z\s]+[]]', line) is not None: + # if kwd is None: + # pass + # # else: + # # last_kwd=kwd + # match, kwd= None, None + # line=line.strip().split(sep='\s') + # if kwd is not None and len(line) != 0:#and not line[0].strip().startswith(';'): + # + # # if ';', in line: + # # line = line[:line.index(';')] + # + # line=''.join(line) + '\n' + # match_list.extend([line]) + # except StopIteration: + # prm_dict[last_kwd].extend([match_list]) + # break + # + # parameters =['defaults', 'atomtypes', 'bondtypes', 'pairtypes', 'angletypes', 'dihedraltypes', + # 'constrainttypes', 'nonbond_params'] + # molecules = ['moleculetype', 'atoms'] + with open(topout, "w") as topfile: + # for prm in parameters: + # topfile.write(f'\n[ {prm} ]\n') + # for p in prm_dict[prm]: + # + # topfile.writelines(p) + # # topfile.write(*prm_dict[prm]) + # for mol_id, mol in enumerate(prm_dict['moleculetype']): + # + # # f'{prm_dict["moleculetype"][mol_id]}\n' + # # '\n[ atoms ]\n') + # + # topfile.write('\n[ moleculetype ]\n') + # for lines in prm_dict["moleculetype"][mol_id]: + # topfile.writelines(lines) + # topfile.write('\n[ atoms ]\n') + # for lines in prm_dict['atoms'][mol_id]: + # topfile.writelines(lines) + # topfile.write('\n[ system ]\n') + # for lines in prm_dict['system']: + # topfile.writelines(lines) + # topfile.write('\n[ molecules ]\n') + # + # for lines in prm_dict['molecules']: + # + # topfile.writelines(lines) + if n_mols != 0: + topfile.write( + "\n".join( + [ + tophead, + topmatch, + f"\n{insert.residues.moltypes[0]} {n_mols}\n", + ] + ) + ) + else: + topfile.write("\n".join([tophead, topmatch])) + + # # substr=r'([#]\s*include\s*[."A-Za-z0-9\s]\{-}\n)' + # # topstr = re.sub(rf'({substr})', '\1' + f'{itp_path}/{itp}.itp', topstr, flags=re.MULTILINE) + # substr=r'^\s*\[ system \]' + # topstr = re.sub(rf'({substr})', rf'\n#include "{itp_path}/{itp}.itp"\n\1', topstr, + # flags=re.MULTILINE) + # + # + # topstr += f'\n{insert.residues.moltypes[0]} {n_mols}\n' + # + # with open(topout, 'w') as topfile: + # topfile.write(topstr) + + +# def remove_charge_Cl( +# topin: Union[str, pl.Path], +# topout: Union[str, pl.Path], +# insert: Union[str, pl.Path], +# n_mols: int, +# include_dir, +# ): +# itp_path = pl.Path(insert).parent.resolve() +# itp = pl.Path(insert).stem +# insert = Universe( +# str(itp_path / f"{itp}.top"), +# topology_format="ITP", +# include_dir=include_dir, +# infer_system=True, +# ) +# insert_charge = np.sum(insert.residues.charges) +# +# insert_charge *= n_mols +# insert_charge = np.round(insert_charge, 3) +# +# if insert_charge < 0: +# with open(topin, "r") as topfile: +# topstr = topfile.read() +# +# substr = r"(Cl\s*)([0-9]*)" +# +# pattern = rf"{substr}(?!.*{substr})" +# +# try: +# topmatch = re.search(pattern, topstr, flags=re.MULTILINE | re.DOTALL).group( +# 2 +# ) +# n_cl = int(topmatch) + insert_charge +# +# if n_cl < 0: +# pass +# +# else: +# topstr = re.sub(pattern, rf"\1 {n_cl}", flags=re.MULTILINE | re.DOTALL) +# with open(topout, "w") as topfile: +# topfile.write(topstr) +# +# except: +# pass +# + + +# def remove_charge_ions(topin, topout, charge, ion_itp, include_dir, pion: str, nion: str): +# if charge == 0: +# logger.debug("charge", 0) +# return pion, 0, 0 +# else: +# logger.debug("charge", charge) +# ion_u = Universe( +# str(ion_itp), +# topology_format="ITP", +# include_dir=include_dir, +# infer_system=True, +# ) +# bulk_ions = (pion, nion) +# ion_sel = ( +# ion_u.select_atoms(f"resname {pion}"), +# ion_u.select_atoms(f"resname {nion}"), +# ) +# ion_charges = tuple(np.sum(ion[0].charge) for ion in ion_sel) +# pion = {k: v for k, v in zip(bulk_ions, ion_charges) if v > 0} +# nion = {k: v for k, v in zip(bulk_ions, ion_charges) if v < 0} +# if len(pion) > 1 or len(nion) > 1: +# raise KeyError( +# f"expected one value for positive/negative bulk ion, found {len(pion)}/{len(nion)}" +# ) +# if charge > 0: +# ion_id = 0 +# else: +# ion_id = 1 +# with open(topin, "r") as topfile: +# topstr = topfile.read() +# substr = rf"({bulk_ions[ion_id]}\s*)([0-9]*)" +# pattern = rf"{substr}(?!.*{substr})" +# try: +# topmatch = re.search(pattern, topstr, flags=re.MULTILINE | re.DOTALL).group( +# 2 +# ) +# logger.debug(f"Found {topmatch} {bulk_ions[ion_id]} ions") +# remove_ions = int(abs(charge // ion_charges[ion_id])) +# remainder = int((charge % ion_charges[ion_id])) +# n_ions = int(int(topmatch) - remove_ions) +# logger.debug(f"Removing {remove_ions} {bulk_ions[ion_id]} from topology.") +# if remainder != 0: +# cl_substr = rf"(Cl\s*)([0-9]*)" +# logger.debug("Remainder not 0") +# cl_pattern = rf"{cl_substr}(?!.*{cl_substr})" +# cl_match = re.search( +# cl_pattern, topstr, flags=re.MULTILINE | re.DOTALL +# ).group(2) +# topstr = re.sub( +# pattern, +# rf"\1 {int(cl_match) - remainder}", +# topstr, +# flags=re.MULTILINE | re.DOTALL, +# ) +# logger.debug(f"Removing {cl_match} {bulk_ions[1]} from topology.") +# n_cl = int(cl_match) +# else: +# n_cl = 0 +# topstr = re.sub( +# pattern, rf"\1 {n_ions}", topstr, flags=re.MULTILINE | re.DOTALL +# ) +# with open(topout, "w") as topfile: +# logger.debug( +# f"Removing {remove_ions} {bulk_ions[ion_id]} from topology {topout.name!r}" +# ) +# topfile.write(topstr) +# +# except: +# remove_ions = 0 +# logger.debug("No matching ions to remove") +# raise KeyError("Not matching ions found") +# return bulk_ions[ion_id], remove_ions, n_cl + + +def remove_replaced_SOL( + topin: Union[str, pl.Path], topout: Union[str, pl.Path], n_mols: int +): + if n_mols > 0: + with open(topin, "r") as topfile: + topstr = topfile.read() + + substr = r"(SOL\s*)([0-9]*)" + + pattern = rf"{substr}(?!.*{substr})" + + try: + topmatch = re.search( + pattern, topstr, flags=re.MULTILINE | re.DOTALL + ).group(2) + n_sol = int(topmatch) - n_mols + logger.debug( + f"Removing {n_mols} SOL residues from topology." + ) # , topmatch) + + if n_sol < 0: + raise ValueError + + else: + topstr = re.sub( + pattern, + rf"\1 {n_sol}", + topstr, + flags=re.MULTILINE | re.DOTALL, + ) + + with open(topout, "w") as topfile: + logger.debug( + f"New topology {topout.name!r} has {n_sol} SOL molecules." + ) + topfile.write(topstr) + except: + raise ValueError + + +@update_universe +def center_clay_universe( + u: Universe, crdout: Union[str, Path], uc_name: Optional[str] +) -> None: + if uc_name is None: + clay = u.select_atoms("not resname SOL iSL" + " ".join(IONS)) + else: + clay = u.select_atoms(f"resname {uc_name}*") + for ts in u.trajectory: + ts = center_in_box(clay, wrap=True)(ts) + ts = wrap(u.atoms)(ts) + u.atoms.write(crdout) + + +@update_universe +def remove_ag( + u: Universe, + crdout: str, + selstr: str, + last: Union[bool, int], + first: Union[bool, int], +) -> None: + sel = u.select_atoms(selstr) + logger.debug( + f"Before: {u.atoms.n_atoms}. " + f"Removing first {first} last {last} {np.unique(sel.residues.resnames)}" + ) + if first is not False: + if last is not False: + raise ValueError( + f"Not possible to select first and last ends of atom group at the same time" + ) + elif last is not False: + first = -last + logger.debug("last not false", first) + else: + first = 0 + u.atoms -= sel[first:] + logger.debug(f"After: {u.atoms.n_atoms}") + u.atoms.write(crdout) + + +def read_edge_file( + fname: Union[str, PathType], + cutoff: Union[int, str, float, Cutoff], + skip=False, +): + fname = File(fname, check=False) + if not fname.exists(): + logger.info("No edge file found.") + # os.makedirs(fname.parent, exist_ok=True) + # logger.info(f"{fname.parent}") + if skip is True: + logger.info(f"Continuing without ads_edges") + p = [0, float(cutoff)] + else: + raise FileNotFoundError(f"No edge file found {fname}.") + else: + with open(fname, "rb") as edges_file: + logger.info(f"Reading ads_edges {edges_file.name}") + p = pkl.load(edges_file)["edges"] + logger.finfo( + ", ".join(list(map(lambda e: f"{e:.2f}", p))), + kwd_str="ads_edges: ", + indent="\t", + ) + return p + + +def get_edge_fname( + atom_type: str, + cutoff: Union[int, str, float], + bins: Union[int, str, float], + other: Optional[str] = None, + path: Union[str, PathType] = PE_DATA, + name: Union[Literal["pe"], Literal["edge"]] = "pe", +): + if other is not None: + other = f"{other}_" + else: + other = "" + cutoff = Cutoff(cutoff) + bins = Bins(bins) + # fname = Path.cwd() / f"edge_data/edges_{atom_type}_{self.cutoff}_{self.bins}.p" + fname = ( + Dir(path) / f"{atom_type}_{other}{name}_data_{cutoff}_{bins}.p" + ).resolve() + logger.info(f"Peak/edge Filename: {fname.name!r}") + return fname + + +def get_paths( + path: Union[str, PathType] = None, + infiles: List[Union[PathType, str]] = None, + inpname: str = None, +): + logger.info(get_header(f"Getting run files")) + if path is not None: + path = Dir(path) + if infiles is None: + gro = select_file(path=path, suffix="gro", searchstr=inpname) + trr = select_file( + path=path, suffix="trr", how="largest", searchstr=inpname + ) + else: + gro = select_file( + path=path, suffix="gro", searchstr=infiles[0].strip(".gro") + ) + trr = select_file( + path=path, + suffix="trr", + how="largest", + searchstr=infiles[1].strip(".gro"), + ) + else: + gro, trr = infiles + gro, trr = GROFile(gro), File(trr) + path = Dir(gro.parent) + logger.finfo(f"{str(gro.resolve())!r}", kwd_str="Found coordinates: ") + logger.finfo(f"{str(trr.resolve())!r}", kwd_str="Found trajectory: ") + return gro, trr, path + + +class Cutoff(str): + def __new__(cls, length): + string = f"{int(length):02}" + return super().__new__(cls, string) + + def __init__(self, length): + self.num = float(length) + + def __float__(self): + return float(self.num) + + def __int__(self): + return int(self.num) + + def __str__(self): + return self + + +class Bins(str): + def __new__(cls, length): + string = f"{float(length):.02f}"[2:] + return super().__new__(cls, string) + + def __init__(self, length): + self.num = float(length) + + def __float__(self): + return float(self.num) + + def __int__(self): + return int(self.num) + + def __str__(self): + return self diff --git a/package/ClayCode/analysis/utils.py b/package/ClayCode/analysis/utils.py new file mode 100755 index 00000000..8b4ac59c --- /dev/null +++ b/package/ClayCode/analysis/utils.py @@ -0,0 +1,581 @@ +#!/usr/bin/env python3 +import logging +import os +import re +import shutil +import subprocess as sp +import sys +import warnings +from functools import singledispatch +from itertools import chain +from pathlib import Path +from typing import List, Literal, Optional, Union + +import MDAnalysis as mda +import numpy as np +import pandas as pd +from ClayCode.core.consts import exec_date, exec_time +from numpy._typing import NDArray +from tqdm.contrib.logging import logging_redirect_tqdm + +warnings.filterwarnings("ignore", category=UserWarning) +warnings.filterwarnings("ignore", category=DeprecationWarning) +warnings.simplefilter("ignore") + +logger = logging.getLogger(Path(__file__).stem) + +__all__ = [ + "remove_files", + "change_suffix", + "convert_num_to_int", + "get_sequence_element", + "get_first_item_as_int", + "check_file_exists", + "str_to_int", + "execute_bash_command", + "get_file_diff", + "grep_file", + "get_logfname", + "get_search_str", + "convert_str_list_to_arr", + "select_file", + "select_named_file", +] + + +# path = UC_PATH +# gro_list = UC_PATH.glob(rf'{UC_STEM}[0-9]\{1-2}._gro') +# _uc_data = UCData(UC_FOLDER) +# uc_composition = _uc_data.uc_composition +# uc_charges = _uc_data.uc_charges +# print(uc_composition, uc_charges) +# from argparse import ArgumentParser +# +# parser = ArgumentParser(prog="utils.py", add_help=True, allow_abbrev=False) +# +# parser.add_argument( +# "-crd", +# "-crdin", +# type=str, +# dest="INFNAME", +# required=False, +# metavar="filename", +# help="Coordinate file path", +# ) +# +# parser.add_argument( +# "-ff", +# required=False, +# type=str, +# help="Force field directory", +# default="/usr/local/gromacs/share/gromacs/top/", +# metavar="ff directory", +# dest="FF", +# ) +# +# conc_args = parser.add_argument_group( +# title="molecule number arguments:", +# description="Calculate molecule number from bulk concentration in mol L-1", +# ) +# +# conc_args.add_argument( +# "-conc", +# type=float, +# dest="CONC", +# required=False, +# default=False, +# metavar="concentration", +# help="Concentration in mol L-1", +# ) +# +# conc_args.add_argument( +# "-savepos", +# required=False, +# dest="SAVEPOS", +# default="positions.dat", +# metavar="filename", +# help=' Filename .dat (default "positions.dat")', +# ) +# +# center_args = parser.add_argument_group( +# title="center atoms arguments:", +# description="Align positions of clay atom and box center.", +# ) +# +# center_args.add_argument( +# "--center", +# required=False, +# action="store_true", +# dest="CENTER", +# default=False, +# help="Center clay atoms in box.", +# ) +# +# center_args.add_argument( +# "-outgro", +# required=False, +# default="center._gro", +# dest="OUTGRO", +# help='Save centered coordinates to < filename > (default "center._gro"', +# ) +# +# center_args.add_argument( +# "-uc", +# required=False, +# type=str, +# help="CLay unit cell selection string", +# dest="UC", +# metavar="uc_type", +# default=None, +# ) +# +# charges_args = parser.add_argument_group( +# title="Charge checker", +# description="Check charges and remove Cl from _gro and top file if < 0 or return n_mols.", +# ) +# +# charges_args.add_argument( +# "--neutralise", +# action="store_true", +# required=False, +# default=False, +# dest="NEUTRAL", +# help="Check charges and remove Cl from _gro and top file if < 0 or return n_mols.", +# ) +# +# charges_args.add_argument( +# "-insertgro", +# required=False, +# type=str, +# dest="INSERT", +# default=None, +# metavar="insert coordinates", +# help="Coordinate file of inserted molecule.", +# ) +# +# charges_args.add_argument( +# "-topin", +# required=False, +# type=str, +# dest="ITOP", +# metavar="filename", +# help="Input topology filename", +# ) +# +# charges_args.add_argument( +# "-topout", +# required=False, +# type=str, +# dest="OTOP", +# default="insert.top", +# metavar="filename", +# help="Output topology filename", +# ) +# +# charges_args.add_argument( +# "-n_mols", +# required=False, +# type=int, +# dest="N_MOLS", +# metavar="n mols", +# help="Number of inserted molecules", +# default=None, +# ) +# +# parser.add_argument( +# "--add-insert", +# action="store_true", +# default=False, +# required=False, +# dest="MODTOP", +# ) +# +# parser.add_argument( +# "--replace-sol", +# help="Remove replaced SOL from topology file", +# required=False, +# default=0, +# dest="SOLREPL", +# action="store_true", +# ) +# +# parser.add_argument( +# "--new-dat", action="store_true", default=False, required=False, dest="DAT" +# ) +# +# p = parser.parse_args(sys.argv[1:]) +# +# n_mols = None +# +# if p.INFNAME: +# crdin = mda.Universe(p.INFNAME) +# +# if p.CONC: +# n_mols = get_n_mols(conc=p.CONC, crdin=crdin) +# write_insert_dat(n_mols=n_mols, save=p.SAVEPOS) +# +# if p.CENTER: +# center_clay(crdin=crdin, uc_name=p.UC, crdout=p.OUTGRO) +# +# # if not 'n_mols' in globals(): +# if p.N_MOLS is None and n_mols is None: +# n_mols = 0 +# elif n_mols is None: +# n_mols = p.N_MOLS +# +# if p.MODTOP: +# add_mols_to_top( +# topin=p.ITOP, topout=p.OTOP, insert=p.INSERT, n_mols=n_mols, include_dir=p.FF +# ) +# +# if p.NEUTRAL: +# remove_charge_Cl( +# topin=p.ITOP, topout=p.OTOP, insert=p.INSERT, n_mols=n_mols, include_dir=p.FF +# ) +# +# if p.SOLREPL: +# remove_replaced_SOL(topin=p.ITOP, topout=p.OTOP, n_mols=n_mols) +# +# if p.DAT: +# write_insert_dat(n_mols=n_mols, save=p.SAVEPOS) +# +# sys.displayhook(n_mols) +# +# neutralise_system(1, 2, 3, 4, 5, 6, 7) +# # + + +def remove_files(path, searchstr): + backupfiles = list(path.glob(rf"{searchstr}")) + removing = False + for fname in backupfiles: + if fname.exists(): + removing = True + os.remove(fname) + logger.debug(f"Removing {fname.name}.") + else: + logger.debug(f"No backups to remove {fname.name}.") + return removing + + +def change_suffix(path: Path, new_suffix: str): + return path.parent / f'{path.stem}.{new_suffix.strip(".")}' + + +def convert_num_to_int(f): + def wrapper(number: Union[int, float]): + if type(number) not in [float, int, np.int_, np.float_]: + raise TypeError( + f"Expected float or int type, found {type(number)}!" + ) + else: + return f(int(np.round(number, 0))) + + return wrapper + + +def get_sequence_element(f): + def wrapper(seq, id=0): + try: + if len(list(seq)) < 2: + pass + else: + logger.debug(1, seq) + except TypeError: + logger.debug(2, seq) + seq = [seq] + logger.debug(3, seq) + if type(seq) not in [list, tuple, np.array]: + raise TypeError(f"Expected sequence, found {type(seq)}") + if not isinstance(id, int): + raise TypeError(f"Expected int index, found {type(id)}") + else: + result = f(seq[id]) + logger.debug(4, result) + return result + + return wrapper + + +@get_sequence_element +@convert_num_to_int +def get_first_item_as_int(seq): + return seq + + +def check_file_exists(tempdir: Union[Path, str], file: Union[Path, str]): + if not Path(file).exists(): + os.rmdir(tempdir) + raise FileNotFoundError(f"{file!r} does not exist!") + + +@singledispatch +def str_to_int(str_obj): + raise TypeError(f"Invalid type for str_obj: {type(str_obj)}") + + +@str_to_int.register +def _(str_obj: list): + return list(map(lambda str_item: int(str_item), str_obj)) + + +@str_to_int.register +def _(str_obj: str): + return int(str_obj) + + +def execute_bash_command(command, **outputargs): + output = sp.run(["/bin/bash", "-c", command], **outputargs) + return output + + +def get_file_diff(file_1, file_2): + diff = execute_bash_command( + f"diff {file_1} {file_2}", capture_output=True, text=True + ) + return diff.stdout + + +def grep_file(file, regex: str): + diff = execute_bash_command( + f'grep -E "{regex}" {file}', capture_output=True, text=True + ) + return diff.stdout + + +def get_logfname( + logname: str, + run_name=None, + time: Union[Literal[exec_time], Literal[exec_date]] = exec_date, + logpath=None, +): + if logpath is None: + logpath = Path().cwd() / "logs" + if not logpath.is_dir(): + os.mkdir(logpath) + if run_name is None: + run_name = "" + else: + run_name += "-" + return f"{logpath}/{logname}-{run_name}{time}.log" + + +def get_search_str(match_dict: dict): + return "|".join(match_dict.keys()) + + +def convert_str_list_to_arr( + str_list: Union[List[str], List[List[str]]] +) -> np.array: + array = np.array(list(map(lambda x: x.split(), str_list)), dtype=str) + arr_strip = np.vectorize(lambda x: x.strip()) + try: + array = arr_strip(array) + except: + logger.debug("Could not convert list to array") + return array + + +def copy_final_setup(outpath: Path, tmpdir: Path, rm_tempfiles: bool = True): + if not outpath.is_dir(): + os.mkdir(outpath) + gro = select_file(tmpdir, suffix="_gro", how="latest") + top = select_file(tmpdir, suffix="top", how="latest") + tpr = select_file(tmpdir, suffix="tpr", how="latest") + log = select_file(tmpdir, suffix="log", how="latest") + mdp = select_file(tmpdir, suffix="mdp", how="latest") + new_files = [] + for file in [gro, top]: + shutil.move(file, outpath / file.name) + new_files.append(outpath / file.name) + for file in [tpr, log, mdp]: + shutil.move(file, outpath / file.with_stem(f"{file.name}_em").name) + logger.info(f"Done! Copied files to {outpath.name!r}") + if rm_tempfiles: + shutil.rmtree(tmpdir) + return tuple(new_files) + + +def select_named_file( + path: Union[Path, str], + searchstr: Optional[str] = None, + suffix=None, + searchlist: List[str] = ["*"], + how: Literal["latest", "largest"] = "latest", +): + path = Path(path) + if suffix is None: + suffix = "" + if searchstr is None: + searchstr = "" + f_iter = list( + path.glob(rf"*{searchstr.strip('*')}[.]*{suffix.strip('.')}") + ) + searchlist = list( + map( + lambda x: rf'.*{searchstr}{x.strip("*")}[.]*{suffix.strip(".")}', + searchlist, + ) + ) + searchstr = "|".join(searchlist) + pattern = re.compile(rf"{searchstr}", flags=re.DOTALL) + f_list = [ + path / pattern.search(f.name).group(0) + for f in f_iter + if pattern.search(f.name) is not None + ] + if len(f_list) == 1: + match = f_list[0] + elif len(f_list) == 0: + match = None + else: + logger.error( + f"Found {len(f_list)} matches: " + + ", ".join([f.name for f in f_list]) + ) + check_func_dict = { + "latest": lambda x: x.st_mtime, + "largest": lambda x: x.st_size, + } + check_func = check_func_dict[how] + prev_file_stat = 0 + last_file = None + for file in f_list: + if file.is_dir(): + pass + else: + if last_file is None: + last_file = file + filestat = os.stat(file) + last_file_stat = check_func(filestat) + if last_file_stat > prev_file_stat: + prev_file_stat = last_file_stat + last_file = file + match = last_file + logger.info(f"{how} file: {match.name!r}") + return match + + +def select_file( + path: Union[Path, str], + searchstr: Optional[str] = None, + suffix=None, + how: Literal["latest", "largest"] = "latest", +): + check_func_dict = { + "latest": lambda x: x.st_mtime, + "largest": lambda x: x.st_size, + } + check_func = check_func_dict[how] + logger.debug(f"Getting {how} file:") + if type(path) != Path: + path = Path(path) + if searchstr is None and suffix is None: + f_iter = path.iterdir() + else: + if suffix is None: + suffix = "" + if searchstr is None: + searchstr = "*" + f_iter = path.glob(rf"{searchstr}[.]*{suffix}") + backups = path.glob(rf"#{searchstr}[.]*{suffix}.[1-9]*#") + f_iter = chain(f_iter, backups) + prev_file_stat = 0 + last_file = None + for file in f_iter: + if file.is_dir(): + pass + else: + if last_file is None: + last_file = file + filestat = os.stat(file) + last_file_stat = check_func(filestat) + if last_file_stat > prev_file_stat: + prev_file_stat = last_file_stat + last_file = file + if last_file is None: + logger.debug(f"No matching files found in {path.resolve()}!") + else: + logger.debug(f"{last_file.name} matches") + return last_file + + +def get_pd_idx_iter(idx: pd.MultiIndex, name_sel: List[str]): + idx_names = idx.names + idx_values = [ + idx.get_level_values(level=name) + for name in idx_names + if name in name_sel + ] + idx_product = np.array( + np.meshgrid(*[idx_value for idx_value in idx_values]) + ).T.reshape(-1, len(idx_values)) + # idx_product = np.apply_along_axis(lambda x: '/'.join(x), 1, idx_product) + return idx_product + + +def get_u_files(path: Union[str, Path], suffices=["_gro", "top"]): + files = {} + path = Path(path) + largest_files = ["trr"] + how = "latest" + for selection in suffices: + selection = selection.strip(".") + if selection in largest_files: + how = "largest" + files[selection] = select_file(path=path, suffix=selection, how=how) + return files["_gro"], files["trr"] + + +def open_outfile(outpath: Union[Path, str], suffix: str, default: str): + if type(outpath) == bool: + outpath = Path(f"{default}.json") + elif type(outpath) in [str, Path]: + outpath = change_suffix(outpath, suffix) + if not outpath.parent.is_dir(): + os.makedirs(outpath.parent) + else: + raise ValueError(f"Could not interpret {outpath} as path or bool.") + return outpath + + +if __name__ == "__main__": + print(mda.__version__, "\n", np.__version__) + + +def _check_methods(C, *methods): + mro = C.__mro__ + for method in methods: + for B in mro: + if method in B.__dict__: + if B.__dict__[method] is None: + return NotImplemented + break + else: + return NotImplemented + return True + + +def redirect_tqdm(f): + def wrapper(*args, **kwargs): + with logging_redirect_tqdm(): + result = f(*args, **kwargs) + return result + + return wrapper + + +def make_1d(arr: NDArray, sel=None, return_dims=False): + idxs = np.arange(arr.ndim) + arr_list = [] + if sel == None: + sel = idxs + else: + sel = np.ravel(sel) + for idx in sel: + dimarr = np.add.reduce(arr, axis=tuple(idxs[idxs != idx])) + arr_list.append(dimarr) + if return_dims == True: + return np.array(*dimarr), sel + else: + return np.array(*dimarr) diff --git a/package/ClayCode/analysis/zdist.py b/package/ClayCode/analysis/zdist.py new file mode 100755 index 00000000..0100b986 --- /dev/null +++ b/package/ClayCode/analysis/zdist.py @@ -0,0 +1,490 @@ +#!/usr/bin/env python3 +import logging +import pathlib as pl +import sys +import warnings +from argparse import ArgumentParser +from typing import Any, Literal, NoReturn, Optional, Union + +import MDAnalysis as mda +import numpy as np +from ClayCode.analysis.analysisbase import ClayAnalysisBase +from ClayCode.analysis.lib import check_traj, get_paths +from ClayCode.core.classes import Dir +from ClayCode.core.lib import ( + get_dist, + get_selections, + process_box, + run_analysis, + select_cyzone, +) +from ClayCode.core.utils import get_subheader +from MDAnalysis import Universe +from MDAnalysis.lib.distances import apply_PBC + +warnings.filterwarnings("ignore", category=DeprecationWarning) + +__all__ = ["ZDens"] + +logger = logging.getLogger(__name__) + + +class ZDens(ClayAnalysisBase): + # histogram attributes format: + # -------------------------- + # name: [name, bins, timeseries, hist, hist2d, edges, n_bins, cutoff, bin_step] + _attrs = ["zdens"] + _abs = [True] + """Calculate absolute densities of atom z-positions relative to clay surface O-atoms. + """ + + def __init__( + self, + sysname: str, + sel: mda.AtomGroup, + clay: mda.AtomGroup, + n_bins: Optional[int] = None, + bin_step: Optional[Union[int, float]] = None, + xy_rad: Union[float, int] = 3.0, + cutoff: Union[float, int] = 20.0, + save: Union[bool, str] = True, + write: Union[bool, str] = True, + overwrite: bool = False, + check_traj_len: Union[Literal[False], int] = False, + **basekwargs: Any, + ) -> None: + """ + :param sysname: system name + :type sysname: str + :param sel: adsorbed atom group + :type sel: MDAnalysis.core.groups.AtomGroup + :param clay: clay surface O-atoms + :type clay: MDAnalysis.core.groups.AtomGroup + :param n_bins: z-distance bins + :type n_bins: int, defaults to None + :param bin_step: default z-distance bin step in Angstrom, defaults to None + :type bin_step: float + :param xy_rad: cylinder radius for selecting surface O, defaults to 5.0 + :type xy_rad: float + :param cutoff: cylinder height for selecting surface O, defaults to 20.0 + :type cutoff: float + :param save: + :type save: + :param write: + :type write: + :param overwrite: + :type overwrite: + :param check_traj_len: + :type check_traj_len: + :param basekwargs: + :type basekwargs: + """ + super(ZDens, self).__init__(sel.universe.trajectory, **basekwargs) + self._init_data(n_bins=n_bins, bin_step=bin_step, cutoff=cutoff) + self._process_distances = None + self.sysname = sysname + self._ags = [sel] + self._universe = self._ags[0].universe + self.sel = sel + self.sel_n_atoms = sel.n_atoms + self.clay = clay + self.xy_rad = float(xy_rad) + self.save = save + self.write = write + if self.save is False: + pass + else: + try: + self.save = Dir(self.save) + except TypeError: + pass + if type(self.save) in [bool, Dir]: + savename = ( + f"{self.__class__.__name__.lower()}_" f"{self.sysname}" + ) + try: + self.save = self.save / savename + except TypeError: + self.save = savename + if self.write is not False: + if self.save is False: + savename = ( + f"{self.__class__.__name__.lower()}_" f"{self.sysname}" + ) + try: + self.write = self.write / savename + except TypeError: + self.write = savename + self.write = pl.Path(self.write).with_suffix(".npz") + elif type(self.write) == bool: + self.write = pl.Path(self.save).with_suffix(".npz") + else: + logger.error("Should not get here!") + sys.exit(1) + if pl.Path(self.write).is_file(): + if overwrite is False: + logger.finfo( + f"Done!\n{str(self.write)!r} already exists and overwrite not selected." + ) + self._get_new_data = False + return + # raise FileExistsError(f"{self.write!r} already exists.") + check_traj(self, check_traj_len) + + def _prepare(self) -> NoReturn: + process_box(self) + logger.info( + f"Starting run:\n" + f"Frames start: {self.start}, " + f"stop: {self.stop}, " + f"step: {self.step}\n" + ) + self._dist_array = np.ma.empty( + (self.sel.n_atoms, self.clay.n_atoms, 3), + dtype=np.float64, + fill_value=np.nan, + ) + self._z_dist = np.empty(self.sel.n_atoms, dtype=np.float64) + self.mask = [] + self._sel = np.empty_like(self._z_dist, dtype=bool) + self._sel_mask = np.ma.empty_like( + self._dist_array[:, :, 0], dtype=np.float64 + ) + + def _single_frame(self) -> NoReturn: + self._dist_array.fill(0) + self._dist_array.mask = False + self._sel_mask.fill(0) + self._sel_mask.mask = False + self._z_dist.fill(0) + self._sel_mask.soften_mask() + # Wrap coordinates back into simulation box (use mda apply_PBC) + sel_pos = apply_PBC(self.sel.positions, self._ts.dimensions) + clay_pos = apply_PBC(self.clay.positions, self._ts.dimensions) + # get minimum x, y, z distances between sel and clay in box + get_dist(sel_pos, clay_pos, self._dist_array, self._ts.dimensions) + self._process_distances(self._dist_array, self._ts.dimensions) + # self._dist_array[:] = np.apply_along_axis(lambda x: minimize_vectors(x, self._ts.dimensions), axis=self._dist_array) + # consider only clay atoms within a cylinder around sel atoms + select_cyzone( + distances=self._dist_array, + xy_rad=self.xy_rad, + z_dist=self.data["zdens"].cutoff, + mask_array=self._sel_mask, + ) + + # get minimum z-distance to clay for each sel atom + self._z_dist = np.min( + np.abs(self._dist_array[:, :, 2]), axis=1, out=self._z_dist + ) + # if np.isnan(self._z_dist).any(): + # logger.info(f"{self._z_dist}, {self._z_dist.shape}") + # logger.info(np.argwhere(self._z_dist[np.isnan(self._z_dist)])) + # only_data = np.min( + # np.abs(self._dist_array.data[:, :, 2]), + # axis=1, + # out=self._z_dist, + # ) + # for i in range(len(self._z_dist)): + # logger.info(f"{i}: {self._z_dist[i]:.3f}, {only_data[i]:.3f}") + # logger.info(self._z_dist.shape) + # logger.info(np.isnan(self._z_dist).any()) + self._sel[:] = np.isnan(self._z_dist) + self.data["zdens"].timeseries.append(self._z_dist.copy()) + self.mask.append(self._sel.copy()) + + def _save(self) -> NoReturn: + if self.save is False: + pass + else: + for v in self.data.values(): + v.save( + self.save, + sel_names=np.unique(self.sel.names), + n_atoms=self.sel.n_atoms, + n_frames=self.n_frames, + ) + if self.write is not False: + with open(self.write, "wb") as outfile: + np.savez( + outfile, + zdist=np.array(self.data["zdens"].timeseries), + mask=np.array(self.mask), + frames=self.results["frames"], + times=self.results["times"], + run_prms=np.array([self.start, self.stop, self.step]), + cutoff=self.data["zdens"].cutoff, + bin_step=self.data["zdens"].bin_step, + sel_n_atoms=self.sel.n_atoms, + ) + assert len( + np.arange(start=self.start, stop=self.stop, step=self.step) + ) == len( + self.data["zdens"].timeseries + ), "Length of timeseries does not conform to selected start, stop and step!" + # logger.info(f"{self.start}, {self.stop}, {self.step}") + + logger.finfo(f"Wrote z-dist array to {str(self.write)!r}") + # outsel = self.sel + self.clay + # ocoords = str(change_suffix(self.save, "pdbqt")) + # otraj = str(change_suffix(self.save, "traj")) + # outsel.write( + # otraj, frames=self._trajectory[self.start : self.stop : self.step] + # ) + # outsel.write( + # ocoords, frames=self._trajectory[self.start : self.stop : self.step][-1] + # ) + # logger.info( + # f"Wrote final coordinates to {ocoords.name} and trajectory to {otraj.name}" + # ) + + +parser = ArgumentParser( + prog="zdens", + description="Compute z-density relative to clay surface OB atoms.", + add_help=True, + allow_abbrev=False, +) +parser.add_argument( + "-name", type=str, help="System name", dest="sysname", required=True +) + +parser.add_argument( + "-inp", + type=str, + help="Input file names", + nargs=2, + metavar=("coordinates", "trajectory"), + dest="infiles", + required=False, +) +parser.add_argument( + "-inpname", + type=str, + help="Input file names", + metavar="name_stem", + dest="inpname", + required=False, +) +parser.add_argument( + "-uc", + type=str, + help="Clay unit cell type", + dest="clay_type", + required=True, +) +parser.add_argument( + "-sel", + type=str, + nargs="+", + help="Atom type selection", + dest="sel", + required=True, +) +parser.add_argument( + "-n_bins", + default=None, + type=int, + help="Number of bins in histogram", + dest="n_bins", +) +parser.add_argument( + "-bin_step", + type=float, + default=None, + help="bin size in histogram", + dest="bin_step", +) +parser.add_argument( + "-xyrad", + type=float, + default=3, + help="xy-radius for calculating z-position clay surface", + dest="xyrad", +) + +parser.add_argument( + "-cutoff", + type=float, + default=20, + help="cutoff in z-direction", + dest="cutoff", +) + +parser.add_argument( + "-start", + type=int, + default=None, + help="First frame for analysis.", + dest="start", +) +parser.add_argument( + "-step", + type=int, + default=None, + help="Frame steps for analysis.", + dest="step", +) +parser.add_argument( + "-stop", + type=int, + default=None, + help="Last frame for analysis.", + dest="stop", +) +parser.add_argument( + "-out", + type=str, + help="Filename for results pickle.", + dest="save", + default=True, +) +parser.add_argument( + "-check_traj", + type=int, + default=False, + help="Expected trajectory length.", + dest="check_traj_len", +) +parser.add_argument( + "--write_z", + type=str, + default=True, + help="Binary array output of selection z-distances.", + dest="write", +) +parser.add_argument( + "--overwrite", + action="store_true", + default=False, + help="Overwrite existing z-distance array data.", + dest="overwrite", +) +parser.add_argument( + "--update", + action="store_true", + default=False, + help="Overwrite existing trajectory and coordinate array data.", + dest="new", +) + +parser.add_argument( + "-path", default=False, help="File with analysis data paths.", dest="path" +) +parser.add_argument( + "--in_mem", + default=False, + action="store_true", + help="Read trajectory in memory.", + dest="in_mem", +) + +if __name__ == "__main__": + logger.info(f"Using MDAnalysis {mda.__version__}") + logger.info(f"Using numpy {np.__version__}") + args = parser.parse_args(sys.argv[1:]) + traj_format = ".xtc" + + sysname = args.sysname + + gro, trr, path = get_paths(args.path, args.infiles, args.inpname) + + logger.finfo(f"{sysname!r}", kwd_str=f"System name: ") + + if args.save is None: + outpath = path + else: + outpath = pl.Path(args.save) + if outpath.is_dir(): + outname = f'{gro}_{args.sel[-1].strip("*")}' + outname = (path / outname).resolve() + else: + outname = pl.Path(args.save).resolve() + logger.finfo(f"{str(outname.resolve())!r}", kwd_str=f"Output path: ") + # pdbqt = lambda: path._get_filelist(ext='.pdbqt') + # traj = lambda: path._get_filelist(ext=f'.{traj_format}') + # traj = outname.with_suffix(traj_format) + # coords = outname.with_suffix(".gro") + # try: + # u = mda.Universe(str(coords), str(traj)) + # except: + + # if (args.new is True) or (not traj.is_file() or not coords.is_file()): + # logger.info("Files missing") + # new = True + # else: + # try: + # u = mda.Universe(str(coords), str(traj)) + # new = False + # if not u.trajectory.n_frames == 35001: + # logger.info("Wrong frame number") + # new = True + # except: + # logger.info("Could not construct universe") + # new = True + # # if len(traj()) == 0: # or len(pdbqt()) == 0 or + # + # if new is True: + # logger.info(f"Saving selection coordinates and trajectory.") + # sel, clay = get_selections((gro, trr), args.sel, args.clay_type) + # save_selection(outname=outname, atom_groups=[clay, sel], traj=traj_format) + # pdbqt = pdbqt()[0] + # crds = select_file(path, suffix='crdin') + # traj = select_file(path, suffix=traj_format.strip('.')) + coords = gro + traj = trr + logger.debug(f"Using {coords.name} and {traj.name}.") + try: + u = Universe(str(coords), str(traj)) + new = False + if not args.check_traj_len: + logger.finfo( + "Skipping trajectory length check.", initial_linebreak=True + ) + else: + if not u.trajectory.n_frames == args.check_traj_len: + logger.finfo( + f"Wrong frame number, found {u.trajectory.n_frames}, expected {args.check_traj_len}!", + initial_linebreak=True, + ) + new = True + else: + logger.finfo( + f"Trajectory has correct frame number of {args.check_traj_len}.", + initial_linebreak=True, + ) + except: + logger.info("Could not construct universe!", initial_linebreak=True) + new = True + logger.info(get_subheader("Getting atom groups")) + sel, clay = get_selections( + infiles=(coords, traj), + sel=args.sel, + clay_type=args.clay_type, + in_memory=args.in_mem, + ) + + if args.save == "True": + args.save = True + elif args.save == "False": + args.save = False + if args.write == "True": + args.write = True + elif args.write == "False": + args.write = False + + zdens = ZDens( + sysname=sysname, + sel=sel, + clay=clay, + n_bins=args.n_bins, + bin_step=args.bin_step, + xy_rad=args.xyrad, + cutoff=args.cutoff, + save=args.save, + write=args.write, + overwrite=args.overwrite, + check_traj_len=args.check_traj_len, + ) + + run_analysis(zdens, start=args.start, stop=args.stop, step=args.step) diff --git a/package/ClayCode/builder/assembly.py b/package/ClayCode/builder/assembly.py index 1e47ba8c..3e44d0bd 100644 --- a/package/ClayCode/builder/assembly.py +++ b/package/ClayCode/builder/assembly.py @@ -1116,7 +1116,7 @@ def write_gro(self, backup: bool = False) -> None: while uc_array is False: uc_array = self.get_uc_sheet_array() logger.finfo(f"Unit cell arrangement in sheet {self.n_sheet}:") - for line in uc_array: + for line in uc_array.T: logger.finfo(" ".join(map(str, line)), indent="\t") sheet_df = pd.concat( [ @@ -1213,19 +1213,19 @@ def get_uc_sheet_array(self): remainder_choices = np.arange(max_ax_len) self.random_generator.shuffle(remainder_choices) lines = {} - if self.debug: - symbols = np.array(["x", "o", "+", "#", "-", "*"]) - symbols = itertools.cycle(symbols) - symbol_arr = np.full((self.x_cells, self.y_cells), fill_value=" ") - symbol_dict = {} + # if self.debug: + symbols = np.array(["x", "o", "+", "#", "-", "*"]) + symbols = itertools.cycle(symbols) + symbol_arr = np.full((self.x_cells, self.y_cells), fill_value=" ") + symbol_dict = {} for charge_group_id, ( charge, charge_group_n_ucs, charge_group, ) in enumerate(self.get_charge_groups()): - if self.debug: - symbol = next(symbols) - symbol_dict[symbol] = charge + # if self.debug: + symbol = next(symbols) + symbol_dict[symbol] = charge uc_array = charge_group.copy() self.random_generator.shuffle(uc_array) remaining_add[charge_group_id] = 0 @@ -1286,29 +1286,40 @@ def get_uc_sheet_array(self): occ_counts, ] ) - combined_counts = np.max( - [ - np.roll(diag_counts, axis_id), - np.roll(opposite_diag_counts, -axis_id), - occ_counts, - ], - axis=0, + combined_counts = np.rint( + np.mean( + [ + np.roll(diag_counts, axis_id), + np.roll(opposite_diag_counts, -axis_id), + occ_counts, + ], + axis=0, + ) ) free_cols = np.logical_and( free, - combined_counts < n_col_ucs + min(1, per_col_remainder), + combined_counts < n_col_ucs, ) init_i = np.array([0, 0, 0]) - if free_cols[free_cols].size <= n_add_ucs: + minmax = itertools.cycle([min, max]) + while free_cols[free_cols].size < n_add_ucs: free_cols = np.logical_and( free, combined_counts - < n_col_ucs + max(1, per_col_remainder), + < n_col_ucs + next(minmax)(1, per_col_remainder), + ) + extra_remainder = np.zeros_like(init_i, dtype=np.int32) + while np.any( + np.less( + init_i, extra_remainder + per_col_remainder + n_per_col ) - while np.all(init_i < per_col_remainder + n_per_col) and ( + ) and ( idx_choices is None or (idx_choices.flatten().size < n_add_ucs) ): + if free[free].flatten().size == n_add_ucs: + idx_choices = np.argwhere(free).flatten() + break allowed_cols = free_cols.copy() idx_choices = None occ_devs = np.std(counts, axis=1) @@ -1324,12 +1335,13 @@ def get_uc_sheet_array(self): logstr = list( map( lambda x, y, z: f"\t{logdict[x]:15}: {y:.1f} {pm} {z:.1f}", - order, - np.mean(counts, axis=1), - occ_devs, + np.sort(order), + np.mean(counts[np.argsort(order)], axis=1), + occ_devs[np.argsort(order)], ) ) logger.finfo("\n".join(logstr)) + prev_choices = None for occ_id in order: # self.random_generator.choice( # [0, 1, 2], 3, replace=False # ): @@ -1345,20 +1357,60 @@ def get_uc_sheet_array(self): free_cols, intersect_idxs=idx_choices, intersect_allowed=allowed_cols, - remainder=per_col_remainder, + remainder=per_col_remainder + + extra_remainder[occ_id], ) - if free[free].flatten().size == n_add_ucs: - idx_choices = np.argwhere(free).flatten() + if idx_choices.flatten().size == n_add_ucs: + print( + occ_id, + f": stopping with {idx_choices.flatten()}, n_add_ucs = {n_add_ucs}", + ) + break + elif ( + idx_choices.flatten().size < n_add_ucs + and prev_choices is not None + and prev_choices.flatten().size >= n_add_ucs + ): + idx_choices = prev_choices + + prev_choices = idx_choices if ( - idx_choices.flatten().size == prev.flatten().size - and np.equal( - idx_choices.flatten(), prev.flatten() - ).all() - and idx_choices.flatten().size - != free_cols[free_cols].flatten().size + idx_choices.flatten().size > n_add_ucs + and np.unique(counts).size > 1 ): - continue - elif idx_choices.flatten().size >= n_add_ucs: + min_count = np.argwhere( + np.min(counts, axis=0) == np.min(counts) + ).flatten() + try: + if ( + min_count.size >= n_add_ucs + and min_count.size != combined_counts.size + and np.all(min_count != idx_choices) + ): + pass + except ValueError: + pass + if ( + min_count.size >= n_add_ucs + and min_count.size != combined_counts.size + and np.all(min_count != idx_choices) + ): + intersect_idxs = np.intersect1d( + idx_choices, min_count + ).flatten() + if intersect_idxs.size >= n_add_ucs: + idx_choices = intersect_idxs + elif idx_choices.flatten().size < n_add_ucs: + return False + # if idx_choices.flatten().size > n_add_ucs: + # if prev.flatten().size != 0: + # _, prev_idxs, _ = np.intersect1d( + # idx_choices.flatten(), prev.flatten(), assume_unique=True, return_indices=True + # ) + # if prev_idxs.flatten().size != 0: + # remove_idxs = np.random.choice(prev_idxs, idx_choices.size - n_add_ucs, replace=False) + # idx_choices = np.delete(idx_choices, remove_idxs) + if idx_choices.flatten().size >= n_add_ucs: if self.debug: logger.finfo(f"Row {axis_id}:", indent="\t") logger.finfo( @@ -1367,44 +1419,145 @@ def get_uc_sheet_array(self): ) break else: - return False + if np.any( + np.greater_equal( + init_i, + extra_remainder + + per_col_remainder + + n_per_col, + ) + ): + if np.any( + np.greater_equal( + init_i, + np.sort(counts, axis=1)[:, n_add_ucs - 1], + ) + ): + init_i[ + np.argwhere( + np.greater_equal( + init_i, + np.sort(counts, axis=1)[ + :, n_add_ucs - 1 + ], + ) + ) + ] -= 1 + + else: + extra_remainder[ + np.argwhere( + np.greater_equal( + init_i, + extra_remainder + + per_col_remainder + + n_per_col, + ) + ) + ] = 1 + # init_i = np.apply_along_axis(lambda arr: np.where(sorted(arr)[:n_add_ucs], ) + continue + # if idx_choices.flatten().size == 0: + # return False + # if np.any(extra_remainder == 0) and idx_choices.flatten().size < n_col_ucs + per_col_remainder and np.all(init_i == n_per_col + per_col_remainder): + # if np.all(init_i + extra_remainder == n_per_col + per_col_remainder): + # extra_remainder[np.intersect1d(order, np.argwhere(extra_remainder == 0))[-1]] += 1 + # continue + # else: + # return False + # elif ( + # idx_choices.flatten().size == prev.flatten().size + # and np.equal( + # idx_choices.flatten(), prev.flatten() + # ).all()) and np.all(init_i + extra_remainder <= n_per_col + per_col_remainder): + # if np.all(init_i + extra_remainder == n_per_col + per_col_remainder): + # extra_remainder = 1 + # continue + # elif np.min(init_i) < n_per_col - per_col_remainder - 1: + # new_init_i = np.where(init_i == min(min(init_i), n_per_col - per_col_remainder - 1), init_i + 1, init_i) + # if not np.equal(new_init_i, init_i).all(): + # init_i = new_init_i + # continue + # elif n_col_ucs <= n_per_col: + # n_col_ucs = n_per_col + 1 + # continue + # elif n_col_ucs <= n_per_col: + # n_col_ucs += 1 + # elif n_add_ucs == idx_choices.flatten().size: + # if self.debug: + # logger.finfo(f"Row {axis_id}:", indent="\t") + # logger.finfo( + # f"Adding {n_add_ucs} from {idx_choices.flatten()}", + # indent="\t\t", + # ) + # break + # elif np.any(init_i == extra_remainder + n_per_col + per_col_remainder - 1): + # init_i_idxs = np.argwhere(init_i < n_per_col + per_col_remainder + extra_remainder - 1) + # if init_i_idxs.size == 0: + # init_i_idxs = np.argwhere(init_i < n_per_col + per_col_remainder + extra_remainder) + # remainder_idxs = np.argwhere(extra_remainder == 0) + # order_idxs = np.intersect1d(init_i_idxs, remainder_idxs, assume_unique=True) + # if order_idxs.size == 0: + # return False + # order_idxs = np.intersect1d(order, order_idxs, assume_unique=True) + # extra_remainder[order[order_idxs[-1]]] = 1 + # else: + # order_idxs = np.intersect1d(order, init_i_idxs, assume_unique=True) + # init_i[order[order_idxs[-1]]] += 1 + + # elif extra_remainder == 0 and np.all(init_i >= n_per_col + per_col_remainder - 1): + # extra_remainder = 1 + # continue + # elif np.any(init_i < extra_remainder + n_per_col + per_col_remainder) and np.any(extra_remainder == 0): + + # free_cols = np.logical_and( + # free, + # combined_counts + # < n_col_ucs + max(1, per_col_remainder), + # ) if ( idx_choices is None or idx_choices.flatten().size < n_add_ucs ): return False # if - continuous_choices = idx_choices[ - [ - *( - idx_choices[:-1] + 1 - == np.roll(idx_choices, -1)[:-1] - ), - *( - idx_choices[-1:] - 1 - == np.roll(idx_choices, 1)[-1:] - ), - ] - ] - if n_add_ucs <= continuous_choices.size // 2 and n_add_ucs > 1: - p = np.zeros_like(idx_choices, dtype=np.float_) - start_idx = self.random_generator.choice( - [0, 1], 1, replace=False - )[0] - p[np.sort(continuous_choices)[start_idx::2]] = 1 - if p[0] == 1 and p[-1] == 1: - p[ - self.random_generator.choice( - [0, -1], 1, replace=False - )[0] - ] = 0 - p = np.divide(p, np.sum(p), where=p != 0) - else: - p = np.full_like( - idx_choices, - np.divide(1, idx_choices.size), - dtype=np.float_, - ) + # continuous_choices = np.intersect1d( + # idx_choices, + # idx_choices[ + # [ + # *( + # idx_choices[:-1] + 1 + # == np.roll(idx_choices, -1)[:-1] + # ), + # *( + # idx_choices[-1:] - 1 + # == np.roll(idx_choices, 1)[-1:] + # ), + # ] + # ], + # assume_unique=True, + # return_indices=True, + # )[1] + # if n_add_ucs <= continuous_choices.size // 2 and n_add_ucs > 1: + # p = np.zeros_like(idx_choices, dtype=np.float_) + # start_idx = self.random_generator.choice( + # [0, 1], 1, replace=False + # )[0] + # p[np.sort(continuous_choices)[start_idx::2]] = 1 + # # pass + # if p[0] == 1 and p[-1] == 1: + # p[ + # self.random_generator.choice( + # [0, -1], 1, replace=False + # )[0] + # ] = 0 + # p = np.divide(p, np.sum(p), where=p != 0) + # else: + p = np.full_like( + idx_choices, + np.divide(1, idx_choices.size), + dtype=np.float_, + ) idx_sel = None if idx_choices.size == n_add_ucs: idx_sel = idx_choices @@ -1424,25 +1577,30 @@ def get_uc_sheet_array(self): idx_sel = idx_choices if self.debug: logger.finfo(f"Selected {idx_sel}", indent="\t") + logger.finfo( + f"Occupancy counts:\n" + + "\n\t".join([f"{c}" for c in counts.tolist()]), + indent="\t", + ) idxs_mask[axis_id, idx_sel] = charge uc_ids[axis_id, idx_sel], uc_array = np.split( uc_array, [n_add_ucs] ) - if self.debug: - symbol_arr[axis_id, idx_sel] = symbol + # if self.debug: + symbol_arr[axis_id, idx_sel] = symbol if idx_sel.size != 0: prev = np.sort(idx_sel) if max_dict[max_ax_len] == 1: uc_ids = uc_ids.T - if self.debug: - logger.finfo("Added charges:") - for k, v in symbol_dict.items(): - logger.finfo( - kwd_str=f"{k}: ", message=f"{v:2.1f}", indent="\t" - ) - logger.finfo("Final symbol matrix:") - for line in symbol_arr: - logger.finfo(" ".join(line), indent="\t") + else: + symbol_arr = symbol_arr.T + # if self.debug: + logger.finfo("Added charges:") + for k, v in symbol_dict.items(): + logger.finfo(kwd_str=f"{k}: ", message=f"{v:2.1f}", indent="\t") + logger.finfo("Final charge arrangement:") + for line in symbol_arr: + logger.finfo(" ".join(line), indent="\t") return uc_ids def get_all_diagonals(self, arr): @@ -1528,6 +1686,7 @@ def get_idxs( remainder=0, ): n_remaining = 0 + init_i = max(1, init_i) for i in range(init_i, per_col_ucs + remainder + 1): if i == per_col_ucs: n_remaining = remainder diff --git a/package/ClayCode/builder/claycomp.py b/package/ClayCode/builder/claycomp.py index 51b56204..9c21b325 100644 --- a/package/ClayCode/builder/claycomp.py +++ b/package/ClayCode/builder/claycomp.py @@ -94,7 +94,7 @@ def uc_stem(self) -> str: """Unit cell stem :return: unit cell stem :rtype: str""" - return self.stem[:2] + return self.stem[:-3] @cached_property def atom_df(self) -> pd.DataFrame: @@ -131,7 +131,7 @@ def __init__( from ClayCode.core.classes import ForceField if uc_stem is None: - self.uc_stem = self.itp_filelist[0].stem[:-3] + self.uc_stem = self.uc_itp_filelist[0].stem[:-3] else: self.uc_stem: str = uc_stem logger.info(get_subheader("Getting unit cell data")) @@ -157,6 +157,10 @@ def __init__( self.df.index.get_level_values("sheet").unique().to_list() ) + @property + def uc_gro_filelist(self): + return self.gro_filelist.filter(f"{self.uc_stem}[0-9][0-9][0-9]") + def check_ucs(self): uc_error_charges = {} for uc in sorted(self.uc_list): @@ -205,7 +209,7 @@ def get_uc_groups(self, write=True, reset=False): self.__bbox_height = {} extract_id = lambda file: file.stem[-3:] if not (self.path / "uc_groups.pkl").is_file() or not write: - for uc in sorted(self.gro_filelist): + for uc in sorted(self.uc_gro_filelist): uc_dimensions = uc.universe.dimensions bbox_height_new = np.ediff1d(uc.universe.atoms.bbox()[:, 2])[0] dim_str = "_".join(uc_dimensions.round(3).astype(str)) @@ -253,6 +257,10 @@ def get_uc_groups(self, write=True, reset=False): self.__itp_groups = uc_groups["itp_groups"] self.__base_ucs = {} + @property + def uc_itp_filelist(self): + return self.itp_filelist.filter(f"{self.uc_stem}[0-9][0-9][0-9]") + @property def bbox_height(self): if self.group_id is not None: @@ -316,7 +324,7 @@ def _gro_df(self): gro_df.set_index("uc-id", inplace=True, append=True) gro_df.index = gro_df.index.reorder_levels(["uc-id", "atom-id"]) gro_df.sort_index(level="uc-id", inplace=True, sort_remaining=True) - for gro in self.gro_filelist: + for gro in self.uc_gro_filelist: n_atoms = self.n_atoms.filter(regex=gro.stem[-3:]).values[0] gro_df.update(gro.df.set_index("atom-id", append=True)) return gro_df @@ -392,58 +400,86 @@ def df(self) -> pd.DataFrame: ) def __get_full_df(self, write=True, reset=False): - if reset is True and (self.path / "full_df.csv").is_file(): - os.remove(self.path / "full_df.csv") - idx = self.atomtypes.iloc[:, 0] - cols = [*self.uc_idxs, "charge", "sheet"] - self._full_df: pd.DataFrame = pd.DataFrame( - index=idx, columns=cols, dtype=np.float64 - ) - self._full_df["charge"].update( - self.atomtypes.set_index("at-type")["charge"] - ) - self.__get_df_sheet_annotations() - self._full_df["sheet"].fillna("X", inplace=True) - self._full_df.fillna(0, inplace=True) - # self._full_df = dd.from_pandas(self._full_df, chunksize=1000) - if not (self.path / "full_df.csv").is_file() or not write: - for uc in self.uc_list: - try: - atoms = uc["atoms"].df - self._full_df[f"{uc.idx}"].update( - atoms.value_counts("at-type") - ) - except AttributeError: - logger.finfo(f"Invalid unit cell {uc.name!r}") - for suffix in [".gro", ".itp"]: - try: - remove = select_input_option( - instance_or_manual_setup=True, - query=f"Remove invalid unit cell {uc.idx}? [y]es/[n]o (default y)\n", - options=["y", "n", ""], - result=None, - result_map={"y": True, "n": False, "": True}, - ) - if remove: - os.remove(uc.with_suffix(suffix)) - except FileNotFoundError: - pass - self._full_df.set_index("sheet", append=True, inplace=True) - self._full_df.sort_index( - inplace=True, level=1, sort_remaining=True + finished = False + while finished is False: + finished = True + if reset is True and (self.path / "full_df.csv").is_file(): + os.remove(self.path / "full_df.csv") + if (self.path / "uc_groups.pkl").is_file(): + os.remove(self.path / "uc_groups.pkl") + idx = self.atomtypes.iloc[:, 0].copy() + idx = idx.reindex([*idx.index.values, len(idx)]) + idx.iloc[-1] = "itp_charge" + cols = [*self.uc_idxs, "charge", "sheet"] + self._full_df: pd.DataFrame = pd.DataFrame( + index=idx, columns=cols, dtype=np.float64 ) - self._full_df.index = self._full_df.index.reorder_levels( - ["sheet", "at-type"] - ) - self.check_ucs() - if write: - self._full_df.to_csv(self.path / "full_df.csv") - else: - self._full_df = pd.read_csv( - self.path / "full_df.csv", index_col=[0, 1] + self._full_df["charge"].update( + self.atomtypes.set_index("at-type")["charge"] ) - if len(self._full_df.columns) == len(self.itp_filelist): - return False + self.__get_df_sheet_annotations() + self._full_df["sheet"].fillna("X", inplace=True) + self._full_df.fillna(0, inplace=True) + # self._full_df = dd.from_pandas(self._full_df, chunksize=1000) + if not (self.path / "full_df.csv").is_file() or not write: + for uc in self.uc_list: + try: + atoms: UnitCell = uc.atom_df + self._full_df[f"{uc.idx}"].update( + atoms.value_counts("at-type") + ) + self._full_df.loc[ + "itp_charge", f"{uc.idx}" + ] = uc.charge + self._full_df[f"{uc.idx}"].fillna(0, inplace=True) + except AttributeError: + logger.finfo(f"Invalid unit cell {uc.name!r}") + for suffix in [".gro", ".itp"]: + try: + remove = select_input_option( + instance_or_manual_setup=True, + query=f"Remove invalid unit cell {uc.idx}? [y]es/[n]o (default y)\n", + options=["y", "n", ""], + result=None, + result_map={ + "y": True, + "n": False, + "": True, + }, + ) + if remove: + os.remove(uc.with_suffix(suffix)) + except FileNotFoundError: + pass + self._full_df.set_index("sheet", append=True, inplace=True) + self._full_df.sort_index( + inplace=True, level=1, sort_remaining=True + ) + self._full_df.index = self._full_df.index.reorder_levels( + ["sheet", "at-type"] + ) + self.check_ucs() + if write: + self._full_df.to_csv(self.path / "full_df.csv") + else: + self._full_df = pd.read_csv( + self.path / "full_df.csv", index_col=[0, 1] + ) + try: + if not np.equal( + self._full_df.loc["itp_charge", "itp_charge"], + np.rint(self._full_df.loc["itp_charge", "itp_charge"]), + ).all(): + finished = False + except KeyError: + finished = False + if ( + sorted(self._full_df.columns)[:-1] != sorted(self.uc_idxs) + ) or finished is False: + os.remove(self.path / "full_df.csv") + if (self.path / "uc_groups.pkl").is_file(): + os.remove(self.path / "uc_groups.pkl") + finished = False def __get_df_sheet_annotations(self): old_index = self._full_df.index @@ -461,23 +497,27 @@ def __get_df_sheet_annotations(self): new_index = pd.MultiIndex.from_tuples(index_extension_list) new_index = new_index.to_frame().set_index(1) self._full_df["sheet"].update(new_index[0]) + self._full_df.loc["itp_charge", "sheet"] = "itp_charge" def __get_df(self): # self._df = dd.from_pandas(self._full_df, chunksize=1000) # self._df.reset_index('at-type', inplace=True) # self._df.filter(regex=r"^(?![X].*)", axis=0, inplace=True) - self._df = self.full_df.reset_index("at-type").filter( - regex=r"^(?![X].*)", axis=0 + self._df = ( + self.full_df.copy() + .reset_index("at-type") + .filter(regex=r"^(?![X].*)", axis=0) ) self._df = ( self._df.reset_index() .set_index(["sheet", "at-type"]) .sort_index(axis=1) ) + self._df.drop("itp_charge", inplace=True) @cached_property def uc_list(self) -> List[UnitCell]: - uc_list = [UnitCell(itp) for itp in self.itp_filelist] + uc_list = [UnitCell(itp) for itp in self.uc_itp_filelist] return uc_list @cached_property @@ -485,7 +525,7 @@ def occupancies(self) -> Dict[str, int]: return self._get_occupancies(self.df) @cached_property - def tot_charge(self) -> pd.Series: + def ff_charge(self) -> pd.Series: charge = self.full_df.apply( lambda x: x * self.full_df["charge"], raw=True ) @@ -504,9 +544,21 @@ def tot_charge(self) -> pd.Series: ] return total_charge + @cached_property + def tot_charge(self) -> pd.Series: + total_charge = self.full_df.loc[("itp_charge", "itp_charge")].filter( + regex="[0-9]+" + ) + return total_charge.sort_index() + @cached_property def n_atoms(self): - return self.full_df.filter(regex="[0-9]+").astype(int).sum(axis=0) + return ( + self.full_df.drop("itp_charge") + .filter(regex="[0-9]+") + .astype(int) + .sum(axis=0) + ) @cached_property def uc_composition(self) -> pd.DataFrame: @@ -534,7 +586,9 @@ def check(self) -> None: @cached_property def available(self) -> List[ITPFile]: - return self.itp_filelist.extract_fstems() + return self.uc_itp_filelist._extract_parts( + part="stem", pre_reset=False + ) # def __str__(self): # return f"{self.__class__.__name__}({self.name!r})" @@ -608,7 +662,7 @@ def _( ox_dict = UCData._get_ox_dict() # df = df.loc[['T','O']] ox_df: pd.DataFrame = df.copy() - + # ox_df = ox_df.drop('itp_charge') try: ox_df = ox_df.loc[~(ox_df == 0).all(1), :] except ValueError: @@ -618,6 +672,12 @@ def _( "at-type" ).to_frame() at_types.index = ox_df.index + # for idx_entry in ("O", "fe_tot"), ('itp_charge', 'itp_charge')]: + try: + at_types.drop(("O", "fe_tot"), inplace=True) + + except KeyError: + pass at_types = at_types.applymap(lambda x: ox_dict[x]) if tot_charge is not None: _ox_df = ox_df.loc[:, tot_charge.abs() == tot_charge.abs().min()] @@ -649,10 +709,6 @@ def _( ) # _ox_df = ox_df.groupby('sheet', group_keys=False, sort=True).apply(lambda x: x.sort_values(ascending=False).first()) # ox_df[:] = ox_df.groupby('sheet', group_keys=True).sum() - try: - at_types.drop(("O", "fe_tot"), inplace=True) - except KeyError: - pass if type(ox_df) == pd.DataFrame: ox: pd.DataFrame = ox_df.apply(lambda x: x * at_types["at-type"]) @@ -983,9 +1039,14 @@ def __get_match_df(self, csv_file: Union[str, File]) -> pd.DataFrame: @property def df(self): - return self._df.dropna().sort_index( + df = self._df.dropna().sort_index( ascending=False, level="sheet", sort_remaining=True ) + # try: + # df = df.drop('itp_charge') + # except KeyError: + # pass + return df def reduce_charge( self, diff --git a/package/ClayCode/builder/utils.py b/package/ClayCode/builder/utils.py index 2f3de8e6..16a08b65 100644 --- a/package/ClayCode/builder/utils.py +++ b/package/ClayCode/builder/utils.py @@ -80,7 +80,7 @@ def select_input_option( :rtype: Any """ while result not in options: - result = input(query).lower() + result = input(query).lower().strip(" ") if result_map is not None: result = result_map[result] return result @@ -117,7 +117,9 @@ def get_checked_input( :rtype: Any """ while not isinstance(result, result_type): - result_input = input(f"{query} (or exit with {exit_val!r})\n") + result_input = input(f"{query} (or exit with {exit_val!r})\n").strip( + " " + ) if result_input == exit_val: logger.info(f"Selected {exit_val!r}, exiting.") sys.exit(0) diff --git a/package/ClayCode/data/data/UCS/D11/D1001.gro b/package/ClayCode/data/data/UCS/D11/D1001.gro new file mode 100644 index 00000000..896d784c --- /dev/null +++ b/package/ClayCode/data/data/UCS/D11/D1001.gro @@ -0,0 +1,37 @@ +Dioctahedral 1:1 unit cell 1 + 34 + 1D101 AO1 1 0.061 0.433 0.332 + 1D101 AO2 2 0.321 0.283 0.332 + 1D101 AO3 3 0.320 0.880 0.332 + 1D101 AO4 4 0.064 0.730 0.332 + 1D101 ST1 5 0.237 0.749 0.065 + 1D101 ST2 6 0.500 0.594 0.067 + 1D101 ST3 7 0.493 0.301 0.065 + 1D101 ST4 8 0.244 0.147 0.067 + 1D101 OB1 9 0.225 0.751 0.226 + 1D101 OB2 10 0.255 0.135 0.227 + 1D101 OB3 11 0.258 0.000 0.000 + 1D101 OB4 12 0.359 0.651 0.021 + 1D101 OB5 13 0.360 0.236 0.001 + 1D101 OB6 14 0.480 0.304 0.226 + 1D101 OB7 15 0.510 0.582 0.227 + 1D101 OB8 16 0.002 0.447 0.000 + 1D101 OB9 17 0.100 0.204 0.021 + 1D101 OB10 18 0.104 0.683 0.001 + 1D101 OH1 19 0.223 0.413 0.232 + 1D101 OH2 20 0.123 0.581 0.433 + 1D101 OH3 21 0.164 0.855 0.431 + 1D101 OH4 22 0.162 0.306 0.434 + 1D101 OH5 23 0.480 0.860 0.232 + 1D101 OH6 24 0.379 0.134 0.433 + 1D101 OH7 25 0.420 0.408 0.431 + 1D101 OH8 26 0.420 0.753 0.434 + 1D101 HO1 27 0.530 0.940 0.233 + 1D101 HO2 28 0.410 0.129 0.527 + 1D101 HO3 29 0.400 0.434 0.522 + 1D101 HO4 30 0.137 0.264 0.519 + 1D101 HO5 31 0.272 0.497 0.233 + 1D101 HO6 32 0.150 0.576 0.527 + 1D101 HO7 33 0.136 0.880 0.522 + 1D101 HO8 34 0.400 0.712 0.519 + 0.51540 0.89420 0.63910 diff --git a/package/ClayCode/data/data/UCS/D11/D1001.itp b/package/ClayCode/data/data/UCS/D11/D1001.itp new file mode 100644 index 00000000..97396626 --- /dev/null +++ b/package/ClayCode/data/data/UCS/D11/D1001.itp @@ -0,0 +1,54 @@ +; +; +[ moleculetype ] +; name nrexcl + D1001 1 + +[ atoms ] +; nr type resnr residue atom cgnr charge mass typeB chargeB massB +; residue 1 KAO rtp KAO q 0.0 + 1 ao 1 D1001 AO1 1 1.575 26.98 ; + 2 ao 1 D1001 AO2 2 1.575 26.98 ; + 3 ao 1 D1001 AO3 3 1.575 26.98 ; + 4 ao 1 D1001 AO4 4 1.575 26.98 ; + 5 st 1 D1001 ST1 5 2.1 28.09 ; + 6 st 1 D1001 ST2 6 2.1 28.09 ; + 7 st 1 D1001 ST3 7 2.1 28.09 ; + 8 st 1 D1001 ST4 8 2.1 28.09 ; + 9 ob 1 D1001 OB1 9 -1.05 16 ; + 10 ob 1 D1001 OB2 10 -1.05 16 ; + 11 ob 1 D1001 OB3 11 -1.05 16 ; + 12 ob 1 D1001 OB4 12 -1.05 16 ; + 13 ob 1 D1001 OB5 13 -1.05 16 ; + 14 ob 1 D1001 OB6 14 -1.05 16 ; + 15 ob 1 D1001 OB7 15 -1.05 16 ; + 16 ob 1 D1001 OB8 16 -1.05 16 ; + 17 ob 1 D1001 OB9 17 -1.05 16 ; + 18 ob 1 D1001 OB10 18 -1.05 16 ; + 19 oh 1 D1001 OH1 19 -0.95 16 ; + 20 oh 1 D1001 OH2 20 -0.95 16 ; + 21 oh 1 D1001 OH3 21 -0.95 16 ; + 22 oh 1 D1001 OH4 22 -0.95 16 ; + 23 oh 1 D1001 OH5 23 -0.95 16 ; + 24 oh 1 D1001 OH6 24 -0.95 16 ; + 25 oh 1 D1001 OH7 25 -0.95 16 ; + 26 oh 1 D1001 OH8 26 -0.95 16 ; + 27 ho 1 D1001 HO1 27 0.425 1.008 ; + 28 ho 1 D1001 HO2 28 0.425 1.008 ; + 29 ho 1 D1001 HO3 29 0.425 1.008 ; + 30 ho 1 D1001 HO4 30 0.425 1.008 ; + 31 ho 1 D1001 HO5 31 0.425 1.008 ; + 32 ho 1 D1001 HO6 32 0.425 1.008 ; + 33 ho 1 D1001 HO7 33 0.425 1.008 ; + 34 ho 1 D1001 HO8 34 0.425 1.008 ; + +[ bonds ] +; i j funct length force.c. +19 31 1 0.1 463532.808 +20 32 1 0.1 463532.808 +21 33 1 0.1 463532.808 +22 30 1 0.1 463532.808 +23 27 1 0.1 463532.808 +24 28 1 0.1 463532.808 +25 29 1 0.1 463532.808 +26 34 1 0.1 463532.808 diff --git a/package/ClayCode/data/data/UCS/D11/D1002.gro b/package/ClayCode/data/data/UCS/D11/D1002.gro new file mode 100644 index 00000000..4031fda2 --- /dev/null +++ b/package/ClayCode/data/data/UCS/D11/D1002.gro @@ -0,0 +1,37 @@ +Dioctahedral 1:1 unit cell 2 + 34 + 1D102 AO1 1 0.061 0.433 0.332 + 1D102 AO2 2 0.321 0.283 0.332 + 1D102 AO3 3 0.320 0.880 0.332 + 1D102 AO4 4 0.064 0.730 0.332 + 1D102 ST1 5 0.237 0.749 0.065 + 1D102 ST2 6 0.500 0.594 0.067 + 1D102 ST3 7 0.493 0.301 0.065 + 1D102 AT4 8 0.244 0.147 0.067 + 1D102 OB1 9 0.225 0.751 0.226 + 1D102 OBT2 10 0.255 0.135 0.227 + 1D102 OBT3 11 0.258 0.000 0.000 + 1D102 OB4 12 0.359 0.651 0.021 + 1D102 OBT5 13 0.360 0.236 0.001 + 1D102 OB6 14 0.480 0.304 0.226 + 1D102 OB7 15 0.510 0.582 0.227 + 1D102 OB8 16 0.002 0.447 0.000 + 1D102 OBT9 17 0.100 0.204 0.021 + 1D102 OB10 18 0.104 0.683 0.001 + 1D102 OH1 19 0.223 0.413 0.232 + 1D102 OH2 20 0.123 0.581 0.433 + 1D102 OH3 21 0.164 0.855 0.431 + 1D102 OH4 22 0.162 0.306 0.434 + 1D102 OH5 23 0.480 0.860 0.232 + 1D102 OH6 24 0.379 0.134 0.433 + 1D102 OH7 25 0.420 0.408 0.431 + 1D102 OH8 26 0.420 0.753 0.434 + 1D102 HO1 27 0.530 0.940 0.233 + 1D102 HO2 28 0.410 0.129 0.527 + 1D102 HO3 29 0.400 0.434 0.522 + 1D102 HO4 30 0.137 0.264 0.519 + 1D102 HO5 31 0.272 0.497 0.233 + 1D102 HO6 32 0.150 0.576 0.527 + 1D102 HO7 33 0.136 0.880 0.522 + 1D102 HO8 34 0.400 0.712 0.519 + 0.51540 0.89420 0.63910 diff --git a/package/ClayCode/data/data/UCS/D11/D1002.itp b/package/ClayCode/data/data/UCS/D11/D1002.itp new file mode 100644 index 00000000..b820021f --- /dev/null +++ b/package/ClayCode/data/data/UCS/D11/D1002.itp @@ -0,0 +1,54 @@ +; +; +[ moleculetype ] +; name nrexcl + D1002 1 + +[ atoms ] +; nr type resnr residue atom cgnr charge mass typeB chargeB massB +; residue 1 KAO rtp KAO q 0.0 + 1 ao 1 D1002 AO1 1 1.575 26.98 ; + 2 ao 1 D1002 AO2 2 1.575 26.98 ; + 3 ao 1 D1002 AO3 3 1.575 26.98 ; + 4 ao 1 D1002 AO4 4 1.575 26.98 ; + 5 st 1 D1002 ST1 5 2.1 28.09 ; + 6 st 1 D1002 ST2 6 2.1 28.09 ; + 7 st 1 D1002 ST3 7 2.1 28.09 ; + 8 at 1 D1002 AT4 8 1.575 26.98 ; substitution for Td Al + 9 ob 1 D1002 OB1 9 -1.05 16 ; + 10 obts 1 D1002 OBT2 10 -1.16875 16 ; obts near td Al + 11 obts 1 D1002 OBT3 11 -1.16875 16 ; obts near td Al + 12 ob 1 D1002 OB4 12 -1.05 16 ; + 13 obts 1 D1002 OBT5 13 -1.16875 16 ; obts near td Al + 14 ob 1 D1002 OB6 14 -1.05 16 ; + 15 ob 1 D1002 OB7 15 -1.05 16 ; + 16 ob 1 D1002 OB8 16 -1.05 16 ; + 17 obts 1 D1002 OBT9 17 -1.16875 16 ; obts near td Al + 18 ob 1 D1002 OB10 18 -1.05 16 ; + 19 oh 1 D1002 OH1 19 -0.95 16 ; + 20 oh 1 D1002 OH2 20 -0.95 16 ; + 21 oh 1 D1002 OH3 21 -0.95 16 ; + 22 oh 1 D1002 OH4 22 -0.95 16 ; + 23 oh 1 D1002 OH5 23 -0.95 16 ; + 24 oh 1 D1002 OH6 24 -0.95 16 ; + 25 oh 1 D1002 OH7 25 -0.95 16 ; + 26 oh 1 D1002 OH8 26 -0.95 16 ; + 27 ho 1 D1002 HO1 27 0.425 1.008 ; + 28 ho 1 D1002 HO2 28 0.425 1.008 ; + 29 ho 1 D1002 HO3 29 0.425 1.008 ; + 30 ho 1 D1002 HO4 30 0.425 1.008 ; + 31 ho 1 D1002 HO5 31 0.425 1.008 ; + 32 ho 1 D1002 HO6 32 0.425 1.008 ; + 33 ho 1 D1002 HO7 33 0.425 1.008 ; + 34 ho 1 D1002 HO8 34 0.425 1.008 ; + +[ bonds ] +; i j funct length force.c. +19 31 1 0.1 463532.808 +20 32 1 0.1 463532.808 +21 33 1 0.1 463532.808 +22 30 1 0.1 463532.808 +23 27 1 0.1 463532.808 +24 28 1 0.1 463532.808 +25 29 1 0.1 463532.808 +26 34 1 0.1 463532.808 diff --git a/package/ClayCode/data/data/UCS/D11/D1003.gro b/package/ClayCode/data/data/UCS/D11/D1003.gro new file mode 100644 index 00000000..722f775a --- /dev/null +++ b/package/ClayCode/data/data/UCS/D11/D1003.gro @@ -0,0 +1,37 @@ +Dioctahedral 1:1 unit cell 3 + 34 + 1D103 AO1 1 0.061 0.433 0.332 + 1D103 MGO2 2 0.321 0.283 0.332 + 1D103 AO3 3 0.320 0.880 0.332 + 1D103 AO4 4 0.064 0.730 0.332 + 1D103 ST1 5 0.237 0.749 0.065 + 1D103 ST2 6 0.500 0.594 0.067 + 1D103 ST3 7 0.493 0.301 0.065 + 1D103 ST4 8 0.244 0.147 0.067 + 1D103 OB1 9 0.225 0.751 0.226 + 1D103 OBO2 10 0.255 0.135 0.227 + 1D103 OB3 11 0.258 0.000 0.000 + 1D103 OB4 12 0.359 0.651 0.021 + 1D103 OB5 13 0.360 0.236 0.001 + 1D103 OBO6 14 0.480 0.304 0.226 + 1D103 OB7 15 0.510 0.582 0.227 + 1D103 OB8 16 0.002 0.447 0.000 + 1D103 OB9 17 0.100 0.204 0.021 + 1D103 OB10 18 0.104 0.683 0.001 + 1D103 OHS1 19 0.223 0.413 0.232 + 1D103 OH2 20 0.123 0.581 0.433 + 1D103 OH3 21 0.164 0.855 0.431 + 1D103 OHS4 22 0.162 0.306 0.434 + 1D103 OH5 23 0.480 0.860 0.232 + 1D103 OHS6 24 0.379 0.134 0.433 + 1D103 OHS7 25 0.420 0.408 0.431 + 1D103 OH8 26 0.420 0.753 0.434 + 1D103 HO1 27 0.530 0.940 0.233 + 1D103 HO2 28 0.410 0.129 0.527 + 1D103 HO3 29 0.400 0.434 0.522 + 1D103 HO4 30 0.137 0.264 0.519 + 1D103 HO5 31 0.272 0.497 0.233 + 1D103 HO6 32 0.150 0.576 0.527 + 1D103 HO7 33 0.136 0.880 0.522 + 1D103 HO8 34 0.400 0.712 0.519 + 0.51540 0.89420 0.63910 diff --git a/package/ClayCode/data/data/UCS/D11/D1003.itp b/package/ClayCode/data/data/UCS/D11/D1003.itp new file mode 100644 index 00000000..466cc415 --- /dev/null +++ b/package/ClayCode/data/data/UCS/D11/D1003.itp @@ -0,0 +1,54 @@ +; +; +[ moleculetype ] +; name nrexcl + D1003 1 + +[ atoms ] +; nr type resnr residue atom cgnr charge mass typeB chargeB massB +; residue 1 KAO rtp KAO q 0.0 + 1 ao 1 D1003 AO1 1 1.575 26.98 ; + 2 mgo 1 D1003 MGO2 2 1.360 24.31 ; substitution for oct Mg + 3 ao 1 D1003 AO3 3 1.575 26.98 ; + 4 ao 1 D1003 AO4 4 1.575 26.98 ; + 5 st 1 D1003 ST1 5 2.1 28.09 ; + 6 st 1 D1003 ST2 6 2.1 28.09 ; + 7 st 1 D1003 ST3 7 2.1 28.09 ; + 8 st 1 D1003 ST4 8 2.1 28.09 ; + 9 ob 1 D1003 OB1 9 -1.05 16 ; + 10 obos 1 D1003 OBO2 10 -1.1808 16 ; obos near Octh Mg + 11 ob 1 D1003 OB3 11 -1.05 16 ; + 12 ob 1 D1003 OB4 12 -1.05 16 ; + 13 ob 1 D1003 OB5 13 -1.05 16 ; + 14 obos 1 D1003 OBO6 14 -1.1808 16 ; obos near Octh Mg + 15 ob 1 D1003 OB7 15 -1.05 16 ; + 16 ob 1 D1003 OB8 16 -1.05 16 ; + 17 ob 1 D1003 OB9 17 -1.05 16 ; + 18 ob 1 D1003 OB10 18 -1.05 16 ; + 19 ohs 1 D1003 OHS1 19 -1.08085 16 ; ohs near Octh Mg + 20 oh 1 D1003 OH2 20 -0.95 16 ; + 21 oh 1 D1003 OH3 21 -0.95 16 ; + 22 ohs 1 D1003 OHS4 22 -1.08085 16 ; ohs near Octh Mg + 23 oh 1 D1003 OH5 23 -0.95 16 ; + 24 ohs 1 D1003 OHS6 24 -1.08085 16 ; ohs near Octh Mg + 25 ohs 1 D1003 OHS7 25 -1.08085 16 ; ohs near Octh Mg + 26 oh 1 D1003 OH8 26 -0.95 16 ; + 27 ho 1 D1003 HO1 27 0.425 1.008 ; + 28 ho 1 D1003 HO2 28 0.425 1.008 ; + 29 ho 1 D1003 HO3 29 0.425 1.008 ; + 30 ho 1 D1003 HO4 30 0.425 1.008 ; + 31 ho 1 D1003 HO5 31 0.425 1.008 ; + 32 ho 1 D1003 HO6 32 0.425 1.008 ; + 33 ho 1 D1003 HO7 33 0.425 1.008 ; + 34 ho 1 D1003 HO8 34 0.425 1.008 ; + +[ bonds ] +; i j funct length force.c. +19 31 1 0.1 463532.808 +20 32 1 0.1 463532.808 +21 33 1 0.1 463532.808 +22 30 1 0.1 463532.808 +23 27 1 0.1 463532.808 +24 28 1 0.1 463532.808 +25 29 1 0.1 463532.808 +26 34 1 0.1 463532.808 diff --git a/package/ClayCode/data/data/UCS/charge_occ.csv b/package/ClayCode/data/data/UCS/charge_occ.csv index b9cdc84b..3199ef17 100755 --- a/package/ClayCode/data/data/UCS/charge_occ.csv +++ b/package/ClayCode/data/data/UCS/charge_occ.csv @@ -17,4 +17,5 @@ T,CD21,4,8,True O,CD21,3,4,True T,CD11,4,4,False O,CD11,3,4,False - +T,TD21,4,8,True +O,TD21,3,4,True diff --git a/pyproject.toml b/pyproject.toml index 1e670a52..1641e46c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ include-package-data = true "ClayCode.siminp.config" = ["defaults.yaml"] "ClayCode.siminp.scripts" = ["*.sh"] "ClayCode.addmols.config" = ["defaults.yaml", "addtypes.yaml"] +"ClayCode.analysis.data" = ['peaks_edges/*.p'] [project.urls] "Homepage" = "https://github.com/Erastova-group/ClayCode.git"