Skip to content

Commit

Permalink
Restructure Trackers (#2719)
Browse files Browse the repository at this point in the history
* Restructure Trackers

* Remove ENABLE_RPACKET_TRACKING global variable from montecarlo_main_loop

* Import Numba List

* Import trackers

* Make RPacketLastInteractionTracker Compatible with RPacketTracker

* Fix Typo

* Numba loop to pure python approach

* Rebase

* Resolving conflicts

* Resolving conflicts

* Rebase

* Remove Unused import

* Black

* Import numba.typed.List

* Rename function name

* Fix montecarlo_main_loop benchmark

* Remove code

* Add docstring

* Add tests for the tracker utility functions

* Black

* Fix Typo

* update benchmark

* Update benchmark

* Fix Typo

* Checking if there is circular imports

* Update Benchmark

* Fix Typo

* Add benchmark for the utility functions
  • Loading branch information
Sumit112192 authored Jul 29, 2024
1 parent 5e7a23b commit 7231707
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 38 deletions.
30 changes: 24 additions & 6 deletions benchmarks/benchmark_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
from tardis.transport.montecarlo.packet_collections import (
VPacketCollection,
)
from tardis.transport.montecarlo.packet_trackers import RPacketTracker
from tardis.transport.montecarlo.packet_trackers import (
RPacketTracker,
generate_rpacket_last_interaction_tracker_list,
)


class BenchmarkBase:
Expand Down Expand Up @@ -239,7 +242,9 @@ def packet(self):

@property
def verysimple_packet_collection(self):
return self.nb_simulation_verysimple.transport.transport_state.packet_collection
return (
self.nb_simulation_verysimple.transport.transport_state.packet_collection
)

@property
def nb_simulation_verysimple(self):
Expand Down Expand Up @@ -269,19 +274,25 @@ def verysimple_enable_full_relativity(self):

@property
def verysimple_disable_line_scattering(self):
return self.nb_simulation_verysimple.transport.montecarlo_configuration.DISABLE_LINE_SCATTERING
return (
self.nb_simulation_verysimple.transport.montecarlo_configuration.DISABLE_LINE_SCATTERING
)

@property
def verysimple_continuum_processes_enabled(self):
return montecarlo_globals.CONTINUUM_PROCESSES_ENABLED

@property
def verysimple_tau_russian(self):
return self.nb_simulation_verysimple.transport.montecarlo_configuration.VPACKET_TAU_RUSSIAN
return (
self.nb_simulation_verysimple.transport.montecarlo_configuration.VPACKET_TAU_RUSSIAN
)

@property
def verysimple_survival_probability(self):
return self.nb_simulation_verysimple.transport.montecarlo_configuration.SURVIVAL_PROBABILITY
return (
self.nb_simulation_verysimple.transport.montecarlo_configuration.SURVIVAL_PROBABILITY
)

@property
def static_packet(self):
Expand All @@ -303,7 +314,9 @@ def set_seed(value):

@property
def verysimple_3vpacket_collection(self):
spectrum_frequency_grid = self.nb_simulation_verysimple.transport.spectrum_frequency_grid.value
spectrum_frequency_grid = (
self.nb_simulation_verysimple.transport.spectrum_frequency_grid.value
)
return VPacketCollection(
source_rpacket_index=0,
spectrum_frequency_grid=spectrum_frequency_grid,
Expand Down Expand Up @@ -404,3 +417,8 @@ def estimators(self):
stim_recomb_cooling_estimator=np.empty((0, 0), dtype=np.float64),
photo_ion_estimator_statistics=np.empty((0, 0), dtype=np.int64),
)

@property
def rpacket_tracker_list(self):
no_of_packets = len(self.transport_state.packet_collection.initial_nus)
return generate_rpacket_last_interaction_tracker_list(no_of_packets)
1 change: 1 addition & 0 deletions benchmarks/transport_montecarlo_main_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def time_montecarlo_main_loop(self):
self.montecarlo_configuration,
self.transport_state.radfield_mc_estimators,
self.nb_simulation_verysimple.transport.spectrum_frequency_grid.value,
self.rpacket_tracker_list,
self.montecarlo_configuration.NUMBER_OF_VPACKETS,
iteration=0,
show_progress_bars=False,
Expand Down
22 changes: 19 additions & 3 deletions benchmarks/transport_montecarlo_packet_trackers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from benchmarks.benchmark_base import BenchmarkBase
from tardis.transport.montecarlo.packet_trackers import (
rpacket_trackers_to_dataframe,
generate_rpacket_tracker_list,
generate_rpacket_last_interaction_tracker_list,
)


Expand All @@ -15,6 +17,20 @@ class BenchmarkTransportMontecarloPacketTrackers(BenchmarkBase):
def time_rpacket_trackers_to_dataframe(self):
sim = self.simulation_rpacket_tracking_enabled
transport_state = sim.transport.transport_state
rpacket_trackers_to_dataframe(
transport_state.rpacket_tracker
)
rpacket_trackers_to_dataframe(transport_state.rpacket_tracker)

def time_generate_rpacket_tracker_list(self, no_of_packets, length):
generate_rpacket_tracker_list(no_of_packets, length)

def time_generate_rpacket_last_interaction_tracker_list(
self, no_of_packets
):
generate_rpacket_last_interaction_tracker_list(no_of_packets)

time_generate_rpacket_tracker_list.params = ([1, 10, 50], [1, 10, 50])
time_generate_rpacket_tracker_list.param_names = ["no_of_packets", "length"]

time_generate_rpacket_last_interaction_tracker_list.params = [10, 100, 1000]
time_generate_rpacket_last_interaction_tracker_list.param_names = [
"no_of_packets"
]
19 changes: 16 additions & 3 deletions tardis/transport/montecarlo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
opacity_state_initialize,
)
from tardis.transport.montecarlo.packet_trackers import (
generate_rpacket_tracker_list,
generate_rpacket_last_interaction_tracker_list,
rpacket_trackers_to_dataframe,
)
from tardis.util.base import (
Expand Down Expand Up @@ -158,12 +160,24 @@ def run(
self.transport_state = transport_state

number_of_vpackets = self.montecarlo_configuration.NUMBER_OF_VPACKETS
number_of_rpackets = len(transport_state.packet_collection.initial_nus)

if self.enable_rpacket_tracking:
transport_state.rpacket_tracker = generate_rpacket_tracker_list(
number_of_rpackets,
self.montecarlo_configuration.INITIAL_TRACKING_ARRAY_LENGTH,
)
else:
transport_state.rpacket_tracker = (
generate_rpacket_last_interaction_tracker_list(
number_of_rpackets
)
)

(
v_packets_energy_hist,
last_interaction_tracker,
vpacket_tracker,
rpacket_trackers,
) = montecarlo_main_loop(
transport_state.packet_collection,
transport_state.geometry_state,
Expand All @@ -172,6 +186,7 @@ def run(
self.montecarlo_configuration,
transport_state.radfield_mc_estimators,
self.spectrum_frequency_grid.value,
transport_state.rpacket_tracker,
number_of_vpackets,
iteration=iteration,
show_progress_bars=show_progress_bars,
Expand Down Expand Up @@ -199,8 +214,6 @@ def run(
update_iterations_pbar(1)
refresh_packet_pbar()

transport_state.rpacket_tracker = rpacket_trackers

# Need to change the implementation of rpacket_trackers_to_dataframe
# Such that it also takes of the case of
# RPacketLastInteractionTracker
Expand Down
3 changes: 0 additions & 3 deletions tardis/transport/montecarlo/configuration/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,3 @@ def configuration_initialize(config, transport, number_of_vpackets):
montecarlo_globals.ENABLE_RPACKET_TRACKING = (
transport.enable_rpacket_tracking
)
montecarlo_main_loop.ENABLE_RPACKET_TRACKING = (
transport.enable_rpacket_tracking
)
24 changes: 2 additions & 22 deletions tardis/transport/montecarlo/montecarlo_main_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@
consolidate_vpacket_tracker,
initialize_last_interaction_tracker,
)
import tardis.transport.montecarlo.montecarlo_main_loop as montecarlo_loop
from tardis.transport.montecarlo.packet_trackers import (
RPacketTracker,
RPacketLastInteractionTracker,
)
from tardis.transport.montecarlo.r_packet import (
PacketStatus,
RPacket,
Expand All @@ -24,8 +19,6 @@
)
from tardis.util.base import update_packet_pbar

ENABLE_RPACKET_TRACKING = False


@njit(**njit_dict)
def montecarlo_main_loop(
Expand All @@ -36,6 +29,7 @@ def montecarlo_main_loop(
montecarlo_configuration,
estimators,
spectrum_frequency_grid,
rpacket_trackers,
number_of_vpackets,
iteration,
show_progress_bars,
Expand Down Expand Up @@ -77,19 +71,6 @@ def montecarlo_main_loop(

# Pre-allocate a list of vpacket collections for later storage
vpacket_collections = List()
# Configuring the Tracking for R_Packets
rpacket_trackers = List()
if ENABLE_RPACKET_TRACKING:
for i in range(no_of_packets):
rpacket_trackers.append(
RPacketTracker(
montecarlo_configuration.INITIAL_TRACKING_ARRAY_LENGTH
)
)
else:
for i in range(no_of_packets):
rpacket_trackers.append(RPacketLastInteractionTracker())

for i in range(no_of_packets):
vpacket_collections.append(
VPacketCollection(
Expand Down Expand Up @@ -197,13 +178,12 @@ def montecarlo_main_loop(
1,
)

if ENABLE_RPACKET_TRACKING:
if montecarlo_globals.ENABLE_RPACKET_TRACKING:
for rpacket_tracker in rpacket_trackers:
rpacket_tracker.finalize_array()

return (
v_packets_energy_hist,
last_interaction_tracker,
vpacket_tracker,
rpacket_trackers,
)
42 changes: 41 additions & 1 deletion tardis/transport/montecarlo/packet_trackers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from numba import float64, int64
from numba import float64, int64, njit
from numba.experimental import jitclass
from numba.typed import List
import numpy as np
import pandas as pd

Expand Down Expand Up @@ -203,3 +204,42 @@ def track(self, r_packet):
self.energy = r_packet.energy
self.shell_id = r_packet.current_shell_id
self.interaction_type = r_packet.last_interaction_type

# To make it compatible with RPacketTracker
def finalize_array(self):
pass


@njit
def generate_rpacket_tracker_list(no_of_packets, length):
"""
Parameters
----------
no_of_packets : The count of RPackets that are sent in the ejecta
length : initial length of the tracking array
Returns
-------
A list containing RPacketTracker for each RPacket
"""
rpacket_trackers = List()
for i in range(no_of_packets):
rpacket_trackers.append(RPacketTracker(length))
return rpacket_trackers


@njit
def generate_rpacket_last_interaction_tracker_list(no_of_packets):
"""
Parameters
----------
no_of_packets : The count of RPackets that are sent in the ejecta
Returns
-------
A list containing RPacketLastInteractionTracker for each RPacket
"""
rpacket_trackers = List()
for i in range(no_of_packets):
rpacket_trackers.append(RPacketLastInteractionTracker())
return rpacket_trackers
38 changes: 38 additions & 0 deletions tardis/transport/montecarlo/tests/test_tracker_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest
import numpy as np
from numba import typeof

from tardis.transport.montecarlo.packet_trackers import (
RPacketTracker,
RPacketLastInteractionTracker,
generate_rpacket_tracker_list,
generate_rpacket_last_interaction_tracker_list,
)


def test_generate_rpacket_tracker_list():
no_of_packets = 10
length = 10
random_index = np.random.randint(0, no_of_packets)

rpacket_tracker_list = generate_rpacket_tracker_list(no_of_packets, length)

assert len(rpacket_tracker_list) == no_of_packets
assert len(rpacket_tracker_list[random_index].shell_id) == length
assert typeof(rpacket_tracker_list[random_index]) == typeof(
RPacketTracker(length)
)


def test_generate_rpacket_last_interaction_tracker_list():
no_of_packets = 50
random_index = np.random.randint(0, no_of_packets)

rpacket_last_interaction_tracker_list = (
generate_rpacket_last_interaction_tracker_list(no_of_packets)
)

assert len(rpacket_last_interaction_tracker_list) == no_of_packets
assert typeof(
rpacket_last_interaction_tracker_list[random_index]
) == typeof(RPacketLastInteractionTracker())

0 comments on commit 7231707

Please sign in to comment.