From 4401047fd96f6af4f223da97f11186cc3666d0bb Mon Sep 17 00:00:00 2001 From: Andrew Fullard Date: Mon, 29 Jul 2024 16:10:22 -0400 Subject: [PATCH 1/4] Factor out luminosity calculation, add setup solver method --- tardis/simulation/base.py | 46 +++++++++++++++---------- tardis/spectrum/base.py | 64 ++++++++++------------------------- tardis/spectrum/luminosity.py | 56 ++++++++++++++++++++++++++++++ 3 files changed, 102 insertions(+), 64 deletions(-) create mode 100644 tardis/spectrum/luminosity.py diff --git a/tardis/simulation/base.py b/tardis/simulation/base.py index 599fe2b0caf..9059c28127e 100644 --- a/tardis/simulation/base.py +++ b/tardis/simulation/base.py @@ -20,6 +20,10 @@ from tardis.simulation.convergence import ConvergenceSolver from tardis.spectrum.base import SpectrumSolver from tardis.spectrum.formal_integral import FormalIntegrator +from tardis.spectrum.luminosity import ( + calculate_emitted_luminosity, + calculate_reabsorbed_luminosity, +) from tardis.transport.montecarlo.base import MonteCarloTransportSolver from tardis.transport.montecarlo.configuration import montecarlo_globals from tardis.transport.montecarlo.estimators.continuum_radfield_properties import ( @@ -454,25 +458,23 @@ def iterate(self, no_of_packets, no_of_virtual_packets=0): show_progress_bars=self.show_progress_bars, ) - # Set up spectrum solver - self.spectrum_solver.transport_state = transport_state - self.spectrum_solver._montecarlo_virtual_luminosity.value[ - : - ] = v_packets_energy_hist - output_energy = ( self.transport.transport_state.packet_collection.output_energies ) if np.sum(output_energy < 0) == len(output_energy): logger.critical("No r-packet escaped through the outer boundary.") - emitted_luminosity = self.spectrum_solver.calculate_emitted_luminosity( - self.luminosity_nu_start, self.luminosity_nu_end + emitted_luminosity = calculate_emitted_luminosity( + transport_state.emitted_packet_nu, + transport_state.emitted_packet_luminosity, + self.luminosity_nu_start, + self.luminosity_nu_end, ) - reabsorbed_luminosity = ( - self.spectrum_solver.calculate_reabsorbed_luminosity( - self.luminosity_nu_start, self.luminosity_nu_end - ) + reabsorbed_luminosity = calculate_reabsorbed_luminosity( + transport_state.reabsorbed_packet_nu, + transport_state.reabsorbed_packet_luminosity, + self.luminosity_nu_start, + self.luminosity_nu_end, ) if hasattr(self, "convergence_plots"): self.convergence_plots.fetch_data( @@ -493,7 +495,7 @@ def iterate(self, no_of_packets, no_of_virtual_packets=0): self.log_run_results(emitted_luminosity, reabsorbed_luminosity) self.iterations_executed += 1 - return emitted_luminosity + return emitted_luminosity, v_packets_energy_hist def run_convergence(self): """ @@ -508,7 +510,9 @@ def run_convergence(self): self.plasma.electron_densities, self.simulation_state.t_inner, ) - emitted_luminosity = self.iterate(self.no_of_packets) + emitted_luminosity, v_packets_energy_hist = self.iterate( + self.no_of_packets + ) self.converged = self.advance_state(emitted_luminosity) if hasattr(self, "convergence_plots"): self.convergence_plots.update() @@ -533,11 +537,17 @@ def run_final(self): self.plasma.electron_densities, self.simulation_state.t_inner, ) - self.iterate(self.last_no_of_packets, self.no_of_virtual_packets) + emitted_luminosity, v_packets_energy_hist = self.iterate( + self.last_no_of_packets, self.no_of_virtual_packets + ) - # Set up spectrum solver integrator - self.spectrum_solver._integrator = FormalIntegrator( - self.simulation_state, self.plasma, self.transport + # Set up spectrum solver integrator and virtual spectrum + self.spectrum_solver.setup_optional_spectra( + self.transport.transport_state, + v_packets_energy_hist, + FormalIntegrator( + self.simulation_state, self.plasma, self.transport + ), ) self.reshape_plasma_state_store(self.iterations_executed) diff --git a/tardis/spectrum/base.py b/tardis/spectrum/base.py index 4ef1a913258..2780e893783 100644 --- a/tardis/spectrum/base.py +++ b/tardis/spectrum/base.py @@ -34,6 +34,24 @@ def __init__( self.integrator_settings = integrator_settings self._spectrum_integrated = None + def setup_optional_spectra( + self, transport_state, virtual_packet_luminosity=None, integrator=None + ): + """Set up the solver to handle virtual and integrated spectra + + Parameters + ---------- + virtual_packet_luminosity : np.ndarray, optional + Virtual packet luminosity, unnormalized, by default None + integrator : FormalIntegrator, optional + Integrator to compute the integrated spectrum with, by default None + """ + self.transport_state = transport_state + self._montecarlo_virtual_luminosity = ( + virtual_packet_luminosity * u.erg / u.s + ) + self._integrator = integrator + @property def spectrum_real_packets(self): return TARDISSpectrum( @@ -136,52 +154,6 @@ def montecarlo_virtual_luminosity(self): / self.transport_state.time_of_simulation.value ) - def calculate_emitted_luminosity( - self, luminosity_nu_start, luminosity_nu_end - ): - """ - Calculate emitted luminosity. - - Parameters - ---------- - luminosity_nu_start : astropy.units.Quantity - luminosity_nu_end : astropy.units.Quantity - - Returns - ------- - astropy.units.Quantity - """ - luminosity_wavelength_filter = ( - self.transport_state.emitted_packet_nu > luminosity_nu_start - ) & (self.transport_state.emitted_packet_nu < luminosity_nu_end) - - return self.transport_state.emitted_packet_luminosity[ - luminosity_wavelength_filter - ].sum() - - def calculate_reabsorbed_luminosity( - self, luminosity_nu_start, luminosity_nu_end - ): - """ - Calculate reabsorbed luminosity. - - Parameters - ---------- - luminosity_nu_start : astropy.units.Quantity - luminosity_nu_end : astropy.units.Quantity - - Returns - ------- - astropy.units.Quantity - """ - luminosity_wavelength_filter = ( - self.transport_state.reabsorbed_packet_nu > luminosity_nu_start - ) & (self.transport_state.reabsorbed_packet_nu < luminosity_nu_end) - - return self.transport_state.reabsorbed_packet_luminosity[ - luminosity_wavelength_filter - ].sum() - def solve(self, transport_state): """Solve the spectra diff --git a/tardis/spectrum/luminosity.py b/tardis/spectrum/luminosity.py new file mode 100644 index 00000000000..830b2650238 --- /dev/null +++ b/tardis/spectrum/luminosity.py @@ -0,0 +1,56 @@ +import astropy.units as u +import numpy as np + + +def calculate_emitted_luminosity( + emitted_packet_nu, + emitted_packet_luminosity, + luminosity_nu_start=0 * u.Hz, + luminosity_nu_end=np.inf * u.Hz, +): + """ + Calculate emitted luminosity. + + Parameters + ---------- + emitted_packet_nu : + emitted_packet_luminosity : + luminosity_nu_start : astropy.units.Quantity + luminosity_nu_end : astropy.units.Quantity + + Returns + ------- + astropy.units.Quantity + """ + luminosity_wavelength_filter = (emitted_packet_nu > luminosity_nu_start) & ( + emitted_packet_nu < luminosity_nu_end + ) + + return emitted_packet_luminosity[luminosity_wavelength_filter].sum() + + +def calculate_reabsorbed_luminosity( + reabsorbed_packet_nu, + reabsorbed_packet_luminosity, + luminosity_nu_start=0 * u.Hz, + luminosity_nu_end=np.inf * u.Hz, +): + """ + Calculate reabsorbed luminosity. + + Parameters + ---------- + reabsorbed_packet_nu : + reabsorbed_packet_luminosity : + luminosity_nu_start : astropy.units.Quantity + luminosity_nu_end : astropy.units.Quantity + + Returns + ------- + astropy.units.Quantity + """ + luminosity_wavelength_filter = ( + reabsorbed_packet_nu > luminosity_nu_start + ) & (reabsorbed_packet_nu < luminosity_nu_end) + + return reabsorbed_packet_luminosity[luminosity_wavelength_filter].sum() From dba6d3523371a283d191d902ce5fbb284902b384 Mon Sep 17 00:00:00 2001 From: Andrew Fullard Date: Mon, 29 Jul 2024 17:06:25 -0400 Subject: [PATCH 2/4] Fixes docs --- docs/physics/update_and_conv/update_and_conv.ipynb | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/physics/update_and_conv/update_and_conv.ipynb b/docs/physics/update_and_conv/update_and_conv.ipynb index cbc5c0c9618..da9db07d389 100644 --- a/docs/physics/update_and_conv/update_and_conv.ipynb +++ b/docs/physics/update_and_conv/update_and_conv.ipynb @@ -462,7 +462,8 @@ "#nu_lower = tardis_config.supernova.luminosity_wavelength_end.to(u.Hz, u.spectral)\n", "#nu_upper = tardis_config.supernova.luminosity_wavelength_start.to(u.Hz, u.spectral)\n", "\n", - "L_output = sim.spectrum_solver.calculate_emitted_luminosity(0,np.inf)\n", + "from tardis.spectrum.luminosity import calculate_emitted_luminosity\n", + "L_output = calculate_emitted_luminosity(transport.transport_state.emitted_packet_nu, transport.transport_state.emitted_packet_luminosity, nu_lower, nu_upper)\n", "L_output" ] }, From 273bf66d79f7b275b21d1f6f06b2f1e3dfaf7e70 Mon Sep 17 00:00:00 2001 From: Andrew Fullard Date: Mon, 29 Jul 2024 17:19:42 -0400 Subject: [PATCH 3/4] Further cleanup and tests --- .../update_and_conv/update_and_conv.ipynb | 4 +- tardis/simulation/base.py | 7 +-- tardis/spectrum/luminosity.py | 45 +++----------- tardis/spectrum/tests/test_luminosity.py | 60 +++++++++++++++++++ 4 files changed, 74 insertions(+), 42 deletions(-) create mode 100644 tardis/spectrum/tests/test_luminosity.py diff --git a/docs/physics/update_and_conv/update_and_conv.ipynb b/docs/physics/update_and_conv/update_and_conv.ipynb index da9db07d389..b6d9094c3cd 100644 --- a/docs/physics/update_and_conv/update_and_conv.ipynb +++ b/docs/physics/update_and_conv/update_and_conv.ipynb @@ -462,8 +462,8 @@ "#nu_lower = tardis_config.supernova.luminosity_wavelength_end.to(u.Hz, u.spectral)\n", "#nu_upper = tardis_config.supernova.luminosity_wavelength_start.to(u.Hz, u.spectral)\n", "\n", - "from tardis.spectrum.luminosity import calculate_emitted_luminosity\n", - "L_output = calculate_emitted_luminosity(transport.transport_state.emitted_packet_nu, transport.transport_state.emitted_packet_luminosity, nu_lower, nu_upper)\n", + "from tardis.spectrum.luminosity import calculate_filtered_luminosity\n", + "L_output = calculate_filtered_luminosity(transport.transport_state.emitted_packet_nu, transport.transport_state.emitted_packet_luminosity, nu_lower, nu_upper)\n", "L_output" ] }, diff --git a/tardis/simulation/base.py b/tardis/simulation/base.py index 9059c28127e..c4c6b3c00e6 100644 --- a/tardis/simulation/base.py +++ b/tardis/simulation/base.py @@ -21,8 +21,7 @@ from tardis.spectrum.base import SpectrumSolver from tardis.spectrum.formal_integral import FormalIntegrator from tardis.spectrum.luminosity import ( - calculate_emitted_luminosity, - calculate_reabsorbed_luminosity, + calculate_filtered_luminosity, ) from tardis.transport.montecarlo.base import MonteCarloTransportSolver from tardis.transport.montecarlo.configuration import montecarlo_globals @@ -464,13 +463,13 @@ def iterate(self, no_of_packets, no_of_virtual_packets=0): if np.sum(output_energy < 0) == len(output_energy): logger.critical("No r-packet escaped through the outer boundary.") - emitted_luminosity = calculate_emitted_luminosity( + emitted_luminosity = calculate_filtered_luminosity( transport_state.emitted_packet_nu, transport_state.emitted_packet_luminosity, self.luminosity_nu_start, self.luminosity_nu_end, ) - reabsorbed_luminosity = calculate_reabsorbed_luminosity( + reabsorbed_luminosity = calculate_filtered_luminosity( transport_state.reabsorbed_packet_nu, transport_state.reabsorbed_packet_luminosity, self.luminosity_nu_start, diff --git a/tardis/spectrum/luminosity.py b/tardis/spectrum/luminosity.py index 830b2650238..6e7f214d50c 100644 --- a/tardis/spectrum/luminosity.py +++ b/tardis/spectrum/luminosity.py @@ -2,19 +2,19 @@ import numpy as np -def calculate_emitted_luminosity( - emitted_packet_nu, - emitted_packet_luminosity, +def calculate_filtered_luminosity( + packet_nu, + packet_luminosity, luminosity_nu_start=0 * u.Hz, luminosity_nu_end=np.inf * u.Hz, ): """ - Calculate emitted luminosity. + Calculate total luminosity within a filter range. Parameters ---------- - emitted_packet_nu : - emitted_packet_luminosity : + packet_nu : astropy.units.Quantity + packet_luminosity : astropy.units.Quantity luminosity_nu_start : astropy.units.Quantity luminosity_nu_end : astropy.units.Quantity @@ -22,35 +22,8 @@ def calculate_emitted_luminosity( ------- astropy.units.Quantity """ - luminosity_wavelength_filter = (emitted_packet_nu > luminosity_nu_start) & ( - emitted_packet_nu < luminosity_nu_end + luminosity_wavelength_filter = (packet_nu > luminosity_nu_start) & ( + packet_nu < luminosity_nu_end ) - return emitted_packet_luminosity[luminosity_wavelength_filter].sum() - - -def calculate_reabsorbed_luminosity( - reabsorbed_packet_nu, - reabsorbed_packet_luminosity, - luminosity_nu_start=0 * u.Hz, - luminosity_nu_end=np.inf * u.Hz, -): - """ - Calculate reabsorbed luminosity. - - Parameters - ---------- - reabsorbed_packet_nu : - reabsorbed_packet_luminosity : - luminosity_nu_start : astropy.units.Quantity - luminosity_nu_end : astropy.units.Quantity - - Returns - ------- - astropy.units.Quantity - """ - luminosity_wavelength_filter = ( - reabsorbed_packet_nu > luminosity_nu_start - ) & (reabsorbed_packet_nu < luminosity_nu_end) - - return reabsorbed_packet_luminosity[luminosity_wavelength_filter].sum() + return packet_luminosity[luminosity_wavelength_filter].sum() diff --git a/tardis/spectrum/tests/test_luminosity.py b/tardis/spectrum/tests/test_luminosity.py new file mode 100644 index 00000000000..9acad9bf6f8 --- /dev/null +++ b/tardis/spectrum/tests/test_luminosity.py @@ -0,0 +1,60 @@ +import astropy.units as u +import numpy as np +import pytest + +from tardis.spectrum.luminosity import ( + calculate_filtered_luminosity, +) + + +@pytest.mark.parametrize( + "packet_nu, packet_luminosity, luminosity_nu_start, luminosity_nu_end, expected", + [ + # All frequencies within the range + ( + np.array([1, 2, 3]) * u.Hz, + np.array([10, 20, 30]) * u.erg / u.s, + 0 * u.Hz, + 4 * u.Hz, + 60 * u.erg / u.s, + ), + # All frequencies outside the range + ( + np.array([1, 2, 3]) * u.Hz, + np.array([10, 20, 30]) * u.erg / u.s, + 4 * u.Hz, + 5 * u.Hz, + 0 * u.erg / u.s, + ), + # Mix of frequencies within and outside the range + ( + np.array([1, 2, 3, 4]) * u.Hz, + np.array([10, 20, 30, 40]) * u.erg / u.s, + 2 * u.Hz, + 4 * u.Hz, + 30 * u.erg / u.s, + ), + # Edge case: Frequencies exactly on the boundary + ( + np.array([1, 2, 3, 4]) * u.Hz, + np.array([10, 20, 30, 40]) * u.erg / u.s, + 2 * u.Hz, + 3 * u.Hz, + 0 * u.erg / u.s, + ), + ], +) +def test_calculate_filtered_luminosity( + packet_nu, + packet_luminosity, + luminosity_nu_start, + luminosity_nu_end, + expected, +): + result = calculate_filtered_luminosity( + packet_nu, + packet_luminosity, + luminosity_nu_start, + luminosity_nu_end, + ) + assert result == expected From 8f750b2af8878abf754940ba86790b8af118704a Mon Sep 17 00:00:00 2001 From: Andrew Fullard Date: Tue, 6 Aug 2024 11:28:52 -0400 Subject: [PATCH 4/4] Fix formal integral benchmark --- benchmarks/spectrum_formal_integral.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/spectrum_formal_integral.py b/benchmarks/spectrum_formal_integral.py index 8c69ea71bce..f557a6dffbf 100644 --- a/benchmarks/spectrum_formal_integral.py +++ b/benchmarks/spectrum_formal_integral.py @@ -35,10 +35,10 @@ def time_intensity_black_body(self): # Benchmark for functions in FormalIntegrator class def time_FormalIntegrator_functions(self): self.FormalIntegrator.calculate_spectrum( - self.sim.spectrum_solver.spectrum_real_packets.frequency + self.sim.spectrum_solver.spectrum_frequency_grid ) self.FormalIntegrator.make_source_function() self.FormalIntegrator.generate_numba_objects() self.FormalIntegrator.formal_integral( - self.sim.spectrum_solver.spectrum_real_packets.frequency, 1000 + self.sim.spectrum_solver.spectrum_frequency_grid, 1000 )