Skip to content

Commit

Permalink
decoupled macro_atom from the opacity solver
Browse files Browse the repository at this point in the history
  • Loading branch information
Rodot- committed Aug 9, 2024
1 parent c03d1b0 commit 81f280a
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 57 deletions.
40 changes: 4 additions & 36 deletions tardis/opacities/opacity_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,6 @@ def __init__(

self.line_interaction_type = line_interaction_type
self.disable_line_scattering = disable_line_scattering
if self.line_interaction_type in (
"downbranch",
"macroatom",
):
if montecarlo_globals.CONTINUUM_PROCESSES_ENABLED:
self.macro_atom_solver = (
None # in the future will be the MacroAtomContinuum solver
)
else:
self.macro_atom_solver = MacroAtomSolver()
else:
self.macro_atom_solver = None

def solve(self, legacy_plasma) -> OpacityState:
"""
Expand All @@ -63,43 +51,23 @@ def solve(self, legacy_plasma) -> OpacityState:
tau_sobolev = pd.DataFrame(
np.zeros(
(
legacy_plasma.atomic_data.lines.shape[
0
], # number of lines
atomic_data.lines.shape[0], # number of lines
legacy_plasma.abundance.shape[1], # number of shells
),
dtype=np.float64,
),
index=legacy_plasma.atomic_data.lines.index,
index=atomic_data.lines.index,
)
else:
tau_sobolev = calculate_sobolev_line_opacity(
legacy_plasma.atomic_data.lines,
atomic_data.lines,
legacy_plasma.level_number_density,
legacy_plasma.time_explosion,
legacy_plasma.stimulated_emission_factor,
)

if montecarlo_globals.CONTINUUM_PROCESSES_ENABLED:
macroatom_state = MacroAtomState.from_legacy_plasma(
legacy_plasma
) # TODO: Impliment

elif self.line_interaction_type in (
"downbranch",
"macroatom",
):
macroatom_state = self.macro_atom_solver.solve(
legacy_plasma,
atomic_data,
tau_sobolev,
legacy_plasma.stimulated_emission_factor,
)
else:
macroatom_state = None

opacity_state = OpacityState.from_legacy_plasma(
legacy_plasma, tau_sobolev, macroatom_state
legacy_plasma, tau_sobolev
)

return opacity_state
28 changes: 10 additions & 18 deletions tardis/opacities/opacity_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def __init__(
t_electrons,
line_list_nu,
tau_sobolev,
macroatom_state,
continuum_state,
):
"""
Expand All @@ -27,7 +26,6 @@ def __init__(
t_electrons : numpy.ndarray
line_list_nu : pd.DataFrame
tau_sobolev : pd.DataFrame
macroatom_state: tardis.opacities.macro_atom.macroatom_state.MacroAtomState
continuum_state: tardis.opacities.continuum.continuum_state.ContinuumState
"""
self.electron_density = electron_density
Expand All @@ -38,10 +36,9 @@ def __init__(

# Continuum Opacity Data
self.continuum_state = continuum_state
self.macroatom_state = macroatom_state

@classmethod
def from_legacy_plasma(cls, plasma, tau_sobolev, macroatom_state=None):
def from_legacy_plasma(cls, plasma, tau_sobolev):
"""
Generates an OpacityStatePython object from a tardis BasePlasma
Expand Down Expand Up @@ -69,7 +66,6 @@ def from_legacy_plasma(cls, plasma, tau_sobolev, macroatom_state=None):
plasma.t_electrons,
atomic_data.lines.nu,
tau_sobolev,
macroatom_state,
continuum_state,
)

Expand Down Expand Up @@ -212,7 +208,9 @@ def __getitem__(self, i: slice):


def opacity_state_to_numba(
opacity_state: OpacityState, line_interaction_type
opacity_state: OpacityState,
macro_atom_state: MacroAtomState,
line_interaction_type,
) -> OpacityStateNumba:
"""
Initialize the OpacityStateNumba object and copy over the data over from OpacityState class
Expand Down Expand Up @@ -245,27 +243,21 @@ def opacity_state_to_numba(
transition_line_id = np.zeros(array_size, dtype=np.int64)
else:
transition_probabilities = np.ascontiguousarray(
opacity_state.macroatom_state.transition_probabilities.values.copy(),
macro_atom_state.transition_probabilities.values.copy(),
dtype=np.float64,
)
line2macro_level_upper = (
opacity_state.macroatom_state.line2macro_level_upper
)
line2macro_level_upper = macro_atom_state.line2macro_level_upper
# TODO: Fix setting of block references for non-continuum mode

macro_block_references = np.asarray(
opacity_state.macroatom_state.macro_block_references
macro_atom_state.macro_block_references
)

transition_type = opacity_state.macroatom_state.transition_type.values
transition_type = macro_atom_state.transition_type.values

# Destination level is not needed and/or generated for downbranch
destination_level_id = (
opacity_state.macroatom_state.destination_level_id.values
)
transition_line_id = (
opacity_state.macroatom_state.transition_line_id.values
)
destination_level_id = macro_atom_state.destination_level_id.values
transition_line_id = macro_atom_state.transition_line_id.values

if montecarlo_globals.CONTINUUM_PROCESSES_ENABLED:
bf_threshold_list_nu = (
Expand Down
10 changes: 8 additions & 2 deletions tardis/opacities/tests/test_opacity_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy.testing as npt
import pandas.testing as pdt
from tardis.opacities.opacity_solver import OpacitySolver
from tardis.opacities.macro_atom.macroatom_solver import MacroAtomSolver
from tardis.opacities.opacity_state import OpacityState
from tardis.opacities.tau_sobolev import calculate_sobolev_line_opacity

Expand Down Expand Up @@ -38,9 +39,14 @@ def test_opacity_solver(
if not disable_line_scattering:
pdt.assert_frame_equal(actual.tau_sobolev, legacy_plasma.tau_sobolevs)
if line_interaction_type == "scatter":
assert actual.macroatom_state is None
pass
else:
macroatom_state = actual.macroatom_state
macroatom_state = MacroAtomSolver().solve(
legacy_plasma,
legacy_plasma.atomic_data,
actual.tau_sobolev,
actual.opacity_state.simulated_emission_factor,
)
pdt.assert_frame_equal(
macroatom_state.transition_probabilities,
legacy_plasma.transition_probabilities,
Expand Down
26 changes: 26 additions & 0 deletions tardis/simulation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
MCContinuumPropertiesSolver,
)
from tardis.opacities.opacity_solver import OpacitySolver
from tardis.opacities.macro_atom.macroatom_solver import MacroAtomSolver
from tardis.opacities.macro_atom.macroatom_state import MacroAtomState
from tardis.util.base import is_notebook
from tardis.visualization import ConvergencePlots

Expand Down Expand Up @@ -105,6 +107,7 @@ class Simulation(PlasmaStateStorerMixin, HDFWriterMixin):
plasma : tardis.plasma.BasePlasma
transport : tardis.transport.montecarlo.MontecarloTransport
opacity : tardis.opacities.opacity_solver.OpacitySolver
macro_atom : tardis.opacities.macro_atom.macroatom_solver.MacroAtomSolver
no_of_packets : int
last_no_of_packets : int
no_of_virtual_packets : int
Expand Down Expand Up @@ -133,6 +136,7 @@ def __init__(
plasma,
transport,
opacity,
macro_atom,
no_of_packets,
no_of_virtual_packets,
luminosity_nu_start,
Expand All @@ -156,6 +160,7 @@ def __init__(
self.plasma = plasma
self.transport = transport
self.opacity = opacity
self.macro_atom = macro_atom
self.no_of_packets = no_of_packets
self.last_no_of_packets = last_no_of_packets
self.no_of_virtual_packets = no_of_virtual_packets
Expand Down Expand Up @@ -444,10 +449,23 @@ def iterate(self, no_of_packets, no_of_virtual_packets=0):
)

opacity_state = self.opacity.solve(self.plasma)
if self.macro_atom is not None:
if montecarlo_globals.CONTINUUM_PROCESSES_ENABLED:
macro_atom_state = MacroAtomState.from_legacy_plasma(
self.plasma
) # TODO: Impliment
else:
macro_atom_state = self.macro_atom.solve(
self.plasma,
self.plasma.atomic_data,
opacity_state.tau_sobolev,
opacity_state.simulated_emission_factor,
)

transport_state = self.transport.initialize_transport_state(
self.simulation_state,
opacity_state,
macro_atom_state,
self.plasma,
no_of_packets,
no_of_virtual_packets=no_of_virtual_packets,
Expand Down Expand Up @@ -703,6 +721,7 @@ def from_config(
plasma=None,
transport=None,
opacity=None,
macro_atom=None,
**kwargs,
):
"""
Expand Down Expand Up @@ -765,6 +784,12 @@ def from_config(
config.plasma.line_interaction_type,
config.plasma.disable_line_scattering,
)
if macro_atom is None:
if config.plasma.line_interaction_type in (
"downbranch",
"macroatom",
):
macro_atom = MacroAtomSolver()

convergence_plots_config_options = [
"plasma_plot_config",
Expand Down Expand Up @@ -805,6 +830,7 @@ def from_config(
plasma=plasma,
transport=transport,
opacity=opacity,
macro_atom=macro_atom,
show_convergence_plots=show_convergence_plots,
no_of_packets=int(config.montecarlo.no_of_packets),
no_of_virtual_packets=int(config.montecarlo.no_of_virtual_packets),
Expand Down
3 changes: 2 additions & 1 deletion tardis/transport/montecarlo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def initialize_transport_state(
self,
simulation_state,
opacity_state,
macro_atom_state,
plasma,
no_of_packets,
no_of_virtual_packets=0,
Expand All @@ -115,7 +116,7 @@ def initialize_transport_state(
geometry_state = simulation_state.geometry.to_numba()

opacity_state_numba = opacity_state_to_numba(
opacity_state, self.line_interaction_type
opacity_state, macro_atom_state, self.line_interaction_type
)
opacity_state_numba = opacity_state_numba[
simulation_state.geometry.v_inner_boundary_index : simulation_state.geometry.v_outer_boundary_index
Expand Down

0 comments on commit 81f280a

Please sign in to comment.