Skip to content

Commit

Permalink
Merge branch 'restructure/base_estimator_cleanup' into restructure/as…
Browse files Browse the repository at this point in the history
…semble_plasma_cleanup
  • Loading branch information
wkerzendorf committed Jul 29, 2024
2 parents 4ba4afb + 7d8e6e0 commit 03e1aae
Show file tree
Hide file tree
Showing 20 changed files with 93 additions and 47 deletions.
17 changes: 10 additions & 7 deletions .github/workflows/benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,17 @@ jobs:
with:
fetch-depth: 0

- name: Checkout pull/${{ github.event.number }}
- name: Checkout PR and master branch
uses: actions/checkout@v4
with:
ref: ${{ github.sha }}
ref: ${{ github.event.pull_request.head.sha }}
fetch-depth: 0
if: github.event_name == 'pull_request_target'

- name: Fetch master branch
run: git fetch origin master:master
if: github.event_name == 'pull_request_target'

- name: Setup LFS
uses: ./.github/actions/setup_lfs

Expand Down Expand Up @@ -101,20 +105,19 @@ jobs:
- name: Run benchmarks for base and head commits of PR
if: github.event_name == 'pull_request_target'
run: |
echo ${{ github.event.pull_request.base.sha }} > commit_hashes.txt
echo ${{ github.event.pull_request.head.sha }} >> commit_hashes.txt
echo $(git rev-parse HEAD) > commit_hashes.txt
echo $(git rev-parse master) >> commit_hashes.txt
asv run -a repeat=2 -a rounds=1 HASHFILE:commit_hashes.txt | tee asv-output-PR.log
if grep -q failed asv-output-PR.log; then
echo "Some benchmarks have failed!"
exit 1
fi
- name: Compare Master and PR head
run: asv compare ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }} --config asv.conf.json | tee asv-compare-output.log
run: asv compare origin/master HEAD --config asv.conf.json | tee asv-compare-output.log

- name: Compare Master and PR head but only show changed results
run: asv compare ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }} --only-changed --config asv.conf.json | tee asv-compare-changed-output.log
run: asv compare origin/master HEAD --only-changed --config asv.conf.json | tee asv-compare-changed-output.log

- name: Benchmarks compare output
id: asv_pr_vs_master
Expand Down
6 changes: 2 additions & 4 deletions benchmarks/opacities_opacity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

from asv_runner.benchmarks.mark import parameterize

import tardis.opacities.compton_opacity_calculation
import tardis.opacities.opacities as calculate_opacity
from benchmarks.benchmark_base import BenchmarkBase
from tardis.opacities.opacities import compton_opacity_calculation


class BenchmarkMontecarloMontecarloNumbaOpacities(BenchmarkBase):
Expand All @@ -29,9 +29,7 @@ class BenchmarkMontecarloMontecarloNumbaOpacities(BenchmarkBase):
}
)
def time_compton_opacity_calculation(self, electron_number_density, energy):
tardis.opacities.compton_opacity_calculation.compton_opacity_calculation(
energy, electron_number_density
)
compton_opacity_calculation(energy, electron_number_density)

@parameterize(
{
Expand Down
4 changes: 1 addition & 3 deletions tardis/energy_input/gamma_packet_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@
doppler_factor_3d,
get_index,
)
from tardis.opacities.compton_opacity_calculation import (
compton_opacity_calculation,
)
from tardis.opacities.opacities import (
SIGMA_T,
compton_opacity_calculation,
kappa_calculation,
pair_creation_opacity_artis,
pair_creation_opacity_calculation,
Expand Down
4 changes: 1 addition & 3 deletions tardis/energy_input/gamma_ray_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@
angle_aberration_gamma,
doppler_factor_3d,
)
from tardis.opacities.compton_opacity_calculation import (
compton_opacity_calculation,
)
from tardis.opacities.opacities import (
SIGMA_T,
compton_opacity_calculation,
kappa_calculation,
photoabsorption_opacity_calculation,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

H = const.h.cgs.value


class RawCollIonTransProbs(TransitionProbabilitiesProperty, IndexSetterMixin):
"""
Attributes
Expand Down
2 changes: 1 addition & 1 deletion tardis/plasma/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def write_to_dot(self, fname, args=None, latex_label=True):
edge labels into the file.
"""
try:
pass
import pygraphviz
except:
logger.warning(
"pygraphviz missing. Plasma graph will not be " "generated."
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from tardis.plasma.properties.base import Input
from tardis.plasma.properties.continuum_processes.rates import H


class PhotoIonRateCoeff(Input):
Expand Down
1 change: 1 addition & 0 deletions tardis/plasma/properties/continuum_processes/rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ def calculate(
)
return gamma_corr


class StimRecombCoolingRateCoeffEstimator(Input):
"""
Attributes
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import numpy as np
import pandas as pd

import tardis.constants as const
from tardis.plasma.properties.base import Input, ProcessingPlasmaProperty
from tardis.plasma.properties.continuum_processes.rates import C, H
from tardis.transport.montecarlo.estimators.util import (
integrate_array_by_blocks,
)

C = const.c.cgs.value


class StimRecombRateFactor(Input):
"""
Expand Down
1 change: 1 addition & 0 deletions tardis/plasma/properties/plasma_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ class LinkTRadTElectron(Input):
class HeliumTreatment(Input):
outputs = ("helium_treatment",)


class ContinuumInteractionSpecies(Input):
"""
Attributes
Expand Down
2 changes: 1 addition & 1 deletion tardis/plasma/properties/property_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class PlasmaPropertyCollection(list):
basic_inputs = PlasmaPropertyCollection(
[
DilutePlanckianRadField,
DilutePlanckianRadField,
Abundance,
NumberDensity,
TimeExplosion,
AtomicData,
Expand Down
4 changes: 2 additions & 2 deletions tardis/plasma/radiation_field/planck_rad_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

class DilutePlanckianRadiationField:
"""
Represents the state of a dilute thermal radiation field.
Represents the state of a dilute planckian radiation field.
Parameters
Expand Down Expand Up @@ -79,7 +79,7 @@ def to_planckian_radiation_field(self):

class PlanckianRadiationField:
"""
Represents the state of a dilute thermal radiation field.
Represents the state of a planckian radiation field.
Parameters
Expand Down
10 changes: 3 additions & 7 deletions tardis/simulation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ def __init__(
convergence_plots_kwargs,
show_progress_bars,
spectrum_solver,
integrator_settings,
):
super(Simulation, self).__init__(
iterations, simulation_state.no_of_shells
Expand All @@ -160,7 +159,6 @@ def __init__(
self.luminosity_nu_end = luminosity_nu_end
self.luminosity_requested = luminosity_requested
self.spectrum_solver = spectrum_solver
self.integrator_settings = integrator_settings
self.show_progress_bars = show_progress_bars
self.version = tardis.__version__

Expand Down Expand Up @@ -458,9 +456,9 @@ def iterate(self, no_of_packets, no_of_virtual_packets=0):

# Set up spectrum solver
self.spectrum_solver.transport_state = transport_state
self.spectrum_solver._montecarlo_virtual_luminosity.value[:] = (
v_packets_energy_hist
)
self.spectrum_solver._montecarlo_virtual_luminosity.value[
:
] = v_packets_energy_hist

output_energy = (
self.transport.transport_state.packet_collection.output_energies
Expand Down Expand Up @@ -538,7 +536,6 @@ def run_final(self):
self.iterate(self.last_no_of_packets, self.no_of_virtual_packets)

# Set up spectrum solver integrator
self.spectrum_solver.integrator_settings = self.integrator_settings
self.spectrum_solver._integrator = FormalIntegrator(
self.simulation_state, self.plasma, self.transport
)
Expand Down Expand Up @@ -804,6 +801,5 @@ def from_config(
convergence_strategy=config.montecarlo.convergence_strategy,
convergence_plots_kwargs=convergence_plots_kwargs,
show_progress_bars=show_progress_bars,
integrator_settings=config.spectrum.integrated,
spectrum_solver=spectrum_solver,
)
34 changes: 30 additions & 4 deletions tardis/spectrum/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,16 @@ class SpectrumSolver(HDFWriterMixin):

hdf_name = "spectrum"

def __init__(self, transport_state, spectrum_frequency_grid):
def __init__(
self, transport_state, spectrum_frequency_grid, integrator_settings=None
):
self.transport_state = transport_state
self.spectrum_frequency_grid = spectrum_frequency_grid
self._montecarlo_virtual_luminosity = u.Quantity(
np.zeros_like(self.spectrum_frequency_grid.value), "erg / s"
) # should be init with v_packets_energy_hist
self._integrator = None
self.integrator_settings = None
self.integrator_settings = integrator_settings
self._spectrum_integrated = None

@property
Expand Down Expand Up @@ -60,7 +62,7 @@ def spectrum_virtual_packets(self):

@property
def spectrum_integrated(self):
if self._spectrum_integrated is None:
if self._spectrum_integrated is None and self.integrator is not None:
# This was changed from unpacking to specific attributes as compute
# is not used in calculate_spectrum
try:
Expand All @@ -83,13 +85,15 @@ def spectrum_integrated(self):
np.array([np.nan, np.nan]) * u.Hz,
np.array([np.nan]) * u.erg / u.s,
)
else:
self._spectrum_integrated = None
return self._spectrum_integrated

@property
def integrator(self):
if self._integrator is None:
warnings.warn(
"MontecarloTransport.integrator: "
"SpectrumSolver.integrator: "
"The FormalIntegrator is not yet available."
"Please run the montecarlo simulation at least once.",
UserWarning,
Expand Down Expand Up @@ -178,6 +182,27 @@ def calculate_reabsorbed_luminosity(
luminosity_wavelength_filter
].sum()

def solve(self, transport_state):
"""Solve the spectra
Parameters
----------
transport_state: MonteCarloTransportState
The transport state to be used to compute the spectra
Returns
-------
tuple(TARDISSpectrum)
Real, virtual and integrated spectra, if available
"""
self.transport_state = transport_state

return (
self.spectrum_real_packets,
self.spectrum_virtual_packets,
self.spectrum_integrated,
)

@classmethod
def from_config(cls, config):
spectrum_frequency_grid = quantity_linspace(
Expand All @@ -189,4 +214,5 @@ def from_config(cls, config):
return cls(
transport_state=None,
spectrum_frequency_grid=spectrum_frequency_grid,
integrator_settings=config.spectrum.integrated,
)
27 changes: 27 additions & 0 deletions tardis/spectrum/tests/test_spectrum_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,30 @@ def test_spectrum_real_packets_reabsorbed(self, simulation):
result,
luminosity,
)

def test_solve(self, simulation):
transport_state = simulation.transport.transport_state
spectrum_frequency_grid = simulation.transport.spectrum_frequency_grid

solver = SpectrumSolver(transport_state, spectrum_frequency_grid)
result_real, result_virtual, result_integrated = solver.solve(
transport_state
)
key_real = "simulation/spectrum_solver/spectrum_real_packets/luminosity"
expected_real = self.get_expected_data(key_real)

luminosity_real = u.Quantity(expected_real, "erg /s")

assert_quantity_allclose(
result_real.luminosity,
luminosity_real,
)

assert_quantity_allclose(
result_virtual.luminosity,
u.Quantity(
np.zeros_like(spectrum_frequency_grid.value)[:-1], "erg / s"
),
)

assert result_integrated is None
2 changes: 1 addition & 1 deletion tardis/transport/montecarlo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
MonteCarloConfiguration,
configuration_initialize,
)
from tardis.transport.montecarlo.estimators.dilute_blackbody_properties import (
from tardis.transport.montecarlo.estimators.mc_rad_field_solver import (
MCRadiationFieldPropertiesSolver,
)
from tardis.transport.montecarlo.estimators.radfield_mc_estimators import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,15 @@ def solve(
photo_ion_rate_coeff = self.calculate_photo_ionization_rate_coefficient(
mean_intensity_photo_ion_df
)
stimulated_recomb_rate_coeff = (
stimulated_recomb_rate_factor = (
self.calculate_stimulated_recomb_rate_factor(
mean_intensity_photo_ion_df,
photo_ion_boltzmann_factor,
)
)

return ContinuumProperties(
stimulated_recomb_rate_coeff, photo_ion_rate_coeff
stimulated_recomb_rate_factor, photo_ion_rate_coeff
)

def calculate_photo_ionization_rate_coefficient(
Expand Down Expand Up @@ -214,14 +214,13 @@ def calculate_mean_intensity_photo_ion_table(
self.atom_data.photoionization_data.nu.values
)
)
mean_intensity_df = pd.DataFrame(
return pd.DataFrame(
mean_intensity,
index=self.atom_data.photoionization_data.index,
columns=np.arange(
len(dilute_blackbody_radiationfield_state.temperature)
),
)
return mean_intensity_df


@dataclass
Expand Down
7 changes: 0 additions & 7 deletions tardis/transport/montecarlo/montecarlo_transport_state.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
import warnings

import numpy as np
from astropy import units as u

from tardis.io.util import HDFWriterMixin
from tardis.transport.montecarlo.estimators.dilute_blackbody_properties import (
MCRadiationFieldPropertiesSolver,
)
from tardis.spectrum.formal_integral import IntegrationError
from tardis.spectrum.spectrum import TARDISSpectrum


class MonteCarloTransportState(HDFWriterMixin):
Expand Down Expand Up @@ -63,7 +57,6 @@ def __init__(
rpacket_tracker=None,
vpacket_tracker=None,
):
self.time_explosion = time_explosion
self.packet_collection = packet_collection
self.radfield_mc_estimators = radfield_mc_estimators
self.enable_full_relativity = False
Expand Down
Loading

0 comments on commit 03e1aae

Please sign in to comment.