diff --git a/tardis/transport/montecarlo/base.py b/tardis/transport/montecarlo/base.py index 6638f43e917..5662a6be9c8 100644 --- a/tardis/transport/montecarlo/base.py +++ b/tardis/transport/montecarlo/base.py @@ -110,9 +110,10 @@ def initialize_transport_state( self.line_interaction_type, self.montecarlo_configuration.DISABLE_LINE_SCATTERING, ) - opacity_state = opacity_state[ - simulation_state.geometry.v_inner_boundary_index : simulation_state.geometry.v_outer_boundary_index - ] + opacity_state = opacity_state.slice( + simulation_state.geometry.v_inner_boundary_index, + simulation_state.geometry.v_outer_boundary_index + ) estimators = initialize_estimator_statistics( opacity_state.tau_sobolev.shape, gamma_shape ) diff --git a/tardis/transport/montecarlo/numba_interface.py b/tardis/transport/montecarlo/numba_interface.py index 7499f8fdb3f..2496f230520 100644 --- a/tardis/transport/montecarlo/numba_interface.py +++ b/tardis/transport/montecarlo/numba_interface.py @@ -119,8 +119,9 @@ def __getitem__(self, i): i (int, slice): shell index or slice Returns: - OpacityState a shallow copy of the current instance + OpacityState : a shallow copy of the current instance """ + #NOTE: This currently will not work with continuum processes since it does not slice those arrays return OpacityState( self.electron_density[i], self.t_electrons[i], @@ -146,6 +147,40 @@ def __getitem__(self, i): self.k_packet_idx, ) + def slice(self, i): + """Get a shell or slice of shells of the attributes of the opacity state + + Args: + i (int, slice): shell index or slice + + Returns: + OpacityState : a shallow copy of the current instance + """ + #NOTE: This currently will not work with continuum processes since it does not slice those arrays + return OpacityState( + self.electron_density[i], + self.t_electrons[i], + self.line_list_nu, + self.tau_sobolev[:, i], + self.transition_probabilities[:, i], + self.line2macro_level_upper, + self.macro_block_references, + self.transition_type, + self.destination_level_id, + self.transition_line_id, + self.bf_threshold_list_nu, + self.p_fb_deactivation, + self.photo_ion_nu_threshold_mins, + self.photo_ion_nu_threshold_maxs, + self.photo_ion_block_references, + self.chi_bf, + self.x_sect, + self.phot_nus, + self.ff_opacity_factor, + self.emissivities, + self.photo_ion_activation_idx, + self.k_packet_idx, + ) def opacity_state_initialize( plasma, diff --git a/tardis/transport/montecarlo/tests/test_numba_interface.py b/tardis/transport/montecarlo/tests/test_numba_interface.py index 25d907049f6..10830cb1463 100644 --- a/tardis/transport/montecarlo/tests/test_numba_interface.py +++ b/tardis/transport/montecarlo/tests/test_numba_interface.py @@ -4,8 +4,9 @@ import numpy as np -@pytest.mark.parametrize("input_params", ["scatter", "macroatom", "downbranch"]) -def test_opacity_state_initialize(nb_simulation_verysimple, input_params): + +@pytest.mark.parametrize("input_params,sliced", [("scatter", False), ("macroatom", False), ("macroatom", True), ("downbranch", False), ("downbranch", True)]) +def test_opacity_state_initialize(nb_simulation_verysimple, input_params, sliced): line_interaction_type = input_params plasma = nb_simulation_verysimple.plasma actual = numba_interface.opacity_state_initialize( @@ -13,12 +14,19 @@ def test_opacity_state_initialize(nb_simulation_verysimple, input_params): line_interaction_type, disable_line_scattering=False, ) + print(dir(actual)) + if sliced: + index = slice(2, 5) + actual = actual.slice(index) + else: + index = ... npt.assert_allclose( - actual.electron_density, plasma.electron_densities.values + actual.electron_density, plasma.electron_densities.values[index] ) npt.assert_allclose(actual.line_list_nu, plasma.atomic_data.lines.nu.values) - npt.assert_allclose(actual.tau_sobolev, plasma.tau_sobolevs.values) + print(actual.tau_sobolev.shape, plasma.tau_sobolevs.values[:, index].shape) + npt.assert_allclose(actual.tau_sobolev, plasma.tau_sobolevs.values[:, index]) if line_interaction_type == "scatter": empty = np.zeros(1, dtype=np.int64) npt.assert_allclose( @@ -32,7 +40,7 @@ def test_opacity_state_initialize(nb_simulation_verysimple, input_params): else: npt.assert_allclose( actual.transition_probabilities, - plasma.transition_probabilities.values, + plasma.transition_probabilities.values[:, index], ) npt.assert_allclose( actual.line2macro_level_upper,