Skip to content

Commit

Permalink
TARDIS Full Formal Integral Regression Tests (#2805)
Browse files Browse the repository at this point in the history
* Initial commit

* Rest works

* Fix for spectrum integrated issue

* Reformat using Ruff and Black

* Remove tardis_ref_data fixture

* Reformat using Ruff and Black
  • Loading branch information
atharva-2001 authored Aug 22, 2024
1 parent 1358da9 commit b6f1185
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 74 deletions.
13 changes: 1 addition & 12 deletions tardis/conftest.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import os
from pathlib import Path

import pandas as pd
import pytest
from astropy.version import version as astropy_version

from tardis import run_tardis
from tardis.io.configuration.config_reader import Configuration
from tardis.io.util import YAMLLoader, yaml_load_file
from tardis.simulation import Simulation
from tardis.tests.fixtures.atom_data import *
from tardis.tests.fixtures.regression_data import regression_data
from tardis import run_tardis

# ensuring that regression_data is not removed by ruff
assert regression_data is not None
Expand Down Expand Up @@ -166,16 +165,6 @@ def tardis_snapshot_path(request):
)


@pytest.yield_fixture(scope="session")
def tardis_ref_data(tardis_ref_path, generate_reference):
if generate_reference:
mode = "w"
else:
mode = "r"
with pd.HDFStore(tardis_ref_path / "unit_test_data.h5", mode=mode) as store:
yield store


@pytest.fixture(scope="function")
def tardis_config_verysimple():
return yaml_load_file(
Expand Down
10 changes: 5 additions & 5 deletions tardis/gui/tests/test_gui.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import os

import pytest

from tardis.io.configuration.config_reader import Configuration
from tardis.simulation import Simulation
import astropy.units as u

if "QT_API" in os.environ:
from PyQt5 import QtWidgets
from tardis.gui.widgets import Tardis

from tardis.gui.datahandler import SimpleTableModel
from tardis.gui.widgets import Tardis


@pytest.fixture(scope="module")
Expand All @@ -18,9 +20,7 @@ def config():


@pytest.fixture(scope="module")
def simulation_one_loop(
atomic_data_fname, config, tardis_ref_data, generate_reference
):
def simulation_one_loop(atomic_data_fname, config):
config.atom_data = atomic_data_fname
config.montecarlo.iterations = 2
config.montecarlo.no_of_packets = int(4e4)
Expand Down
4 changes: 1 addition & 3 deletions tardis/spectrum/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,10 @@ def spectrum_integrated(self):
"This RETURNS AN EMPTY SPECTRUM!",
UserWarning,
)
return TARDISSpectrum(
self._spectrum_integrated = TARDISSpectrum(
np.array([np.nan, np.nan]) * u.Hz,
np.array([np.nan]) * u.erg / u.s,
)
else:
self._spectrum_integrated = None
return self._spectrum_integrated

@property
Expand Down
98 changes: 44 additions & 54 deletions tardis/tests/test_tardis_full_formal_integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from astropy.tests.helper import assert_quantity_allclose

from tardis.io.configuration.config_reader import Configuration
from tardis.io.util import HDFWriterMixin
from tardis.simulation.base import Simulation
from tardis.tests.fixtures.regression_data import RegressionData

config_line_modes = ["downbranch", "macroatom"]
interpolate_shells = [-1, 30]
Expand All @@ -17,7 +19,6 @@ def base_config(request, example_configuration_dir: Path):
config = Configuration.from_yaml(
example_configuration_dir / "tardis_configv1_verysimple.yml"
)

config["plasma"]["line_interaction_type"] = request.param
config["montecarlo"]["no_of_packets"] = 4.0e4
config["montecarlo"]["last_no_of_packets"] = 1.0e5
Expand All @@ -35,6 +36,14 @@ def config(base_config, request):
return base_config


class SimulationContainer(HDFWriterMixin):
hdf_properties = ["spectrum_solver", "transport"]

def __init__(self, simulation):
self.spectrum_solver = simulation.spectrum_solver
self.transport = simulation.transport


class TestTransportSimpleFormalIntegral:
"""
Very simple run with the formal integral spectral synthesis method
Expand All @@ -43,9 +52,7 @@ class TestTransportSimpleFormalIntegral:
_name = "test_transport_simple_integral"

@pytest.fixture(scope="class")
def simulation(
self, config, atomic_data_fname, tardis_ref_data, generate_reference
):
def simulation(self, config, atomic_data_fname):
config.atom_data = atomic_data_fname

self.name = self._name + f"_{config.plasma.line_interaction_type:s}"
Expand All @@ -55,54 +62,37 @@ def simulation(
simulation = Simulation.from_config(config)
simulation.run_convergence()
simulation.run_final()

if not generate_reference:
return simulation
else:
simulation.spectrum_solver.hdf_properties = [
"spectrum_real_packets",
"spectrum_integrated",
]
simulation.spectrum_solver.to_hdf(
tardis_ref_data, "", self.name, overwrite=True
)
simulation.transport.hdf_properties = ["transport_state"]
simulation.transport.to_hdf(
tardis_ref_data, "", self.name, overwrite=True
)
pytest.skip("Reference data was generated during this run.")

@pytest.fixture(scope="class")
def refdata(self, tardis_ref_data):
def get_ref_data(key):
return tardis_ref_data[f"{self.name}/{key}"]

return get_ref_data

def test_j_blue_estimators(self, simulation, refdata):
j_blue_estimator = refdata("transport_state/j_blue_estimator").values

npt.assert_allclose(
simulation.transport.transport_state.radfield_mc_estimators.j_blue_estimator,
j_blue_estimator,
)

def test_spectrum(self, simulation, refdata):
luminosity = u.Quantity(
refdata("spectrum_real_packets/luminosity"), "erg /s"
)

assert_quantity_allclose(
simulation.spectrum_solver.spectrum_real_packets.luminosity,
luminosity,
)

def test_spectrum_integrated(self, simulation, refdata):
luminosity = u.Quantity(
refdata("spectrum_integrated/luminosity"), "erg /s"
)

assert_quantity_allclose(
simulation.spectrum_solver.spectrum_integrated.luminosity,
luminosity,
simulation.spectrum_solver.hdf_properties = [
"spectrum_real_packets",
"spectrum_integrated",
]
simulation.transport.hdf_properties = ["transport_state"]

return simulation

def test_simulation(self, simulation, request):
regression_data = RegressionData(request)
container = SimulationContainer(simulation)
regression_data.sync_hdf_store(container)

def test_j_blue_estimators(self, simulation, request):
regression_data = RegressionData(request)
j_blue_estimator = (
simulation.transport.transport_state.radfield_mc_estimators.j_blue_estimator
)
expected = regression_data.sync_ndarray(j_blue_estimator)
npt.assert_allclose(j_blue_estimator, expected)

def test_spectrum(self, simulation, request):
regression_data = RegressionData(request)
luminosity = simulation.spectrum_solver.spectrum_real_packets.luminosity
expected = regression_data.sync_ndarray(luminosity.cgs.value)
expected = u.Quantity(expected, "erg /s")
assert_quantity_allclose(luminosity, expected)

def test_spectrum_integrated(self, simulation, request):
regression_data = RegressionData(request)
luminosity = simulation.spectrum_solver.spectrum_integrated.luminosity
expected = regression_data.sync_ndarray(luminosity.cgs.value)
expected = u.Quantity(expected, "erg /s")
assert_quantity_allclose(luminosity, expected)

0 comments on commit b6f1185

Please sign in to comment.