Skip to content

Commit

Permalink
Added back the 'slice' method because __getitem__ is wacky
Browse files Browse the repository at this point in the history
  • Loading branch information
Rodot- committed Jul 29, 2024
1 parent c441e43 commit 04691d9
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 9 deletions.
7 changes: 4 additions & 3 deletions tardis/transport/montecarlo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,10 @@ def initialize_transport_state(
self.line_interaction_type,
self.montecarlo_configuration.DISABLE_LINE_SCATTERING,
)
opacity_state = opacity_state[
simulation_state.geometry.v_inner_boundary_index : simulation_state.geometry.v_outer_boundary_index
]
opacity_state = opacity_state.slice(
simulation_state.geometry.v_inner_boundary_index,
simulation_state.geometry.v_outer_boundary_index
)
estimators = initialize_estimator_statistics(
opacity_state.tau_sobolev.shape, gamma_shape
)
Expand Down
37 changes: 36 additions & 1 deletion tardis/transport/montecarlo/numba_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,9 @@ def __getitem__(self, i):
i (int, slice): shell index or slice
Returns:
OpacityState a shallow copy of the current instance
OpacityState : a shallow copy of the current instance
"""
#NOTE: This currently will not work with continuum processes since it does not slice those arrays
return OpacityState(
self.electron_density[i],
self.t_electrons[i],
Expand All @@ -146,6 +147,40 @@ def __getitem__(self, i):
self.k_packet_idx,
)

def slice(self, i):
"""Get a shell or slice of shells of the attributes of the opacity state
Args:
i (int, slice): shell index or slice
Returns:
OpacityState : a shallow copy of the current instance
"""
#NOTE: This currently will not work with continuum processes since it does not slice those arrays
return OpacityState(
self.electron_density[i],
self.t_electrons[i],
self.line_list_nu,
self.tau_sobolev[:, i],
self.transition_probabilities[:, i],
self.line2macro_level_upper,
self.macro_block_references,
self.transition_type,
self.destination_level_id,
self.transition_line_id,
self.bf_threshold_list_nu,
self.p_fb_deactivation,
self.photo_ion_nu_threshold_mins,
self.photo_ion_nu_threshold_maxs,
self.photo_ion_block_references,
self.chi_bf,
self.x_sect,
self.phot_nus,
self.ff_opacity_factor,
self.emissivities,
self.photo_ion_activation_idx,
self.k_packet_idx,
)

def opacity_state_initialize(
plasma,
Expand Down
18 changes: 13 additions & 5 deletions tardis/transport/montecarlo/tests/test_numba_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,29 @@
import numpy as np


@pytest.mark.parametrize("input_params", ["scatter", "macroatom", "downbranch"])
def test_opacity_state_initialize(nb_simulation_verysimple, input_params):

@pytest.mark.parametrize("input_params,sliced", [("scatter", False), ("macroatom", False), ("macroatom", True), ("downbranch", False), ("downbranch", True)])
def test_opacity_state_initialize(nb_simulation_verysimple, input_params, sliced):
line_interaction_type = input_params
plasma = nb_simulation_verysimple.plasma
actual = numba_interface.opacity_state_initialize(
plasma,
line_interaction_type,
disable_line_scattering=False,
)
print(dir(actual))
if sliced:
index = slice(2, 5)
actual = actual.slice(index)
else:
index = ...

npt.assert_allclose(
actual.electron_density, plasma.electron_densities.values
actual.electron_density, plasma.electron_densities.values[index]
)
npt.assert_allclose(actual.line_list_nu, plasma.atomic_data.lines.nu.values)
npt.assert_allclose(actual.tau_sobolev, plasma.tau_sobolevs.values)
print(actual.tau_sobolev.shape, plasma.tau_sobolevs.values[:, index].shape)
npt.assert_allclose(actual.tau_sobolev, plasma.tau_sobolevs.values[:, index])
if line_interaction_type == "scatter":
empty = np.zeros(1, dtype=np.int64)
npt.assert_allclose(
Expand All @@ -32,7 +40,7 @@ def test_opacity_state_initialize(nb_simulation_verysimple, input_params):
else:
npt.assert_allclose(
actual.transition_probabilities,
plasma.transition_probabilities.values,
plasma.transition_probabilities.values[:, index],
)
npt.assert_allclose(
actual.line2macro_level_upper,
Expand Down

0 comments on commit 04691d9

Please sign in to comment.