Skip to content

Commit

Permalink
Add Extend Array Function (#2771)
Browse files Browse the repository at this point in the history
* Add Extend Array Function

* Fix Typo and restructure

* Restructure

* Doc string update

* Was expecting numpy arrays to be passed by reference but it is not so updating it

* Add test for the extend array function
  • Loading branch information
Sumit112192 authored Aug 1, 2024
1 parent 642e191 commit 88835ff
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 25 deletions.
46 changes: 21 additions & 25 deletions tardis/transport/montecarlo/packet_trackers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
("shell_id", int64[:]),
("interaction_type", int64[:]),
("num_interactions", int64),
("extend_factor", int64),
]


Expand Down Expand Up @@ -47,6 +48,8 @@ class RPacketTracker(object):
Type of interaction the rpacket undergoes
num_interactions : int
Internal counter for the interactions that a particular RPacket undergoes
extend_factor : int
The factor by which to extend the properties array when the size limit is reached
"""

def __init__(self, length):
Expand All @@ -61,34 +64,27 @@ def __init__(self, length):
self.shell_id = np.empty(self.length, dtype=np.int64)
self.interaction_type = np.empty(self.length, dtype=np.int64)
self.num_interactions = 0
self.extend_factor = 2

def extend_array(self, array, array_length):
temp_array = np.empty(
array_length * self.extend_factor, dtype=array.dtype
)
temp_array[:array_length] = array
return temp_array

def track(self, r_packet):
if self.num_interactions >= self.length:
temp_length = self.length * 2
temp_status = np.empty(temp_length, dtype=np.int64)
temp_r = np.empty(temp_length, dtype=np.float64)
temp_nu = np.empty(temp_length, dtype=np.float64)
temp_mu = np.empty(temp_length, dtype=np.float64)
temp_energy = np.empty(temp_length, dtype=np.float64)
temp_shell_id = np.empty(temp_length, dtype=np.int64)
temp_interaction_type = np.empty(temp_length, dtype=np.int64)

temp_status[: self.length] = self.status
temp_r[: self.length] = self.r
temp_nu[: self.length] = self.nu
temp_mu[: self.length] = self.mu
temp_energy[: self.length] = self.energy
temp_shell_id[: self.length] = self.shell_id
temp_interaction_type[: self.length] = self.interaction_type

self.status = temp_status
self.r = temp_r
self.nu = temp_nu
self.mu = temp_mu
self.energy = temp_energy
self.shell_id = temp_shell_id
self.interaction_type = temp_interaction_type
self.length = temp_length
self.status = self.extend_array(self.status, self.length)
self.r = self.extend_array(self.r, self.length)
self.nu = self.extend_array(self.nu, self.length)
self.mu = self.extend_array(self.mu, self.length)
self.energy = self.extend_array(self.energy, self.length)
self.shell_id = self.extend_array(self.shell_id, self.length)
self.interaction_type = self.extend_array(
self.interaction_type, self.length
)
self.length = self.length * self.extend_factor

self.index = r_packet.index
self.seed = r_packet.seed
Expand Down
12 changes: 12 additions & 0 deletions tardis/transport/montecarlo/tests/test_rpacket_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from tardis.transport.montecarlo.r_packet import InteractionType
from tardis.transport.montecarlo.packet_trackers import (
RPacketTracker,
rpacket_trackers_to_dataframe,
)

Expand Down Expand Up @@ -102,6 +103,17 @@ def nu_rpacket_tracker(rpacket_tracker):
return nu


def test_extend_array():
rpacket_tracker = RPacketTracker(10)
array = np.array([1, 2, 3, 4, 5], dtype=np.int64)

new_array = rpacket_tracker.extend_array(array, array.size)

assert new_array.size == array.size * rpacket_tracker.extend_factor
assert new_array.dtype == array.dtype
npt.assert_allclose(array, new_array[: array.size])


@pytest.mark.parametrize(
"expected,obtained",
[
Expand Down

0 comments on commit 88835ff

Please sign in to comment.