Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restructure Trackers #2719

Merged
merged 28 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
eb16122
Restructure Trackers
Sumit112192 Jul 22, 2024
75f129a
Remove ENABLE_RPACKET_TRACKING global variable from montecarlo_main_loop
Sumit112192 Jul 22, 2024
ac4b6c2
Import Numba List
Sumit112192 Jul 22, 2024
a8b3051
Import trackers
Sumit112192 Jul 22, 2024
83817ee
Make RPacketLastInteractionTracker Compatible with RPacketTracker
Sumit112192 Jul 22, 2024
a820b17
Fix Typo
Sumit112192 Jul 22, 2024
95917aa
Numba loop to pure python approach
Sumit112192 Jul 23, 2024
ee579bc
Rebase
Sumit112192 Jul 23, 2024
912f404
Resolving conflicts
Sumit112192 Jul 23, 2024
f6f9983
Resolving conflicts
Sumit112192 Jul 23, 2024
c1f7e06
Rebase
Sumit112192 Jul 23, 2024
0aad3b7
Remove Unused import
Sumit112192 Jul 23, 2024
44cb334
Black
Sumit112192 Jul 23, 2024
d935e23
Import numba.typed.List
Sumit112192 Jul 23, 2024
0f67ff2
Rename function name
Sumit112192 Jul 23, 2024
c9fdf8a
Fix montecarlo_main_loop benchmark
Sumit112192 Jul 23, 2024
0843aff
Remove code
Sumit112192 Jul 26, 2024
c18503a
Add docstring
Sumit112192 Jul 26, 2024
8f81e54
Add tests for the tracker utility functions
Sumit112192 Jul 27, 2024
4b69372
Black
Sumit112192 Jul 27, 2024
49c0ca6
Fix Typo
Sumit112192 Jul 27, 2024
f04bb84
update benchmark
Sumit112192 Jul 27, 2024
0c1c35a
Update benchmark
Sumit112192 Jul 27, 2024
3c5ead5
Fix Typo
Sumit112192 Jul 27, 2024
afb7be3
Checking if there is circular imports
Sumit112192 Jul 27, 2024
c6f1f71
Update Benchmark
Sumit112192 Jul 28, 2024
2ae3a98
Fix Typo
Sumit112192 Jul 28, 2024
55b291e
Add benchmark for the utility functions
Sumit112192 Jul 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions benchmarks/benchmark_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
from tardis.transport.montecarlo.packet_collections import (
VPacketCollection,
)
from tardis.transport.montecarlo.packet_trackers import RPacketTracker
from tardis.transport.montecarlo.packet_trackers import (
RPacketTracker,
generate_rpacket_last_interaction_tracker_list,
)


class BenchmarkBase:
Expand Down Expand Up @@ -239,7 +242,9 @@ def packet(self):

@property
def verysimple_packet_collection(self):
return self.nb_simulation_verysimple.transport.transport_state.packet_collection
return (
self.nb_simulation_verysimple.transport.transport_state.packet_collection
)

@property
def nb_simulation_verysimple(self):
Expand Down Expand Up @@ -269,19 +274,25 @@ def verysimple_enable_full_relativity(self):

@property
def verysimple_disable_line_scattering(self):
return self.nb_simulation_verysimple.transport.montecarlo_configuration.DISABLE_LINE_SCATTERING
return (
self.nb_simulation_verysimple.transport.montecarlo_configuration.DISABLE_LINE_SCATTERING
)

@property
def verysimple_continuum_processes_enabled(self):
return montecarlo_globals.CONTINUUM_PROCESSES_ENABLED

@property
def verysimple_tau_russian(self):
return self.nb_simulation_verysimple.transport.montecarlo_configuration.VPACKET_TAU_RUSSIAN
return (
self.nb_simulation_verysimple.transport.montecarlo_configuration.VPACKET_TAU_RUSSIAN
)

@property
def verysimple_survival_probability(self):
return self.nb_simulation_verysimple.transport.montecarlo_configuration.SURVIVAL_PROBABILITY
return (
self.nb_simulation_verysimple.transport.montecarlo_configuration.SURVIVAL_PROBABILITY
)

@property
def static_packet(self):
Expand All @@ -303,7 +314,9 @@ def set_seed(value):

@property
def verysimple_3vpacket_collection(self):
spectrum_frequency_grid = self.nb_simulation_verysimple.transport.spectrum_frequency_grid.value
spectrum_frequency_grid = (
self.nb_simulation_verysimple.transport.spectrum_frequency_grid.value
)
return VPacketCollection(
source_rpacket_index=0,
spectrum_frequency_grid=spectrum_frequency_grid,
Expand Down Expand Up @@ -404,3 +417,8 @@ def estimators(self):
stim_recomb_cooling_estimator=np.empty((0, 0), dtype=np.float64),
photo_ion_estimator_statistics=np.empty((0, 0), dtype=np.int64),
)

@property
def rpacket_tracker_list(self):
no_of_packets = len(self.transport_state.packet_collection.initial_nus)
return generate_rpacket_last_interaction_tracker_list(no_of_packets)
1 change: 1 addition & 0 deletions benchmarks/transport_montecarlo_main_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def time_montecarlo_main_loop(self):
self.montecarlo_configuration,
self.transport_state.radfield_mc_estimators,
self.nb_simulation_verysimple.transport.spectrum_frequency_grid.value,
self.rpacket_tracker_list,
self.montecarlo_configuration.NUMBER_OF_VPACKETS,
iteration=0,
show_progress_bars=False,
Expand Down
22 changes: 19 additions & 3 deletions benchmarks/transport_montecarlo_packet_trackers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from benchmarks.benchmark_base import BenchmarkBase
from tardis.transport.montecarlo.packet_trackers import (
rpacket_trackers_to_dataframe,
generate_rpacket_tracker_list,
generate_rpacket_last_interaction_tracker_list,
)


Expand All @@ -15,6 +17,20 @@ class BenchmarkTransportMontecarloPacketTrackers(BenchmarkBase):
def time_rpacket_trackers_to_dataframe(self):
sim = self.simulation_rpacket_tracking_enabled
transport_state = sim.transport.transport_state
rpacket_trackers_to_dataframe(
transport_state.rpacket_tracker
)
rpacket_trackers_to_dataframe(transport_state.rpacket_tracker)

def time_generate_rpacket_tracker_list(self, no_of_packets, length):
generate_rpacket_tracker_list(no_of_packets, length)

def time_generate_rpacket_last_interaction_tracker_list(
self, no_of_packets
):
generate_rpacket_last_interaction_tracker_list(no_of_packets)

time_generate_rpacket_tracker_list.params = ([1, 10, 50], [1, 10, 50])
time_generate_rpacket_tracker_list.param_names = ["no_of_packets", "length"]

time_generate_rpacket_last_interaction_tracker_list.params = [10, 100, 1000]
time_generate_rpacket_last_interaction_tracker_list.param_names = [
"no_of_packets"
]
19 changes: 16 additions & 3 deletions tardis/transport/montecarlo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
opacity_state_initialize,
)
from tardis.transport.montecarlo.packet_trackers import (
generate_rpacket_tracker_list,
generate_rpacket_last_interaction_tracker_list,
rpacket_trackers_to_dataframe,
)
from tardis.util.base import (
Expand Down Expand Up @@ -158,12 +160,24 @@ def run(
self.transport_state = transport_state

number_of_vpackets = self.montecarlo_configuration.NUMBER_OF_VPACKETS
number_of_rpackets = len(transport_state.packet_collection.initial_nus)
andrewfullard marked this conversation as resolved.
Show resolved Hide resolved

if self.enable_rpacket_tracking:
transport_state.rpacket_tracker = generate_rpacket_tracker_list(
number_of_rpackets,
self.montecarlo_configuration.INITIAL_TRACKING_ARRAY_LENGTH,
)
else:
transport_state.rpacket_tracker = (
generate_rpacket_last_interaction_tracker_list(
number_of_rpackets
)
)

(
v_packets_energy_hist,
last_interaction_tracker,
vpacket_tracker,
rpacket_trackers,
) = montecarlo_main_loop(
transport_state.packet_collection,
transport_state.geometry_state,
Expand All @@ -172,6 +186,7 @@ def run(
self.montecarlo_configuration,
transport_state.radfield_mc_estimators,
self.spectrum_frequency_grid.value,
transport_state.rpacket_tracker,
number_of_vpackets,
iteration=iteration,
show_progress_bars=show_progress_bars,
Expand Down Expand Up @@ -199,8 +214,6 @@ def run(
update_iterations_pbar(1)
refresh_packet_pbar()

transport_state.rpacket_tracker = rpacket_trackers

# Need to change the implementation of rpacket_trackers_to_dataframe
# Such that it also takes of the case of
# RPacketLastInteractionTracker
Expand Down
3 changes: 0 additions & 3 deletions tardis/transport/montecarlo/configuration/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,3 @@ def configuration_initialize(config, transport, number_of_vpackets):
montecarlo_globals.ENABLE_RPACKET_TRACKING = (
transport.enable_rpacket_tracking
)
montecarlo_main_loop.ENABLE_RPACKET_TRACKING = (
transport.enable_rpacket_tracking
)
24 changes: 2 additions & 22 deletions tardis/transport/montecarlo/montecarlo_main_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@
consolidate_vpacket_tracker,
initialize_last_interaction_tracker,
)
import tardis.transport.montecarlo.montecarlo_main_loop as montecarlo_loop
from tardis.transport.montecarlo.packet_trackers import (
RPacketTracker,
RPacketLastInteractionTracker,
)
from tardis.transport.montecarlo.r_packet import (
PacketStatus,
RPacket,
Expand All @@ -24,8 +19,6 @@
)
from tardis.util.base import update_packet_pbar

ENABLE_RPACKET_TRACKING = False


@njit(**njit_dict)
def montecarlo_main_loop(
Expand All @@ -36,6 +29,7 @@
montecarlo_configuration,
estimators,
spectrum_frequency_grid,
rpacket_trackers,
number_of_vpackets,
iteration,
show_progress_bars,
Expand Down Expand Up @@ -77,19 +71,6 @@

# Pre-allocate a list of vpacket collections for later storage
vpacket_collections = List()
# Configuring the Tracking for R_Packets
rpacket_trackers = List()
if ENABLE_RPACKET_TRACKING:
for i in range(no_of_packets):
rpacket_trackers.append(
RPacketTracker(
montecarlo_configuration.INITIAL_TRACKING_ARRAY_LENGTH
)
)
else:
for i in range(no_of_packets):
rpacket_trackers.append(RPacketLastInteractionTracker())

for i in range(no_of_packets):
vpacket_collections.append(
VPacketCollection(
Expand Down Expand Up @@ -197,13 +178,12 @@
1,
)

if ENABLE_RPACKET_TRACKING:
if montecarlo_globals.ENABLE_RPACKET_TRACKING:

Check warning on line 181 in tardis/transport/montecarlo/montecarlo_main_loop.py

View check run for this annotation

Codecov / codecov/patch

tardis/transport/montecarlo/montecarlo_main_loop.py#L181

Added line #L181 was not covered by tests
for rpacket_tracker in rpacket_trackers:
rpacket_tracker.finalize_array()

return (
v_packets_energy_hist,
last_interaction_tracker,
vpacket_tracker,
rpacket_trackers,
)
42 changes: 41 additions & 1 deletion tardis/transport/montecarlo/packet_trackers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from numba import float64, int64
from numba import float64, int64, njit
from numba.experimental import jitclass
from numba.typed import List
import numpy as np
import pandas as pd

Expand Down Expand Up @@ -203,3 +204,42 @@
self.energy = r_packet.energy
self.shell_id = r_packet.current_shell_id
self.interaction_type = r_packet.last_interaction_type

# To make it compatible with RPacketTracker
def finalize_array(self):
pass

Check warning on line 210 in tardis/transport/montecarlo/packet_trackers.py

View check run for this annotation

Codecov / codecov/patch

tardis/transport/montecarlo/packet_trackers.py#L210

Added line #L210 was not covered by tests


@njit
def generate_rpacket_tracker_list(no_of_packets, length):
"""
Parameters
----------
no_of_packets : The count of RPackets that are sent in the ejecta
length : initial length of the tracking array

Returns
-------
A list containing RPacketTracker for each RPacket
"""
rpacket_trackers = List()
andrewfullard marked this conversation as resolved.
Show resolved Hide resolved
for i in range(no_of_packets):
rpacket_trackers.append(RPacketTracker(length))
return rpacket_trackers

Check warning on line 228 in tardis/transport/montecarlo/packet_trackers.py

View check run for this annotation

Codecov / codecov/patch

tardis/transport/montecarlo/packet_trackers.py#L225-L228

Added lines #L225 - L228 were not covered by tests


@njit
def generate_rpacket_last_interaction_tracker_list(no_of_packets):
"""
andrewfullard marked this conversation as resolved.
Show resolved Hide resolved
Parameters
----------
no_of_packets : The count of RPackets that are sent in the ejecta

Returns
-------
A list containing RPacketLastInteractionTracker for each RPacket
"""
rpacket_trackers = List()
andrewfullard marked this conversation as resolved.
Show resolved Hide resolved
for i in range(no_of_packets):
rpacket_trackers.append(RPacketLastInteractionTracker())
return rpacket_trackers

Check warning on line 245 in tardis/transport/montecarlo/packet_trackers.py

View check run for this annotation

Codecov / codecov/patch

tardis/transport/montecarlo/packet_trackers.py#L242-L245

Added lines #L242 - L245 were not covered by tests
38 changes: 38 additions & 0 deletions tardis/transport/montecarlo/tests/test_tracker_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest
import numpy as np
from numba import typeof

from tardis.transport.montecarlo.packet_trackers import (
RPacketTracker,
RPacketLastInteractionTracker,
generate_rpacket_tracker_list,
generate_rpacket_last_interaction_tracker_list,
)


def test_generate_rpacket_tracker_list():
no_of_packets = 10
length = 10
random_index = np.random.randint(0, no_of_packets)

rpacket_tracker_list = generate_rpacket_tracker_list(no_of_packets, length)

assert len(rpacket_tracker_list) == no_of_packets
assert len(rpacket_tracker_list[random_index].shell_id) == length
assert typeof(rpacket_tracker_list[random_index]) == typeof(
RPacketTracker(length)
)


def test_generate_rpacket_last_interaction_tracker_list():
no_of_packets = 50
random_index = np.random.randint(0, no_of_packets)
Copy link
Contributor Author

@Sumit112192 Sumit112192 Jul 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding the time usage of the random number generator

import numpy as np
no_of_packets = 50
%timeit np.random.randint(0, no_of_packets)
2.89 μs ± 199 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


rpacket_last_interaction_tracker_list = (
generate_rpacket_last_interaction_tracker_list(no_of_packets)
)

assert len(rpacket_last_interaction_tracker_list) == no_of_packets
assert typeof(
rpacket_last_interaction_tracker_list[random_index]
) == typeof(RPacketLastInteractionTracker())
Loading