Skip to content

Commit

Permalink
Make MontecarloTransport and montecarlo_numba track the last line int…
Browse files Browse the repository at this point in the history
…eraction shell ids for real and virtual packets (#2344)

* Add in/out last line interaction shell ids for real and virtual packets

* Update mailmap

* Do not handle last_line_interaction_in_shell_id and last_line_interaction_out_shell_id separately
  • Loading branch information
AyushiDaksh committed Jun 30, 2023
1 parent 9a3b7ac commit c6e752d
Show file tree
Hide file tree
Showing 13 changed files with 81 additions and 5 deletions.
5 changes: 5 additions & 0 deletions docs/io/output/vpacket_logging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,9 @@ After running the simulation, the following information can be retrieved:
- Numpy array
- | If the last interaction was a line interaction, the
| line_interaction_out_id for that interaction
| (see :doc:`physical_quantities`)
* - ``transport.virt_packet_last_line_interaction_shell_id``
- Numpy array
- | If the last interaction was a line interaction, the
| line_interaction_shell_id for that interaction
| (see :doc:`physical_quantities`)
15 changes: 15 additions & 0 deletions tardis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
self._wavelength_end = np.inf * u.angstrom
self._atomic_number = None
self._ion_number = None
self._shell = None
self.packet_filter_mode = packet_filter_mode
self.update_last_interaction_filter()

Expand Down Expand Up @@ -97,6 +98,15 @@ def ion_number(self, value):
self._ion_number = value
self.update_last_interaction_filter()

@property
def shell(self):
return self._shell

@shell.setter
def shell(self, value):
self._shell = value
self.update_last_interaction_filter()

def update_last_interaction_filter(self):
if self.packet_filter_mode == "packet_out_nu":
packet_filter = (
Expand All @@ -122,6 +132,11 @@ def update_last_interaction_filter(self):
"allowed are: packet_out_nu, packet_in_nu, line_in_nu"
)

if self.shell is not None:
packet_filter = packet_filter & (
self.last_line_interaction_shell_id == self.shell
)

self.last_line_in = self.lines.iloc[
self.last_line_interaction_in_id[packet_filter]
]
Expand Down
4 changes: 4 additions & 0 deletions tardis/io/model_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,7 @@ def transport_to_dict(transport):
"virt_packet_last_interaction_type": transport.virt_packet_last_interaction_type,
"virt_packet_last_line_interaction_in_id": transport.virt_packet_last_line_interaction_in_id,
"virt_packet_last_line_interaction_out_id": transport.virt_packet_last_line_interaction_out_id,
"virt_packet_last_line_interaction_shell_id": transport.virt_packet_last_line_interaction_shell_id,
"virt_packet_nus": transport.virt_packet_nus,
"volume_cgs": transport.volume,
}
Expand Down Expand Up @@ -798,6 +799,9 @@ def transport_from_hdf(fname):
new_transport.virt_packet_last_line_interaction_out_id = d[
"virt_packet_last_line_interaction_out_id"
]
new_transport.virt_packet_last_line_interaction_shell_id = d[
"virt_packet_last_line_interaction_shell_id"
]
new_transport.virt_packet_nus = d["virt_packet_nus"]
new_transport.volume = d["volume_cgs"]

Expand Down
4 changes: 4 additions & 0 deletions tardis/io/tests/test_model_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,10 @@ def test_store_transport_to_hdf(simulation_verysimple, tmp_path):
f["transport/virt_packet_last_line_interaction_out_id"],
transport_data["virt_packet_last_line_interaction_out_id"],
)
assert np.array_equal(
f["transport/virt_packet_last_line_interaction_shell_id"],
transport_data["virt_packet_last_line_interaction_shell_id"],
)
assert np.array_equal(
f["transport/virt_packet_nus"], transport_data["virt_packet_nus"]
)
Expand Down
2 changes: 2 additions & 0 deletions tardis/montecarlo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class MontecarloTransport(HDFWriterMixin):
"virt_packet_last_interaction_type",
"virt_packet_last_line_interaction_in_id",
"virt_packet_last_line_interaction_out_id",
"virt_packet_last_line_interaction_shell_id",
]

hdf_name = "transport"
Expand Down Expand Up @@ -120,6 +121,7 @@ def __init__(
self.virt_packet_last_interaction_in_nu = np.ones(2) * -1.0
self.virt_packet_last_line_interaction_in_id = np.ones(2) * -1
self.virt_packet_last_line_interaction_out_id = np.ones(2) * -1
self.virt_packet_last_line_interaction_shell_id = np.ones(2) * -1
self.virt_packet_nus = np.ones(2) * -1.0
self.virt_packet_energies = np.ones(2) * -1.0
self.virt_packet_initial_rs = np.ones(2) * -1.0
Expand Down
23 changes: 22 additions & 1 deletion tardis/montecarlo/montecarlo_numba/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def montecarlo_radial1d(
last_interaction_in_nu,
last_line_interaction_in_id,
last_line_interaction_out_id,
last_line_interaction_shell_id,
virt_packet_nus,
virt_packet_energies,
virt_packet_initial_mus,
Expand All @@ -88,6 +89,7 @@ def montecarlo_radial1d(
virt_packet_last_interaction_type,
virt_packet_last_line_interaction_in_id,
virt_packet_last_line_interaction_out_id,
virt_packet_last_line_interaction_shell_id,
rpacket_trackers,
) = montecarlo_main_loop(
packet_collection,
Expand All @@ -109,6 +111,7 @@ def montecarlo_radial1d(
transport.last_interaction_in_nu = last_interaction_in_nu
transport.last_line_interaction_in_id = last_line_interaction_in_id
transport.last_line_interaction_out_id = last_line_interaction_out_id
transport.last_line_interaction_shell_id = last_line_interaction_shell_id

if montecarlo_configuration.VPACKET_LOGGING and number_of_vpackets > 0:
transport.virt_packet_nus = np.concatenate(virt_packet_nus).ravel()
Expand All @@ -133,6 +136,9 @@ def montecarlo_radial1d(
transport.virt_packet_last_line_interaction_out_id = np.concatenate(
virt_packet_last_line_interaction_out_id
).ravel()
transport.virt_packet_last_line_interaction_shell_id = np.concatenate(
virt_packet_last_line_interaction_shell_id
).ravel()
update_iterations_pbar(1)
refresh_packet_pbar()
# Condition for Checking if RPacket Tracking is enabled
Expand Down Expand Up @@ -188,6 +194,9 @@ def montecarlo_main_loop(
last_line_interaction_out_ids = (
np.ones_like(packet_collection.packets_output_nu, dtype=np.int64) * -1
)
last_line_interaction_shell_ids = (
np.ones_like(packet_collection.packets_output_nu, dtype=np.int64) * -1
)

v_packets_energy_hist = np.zeros_like(spectrum_frequency)
delta_nu = spectrum_frequency[1] - spectrum_frequency[0]
Expand Down Expand Up @@ -239,10 +248,10 @@ def montecarlo_main_loop(
virt_packet_last_interaction_type = []
virt_packet_last_line_interaction_in_id = []
virt_packet_last_line_interaction_out_id = []
virt_packet_last_line_interaction_shell_id = []
for i in prange(len(output_nus)):
tid = get_thread_id()
if show_progress_bars:

if tid == main_thread_id:
with objmode:
update_amount = 1 * n_threads
Expand Down Expand Up @@ -281,6 +290,9 @@ def montecarlo_main_loop(
last_interaction_in_nus[i] = r_packet.last_interaction_in_nu
last_line_interaction_in_ids[i] = r_packet.last_line_interaction_in_id
last_line_interaction_out_ids[i] = r_packet.last_line_interaction_out_id
last_line_interaction_shell_ids[
i
] = r_packet.last_line_interaction_shell_id

if r_packet.status == PacketStatus.REABSORBED:
output_energies[i] = -r_packet.energy
Expand Down Expand Up @@ -360,6 +372,13 @@ def montecarlo_main_loop(
]
)
)
virt_packet_last_line_interaction_shell_id.append(
np.ascontiguousarray(
vpacket_collection.last_interaction_shell_id[
: vpacket_collection.idx
]
)
)

if montecarlo_configuration.RPACKET_TRACKING:
for rpacket_tracker in rpacket_trackers:
Expand All @@ -373,6 +392,7 @@ def montecarlo_main_loop(
last_interaction_in_nus,
last_line_interaction_in_ids,
last_line_interaction_out_ids,
last_line_interaction_shell_ids,
virt_packet_nus,
virt_packet_energies,
virt_packet_initial_mus,
Expand All @@ -381,5 +401,6 @@ def montecarlo_main_loop(
virt_packet_last_interaction_type,
virt_packet_last_line_interaction_in_id,
virt_packet_last_line_interaction_out_id,
virt_packet_last_line_interaction_shell_id,
rpacket_trackers,
)
1 change: 1 addition & 0 deletions tardis/montecarlo/montecarlo_numba/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ def line_emission(r_packet, emission_line_id, time_explosion, numba_plasma):
"""

r_packet.last_line_interaction_out_id = emission_line_id
r_packet.last_line_interaction_shell_id = r_packet.current_shell_id

if emission_line_id != r_packet.next_line_id:
pass
Expand Down
13 changes: 13 additions & 0 deletions tardis/montecarlo/montecarlo_numba/numba_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ def __init__(
("last_interaction_type", int64[:]),
("last_interaction_in_id", int64[:]),
("last_interaction_out_id", int64[:]),
("last_interaction_shell_id", int64[:]),
]


Expand Down Expand Up @@ -344,6 +345,9 @@ def __init__(
self.last_interaction_out_id = -1 * np.ones(
temporary_v_packet_bins, dtype=np.int64
)
self.last_interaction_shell_id = -1 * np.ones(
temporary_v_packet_bins, dtype=np.int64
)
self.idx = 0
self.rpacket_index = rpacket_index
self.length = temporary_v_packet_bins
Expand All @@ -358,6 +362,7 @@ def set_properties(
last_interaction_type,
last_interaction_in_id,
last_interaction_out_id,
last_interaction_shell_id,
):
if self.idx >= self.length:
temp_length = self.length * 2 + self.number_of_vpackets
Expand All @@ -371,6 +376,9 @@ def set_properties(
temp_last_interaction_type = np.empty(temp_length, dtype=np.int64)
temp_last_interaction_in_id = np.empty(temp_length, dtype=np.int64)
temp_last_interaction_out_id = np.empty(temp_length, dtype=np.int64)
temp_last_interaction_shell_id = np.empty(
temp_length, dtype=np.int64
)

temp_nus[: self.length] = self.nus
temp_energies[: self.length] = self.energies
Expand All @@ -388,6 +396,9 @@ def set_properties(
temp_last_interaction_out_id[
: self.length
] = self.last_interaction_out_id
temp_last_interaction_shell_id[
: self.length
] = self.last_interaction_shell_id

self.nus = temp_nus
self.energies = temp_energies
Expand All @@ -397,6 +408,7 @@ def set_properties(
self.last_interaction_type = temp_last_interaction_type
self.last_interaction_in_id = temp_last_interaction_in_id
self.last_interaction_out_id = temp_last_interaction_out_id
self.last_interaction_shell_id = temp_last_interaction_shell_id
self.length = temp_length

self.nus[self.idx] = nu
Expand All @@ -407,6 +419,7 @@ def set_properties(
self.last_interaction_type[self.idx] = last_interaction_type
self.last_interaction_in_id[self.idx] = last_interaction_in_id
self.last_interaction_out_id[self.idx] = last_interaction_out_id
self.last_interaction_shell_id[self.idx] = last_interaction_shell_id
self.idx += 1


Expand Down
2 changes: 2 additions & 0 deletions tardis/montecarlo/montecarlo_numba/r_packet.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class PacketStatus(IntEnum):
("last_interaction_in_nu", float64),
("last_line_interaction_in_id", int64),
("last_line_interaction_out_id", int64),
("last_line_interaction_shell_id", int64),
]


Expand All @@ -61,6 +62,7 @@ def __init__(self, r, mu, nu, energy, seed, index=0):
self.last_interaction_in_nu = 0.0
self.last_line_interaction_in_id = -1
self.last_line_interaction_out_id = -1
self.last_line_interaction_shell_id = -1

def initialize_line_id(self, numba_plasma, numba_model):
inverse_line_list_nu = numba_plasma.line_list_nu[::-1]
Expand Down
11 changes: 10 additions & 1 deletion tardis/montecarlo/montecarlo_numba/tests/test_numba_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def test_configuration_initialize():


def test_VPacketCollection_set_properties(verysimple_3vpacket_collection):

assert verysimple_3vpacket_collection.length == 0

nus = [3.0e15, 0.0, 1e15, 1e5]
Expand All @@ -73,6 +72,7 @@ def test_VPacketCollection_set_properties(verysimple_3vpacket_collection):
last_interaction_types = np.array([1, 1, 3, 2], dtype=np.int64)
last_interaction_in_ids = np.array([100, 0, 1, 1000], dtype=np.int64)
last_interaction_out_ids = np.array([1201, 123, 545, 1232], dtype=np.int64)
last_interaction_shell_ids = np.array([2, -1, 6, 0], dtype=np.int64)

for (
nu,
Expand All @@ -83,6 +83,7 @@ def test_VPacketCollection_set_properties(verysimple_3vpacket_collection):
last_interaction_type,
last_interaction_in_id,
last_interaction_out_id,
last_interaction_shell_id,
) in zip(
nus,
energies,
Expand All @@ -92,6 +93,7 @@ def test_VPacketCollection_set_properties(verysimple_3vpacket_collection):
last_interaction_types,
last_interaction_in_ids,
last_interaction_out_ids,
last_interaction_shell_ids,
):
verysimple_3vpacket_collection.set_properties(
nu,
Expand All @@ -102,6 +104,7 @@ def test_VPacketCollection_set_properties(verysimple_3vpacket_collection):
last_interaction_type,
last_interaction_in_id,
last_interaction_out_id,
last_interaction_shell_id,
)

npt.assert_array_equal(
Expand Down Expand Up @@ -152,4 +155,10 @@ def test_VPacketCollection_set_properties(verysimple_3vpacket_collection):
],
last_interaction_out_ids,
)
npt.assert_array_equal(
verysimple_3vpacket_collection.last_interaction_shell_id[
: verysimple_3vpacket_collection.idx
],
last_interaction_shell_ids,
)
assert verysimple_3vpacket_collection.length == 9
2 changes: 1 addition & 1 deletion tardis/montecarlo/montecarlo_numba/vpacket.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@ def trace_vpacket_volley(
if (r_packet.nu < vpacket_collection.v_packet_spawn_start_frequency) or (
r_packet.nu > vpacket_collection.v_packet_spawn_end_frequency
):

return

no_of_vpackets = vpacket_collection.number_of_vpackets
Expand Down Expand Up @@ -335,4 +334,5 @@ def trace_vpacket_volley(
r_packet.last_interaction_type,
r_packet.last_line_interaction_in_id,
r_packet.last_line_interaction_out_id,
r_packet.last_line_interaction_shell_id,
)
1 change: 1 addition & 0 deletions tardis/tests/test_tardis_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def test_transport_properties(self, transport):
("virt_packet_last_interaction_type", virt_type),
("virt_packet_last_line_interaction_in_id", virt_type),
("virt_packet_last_line_interaction_out_id", virt_type),
("virt_packet_last_line_interaction_shell_id", virt_type),
("virt_packet_last_interaction_in_nu", virt_type),
("virt_packet_nus", virt_type),
("virt_packet_energies", virt_type),
Expand Down
3 changes: 1 addition & 2 deletions tardis/transport/r_packet_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def trace_packet(
# - do not remove
last_line_id = len(numba_plasma.line_list_nu) - 1
for cur_line_id in range(start_line_id, len(numba_plasma.line_list_nu)):

# Going through the lines
nu_line = numba_plasma.line_list_nu[cur_line_id]

Expand Down Expand Up @@ -107,7 +106,6 @@ def trace_packet(
distance = min(distance_trace, distance_boundary, distance_continuum)

if distance_trace != 0:

if distance == distance_boundary:
interaction_type = InteractionType.BOUNDARY # BOUNDARY
r_packet.next_line_id = cur_line_id
Expand Down Expand Up @@ -143,6 +141,7 @@ def trace_packet(
interaction_type = InteractionType.LINE # Line
r_packet.last_interaction_in_nu = r_packet.nu
r_packet.last_line_interaction_in_id = cur_line_id
r_packet.last_line_interaction_shell_id = r_packet.current_shell_id
r_packet.next_line_id = cur_line_id
distance = distance_trace
break
Expand Down

0 comments on commit c6e752d

Please sign in to comment.