diff --git a/benchmarks/benchmark_base.py b/benchmarks/benchmark_base.py index 16f8cbac31a..e64743a1f08 100644 --- a/benchmarks/benchmark_base.py +++ b/benchmarks/benchmark_base.py @@ -19,7 +19,6 @@ from tardis.transport.montecarlo.estimators import radfield_mc_estimators from tardis.transport.montecarlo.numba_interface import opacity_state_initialize from tardis.transport.montecarlo.packet_collections import VPacketCollection -from tardis.transport.montecarlo.packet_trackers import RPacketTracker class BenchmarkBase: @@ -62,8 +61,7 @@ def config_rpacket_tracking(self): @functools.cached_property def tardis_ref_path(self): ref_data_path = Path( - Path(__file__).parent.parent, - "tardis-refdata" + Path(__file__).parent.parent, "tardis-refdata" ).resolve() return ref_data_path @@ -124,7 +122,9 @@ def packet(self): @functools.cached_property def verysimple_packet_collection(self): - return self.nb_simulation_verysimple.transport.transport_state.packet_collection + return ( + self.nb_simulation_verysimple.transport.transport_state.packet_collection + ) @functools.cached_property def nb_simulation_verysimple(self): @@ -154,11 +154,15 @@ def verysimple_enable_full_relativity(self): @functools.cached_property def verysimple_tau_russian(self): - return self.nb_simulation_verysimple.transport.montecarlo_configuration.VPACKET_TAU_RUSSIAN + return ( + self.nb_simulation_verysimple.transport.montecarlo_configuration.VPACKET_TAU_RUSSIAN + ) @functools.cached_property def verysimple_survival_probability(self): - return self.nb_simulation_verysimple.transport.montecarlo_configuration.SURVIVAL_PROBABILITY + return ( + self.nb_simulation_verysimple.transport.montecarlo_configuration.SURVIVAL_PROBABILITY + ) @functools.cached_property def static_packet(self): @@ -173,7 +177,9 @@ def static_packet(self): @functools.cached_property def verysimple_3vpacket_collection(self): - spectrum_frequency_grid = self.nb_simulation_verysimple.transport.spectrum_frequency_grid.value + spectrum_frequency_grid = ( + self.nb_simulation_verysimple.transport.spectrum_frequency_grid.value + ) return VPacketCollection( source_rpacket_index=0, spectrum_frequency_grid=spectrum_frequency_grid, @@ -195,7 +201,10 @@ def montecarlo_configuration(self): @functools.cached_property def rpacket_tracker(self): - return RPacketTracker(0) + # Do not use RPacketTracker or RPacketLastInteraction directly + # Use it by importing packet_trackers + # functions with name track_* function is used by ASV + return packet_trackers.RPacketLastInteractionTracker() @functools.cached_property def transport_state(self): diff --git a/tardis/transport/montecarlo/packet_trackers.py b/tardis/transport/montecarlo/packet_trackers.py index 466fad27843..20db5244de6 100644 --- a/tardis/transport/montecarlo/packet_trackers.py +++ b/tardis/transport/montecarlo/packet_trackers.py @@ -1,11 +1,20 @@ -from numba import float64, int64, njit +from numba import float64, int64, njit, from_dtype from numba.experimental import jitclass from numba.typed import List import numpy as np import pandas as pd + +boundary_interaction_dtype = np.dtype( + [ + ("event_id", "int64"), + ("current_shell_id", "int64"), + ("next_shell_id", "int64"), + ] +) + + rpacket_tracker_spec = [ - ("length", int64), ("seed", int64), ("index", int64), ("status", int64[:]), @@ -15,7 +24,10 @@ ("energy", float64[:]), ("shell_id", int64[:]), ("interaction_type", int64[:]), + ("boundary_interaction", from_dtype(boundary_interaction_dtype)[:]), ("num_interactions", int64), + ("boundary_interactions_index", int64), + ("event_id", int64), ("extend_factor", int64), ] @@ -53,17 +65,25 @@ class RPacketTracker(object): """ def __init__(self, length): - self.length = length + """ + Initialize the variables with default value + """ self.seed = np.int64(0) self.index = np.int64(0) - self.status = np.empty(self.length, dtype=np.int64) - self.r = np.empty(self.length, dtype=np.float64) - self.nu = np.empty(self.length, dtype=np.float64) - self.mu = np.empty(self.length, dtype=np.float64) - self.energy = np.empty(self.length, dtype=np.float64) - self.shell_id = np.empty(self.length, dtype=np.int64) - self.interaction_type = np.empty(self.length, dtype=np.int64) + self.status = np.empty(length, dtype=np.int64) + self.r = np.empty(length, dtype=np.float64) + self.nu = np.empty(length, dtype=np.float64) + self.mu = np.empty(length, dtype=np.float64) + self.energy = np.empty(length, dtype=np.float64) + self.shell_id = np.empty(length, dtype=np.int64) + self.interaction_type = np.empty(length, dtype=np.int64) + self.boundary_interaction = np.empty( + length, + dtype=boundary_interaction_dtype, + ) self.num_interactions = 0 + self.boundary_interactions_index = 0 + self.event_id = 1 self.extend_factor = 2 def extend_array(self, array, array_length): @@ -74,17 +94,19 @@ def extend_array(self, array, array_length): return temp_array def track(self, r_packet): - if self.num_interactions >= self.length: - self.status = self.extend_array(self.status, self.length) - self.r = self.extend_array(self.r, self.length) - self.nu = self.extend_array(self.nu, self.length) - self.mu = self.extend_array(self.mu, self.length) - self.energy = self.extend_array(self.energy, self.length) - self.shell_id = self.extend_array(self.shell_id, self.length) + """ + Track important properties of RPacket + """ + if self.num_interactions >= self.status.size: + self.status = self.extend_array(self.status, self.status.size) + self.r = self.extend_array(self.r, self.r.size) + self.nu = self.extend_array(self.nu, self.nu.size) + self.mu = self.extend_array(self.mu, self.mu.size) + self.energy = self.extend_array(self.energy, self.energy.size) + self.shell_id = self.extend_array(self.shell_id, self.shell_id.size) self.interaction_type = self.extend_array( - self.interaction_type, self.length + self.interaction_type, self.interaction_type.size ) - self.length = self.length * self.extend_factor self.index = r_packet.index self.seed = r_packet.seed @@ -99,7 +121,36 @@ def track(self, r_packet): ] = r_packet.last_interaction_type self.num_interactions += 1 + def track_boundary_interaction(self, current_shell_id, next_shell_id): + """ + Track boundary interaction properties + """ + if self.boundary_interactions_index >= self.boundary_interaction.size: + self.boundary_interaction = self.extend_array( + self.boundary_interaction, + self.boundary_interaction.size, + ) + + self.boundary_interaction[self.boundary_interactions_index][ + "event_id" + ] = self.event_id + self.event_id += 1 + + self.boundary_interaction[self.boundary_interactions_index][ + "current_shell_id" + ] = current_shell_id + + self.boundary_interaction[self.boundary_interactions_index][ + "next_shell_id" + ] = next_shell_id + + self.boundary_interactions_index += 1 + def finalize_array(self): + """ + Change the size of the array from length ( or multiple of length ) to + the actual number of interactions + """ self.status = self.status[: self.num_interactions] self.r = self.r[: self.num_interactions] self.nu = self.nu[: self.num_interactions] @@ -107,6 +158,9 @@ def finalize_array(self): self.energy = self.energy[: self.num_interactions] self.shell_id = self.shell_id[: self.num_interactions] self.interaction_type = self.interaction_type[: self.num_interactions] + self.boundary_interaction = self.boundary_interaction[ + : self.boundary_interactions_index + ] def rpacket_trackers_to_dataframe(rpacket_trackers): @@ -186,6 +240,9 @@ class RPacketLastInteractionTracker(object): """ def __init__(self): + """ + Initialize properties with default values + """ self.index = -1 self.r = -1.0 self.nu = 0.0 @@ -194,6 +251,9 @@ def __init__(self): self.interaction_type = -1 def track(self, r_packet): + """ + Track properties of RPacket and override the previous values + """ self.index = r_packet.index self.r = r_packet.r self.nu = r_packet.nu @@ -201,8 +261,17 @@ def track(self, r_packet): self.shell_id = r_packet.current_shell_id self.interaction_type = r_packet.last_interaction_type - # To make it compatible with RPacketTracker def finalize_array(self): + """ + Added to make RPacketLastInteractionTracker compatible with RPacketTracker + """ + pass + + # To make it compatible with RPacketTracker + def track_boundary_interaction(self, current_shell_id, next_shell_id): + """ + Added to make RPacketLastInteractionTracker compatible with RPacketTracker + """ pass diff --git a/tardis/transport/montecarlo/single_packet_loop.py b/tardis/transport/montecarlo/single_packet_loop.py index 36763b03725..2871a5bac48 100644 --- a/tardis/transport/montecarlo/single_packet_loop.py +++ b/tardis/transport/montecarlo/single_packet_loop.py @@ -158,6 +158,10 @@ def single_packet_loop( # If continuum processes: update continuum estimators if interaction_type == InteractionType.BOUNDARY: + rpacket_tracker.track_boundary_interaction( + r_packet.current_shell_id, + r_packet.current_shell_id + delta_shell, + ) move_r_packet( r_packet, distance, @@ -166,7 +170,9 @@ def single_packet_loop( montecarlo_configuration.ENABLE_FULL_RELATIVITY, ) move_packet_across_shell_boundary( - r_packet, delta_shell, len(numba_radial_1d_geometry.r_inner) + r_packet, + delta_shell, + len(numba_radial_1d_geometry.r_inner), ) elif interaction_type == InteractionType.LINE: diff --git a/tardis/transport/montecarlo/tests/test_rpacket_tracker.py b/tardis/transport/montecarlo/tests/test_rpacket_tracker.py index fe1e7614701..b9bafae3a0a 100644 --- a/tardis/transport/montecarlo/tests/test_rpacket_tracker.py +++ b/tardis/transport/montecarlo/tests/test_rpacket_tracker.py @@ -135,6 +135,30 @@ def test_rpacket_tracker_properties(expected, obtained, request): npt.assert_allclose(expected, obtained) +def test_boundary_interactions(rpacket_tracker, regression_data): + no_of_packets = len(rpacket_tracker) + + # Hard coding the number of columns + # Based on the largest size of boundary_interaction array (60) + obtained_boundary_interaction = np.full( + (no_of_packets, 64), + [-1], + dtype=rpacket_tracker[0].boundary_interaction.dtype, + ) + + for i, tracker in enumerate(rpacket_tracker): + obtained_boundary_interaction[ + i, : tracker.boundary_interaction.size + ] = tracker.boundary_interaction + + expected_boundary_interaction = regression_data.sync_ndarray( + obtained_boundary_interaction + ) + npt.assert_array_equal( + obtained_boundary_interaction, expected_boundary_interaction + ) + + def test_rpacket_trackers_to_dataframe(simulation_rpacket_tracking): transport_state = simulation_rpacket_tracking.transport.transport_state rtracker_df = rpacket_trackers_to_dataframe(transport_state.rpacket_tracker)