Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Analysis fixes #70

Merged
merged 2 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 220 additions & 43 deletions package/ClayCode/builder/assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
==============================================================="""
from __future__ import annotations

import copy
import itertools
import logging
import math
Expand All @@ -21,11 +22,24 @@
import unicodeit
from ClayCode.builder.claycomp import InterlayerIons, UCData
from ClayCode.builder.topology import TopologyConstructor
from ClayCode.core.classes import Dir, FileFactory, GROFile, TOPFile
from ClayCode.core.consts import ANGSTROM
from ClayCode.core.gmx import GMXCommands, add_gmx_args, check_box_lengths
from ClayCode.core.classes import (
Dir,
FileFactory,
GROFile,
TOPFile,
set_mdp_freeze_groups,
set_mdp_parameter,
)
from ClayCode.core.consts import ANGSTROM, LINE_LENGTH
from ClayCode.core.gmx import (
GMXCommands,
add_gmx_args,
check_box_lengths,
gmx_command_wrapper,
)
from ClayCode.core.lib import (
add_ions_n_mols,
add_ions_neutral,
add_resnum,
center_clay,
check_insert_numbers,
Expand All @@ -35,7 +49,7 @@
write_insert_dat,
)
from ClayCode.core.utils import backup_files, get_header, get_subheader
from ClayCode.data.consts import GRO_FMT
from ClayCode.data.consts import FF, GRO_FMT, MDP, MDP_DEFAULTS
from MDAnalysis import AtomGroup, Merge, ResidueGroup, Universe
from MDAnalysis.lib.mdamath import triclinic_box, triclinic_vectors
from MDAnalysis.units import constants
Expand Down Expand Up @@ -131,7 +145,7 @@ def get_il_ions(self, sheet_id=None):
def extended_box(self) -> bool:
"""
:return: Whether a bulk space has been added to the clay stack.
:rtype: bool"""
:rtype: Bool"""
return self.__box_ext

def solvate_clay_sheets(self, backup: bool = False) -> None:
Expand Down Expand Up @@ -635,7 +649,7 @@ def add_bulk_ions(self, backup=False) -> None:
if excess_charge != 0:
neutral_bulk_ions = InterlayerIons(
excess_charge,
ion_ratios=self.args.bulk_ions,
ion_ratios=self.args.bulk_ions.df["conc"].to_dict(),
n_ucs=1,
neutral=True,
)
Expand Down Expand Up @@ -1288,7 +1302,42 @@ def get_charge_groups(self) -> Tuple[int, NDArray]:
if n_ucs != 0:
yield charge_group, n_ucs, uc_array

def get_occ_counts(self, axis_id: int, free: NDArray) -> NDArray:
occ_cols = np.select([free], [1], 0)
occ_counts = np.sum(occ_cols, axis=0)
diag_counts = np.sum(
np.array(
[
*np.fromiter(
self.get_all_diagonals(occ_cols), dtype=np.ndarray
)
]
),
axis=1,
)
opposite_diag_counts = np.flip(
np.sum(
np.array(
[
*np.fromiter(
self.get_all_diagonals(np.flip(occ_cols, axis=1)),
dtype=np.ndarray,
)
]
),
axis=1,
)
)
return occ_counts, diag_counts, opposite_diag_counts

def get_uc_sheet_array(self):
pm = "\u00B1"
logdict = {
0: "right diagonal",
1: "left diagonal",
2: "columns",
}
# TODO: move counts to bottom and initialise with 0 for all
max_dict = {self.x_cells: 0, self.y_cells: 1}
max_ax_len = max(max_dict.keys())
other_ax_len = min(max_dict.keys())
Expand Down Expand Up @@ -1388,13 +1437,19 @@ def get_uc_sheet_array(self):
)
init_i = np.array([0, 0, 0])
minmax = itertools.cycle([min, max])
cycle_count = 0
# increase max allowed counts if not enough free columns
while free_cols[free_cols].size < n_add_ucs:
free_cols = np.logical_and(
free,
combined_counts
< n_col_ucs + next(minmax)(1, per_col_remainder),
)
cycle_count += 1
if cycle_count > 4:
return False
extra_remainder = np.zeros_like(init_i, dtype=np.int32)
# while init occ < max allowed occ and no or not enough idxs selected
while np.any(
np.less(
init_i, extra_remainder + per_col_remainder + n_per_col
Expand All @@ -1403,6 +1458,7 @@ def get_uc_sheet_array(self):
idx_choices is None
or (idx_choices.flatten().size < n_add_ucs)
):
# number of free cols == n_add_ucs (don't look further)
if free[free].flatten().size == n_add_ucs:
idx_choices = np.argwhere(free).flatten()
break
Expand All @@ -1411,28 +1467,13 @@ def get_uc_sheet_array(self):
occ_devs = np.std(counts, axis=1)
order = np.argsort(occ_devs)[::-1]
if self.debug:
logger.finfo(
"Occupancy deviations:",
logger.finfo("Occupancy deviations:")
self._log_occ_devs(
logdict, pm, counts, occ_devs, order
)
pm = "\u00B1"
logdict = {
0: "right diagonal",
1: "left diagonal",
2: "columns",
}
logstr = list(
map(
lambda x, y, z: f"\t{logdict[x]:15}: {y:.1f} {pm} {z:.1f}",
np.sort(order),
np.mean(counts[np.argsort(order)], axis=1),
occ_devs[np.argsort(order)],
)
)
logger.finfo("\n".join(logstr))
prev_choices = None
for occ_id in order: # self.random_generator.choice(
# [0, 1, 2], 3, replace=False
# ):
# get allowed idxs starting from count with highest deviation
for occ_id in order:
(
idx_choices,
allowed_cols,
Expand All @@ -1448,21 +1489,29 @@ def get_uc_sheet_array(self):
remainder=per_col_remainder
+ extra_remainder[occ_id],
)
# found idxs length == n_add_ucs
if idx_choices.flatten().size == n_add_ucs:
# print(
# occ_id,
# f": stopping with {idx_choices.flatten()}, n_add_ucs = {n_add_ucs}",
# )
break
# if found idxs length < n_add_ucs, use previous idxs if available
elif (
idx_choices.flatten().size < n_add_ucs
idx_choices.flatten().size
< n_add_ucs
<= prev_choices.flatten().size
and prev_choices is not None
and prev_choices.flatten().size >= n_add_ucs
):
idx_choices = prev_choices

prev_choices = idx_choices
if (
# if still not enough idxs found, abort
if idx_choices.flatten().size < n_add_ucs:
return False
# if more idxs found than necessary and not first row, select idxs with lowest counts
# across all occ counters and try to use only idxs with lowest counts if possible
elif (
idx_choices.flatten().size > n_add_ucs
and np.unique(counts).size > 1
):
Expand All @@ -1475,21 +1524,14 @@ def get_uc_sheet_array(self):
and min_count.size != combined_counts.size
and np.all(min_count != idx_choices)
):
pass
intersect_idxs = np.intersect1d(
idx_choices, min_count
).flatten()
if intersect_idxs.size >= n_add_ucs:
idx_choices = intersect_idxs
except ValueError:
pass
if (
min_count.size >= n_add_ucs
and min_count.size != combined_counts.size
and np.all(min_count != idx_choices)
):
intersect_idxs = np.intersect1d(
idx_choices, min_count
).flatten()
if intersect_idxs.size >= n_add_ucs:
idx_choices = intersect_idxs
elif idx_choices.flatten().size < n_add_ucs:
return False

# if idx_choices.flatten().size > n_add_ucs:
# if prev.flatten().size != 0:
# _, prev_idxs, _ = np.intersect1d(
Expand All @@ -1498,6 +1540,7 @@ def get_uc_sheet_array(self):
# if prev_idxs.flatten().size != 0:
# remove_idxs = np.random.choice(prev_idxs, idx_choices.size - n_add_ucs, replace=False)
# idx_choices = np.delete(idx_choices, remove_idxs)
# break if enough idxs found
if idx_choices.flatten().size >= n_add_ucs:
if self.debug:
logger.finfo(f"Row {axis_id}:", indent="\t")
Expand Down Expand Up @@ -1624,6 +1667,23 @@ def get_uc_sheet_array(self):
symbol_arr[axis_id, idx_sel] = symbol
if idx_sel.size != 0:
prev = np.sort(idx_sel)
counts = np.array(
[
np.roll(diag_counts, axis_id),
np.roll(opposite_diag_counts, -axis_id),
occ_counts,
]
)
occ_devs = np.std(counts, axis=1)
# make sure charges are evenly distributed
if np.any(occ_devs > 1):
return False
if self.debug:
order = np.argsort(occ_devs)[::-1]
logger.finfo(
f"Mean occupancies for charge group {charge_group_id} (q = {charge:+2.1f}):"
)
self._log_occ_devs(logdict, pm, counts, occ_devs, order)
if max_dict[max_ax_len] == 1:
uc_ids = uc_ids.T
else:
Expand All @@ -1639,6 +1699,112 @@ def get_uc_sheet_array(self):
logger.finfo(" ".join(line), indent="\t\t")
return uc_ids

def _log_occ_devs(self, logdict, pm, counts, occ_devs, order):
logstr = list(
map(
lambda x, y, z: f"\t{logdict[x]:15}: {y:.1f} {pm} {z:.1f}",
np.sort(order),
np.mean(counts[np.argsort(order)], axis=1),
occ_devs[np.argsort(order)],
)
)
logger.finfo("\n".join(logstr))

@staticmethod
def _get_order(counts):
occ_devs = np.std(counts, axis=1)
order = np.argsort(occ_devs)[::-1]
# if self.debug:
logger.finfo("Occupancy deviations:")
pm = "\u00B1"
logdict = {
0: "right diagonal",
1: "left diagonal",
2: "columns",
}
logstr = list(
map(
lambda x, y, z: f"\t{logdict[x]:15}: {y:.1f} {pm} {z:.1f}",
np.sort(order),
np.mean(counts[np.argsort(order)], axis=1),
occ_devs[np.argsort(order)],
)
)
return logstr, order

def _get_counts(self, axis_id, charge, idxs_mask):
free = np.isnan(idxs_mask[axis_id])
occ_cols = np.select([idxs_mask == charge], [1], 0)
occ_counts = np.sum(occ_cols, axis=0)
diag_counts = np.sum(
np.array(
[
*np.fromiter(
self.get_all_diagonals(occ_cols),
dtype=np.ndarray,
)
]
),
axis=1,
)
opposite_diag_counts = np.flip(
np.sum(
np.array(
[
*np.fromiter(
self.get_all_diagonals(np.flip(occ_cols, axis=1)),
dtype=np.ndarray,
)
]
),
axis=1,
)
)
idx_choices = None
counts = np.array(
[
np.roll(diag_counts, axis_id),
np.roll(opposite_diag_counts, -axis_id),
occ_counts,
]
)
combined_counts = np.rint(
np.mean(
[
np.roll(diag_counts, axis_id),
np.roll(opposite_diag_counts, -axis_id),
occ_counts,
],
axis=0,
)
)
return combined_counts, counts, free, idx_choices

def _init_uc_array(
self,
charge_group,
charge_group_id,
charge_group_n_ucs,
lines,
max_ax_len,
other_ax_len,
remainder_choices,
remaining_add,
):
uc_array = charge_group.copy()
self.random_generator.shuffle(uc_array)
remaining_add[charge_group_id] = 0
n_per_line = charge_group_n_ucs // max_ax_len
per_line_remainder = charge_group_n_ucs % max_ax_len
n_per_col = charge_group_n_ucs // other_ax_len
per_col_remainder = charge_group_n_ucs % other_ax_len
# per_diag_col_remainder = charge_group_n_ucs % (other_ax_len)
# per_opp_diag_col_remainder = charge_group_n_ucs % (other_ax_len)
lines[charge_group_id], remainder_choices = np.split(
remainder_choices, [per_line_remainder]
)
return n_per_col, n_per_line, per_col_remainder, uc_array

def get_all_diagonals(self, arr):
x_dim, y_dim = arr.shape
arr_p = np.pad(arr, ((0, 0), (0, x_dim)), mode="wrap")
Expand Down Expand Up @@ -1977,7 +2143,7 @@ def write(
spc_top: TOPFile = spc_gro.top
spc_gro.universe = Universe.empty(n_atoms=0)
spc_gro.write(topology=topology)
logger.finfo(f"Adding interlayer solvent:")
logger.finfo(f"Adding interlayer solvent to {spc_name.name!r}:")
while True:
if self._z_padding > 5:
raise Exception(
Expand Down Expand Up @@ -2033,3 +2199,14 @@ def check_solvent_nummols(self, solvate_stderr: str) -> None:
f"insert {added_wat} instead of {self.n_mols} water "
f"molecules."
)


#
# if __name__ == "__main__":
# gc = GMXCommands(gmx_alias="gmx_mpi")
# gc.run_gmx_make_ndx_with_new_sel(
# f=Path("/storage/new_clays/Na/NAu-1-fe/NAu-1-fe_7_5_solv_ions.gro"),
# o=Path("index.ndx"),
# sel_str="r T2* & ! a OH* HO*",
# sel_name="new_sel",
# )
Loading
Loading