Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Track boundary interaction #2736

Merged
merged 36 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
adb441b
Import Numba List
Sumit112192 Jul 22, 2024
21d230d
Resolving conflicts
Sumit112192 Jul 23, 2024
4b8d71e
Resolving conflicts
Sumit112192 Jul 23, 2024
b200eb5
Rebase
Sumit112192 Jul 23, 2024
92a0e31
Remove Unused import
Sumit112192 Jul 23, 2024
17c49cd
Add boundary interaction to packet_trackers
Sumit112192 Jul 24, 2024
2eeb98b
Fix errors
Sumit112192 Jul 24, 2024
c8ba959
Add track_boundary functionality
Sumit112192 Jul 28, 2024
0a92b03
Add ENUM PacketEjectaStatus
Sumit112192 Jul 29, 2024
3cb968f
Add Enum
Sumit112192 Jul 29, 2024
76a1808
Trigger Build
Sumit112192 Jul 29, 2024
f85a271
Move OUTSIDE_EJECTA to PacketStatus ENUM
Sumit112192 Jul 30, 2024
c716200
Add Track Line Interaction
Sumit112192 Jul 30, 2024
ae6dc64
Initializae self.line_interaction_array_length
Sumit112192 Jul 30, 2024
593297e
Fix Typo
Sumit112192 Jul 30, 2024
04cf212
Rename interaction_id to event_id
Sumit112192 Jul 30, 2024
be2326d
Add doc strings to the functions
Sumit112192 Jul 30, 2024
4289974
Add more doc strings
Sumit112192 Jul 30, 2024
675deb7
Remove Line Interaction from this PR
Sumit112192 Jul 31, 2024
ebba27a
Remove Track Line Interaction From this PR
Sumit112192 Jul 31, 2024
9883bc7
Use Extend array function
Sumit112192 Jul 31, 2024
7d799dc
Fix Typo
Sumit112192 Jul 31, 2024
1024f57
Remove length attribute from the class as it was clustering the code
Sumit112192 Jul 31, 2024
3c2d733
Remove line interaction from this PR
Sumit112192 Jul 31, 2024
d31b199
Add tests
Sumit112192 Aug 2, 2024
939fbbe
Remove print statement
Sumit112192 Aug 2, 2024
f4ac83e
Update test
Sumit112192 Aug 2, 2024
c1a0b92
Trigger Build
Sumit112192 Aug 2, 2024
8eafee1
Rename track_boundary_interaction to boundary_interactions_track
Sumit112192 Aug 5, 2024
e657f95
Rename from boundary_interactions_track to get_boundary_data
Sumit112192 Aug 6, 2024
42c460d
Rename from get_boundary_data to get_boundary_interaction
Sumit112192 Aug 6, 2024
2ce9cb9
Rename function name in RPacketLastInteractionTracker to make it same…
Sumit112192 Aug 6, 2024
826a2d9
Change the way RPacketTracker is imported
Sumit112192 Aug 6, 2024
3d2ff15
Use RPacketLastInteractionTracker since it is set by default
Sumit112192 Aug 6, 2024
2b5c1f8
Remove ENUM
Sumit112192 Aug 7, 2024
699cd4e
Max column 60
Sumit112192 Aug 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions benchmarks/benchmark_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from tardis.transport.montecarlo.estimators import radfield_mc_estimators
from tardis.transport.montecarlo.numba_interface import opacity_state_initialize
from tardis.transport.montecarlo.packet_collections import VPacketCollection
from tardis.transport.montecarlo.packet_trackers import RPacketTracker


class BenchmarkBase:
Expand Down Expand Up @@ -62,8 +61,7 @@ def config_rpacket_tracking(self):
@functools.cached_property
def tardis_ref_path(self):
ref_data_path = Path(
Path(__file__).parent.parent,
"tardis-refdata"
Path(__file__).parent.parent, "tardis-refdata"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That needs to be parameterized and not fixed. The ref-data directory needs to be a config option.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't write that function. I just ran black on that file, and that was the change it made.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change I made is from line 204 to 207.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay - talked to @officialasishkumar

).resolve()
return ref_data_path

Expand Down Expand Up @@ -124,7 +122,9 @@ def packet(self):

@functools.cached_property
def verysimple_packet_collection(self):
return self.nb_simulation_verysimple.transport.transport_state.packet_collection
return (
self.nb_simulation_verysimple.transport.transport_state.packet_collection
)

@functools.cached_property
def nb_simulation_verysimple(self):
Expand Down Expand Up @@ -154,11 +154,15 @@ def verysimple_enable_full_relativity(self):

@functools.cached_property
def verysimple_tau_russian(self):
return self.nb_simulation_verysimple.transport.montecarlo_configuration.VPACKET_TAU_RUSSIAN
return (
self.nb_simulation_verysimple.transport.montecarlo_configuration.VPACKET_TAU_RUSSIAN
)

@functools.cached_property
def verysimple_survival_probability(self):
return self.nb_simulation_verysimple.transport.montecarlo_configuration.SURVIVAL_PROBABILITY
return (
self.nb_simulation_verysimple.transport.montecarlo_configuration.SURVIVAL_PROBABILITY
)

@functools.cached_property
def static_packet(self):
Expand All @@ -173,7 +177,9 @@ def static_packet(self):

@functools.cached_property
def verysimple_3vpacket_collection(self):
spectrum_frequency_grid = self.nb_simulation_verysimple.transport.spectrum_frequency_grid.value
spectrum_frequency_grid = (
self.nb_simulation_verysimple.transport.spectrum_frequency_grid.value
)
return VPacketCollection(
source_rpacket_index=0,
spectrum_frequency_grid=spectrum_frequency_grid,
Expand All @@ -195,7 +201,10 @@ def montecarlo_configuration(self):

@functools.cached_property
def rpacket_tracker(self):
return RPacketTracker(0)
# Do not use RPacketTracker or RPacketLastInteraction directly
# Use it by importing packet_trackers
# functions with name track_* function is used by ASV
return packet_trackers.RPacketLastInteractionTracker()

@functools.cached_property
def transport_state(self):
Expand Down
109 changes: 89 additions & 20 deletions tardis/transport/montecarlo/packet_trackers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
from numba import float64, int64, njit
from numba import float64, int64, njit, from_dtype
from numba.experimental import jitclass
from numba.typed import List
import numpy as np
import pandas as pd


boundary_interaction_dtype = np.dtype(
[
("event_id", "int64"),
("current_shell_id", "int64"),
("next_shell_id", "int64"),
]
)


rpacket_tracker_spec = [
("length", int64),
("seed", int64),
("index", int64),
("status", int64[:]),
Expand All @@ -15,7 +24,10 @@
("energy", float64[:]),
("shell_id", int64[:]),
("interaction_type", int64[:]),
("boundary_interaction", from_dtype(boundary_interaction_dtype)[:]),
("num_interactions", int64),
("boundary_interactions_index", int64),
("event_id", int64),
("extend_factor", int64),
]

Expand Down Expand Up @@ -53,17 +65,25 @@ class RPacketTracker(object):
"""

def __init__(self, length):
self.length = length
"""
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed this length variable, since If I kept this length, I would have to add another length variable for every track_interaction function.

Initialize the variables with default value
"""
self.seed = np.int64(0)
self.index = np.int64(0)
self.status = np.empty(self.length, dtype=np.int64)
self.r = np.empty(self.length, dtype=np.float64)
self.nu = np.empty(self.length, dtype=np.float64)
self.mu = np.empty(self.length, dtype=np.float64)
self.energy = np.empty(self.length, dtype=np.float64)
self.shell_id = np.empty(self.length, dtype=np.int64)
self.interaction_type = np.empty(self.length, dtype=np.int64)
self.status = np.empty(length, dtype=np.int64)
self.r = np.empty(length, dtype=np.float64)
self.nu = np.empty(length, dtype=np.float64)
self.mu = np.empty(length, dtype=np.float64)
self.energy = np.empty(length, dtype=np.float64)
self.shell_id = np.empty(length, dtype=np.int64)
self.interaction_type = np.empty(length, dtype=np.int64)
self.boundary_interaction = np.empty(
length,
dtype=boundary_interaction_dtype,
)
self.num_interactions = 0
self.boundary_interactions_index = 0
self.event_id = 1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wkerzendorf If I keep the name num_boundary_interactions it seems similar to the variable num_interactions and can be used to obtain the number of boundary interactions whereas with the current name, it seems that it will only be used for array purposes. Should I rename it to num_boundary_interactions

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be used for indexing purposes - essentially all of the analysis will be done on the dataframes

self.extend_factor = 2

def extend_array(self, array, array_length):
Expand All @@ -74,17 +94,19 @@ def extend_array(self, array, array_length):
return temp_array

def track(self, r_packet):
if self.num_interactions >= self.length:
self.status = self.extend_array(self.status, self.length)
self.r = self.extend_array(self.r, self.length)
self.nu = self.extend_array(self.nu, self.length)
self.mu = self.extend_array(self.mu, self.length)
self.energy = self.extend_array(self.energy, self.length)
self.shell_id = self.extend_array(self.shell_id, self.length)
"""
Track important properties of RPacket
"""
if self.num_interactions >= self.status.size:
self.status = self.extend_array(self.status, self.status.size)
self.r = self.extend_array(self.r, self.r.size)
self.nu = self.extend_array(self.nu, self.nu.size)
self.mu = self.extend_array(self.mu, self.mu.size)
self.energy = self.extend_array(self.energy, self.energy.size)
self.shell_id = self.extend_array(self.shell_id, self.shell_id.size)
self.interaction_type = self.extend_array(
self.interaction_type, self.length
self.interaction_type, self.interaction_type.size
)
self.length = self.length * self.extend_factor

self.index = r_packet.index
self.seed = r_packet.seed
Expand All @@ -99,14 +121,46 @@ def track(self, r_packet):
] = r_packet.last_interaction_type
self.num_interactions += 1

def track_boundary_interaction(self, current_shell_id, next_shell_id):
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This

Copy link
Contributor Author

@Sumit112192 Sumit112192 Aug 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only function defined. I added that function.

Track boundary interaction properties
"""
if self.boundary_interactions_index >= self.boundary_interaction.size:
self.boundary_interaction = self.extend_array(
self.boundary_interaction,
self.boundary_interaction.size,
)

self.boundary_interaction[self.boundary_interactions_index][
"event_id"
] = self.event_id
self.event_id += 1

self.boundary_interaction[self.boundary_interactions_index][
"current_shell_id"
] = current_shell_id

self.boundary_interaction[self.boundary_interactions_index][
"next_shell_id"
] = next_shell_id

self.boundary_interactions_index += 1

def finalize_array(self):
"""
Change the size of the array from length ( or multiple of length ) to
the actual number of interactions
"""
self.status = self.status[: self.num_interactions]
self.r = self.r[: self.num_interactions]
self.nu = self.nu[: self.num_interactions]
self.mu = self.mu[: self.num_interactions]
self.energy = self.energy[: self.num_interactions]
self.shell_id = self.shell_id[: self.num_interactions]
self.interaction_type = self.interaction_type[: self.num_interactions]
self.boundary_interaction = self.boundary_interaction[
: self.boundary_interactions_index
]


def rpacket_trackers_to_dataframe(rpacket_trackers):
Expand Down Expand Up @@ -186,6 +240,9 @@ class RPacketLastInteractionTracker(object):
"""

def __init__(self):
"""
Initialize properties with default values
"""
self.index = -1
self.r = -1.0
self.nu = 0.0
Expand All @@ -194,15 +251,27 @@ def __init__(self):
self.interaction_type = -1

def track(self, r_packet):
"""
Track properties of RPacket and override the previous values
"""
self.index = r_packet.index
self.r = r_packet.r
self.nu = r_packet.nu
self.energy = r_packet.energy
self.shell_id = r_packet.current_shell_id
self.interaction_type = r_packet.last_interaction_type

# To make it compatible with RPacketTracker
def finalize_array(self):
"""
Added to make RPacketLastInteractionTracker compatible with RPacketTracker
"""
pass

# To make it compatible with RPacketTracker
def track_boundary_interaction(self, current_shell_id, next_shell_id):
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But it is a function of different class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This

In that case, we also had two finalize_array, track functions before but the benchmark passed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But it is a function of different class.

oh sorry didn't notice that

Added to make RPacketLastInteractionTracker compatible with RPacketTracker
"""
pass


Expand Down
8 changes: 7 additions & 1 deletion tardis/transport/montecarlo/single_packet_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,10 @@ def single_packet_loop(
# If continuum processes: update continuum estimators

if interaction_type == InteractionType.BOUNDARY:
rpacket_tracker.track_boundary_interaction(
r_packet.current_shell_id,
r_packet.current_shell_id + delta_shell,
)
move_r_packet(
r_packet,
distance,
Expand All @@ -166,7 +170,9 @@ def single_packet_loop(
montecarlo_configuration.ENABLE_FULL_RELATIVITY,
)
move_packet_across_shell_boundary(
r_packet, delta_shell, len(numba_radial_1d_geometry.r_inner)
r_packet,
delta_shell,
len(numba_radial_1d_geometry.r_inner),
)

elif interaction_type == InteractionType.LINE:
Expand Down
24 changes: 24 additions & 0 deletions tardis/transport/montecarlo/tests/test_rpacket_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,30 @@ def test_rpacket_tracker_properties(expected, obtained, request):
npt.assert_allclose(expected, obtained)


def test_boundary_interactions(rpacket_tracker, regression_data):
no_of_packets = len(rpacket_tracker)

# Hard coding the number of columns
# Based on the largest size of boundary_interaction array (60)
obtained_boundary_interaction = np.full(
(no_of_packets, 64),
[-1],
dtype=rpacket_tracker[0].boundary_interaction.dtype,
)

for i, tracker in enumerate(rpacket_tracker):
obtained_boundary_interaction[
i, : tracker.boundary_interaction.size
] = tracker.boundary_interaction

expected_boundary_interaction = regression_data.sync_ndarray(
obtained_boundary_interaction
)
npt.assert_array_equal(
obtained_boundary_interaction, expected_boundary_interaction
)


def test_rpacket_trackers_to_dataframe(simulation_rpacket_tracking):
transport_state = simulation_rpacket_tracking.transport.transport_state
rtracker_df = rpacket_trackers_to_dataframe(transport_state.rpacket_tracker)
Expand Down
Loading