Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor refactor of the spectrum solver #2759

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/spectrum_formal_integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def time_intensity_black_body(self):
# Benchmark for functions in FormalIntegrator class
def time_FormalIntegrator_functions(self):
self.FormalIntegrator.calculate_spectrum(
self.sim.spectrum_solver.spectrum_real_packets.frequency
self.sim.spectrum_solver.spectrum_frequency_grid
)
self.FormalIntegrator.make_source_function()
self.FormalIntegrator.generate_numba_objects()
self.FormalIntegrator.formal_integral(
self.sim.spectrum_solver.spectrum_real_packets.frequency, 1000
self.sim.spectrum_solver.spectrum_frequency_grid, 1000
)
3 changes: 2 additions & 1 deletion docs/physics/update_and_conv/update_and_conv.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,8 @@
"#nu_lower = tardis_config.supernova.luminosity_wavelength_end.to(u.Hz, u.spectral)\n",
"#nu_upper = tardis_config.supernova.luminosity_wavelength_start.to(u.Hz, u.spectral)\n",
"\n",
"L_output = sim.spectrum_solver.calculate_emitted_luminosity(0,np.inf)\n",
"from tardis.spectrum.luminosity import calculate_filtered_luminosity\n",
"L_output = calculate_filtered_luminosity(transport.transport_state.emitted_packet_nu, transport.transport_state.emitted_packet_luminosity, nu_lower, nu_upper)\n",
"L_output"
]
},
Expand Down
45 changes: 27 additions & 18 deletions tardis/simulation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from tardis.simulation.convergence import ConvergenceSolver
from tardis.spectrum.base import SpectrumSolver
from tardis.spectrum.formal_integral import FormalIntegrator
from tardis.spectrum.luminosity import (
calculate_filtered_luminosity,
)
from tardis.transport.montecarlo.base import MonteCarloTransportSolver
from tardis.transport.montecarlo.configuration import montecarlo_globals
from tardis.transport.montecarlo.estimators.continuum_radfield_properties import (
Expand Down Expand Up @@ -454,25 +457,23 @@ def iterate(self, no_of_packets, no_of_virtual_packets=0):
show_progress_bars=self.show_progress_bars,
)

# Set up spectrum solver
self.spectrum_solver.transport_state = transport_state
self.spectrum_solver._montecarlo_virtual_luminosity.value[
:
] = v_packets_energy_hist

output_energy = (
self.transport.transport_state.packet_collection.output_energies
)
if np.sum(output_energy < 0) == len(output_energy):
logger.critical("No r-packet escaped through the outer boundary.")

emitted_luminosity = self.spectrum_solver.calculate_emitted_luminosity(
self.luminosity_nu_start, self.luminosity_nu_end
emitted_luminosity = calculate_filtered_luminosity(
transport_state.emitted_packet_nu,
transport_state.emitted_packet_luminosity,
self.luminosity_nu_start,
self.luminosity_nu_end,
)
reabsorbed_luminosity = (
self.spectrum_solver.calculate_reabsorbed_luminosity(
self.luminosity_nu_start, self.luminosity_nu_end
)
reabsorbed_luminosity = calculate_filtered_luminosity(
transport_state.reabsorbed_packet_nu,
transport_state.reabsorbed_packet_luminosity,
self.luminosity_nu_start,
self.luminosity_nu_end,
)
if hasattr(self, "convergence_plots"):
self.convergence_plots.fetch_data(
Expand All @@ -493,7 +494,7 @@ def iterate(self, no_of_packets, no_of_virtual_packets=0):

self.log_run_results(emitted_luminosity, reabsorbed_luminosity)
self.iterations_executed += 1
return emitted_luminosity
return emitted_luminosity, v_packets_energy_hist

def run_convergence(self):
"""
Expand All @@ -508,7 +509,9 @@ def run_convergence(self):
self.plasma.electron_densities,
self.simulation_state.t_inner,
)
emitted_luminosity = self.iterate(self.no_of_packets)
emitted_luminosity, v_packets_energy_hist = self.iterate(
self.no_of_packets
)
self.converged = self.advance_state(emitted_luminosity)
if hasattr(self, "convergence_plots"):
self.convergence_plots.update()
Expand All @@ -533,11 +536,17 @@ def run_final(self):
self.plasma.electron_densities,
self.simulation_state.t_inner,
)
self.iterate(self.last_no_of_packets, self.no_of_virtual_packets)
emitted_luminosity, v_packets_energy_hist = self.iterate(
self.last_no_of_packets, self.no_of_virtual_packets
)

# Set up spectrum solver integrator
self.spectrum_solver._integrator = FormalIntegrator(
self.simulation_state, self.plasma, self.transport
# Set up spectrum solver integrator and virtual spectrum
self.spectrum_solver.setup_optional_spectra(
self.transport.transport_state,
v_packets_energy_hist,
FormalIntegrator(
self.simulation_state, self.plasma, self.transport
),
)

self.reshape_plasma_state_store(self.iterations_executed)
Expand Down
64 changes: 18 additions & 46 deletions tardis/spectrum/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,24 @@ def __init__(
self.integrator_settings = integrator_settings
self._spectrum_integrated = None

def setup_optional_spectra(
self, transport_state, virtual_packet_luminosity=None, integrator=None
):
"""Set up the solver to handle virtual and integrated spectra

Parameters
----------
virtual_packet_luminosity : np.ndarray, optional
Virtual packet luminosity, unnormalized, by default None
integrator : FormalIntegrator, optional
Integrator to compute the integrated spectrum with, by default None
"""
self.transport_state = transport_state
self._montecarlo_virtual_luminosity = (
virtual_packet_luminosity * u.erg / u.s
)
self._integrator = integrator

@property
def spectrum_real_packets(self):
return TARDISSpectrum(
Expand Down Expand Up @@ -136,52 +154,6 @@ def montecarlo_virtual_luminosity(self):
/ self.transport_state.time_of_simulation.value
)

def calculate_emitted_luminosity(
self, luminosity_nu_start, luminosity_nu_end
):
"""
Calculate emitted luminosity.

Parameters
----------
luminosity_nu_start : astropy.units.Quantity
luminosity_nu_end : astropy.units.Quantity

Returns
-------
astropy.units.Quantity
"""
luminosity_wavelength_filter = (
self.transport_state.emitted_packet_nu > luminosity_nu_start
) & (self.transport_state.emitted_packet_nu < luminosity_nu_end)

return self.transport_state.emitted_packet_luminosity[
luminosity_wavelength_filter
].sum()

def calculate_reabsorbed_luminosity(
self, luminosity_nu_start, luminosity_nu_end
):
"""
Calculate reabsorbed luminosity.

Parameters
----------
luminosity_nu_start : astropy.units.Quantity
luminosity_nu_end : astropy.units.Quantity

Returns
-------
astropy.units.Quantity
"""
luminosity_wavelength_filter = (
self.transport_state.reabsorbed_packet_nu > luminosity_nu_start
) & (self.transport_state.reabsorbed_packet_nu < luminosity_nu_end)

return self.transport_state.reabsorbed_packet_luminosity[
luminosity_wavelength_filter
].sum()

def solve(self, transport_state):
"""Solve the spectra

Expand Down
29 changes: 29 additions & 0 deletions tardis/spectrum/luminosity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import astropy.units as u
import numpy as np


def calculate_filtered_luminosity(
packet_nu,
packet_luminosity,
luminosity_nu_start=0 * u.Hz,
luminosity_nu_end=np.inf * u.Hz,
):
"""
Calculate total luminosity within a filter range.

Parameters
----------
packet_nu : astropy.units.Quantity
packet_luminosity : astropy.units.Quantity
luminosity_nu_start : astropy.units.Quantity
luminosity_nu_end : astropy.units.Quantity

Returns
-------
astropy.units.Quantity
"""
luminosity_wavelength_filter = (packet_nu > luminosity_nu_start) & (
packet_nu < luminosity_nu_end
)

return packet_luminosity[luminosity_wavelength_filter].sum()
60 changes: 60 additions & 0 deletions tardis/spectrum/tests/test_luminosity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import astropy.units as u
import numpy as np
import pytest

from tardis.spectrum.luminosity import (
calculate_filtered_luminosity,
)


@pytest.mark.parametrize(
"packet_nu, packet_luminosity, luminosity_nu_start, luminosity_nu_end, expected",
[
# All frequencies within the range
(
np.array([1, 2, 3]) * u.Hz,
np.array([10, 20, 30]) * u.erg / u.s,
0 * u.Hz,
4 * u.Hz,
60 * u.erg / u.s,
),
# All frequencies outside the range
(
np.array([1, 2, 3]) * u.Hz,
np.array([10, 20, 30]) * u.erg / u.s,
4 * u.Hz,
5 * u.Hz,
0 * u.erg / u.s,
),
# Mix of frequencies within and outside the range
(
np.array([1, 2, 3, 4]) * u.Hz,
np.array([10, 20, 30, 40]) * u.erg / u.s,
2 * u.Hz,
4 * u.Hz,
30 * u.erg / u.s,
),
# Edge case: Frequencies exactly on the boundary
(
np.array([1, 2, 3, 4]) * u.Hz,
np.array([10, 20, 30, 40]) * u.erg / u.s,
2 * u.Hz,
3 * u.Hz,
0 * u.erg / u.s,
),
],
)
def test_calculate_filtered_luminosity(
packet_nu,
packet_luminosity,
luminosity_nu_start,
luminosity_nu_end,
expected,
):
result = calculate_filtered_luminosity(
packet_nu,
packet_luminosity,
luminosity_nu_start,
luminosity_nu_end,
)
assert result == expected
Loading