From 703a288411e48e1618c7e30673e34df59fd0848c Mon Sep 17 00:00:00 2001 From: Sumit112192 Date: Wed, 31 Jul 2024 15:11:53 +0530 Subject: [PATCH 1/6] Add Extend Array Function --- .../transport/montecarlo/packet_trackers.py | 39 +++++++------------ 1 file changed, 14 insertions(+), 25 deletions(-) diff --git a/tardis/transport/montecarlo/packet_trackers.py b/tardis/transport/montecarlo/packet_trackers.py index 7a10a992f05..2f352c95535 100644 --- a/tardis/transport/montecarlo/packet_trackers.py +++ b/tardis/transport/montecarlo/packet_trackers.py @@ -62,33 +62,22 @@ def __init__(self, length): self.interaction_type = np.empty(self.length, dtype=np.int64) self.num_interactions = 0 + def extend_array(self, array, array_length, dtype): + temp_length = array_length * 2 + temp_array = np.empty(temp_length, dtype=dtype) + temp_array[: self.array_length] = self.array + array = temp_array + def track(self, r_packet): if self.num_interactions >= self.length: - temp_length = self.length * 2 - temp_status = np.empty(temp_length, dtype=np.int64) - temp_r = np.empty(temp_length, dtype=np.float64) - temp_nu = np.empty(temp_length, dtype=np.float64) - temp_mu = np.empty(temp_length, dtype=np.float64) - temp_energy = np.empty(temp_length, dtype=np.float64) - temp_shell_id = np.empty(temp_length, dtype=np.int64) - temp_interaction_type = np.empty(temp_length, dtype=np.int64) - - temp_status[: self.length] = self.status - temp_r[: self.length] = self.r - temp_nu[: self.length] = self.nu - temp_mu[: self.length] = self.mu - temp_energy[: self.length] = self.energy - temp_shell_id[: self.length] = self.shell_id - temp_interaction_type[: self.length] = self.interaction_type - - self.status = temp_status - self.r = temp_r - self.nu = temp_nu - self.mu = temp_mu - self.energy = temp_energy - self.shell_id = temp_shell_id - self.interaction_type = temp_interaction_type - self.length = temp_length + self.extend_array(self.status, length, np.int64) + self.extend_array(self.r, length, np.float64) + self.extend_array(self.nu, length, np.int64) + self.extend_array(self.mu, length, np.int64) + self.extend_array(self.energy, length, np.int64) + self.extend_array(self.shell_id, length, np.int64) + self.extend_array(self.interaction_type, length, np.int64) + self.length = self.length * 2 self.index = r_packet.index self.seed = r_packet.seed From b2f8e9217d4a3f4ad3df351edcf4b9fc944f40bc Mon Sep 17 00:00:00 2001 From: Sumit112192 Date: Wed, 31 Jul 2024 15:16:17 +0530 Subject: [PATCH 2/6] Fix Typo and restructure --- tardis/transport/montecarlo/packet_trackers.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tardis/transport/montecarlo/packet_trackers.py b/tardis/transport/montecarlo/packet_trackers.py index 2f352c95535..438d1c30504 100644 --- a/tardis/transport/montecarlo/packet_trackers.py +++ b/tardis/transport/montecarlo/packet_trackers.py @@ -62,14 +62,14 @@ def __init__(self, length): self.interaction_type = np.empty(self.length, dtype=np.int64) self.num_interactions = 0 - def extend_array(self, array, array_length, dtype): - temp_length = array_length * 2 - temp_array = np.empty(temp_length, dtype=dtype) - temp_array[: self.array_length] = self.array + def extend_array(self, array, array_length, dtype, extend_factor=2): + temp_array = np.empty(array_length, dtype=dtype) + temp_array[: array_length / extend_factor] = array array = temp_array def track(self, r_packet): if self.num_interactions >= self.length: + self.length = self.length * 2 self.extend_array(self.status, length, np.int64) self.extend_array(self.r, length, np.float64) self.extend_array(self.nu, length, np.int64) @@ -77,7 +77,6 @@ def track(self, r_packet): self.extend_array(self.energy, length, np.int64) self.extend_array(self.shell_id, length, np.int64) self.extend_array(self.interaction_type, length, np.int64) - self.length = self.length * 2 self.index = r_packet.index self.seed = r_packet.seed From bcd25c1e16c594f63b79d3cdeb67d336d0ac4a2d Mon Sep 17 00:00:00 2001 From: Sumit112192 Date: Wed, 31 Jul 2024 18:13:30 +0530 Subject: [PATCH 3/6] Restructure --- .../transport/montecarlo/packet_trackers.py | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/tardis/transport/montecarlo/packet_trackers.py b/tardis/transport/montecarlo/packet_trackers.py index 438d1c30504..2b933b8f3d1 100644 --- a/tardis/transport/montecarlo/packet_trackers.py +++ b/tardis/transport/montecarlo/packet_trackers.py @@ -16,6 +16,7 @@ ("shell_id", int64[:]), ("interaction_type", int64[:]), ("num_interactions", int64), + ("extend_factor", int64), ] @@ -61,22 +62,23 @@ def __init__(self, length): self.shell_id = np.empty(self.length, dtype=np.int64) self.interaction_type = np.empty(self.length, dtype=np.int64) self.num_interactions = 0 + self.extend_factor = 2 - def extend_array(self, array, array_length, dtype, extend_factor=2): - temp_array = np.empty(array_length, dtype=dtype) - temp_array[: array_length / extend_factor] = array + def extend_array(self, array, array_length, dtype): + temp_array = np.empty(array_length * self.extend_factor, dtype=dtype) + temp_array[:array_length] = array array = temp_array def track(self, r_packet): if self.num_interactions >= self.length: - self.length = self.length * 2 - self.extend_array(self.status, length, np.int64) - self.extend_array(self.r, length, np.float64) - self.extend_array(self.nu, length, np.int64) - self.extend_array(self.mu, length, np.int64) - self.extend_array(self.energy, length, np.int64) - self.extend_array(self.shell_id, length, np.int64) - self.extend_array(self.interaction_type, length, np.int64) + self.extend_array(self.status, self.length, np.int64) + self.extend_array(self.r, self.length, np.float64) + self.extend_array(self.nu, self.length, np.int64) + self.extend_array(self.mu, self.length, np.int64) + self.extend_array(self.energy, self.length, np.int64) + self.extend_array(self.shell_id, self.length, np.int64) + self.extend_array(self.interaction_type, self.length, np.int64) + self.length = self.length * self.extend_factor self.index = r_packet.index self.seed = r_packet.seed From 0a51d77e9c42e00a8a55549ac7fc4890ca2d185c Mon Sep 17 00:00:00 2001 From: Sumit112192 Date: Wed, 31 Jul 2024 18:14:17 +0530 Subject: [PATCH 4/6] Doc string update --- tardis/transport/montecarlo/packet_trackers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tardis/transport/montecarlo/packet_trackers.py b/tardis/transport/montecarlo/packet_trackers.py index 2b933b8f3d1..ccfa114cde7 100644 --- a/tardis/transport/montecarlo/packet_trackers.py +++ b/tardis/transport/montecarlo/packet_trackers.py @@ -48,6 +48,8 @@ class RPacketTracker(object): Type of interaction the rpacket undergoes num_interactions : int Internal counter for the interactions that a particular RPacket undergoes + extend_factor : int + The factor by which to extend the properties array when the size limit is reached """ def __init__(self, length): From 2d0a44e6302f91e8e4b580f7ef974169cdff3017 Mon Sep 17 00:00:00 2001 From: Sumit112192 Date: Wed, 31 Jul 2024 18:51:09 +0530 Subject: [PATCH 5/6] Was expecting numpy arrays to be passed by reference but it is not so updating it --- .../transport/montecarlo/packet_trackers.py | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/tardis/transport/montecarlo/packet_trackers.py b/tardis/transport/montecarlo/packet_trackers.py index ccfa114cde7..466fad27843 100644 --- a/tardis/transport/montecarlo/packet_trackers.py +++ b/tardis/transport/montecarlo/packet_trackers.py @@ -66,20 +66,24 @@ def __init__(self, length): self.num_interactions = 0 self.extend_factor = 2 - def extend_array(self, array, array_length, dtype): - temp_array = np.empty(array_length * self.extend_factor, dtype=dtype) + def extend_array(self, array, array_length): + temp_array = np.empty( + array_length * self.extend_factor, dtype=array.dtype + ) temp_array[:array_length] = array - array = temp_array + return temp_array def track(self, r_packet): if self.num_interactions >= self.length: - self.extend_array(self.status, self.length, np.int64) - self.extend_array(self.r, self.length, np.float64) - self.extend_array(self.nu, self.length, np.int64) - self.extend_array(self.mu, self.length, np.int64) - self.extend_array(self.energy, self.length, np.int64) - self.extend_array(self.shell_id, self.length, np.int64) - self.extend_array(self.interaction_type, self.length, np.int64) + self.status = self.extend_array(self.status, self.length) + self.r = self.extend_array(self.r, self.length) + self.nu = self.extend_array(self.nu, self.length) + self.mu = self.extend_array(self.mu, self.length) + self.energy = self.extend_array(self.energy, self.length) + self.shell_id = self.extend_array(self.shell_id, self.length) + self.interaction_type = self.extend_array( + self.interaction_type, self.length + ) self.length = self.length * self.extend_factor self.index = r_packet.index From 0d6dcecc596948df7f8de44e395d73ab4f285f34 Mon Sep 17 00:00:00 2001 From: Sumit112192 Date: Wed, 31 Jul 2024 21:25:56 +0530 Subject: [PATCH 6/6] Add test for the extend array function --- .../montecarlo/tests/test_rpacket_tracker.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tardis/transport/montecarlo/tests/test_rpacket_tracker.py b/tardis/transport/montecarlo/tests/test_rpacket_tracker.py index 9a42da459e5..7dbfc03784c 100644 --- a/tardis/transport/montecarlo/tests/test_rpacket_tracker.py +++ b/tardis/transport/montecarlo/tests/test_rpacket_tracker.py @@ -4,6 +4,7 @@ from tardis.transport.montecarlo.r_packet import InteractionType from tardis.transport.montecarlo.packet_trackers import ( + RPacketTracker, rpacket_trackers_to_dataframe, ) @@ -102,6 +103,17 @@ def nu_rpacket_tracker(rpacket_tracker): return nu +def test_extend_array(): + rpacket_tracker = RPacketTracker(10) + array = np.array([1, 2, 3, 4, 5], dtype=np.int64) + + new_array = rpacket_tracker.extend_array(array, array.size) + + assert new_array.size == array.size * rpacket_tracker.extend_factor + assert new_array.dtype == array.dtype + npt.assert_allclose(array, new_array[: array.size]) + + @pytest.mark.parametrize( "expected,obtained", [