Skip to content

Commit c584b0b

Browse files
committed
Adds custom aggregegator for InterRDF_s
1 parent de5d5e0 commit c584b0b

File tree

1 file changed

+98
-0
lines changed
  • package/MDAnalysis/analysis

1 file changed

+98
-0
lines changed

package/MDAnalysis/analysis/rdf.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,22 @@ class InterRDF_s(AnalysisBase):
577577
.. deprecated:: 2.3.0
578578
The `universe` parameter is superflous.
579579
"""
580+
@classmethod
581+
def get_supported_backends(cls):
582+
return ('serial', 'multiprocessing', 'dask',)
583+
584+
_analysis_algorithm_is_parallelizable = True
585+
586+
587+
def _get_aggregator(self):
588+
return ResultsGroup(
589+
lookup={
590+
'count': self._flattened_ndarray_sum,
591+
'volume_cum': ResultsGroup.ndarray_sum,
592+
'bins': ResultsGroup.ndarray_sum,
593+
'edges': ResultsGroup.ndarray_sum,
594+
}
595+
)
580596

581597

582598
def __init__(
@@ -630,6 +646,7 @@ def _prepare(self):
630646

631647
if self.norm == "rdf":
632648
# Cumulative volume for rdf normalization
649+
self.results.volume_cum = 0
633650
self.volume_cum = 0
634651
self._maxrange = self.rdf_settings["range"][1]
635652

@@ -647,8 +664,88 @@ def _single_frame(self):
647664
self.results.count[i][idx1, idx2, :] += count
648665

649666
if self.norm == "rdf":
667+
self.results.volume_cum += self._ts.volume
650668
self.volume_cum += self._ts.volume
651669

670+
@staticmethod
671+
def arr_resize(arr):
672+
if arr.ndim == 2: # If shape is (x, y)
673+
return arr[np.newaxis, ...] # Add a new axis at the beginning
674+
elif arr.ndim == 3 and arr.shape[0] == 1: # If shape is already (1, x, y)
675+
return arr
676+
else:
677+
raise ValueError("Array has an invalid shape")
678+
679+
# @staticmethod
680+
# def custom_aggregate(combined_arr):
681+
# arr1 = combined_arr[0][0]
682+
# arr2 = combined_arr[1][0]
683+
# arr3 = combined_arr[1][1][0]
684+
# arr4 = combined_arr[1][1][1]
685+
686+
# arr1 = InterRDF_s.arr_resize(arr1)
687+
# arr2 = InterRDF_s.arr_resize(arr2)
688+
# arr3 = InterRDF_s.arr_resize(arr3)
689+
# arr4 = InterRDF_s.arr_resize(arr4)
690+
691+
692+
# print(arr1.shape, arr2.shape, arr3.shape, arr4.shape)
693+
694+
695+
696+
# arr01 = arr1 + arr2
697+
# arr02 = np.vstack((arr3, arr4))
698+
# print("New shape", arr01.shape, arr02.shape)
699+
700+
# arr = [arr01, arr02]
701+
# # arr should be [(1,2,75), (2,2,75)]
702+
# return arr
703+
704+
705+
# #TODO: check shapes without parallelization then emulate that in custom_aggregate
706+
707+
# def _get_aggregator(self):
708+
# return ResultsGroup(lookup={'count': self.custom_aggregate,
709+
# 'volume_cum': ResultsGroup.ndarray_sum,
710+
# 'bins': ResultsGroup.ndarray_sum,
711+
# 'edges': ResultsGroup.ndarray_sum})
712+
713+
@staticmethod
714+
def _flattened_ndarray_sum(arrs):
715+
"""Custom aggregator for nested count arrays
716+
717+
Parameters
718+
----------
719+
arrs : list
720+
List of arrays or nested lists of arrays to sum
721+
722+
Returns
723+
-------
724+
ndarray
725+
Sum of all arrays after flattening nested structure
726+
"""
727+
# Handle nested list/array structures
728+
def flatten(arr):
729+
if isinstance(arr, (list, tuple)):
730+
return [item for sublist in arr for item in flatten(sublist)]
731+
return [arr]
732+
733+
# Flatten and sum arrays
734+
flat = flatten(arrs)
735+
if not flat:
736+
return None
737+
738+
f1 = np.zeros_like(flat[0])
739+
f2 = np.zeros_like(flat[1])
740+
# print(flat, "SIZE:", len(flat))
741+
for i in range(len(flat)//2):
742+
f1 += flat[2*i]
743+
f2 += flat[2*i+1]
744+
array1 = [f1, f2]
745+
# print("ARRAY", array1)
746+
return array1
747+
748+
652749
def _conclude(self):
653750
norm = self.n_frames
654751
if self.norm in ["rdf", "density"]:
@@ -658,6 +755,7 @@ def _conclude(self):
658755

659756
if self.norm == "rdf":
660757
# Average number density
758+
self.volume_cum = self.results.volume_cum
661759
norm *= 1 / (self.volume_cum / self.n_frames)
662760

663761
# Empty lists to restore indices, RDF

0 commit comments

Comments
 (0)