@@ -577,6 +577,22 @@ class InterRDF_s(AnalysisBase):
577
577
.. deprecated:: 2.3.0
578
578
The `universe` parameter is superflous.
579
579
"""
580
+ @classmethod
581
+ def get_supported_backends (cls ):
582
+ return ('serial' , 'multiprocessing' , 'dask' ,)
583
+
584
+ _analysis_algorithm_is_parallelizable = True
585
+
586
+
587
+ def _get_aggregator (self ):
588
+ return ResultsGroup (
589
+ lookup = {
590
+ 'count' : self ._flattened_ndarray_sum ,
591
+ 'volume_cum' : ResultsGroup .ndarray_sum ,
592
+ 'bins' : ResultsGroup .ndarray_sum ,
593
+ 'edges' : ResultsGroup .ndarray_sum ,
594
+ }
595
+ )
580
596
581
597
582
598
def __init__ (
@@ -630,6 +646,7 @@ def _prepare(self):
630
646
631
647
if self .norm == "rdf" :
632
648
# Cumulative volume for rdf normalization
649
+ self .results .volume_cum = 0
633
650
self .volume_cum = 0
634
651
self ._maxrange = self .rdf_settings ["range" ][1 ]
635
652
@@ -647,8 +664,88 @@ def _single_frame(self):
647
664
self .results .count [i ][idx1 , idx2 , :] += count
648
665
649
666
if self .norm == "rdf" :
667
+ self .results .volume_cum += self ._ts .volume
650
668
self .volume_cum += self ._ts .volume
651
669
670
+ @staticmethod
671
+ def arr_resize (arr ):
672
+ if arr .ndim == 2 : # If shape is (x, y)
673
+ return arr [np .newaxis , ...] # Add a new axis at the beginning
674
+ elif arr .ndim == 3 and arr .shape [0 ] == 1 : # If shape is already (1, x, y)
675
+ return arr
676
+ else :
677
+ raise ValueError ("Array has an invalid shape" )
678
+
679
+ # @staticmethod
680
+ # def custom_aggregate(combined_arr):
681
+ # arr1 = combined_arr[0][0]
682
+ # arr2 = combined_arr[1][0]
683
+ # arr3 = combined_arr[1][1][0]
684
+ # arr4 = combined_arr[1][1][1]
685
+
686
+ # arr1 = InterRDF_s.arr_resize(arr1)
687
+ # arr2 = InterRDF_s.arr_resize(arr2)
688
+ # arr3 = InterRDF_s.arr_resize(arr3)
689
+ # arr4 = InterRDF_s.arr_resize(arr4)
690
+
691
+
692
+ # print(arr1.shape, arr2.shape, arr3.shape, arr4.shape)
693
+
694
+
695
+
696
+ # arr01 = arr1 + arr2
697
+ # arr02 = np.vstack((arr3, arr4))
698
+ # print("New shape", arr01.shape, arr02.shape)
699
+
700
+ # arr = [arr01, arr02]
701
+ # # arr should be [(1,2,75), (2,2,75)]
702
+ # return arr
703
+
704
+
705
+ # #TODO: check shapes without parallelization then emulate that in custom_aggregate
706
+
707
+ # def _get_aggregator(self):
708
+ # return ResultsGroup(lookup={'count': self.custom_aggregate,
709
+ # 'volume_cum': ResultsGroup.ndarray_sum,
710
+ # 'bins': ResultsGroup.ndarray_sum,
711
+ # 'edges': ResultsGroup.ndarray_sum})
712
+
713
+ @staticmethod
714
+ def _flattened_ndarray_sum (arrs ):
715
+ """Custom aggregator for nested count arrays
716
+
717
+ Parameters
718
+ ----------
719
+ arrs : list
720
+ List of arrays or nested lists of arrays to sum
721
+
722
+ Returns
723
+ -------
724
+ ndarray
725
+ Sum of all arrays after flattening nested structure
726
+ """
727
+ # Handle nested list/array structures
728
+ def flatten (arr ):
729
+ if isinstance (arr , (list , tuple )):
730
+ return [item for sublist in arr for item in flatten (sublist )]
731
+ return [arr ]
732
+
733
+ # Flatten and sum arrays
734
+ flat = flatten (arrs )
735
+ if not flat :
736
+ return None
737
+
738
+ f1 = np .zeros_like (flat [0 ])
739
+ f2 = np .zeros_like (flat [1 ])
740
+ # print(flat, "SIZE:", len(flat))
741
+ for i in range (len (flat )// 2 ):
742
+ f1 += flat [2 * i ]
743
+ f2 += flat [2 * i + 1 ]
744
+ array1 = [f1 , f2 ]
745
+ # print("ARRAY", array1)
746
+ return array1
747
+
748
+
652
749
def _conclude (self ):
653
750
norm = self .n_frames
654
751
if self .norm in ["rdf" , "density" ]:
@@ -658,6 +755,7 @@ def _conclude(self):
658
755
659
756
if self .norm == "rdf" :
660
757
# Average number density
758
+ self .volume_cum = self .results .volume_cum
661
759
norm *= 1 / (self .volume_cum / self .n_frames )
662
760
663
761
# Empty lists to restore indices, RDF
0 commit comments