diff --git a/tardis/transport/montecarlo/packet_trackers.py b/tardis/transport/montecarlo/packet_trackers.py index 7a10a992f05..466fad27843 100644 --- a/tardis/transport/montecarlo/packet_trackers.py +++ b/tardis/transport/montecarlo/packet_trackers.py @@ -16,6 +16,7 @@ ("shell_id", int64[:]), ("interaction_type", int64[:]), ("num_interactions", int64), + ("extend_factor", int64), ] @@ -47,6 +48,8 @@ class RPacketTracker(object): Type of interaction the rpacket undergoes num_interactions : int Internal counter for the interactions that a particular RPacket undergoes + extend_factor : int + The factor by which to extend the properties array when the size limit is reached """ def __init__(self, length): @@ -61,34 +64,27 @@ def __init__(self, length): self.shell_id = np.empty(self.length, dtype=np.int64) self.interaction_type = np.empty(self.length, dtype=np.int64) self.num_interactions = 0 + self.extend_factor = 2 + + def extend_array(self, array, array_length): + temp_array = np.empty( + array_length * self.extend_factor, dtype=array.dtype + ) + temp_array[:array_length] = array + return temp_array def track(self, r_packet): if self.num_interactions >= self.length: - temp_length = self.length * 2 - temp_status = np.empty(temp_length, dtype=np.int64) - temp_r = np.empty(temp_length, dtype=np.float64) - temp_nu = np.empty(temp_length, dtype=np.float64) - temp_mu = np.empty(temp_length, dtype=np.float64) - temp_energy = np.empty(temp_length, dtype=np.float64) - temp_shell_id = np.empty(temp_length, dtype=np.int64) - temp_interaction_type = np.empty(temp_length, dtype=np.int64) - - temp_status[: self.length] = self.status - temp_r[: self.length] = self.r - temp_nu[: self.length] = self.nu - temp_mu[: self.length] = self.mu - temp_energy[: self.length] = self.energy - temp_shell_id[: self.length] = self.shell_id - temp_interaction_type[: self.length] = self.interaction_type - - self.status = temp_status - self.r = temp_r - self.nu = temp_nu - self.mu = temp_mu - self.energy = temp_energy - self.shell_id = temp_shell_id - self.interaction_type = temp_interaction_type - self.length = temp_length + self.status = self.extend_array(self.status, self.length) + self.r = self.extend_array(self.r, self.length) + self.nu = self.extend_array(self.nu, self.length) + self.mu = self.extend_array(self.mu, self.length) + self.energy = self.extend_array(self.energy, self.length) + self.shell_id = self.extend_array(self.shell_id, self.length) + self.interaction_type = self.extend_array( + self.interaction_type, self.length + ) + self.length = self.length * self.extend_factor self.index = r_packet.index self.seed = r_packet.seed diff --git a/tardis/transport/montecarlo/tests/test_rpacket_tracker.py b/tardis/transport/montecarlo/tests/test_rpacket_tracker.py index 9a42da459e5..7dbfc03784c 100644 --- a/tardis/transport/montecarlo/tests/test_rpacket_tracker.py +++ b/tardis/transport/montecarlo/tests/test_rpacket_tracker.py @@ -4,6 +4,7 @@ from tardis.transport.montecarlo.r_packet import InteractionType from tardis.transport.montecarlo.packet_trackers import ( + RPacketTracker, rpacket_trackers_to_dataframe, ) @@ -102,6 +103,17 @@ def nu_rpacket_tracker(rpacket_tracker): return nu +def test_extend_array(): + rpacket_tracker = RPacketTracker(10) + array = np.array([1, 2, 3, 4, 5], dtype=np.int64) + + new_array = rpacket_tracker.extend_array(array, array.size) + + assert new_array.size == array.size * rpacket_tracker.extend_factor + assert new_array.dtype == array.dtype + npt.assert_allclose(array, new_array[: array.size]) + + @pytest.mark.parametrize( "expected,obtained", [