From d0258eb60d423c82f9821e0694aa5e7c7b3b6599 Mon Sep 17 00:00:00 2001 From: s2123329 Date: Thu, 29 Feb 2024 15:10:04 +0000 Subject: [PATCH 1/2] Update EM freeze index generation --- package/ClayCode/builder/assembly.py | 262 ++++++++++++++++++++++----- package/ClayCode/core/gmx.py | 34 +++- tests/ClayCodeTests/__main__.py | 0 tests/__init__.py | 3 - 4 files changed, 250 insertions(+), 49 deletions(-) create mode 100644 tests/ClayCodeTests/__main__.py delete mode 100644 tests/__init__.py diff --git a/package/ClayCode/builder/assembly.py b/package/ClayCode/builder/assembly.py index f8976518..7ea8941d 100644 --- a/package/ClayCode/builder/assembly.py +++ b/package/ClayCode/builder/assembly.py @@ -4,6 +4,7 @@ ===============================================================""" from __future__ import annotations +import copy import itertools import logging import math @@ -21,11 +22,24 @@ import unicodeit from ClayCode.builder.claycomp import InterlayerIons, UCData from ClayCode.builder.topology import TopologyConstructor -from ClayCode.core.classes import Dir, FileFactory, GROFile, TOPFile -from ClayCode.core.consts import ANGSTROM -from ClayCode.core.gmx import GMXCommands, add_gmx_args, check_box_lengths +from ClayCode.core.classes import ( + Dir, + FileFactory, + GROFile, + TOPFile, + set_mdp_freeze_groups, + set_mdp_parameter, +) +from ClayCode.core.consts import ANGSTROM, LINE_LENGTH +from ClayCode.core.gmx import ( + GMXCommands, + add_gmx_args, + check_box_lengths, + gmx_command_wrapper, +) from ClayCode.core.lib import ( add_ions_n_mols, + add_ions_neutral, add_resnum, center_clay, check_insert_numbers, @@ -35,7 +49,7 @@ write_insert_dat, ) from ClayCode.core.utils import backup_files, get_header, get_subheader -from ClayCode.data.consts import GRO_FMT +from ClayCode.data.consts import FF, GRO_FMT, MDP, MDP_DEFAULTS from MDAnalysis import AtomGroup, Merge, ResidueGroup, Universe from MDAnalysis.lib.mdamath import triclinic_box, triclinic_vectors from MDAnalysis.units import constants @@ -131,7 +145,7 @@ def get_il_ions(self, sheet_id=None): def extended_box(self) -> bool: """ :return: Whether a bulk space has been added to the clay stack. - :rtype: bool""" + :rtype: Bool""" return self.__box_ext def solvate_clay_sheets(self, backup: bool = False) -> None: @@ -635,7 +649,7 @@ def add_bulk_ions(self, backup=False) -> None: if excess_charge != 0: neutral_bulk_ions = InterlayerIons( excess_charge, - ion_ratios=self.args.bulk_ions, + ion_ratios=self.args.bulk_ions.df["conc"].to_dict(), n_ucs=1, neutral=True, ) @@ -1288,7 +1302,42 @@ def get_charge_groups(self) -> Tuple[int, NDArray]: if n_ucs != 0: yield charge_group, n_ucs, uc_array + def get_occ_counts(self, axis_id: int, free: NDArray) -> NDArray: + occ_cols = np.select([free], [1], 0) + occ_counts = np.sum(occ_cols, axis=0) + diag_counts = np.sum( + np.array( + [ + *np.fromiter( + self.get_all_diagonals(occ_cols), dtype=np.ndarray + ) + ] + ), + axis=1, + ) + opposite_diag_counts = np.flip( + np.sum( + np.array( + [ + *np.fromiter( + self.get_all_diagonals(np.flip(occ_cols, axis=1)), + dtype=np.ndarray, + ) + ] + ), + axis=1, + ) + ) + return occ_counts, diag_counts, opposite_diag_counts + def get_uc_sheet_array(self): + pm = "\u00B1" + logdict = { + 0: "right diagonal", + 1: "left diagonal", + 2: "columns", + } + # TODO: move counts to bottom and initialise with 0 for all max_dict = {self.x_cells: 0, self.y_cells: 1} max_ax_len = max(max_dict.keys()) other_ax_len = min(max_dict.keys()) @@ -1388,13 +1437,19 @@ def get_uc_sheet_array(self): ) init_i = np.array([0, 0, 0]) minmax = itertools.cycle([min, max]) + cycle_count = 0 + # increase max allowed counts if not enough free columns while free_cols[free_cols].size < n_add_ucs: free_cols = np.logical_and( free, combined_counts < n_col_ucs + next(minmax)(1, per_col_remainder), ) + cycle_count += 1 + if cycle_count > 4: + return False extra_remainder = np.zeros_like(init_i, dtype=np.int32) + # while init occ < max allowed occ and no or not enough idxs selected while np.any( np.less( init_i, extra_remainder + per_col_remainder + n_per_col @@ -1403,6 +1458,7 @@ def get_uc_sheet_array(self): idx_choices is None or (idx_choices.flatten().size < n_add_ucs) ): + # number of free cols == n_add_ucs (don't look further) if free[free].flatten().size == n_add_ucs: idx_choices = np.argwhere(free).flatten() break @@ -1411,28 +1467,13 @@ def get_uc_sheet_array(self): occ_devs = np.std(counts, axis=1) order = np.argsort(occ_devs)[::-1] if self.debug: - logger.finfo( - "Occupancy deviations:", + logger.finfo("Occupancy deviations:") + self._log_occ_devs( + logdict, pm, counts, occ_devs, order ) - pm = "\u00B1" - logdict = { - 0: "right diagonal", - 1: "left diagonal", - 2: "columns", - } - logstr = list( - map( - lambda x, y, z: f"\t{logdict[x]:15}: {y:.1f} {pm} {z:.1f}", - np.sort(order), - np.mean(counts[np.argsort(order)], axis=1), - occ_devs[np.argsort(order)], - ) - ) - logger.finfo("\n".join(logstr)) prev_choices = None - for occ_id in order: # self.random_generator.choice( - # [0, 1, 2], 3, replace=False - # ): + # get allowed idxs starting from count with highest deviation + for occ_id in order: ( idx_choices, allowed_cols, @@ -1448,21 +1489,29 @@ def get_uc_sheet_array(self): remainder=per_col_remainder + extra_remainder[occ_id], ) + # found idxs length == n_add_ucs if idx_choices.flatten().size == n_add_ucs: # print( # occ_id, # f": stopping with {idx_choices.flatten()}, n_add_ucs = {n_add_ucs}", # ) break + # if found idxs length < n_add_ucs, use previous idxs if available elif ( - idx_choices.flatten().size < n_add_ucs + idx_choices.flatten().size + < n_add_ucs + <= prev_choices.flatten().size and prev_choices is not None - and prev_choices.flatten().size >= n_add_ucs ): idx_choices = prev_choices prev_choices = idx_choices - if ( + # if still not enough idxs found, abort + if idx_choices.flatten().size < n_add_ucs: + return False + # if more idxs found than necessary and not first row, select idxs with lowest counts + # across all occ counters and try to use only idxs with lowest counts if possible + elif ( idx_choices.flatten().size > n_add_ucs and np.unique(counts).size > 1 ): @@ -1475,21 +1524,14 @@ def get_uc_sheet_array(self): and min_count.size != combined_counts.size and np.all(min_count != idx_choices) ): - pass + intersect_idxs = np.intersect1d( + idx_choices, min_count + ).flatten() + if intersect_idxs.size >= n_add_ucs: + idx_choices = intersect_idxs except ValueError: pass - if ( - min_count.size >= n_add_ucs - and min_count.size != combined_counts.size - and np.all(min_count != idx_choices) - ): - intersect_idxs = np.intersect1d( - idx_choices, min_count - ).flatten() - if intersect_idxs.size >= n_add_ucs: - idx_choices = intersect_idxs - elif idx_choices.flatten().size < n_add_ucs: - return False + # if idx_choices.flatten().size > n_add_ucs: # if prev.flatten().size != 0: # _, prev_idxs, _ = np.intersect1d( @@ -1498,6 +1540,7 @@ def get_uc_sheet_array(self): # if prev_idxs.flatten().size != 0: # remove_idxs = np.random.choice(prev_idxs, idx_choices.size - n_add_ucs, replace=False) # idx_choices = np.delete(idx_choices, remove_idxs) + # break if enough idxs found if idx_choices.flatten().size >= n_add_ucs: if self.debug: logger.finfo(f"Row {axis_id}:", indent="\t") @@ -1624,6 +1667,23 @@ def get_uc_sheet_array(self): symbol_arr[axis_id, idx_sel] = symbol if idx_sel.size != 0: prev = np.sort(idx_sel) + counts = np.array( + [ + np.roll(diag_counts, axis_id), + np.roll(opposite_diag_counts, -axis_id), + occ_counts, + ] + ) + occ_devs = np.std(counts, axis=1) + # make sure charges are evenly distributed + if np.any(occ_devs > 1): + return False + if self.debug: + order = np.argsort(occ_devs)[::-1] + logger.finfo( + f"Mean occupancies for charge group {charge_group_id} (q = {charge:+2.1f}):" + ) + self._log_occ_devs(logdict, pm, counts, occ_devs, order) if max_dict[max_ax_len] == 1: uc_ids = uc_ids.T else: @@ -1639,6 +1699,112 @@ def get_uc_sheet_array(self): logger.finfo(" ".join(line), indent="\t\t") return uc_ids + def _log_occ_devs(self, logdict, pm, counts, occ_devs, order): + logstr = list( + map( + lambda x, y, z: f"\t{logdict[x]:15}: {y:.1f} {pm} {z:.1f}", + np.sort(order), + np.mean(counts[np.argsort(order)], axis=1), + occ_devs[np.argsort(order)], + ) + ) + logger.finfo("\n".join(logstr)) + + @staticmethod + def _get_order(counts): + occ_devs = np.std(counts, axis=1) + order = np.argsort(occ_devs)[::-1] + # if self.debug: + logger.finfo("Occupancy deviations:") + pm = "\u00B1" + logdict = { + 0: "right diagonal", + 1: "left diagonal", + 2: "columns", + } + logstr = list( + map( + lambda x, y, z: f"\t{logdict[x]:15}: {y:.1f} {pm} {z:.1f}", + np.sort(order), + np.mean(counts[np.argsort(order)], axis=1), + occ_devs[np.argsort(order)], + ) + ) + return logstr, order + + def _get_counts(self, axis_id, charge, idxs_mask): + free = np.isnan(idxs_mask[axis_id]) + occ_cols = np.select([idxs_mask == charge], [1], 0) + occ_counts = np.sum(occ_cols, axis=0) + diag_counts = np.sum( + np.array( + [ + *np.fromiter( + self.get_all_diagonals(occ_cols), + dtype=np.ndarray, + ) + ] + ), + axis=1, + ) + opposite_diag_counts = np.flip( + np.sum( + np.array( + [ + *np.fromiter( + self.get_all_diagonals(np.flip(occ_cols, axis=1)), + dtype=np.ndarray, + ) + ] + ), + axis=1, + ) + ) + idx_choices = None + counts = np.array( + [ + np.roll(diag_counts, axis_id), + np.roll(opposite_diag_counts, -axis_id), + occ_counts, + ] + ) + combined_counts = np.rint( + np.mean( + [ + np.roll(diag_counts, axis_id), + np.roll(opposite_diag_counts, -axis_id), + occ_counts, + ], + axis=0, + ) + ) + return combined_counts, counts, free, idx_choices + + def _init_uc_array( + self, + charge_group, + charge_group_id, + charge_group_n_ucs, + lines, + max_ax_len, + other_ax_len, + remainder_choices, + remaining_add, + ): + uc_array = charge_group.copy() + self.random_generator.shuffle(uc_array) + remaining_add[charge_group_id] = 0 + n_per_line = charge_group_n_ucs // max_ax_len + per_line_remainder = charge_group_n_ucs % max_ax_len + n_per_col = charge_group_n_ucs // other_ax_len + per_col_remainder = charge_group_n_ucs % other_ax_len + # per_diag_col_remainder = charge_group_n_ucs % (other_ax_len) + # per_opp_diag_col_remainder = charge_group_n_ucs % (other_ax_len) + lines[charge_group_id], remainder_choices = np.split( + remainder_choices, [per_line_remainder] + ) + return n_per_col, n_per_line, per_col_remainder, uc_array + def get_all_diagonals(self, arr): x_dim, y_dim = arr.shape arr_p = np.pad(arr, ((0, 0), (0, x_dim)), mode="wrap") @@ -1977,7 +2143,7 @@ def write( spc_top: TOPFile = spc_gro.top spc_gro.universe = Universe.empty(n_atoms=0) spc_gro.write(topology=topology) - logger.finfo(f"Adding interlayer solvent:") + logger.finfo(f"Adding interlayer solvent to {spc_name.name!r}:") while True: if self._z_padding > 5: raise Exception( @@ -2033,3 +2199,13 @@ def check_solvent_nummols(self, solvate_stderr: str) -> None: f"insert {added_wat} instead of {self.n_mols} water " f"molecules." ) + + +if __name__ == "__main__": + gc = GMXCommands(gmx_alias="gmx_mpi") + gc.run_gmx_make_ndx_with_new_sel( + f=Path("/storage/new_clays/Na/NAu-1-fe/NAu-1-fe_7_5_solv_ions.gro"), + o=Path("index.ndx"), + sel_str="r T2* & ! a OH* HO*", + sel_name="new_sel", + ) diff --git a/package/ClayCode/core/gmx.py b/package/ClayCode/core/gmx.py index ca5cc9f9..dc1a7461 100644 --- a/package/ClayCode/core/gmx.py +++ b/package/ClayCode/core/gmx.py @@ -754,8 +754,33 @@ def run_gmx_make_ndx_with_new_sel( logger.ferror(f"No index file {o!r} was written.") sys.exit(3) out, err = output.stdout, output.stderr + # group_outp_error = re.search( + # r"\n>.*\n\s*\d+\s+(.*)\s*:\s+(\d+)\s+atoms\s*\n.*>", + # out, + # flags=re.MULTILINE | re.DOTALL, + # ) + no_residue = re.search( + "\n>\s*?\n.*?\nFound 0 atoms with (.*?)\s*?\n", + out, + flags=re.MULTILINE | re.DOTALL, + ) + if no_residue is not None: + logger.ferror( + f"Invalid group selector: {sel_str}.\n No atoms with {no_residue.group(1)} were found." + ) + sys.exit(3) + syntax_error = re.search( + "\n>\s*?\n.*?\nSyntax error: (.*?)\n.*?\n>", + out, + flags=re.MULTILINE | re.DOTALL, + ) + if syntax_error is not None: + logger.ferror( + f"Invalid group selector: {sel_str}.\n {syntax_error.group(1)}" + ) + sys.exit(3) group_outp = re.search( - r"\n>.*\n\s*\d+\s+(.*)\s*:\s+(\d+)\s+atoms\s*\n.*>", + r"\n>\s*?\n.*?(\d+)\s*?(atoms)\s*?\n+?>", out, flags=re.MULTILINE | re.DOTALL, ) @@ -765,8 +790,11 @@ def run_gmx_make_ndx_with_new_sel( ) sys.exit(3) else: - group_name = group_outp.group(1) - group_n_atoms = int(group_outp.group(2)) + group_name = re.sub("\s+", "_", sel_str) + group_name = re.sub("[rati]_", "", group_name) + group_name = re.sub("!_", "!", group_name) + # group_name = group_outp.group(1) + group_n_atoms = int(group_outp.group(1)) if sel_name is not None: with open(o, "r") as ndx_file: ndx_str = ndx_file.read() diff --git a/tests/ClayCodeTests/__main__.py b/tests/ClayCodeTests/__main__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index 9c244d29..00000000 --- a/tests/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from pathlib import Path - -DATA = Path(__file__).parent / "data" From 4570efedb2312753d82e2563865de9ca4230a17a Mon Sep 17 00:00:00 2001 From: s2123329 Date: Thu, 29 Feb 2024 15:10:42 +0000 Subject: [PATCH 2/2] Update EM freeze index generation --- package/ClayCode/builder/assembly.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/package/ClayCode/builder/assembly.py b/package/ClayCode/builder/assembly.py index 7ea8941d..72ad94e3 100644 --- a/package/ClayCode/builder/assembly.py +++ b/package/ClayCode/builder/assembly.py @@ -2201,11 +2201,12 @@ def check_solvent_nummols(self, solvate_stderr: str) -> None: ) -if __name__ == "__main__": - gc = GMXCommands(gmx_alias="gmx_mpi") - gc.run_gmx_make_ndx_with_new_sel( - f=Path("/storage/new_clays/Na/NAu-1-fe/NAu-1-fe_7_5_solv_ions.gro"), - o=Path("index.ndx"), - sel_str="r T2* & ! a OH* HO*", - sel_name="new_sel", - ) +# +# if __name__ == "__main__": +# gc = GMXCommands(gmx_alias="gmx_mpi") +# gc.run_gmx_make_ndx_with_new_sel( +# f=Path("/storage/new_clays/Na/NAu-1-fe/NAu-1-fe_7_5_solv_ions.gro"), +# o=Path("index.ndx"), +# sel_str="r T2* & ! a OH* HO*", +# sel_name="new_sel", +# )