-
-
Notifications
You must be signed in to change notification settings - Fork 405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Weighted Packet Sampler #2718
Merged
Merged
Weighted Packet Sampler #2718
Changes from 13 commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
9d8edab
Added a class for a weighted packet sampler which samples packets uni…
Rodot- bb75e0f
Changed the hdf_name for the weighted source
Rodot- b4c4b1e
Added fixture for the weighted packet source and a integration test a…
Rodot- 8be6352
Testing readding a fixture in the weighted sampler tests
Rodot- 42175e3
moved the montecarlo config fixture to conftest.py
Rodot- 7ab8cb1
removed legacy mode from the simple weighted packet source fixture
Rodot- 397c5ce
updated path the the montecarlo main loop test data
Rodot- 4bd83cb
updated path the the montecarlo main loop test data part 2
Rodot- d0a23d9
fixed typo
Rodot- ebf3d4a
fixed typo 2
Rodot- cb2ad20
ran ruff
Rodot- d7a6ad3
ran black
Rodot- f61c6c8
Added the unit test for the blackbodyweightedsource
Rodot- 2b85f55
Fixed a typo in one of the fixtures
Rodot- 802fa15
Updated the documentation and cleaned up the implimentation of the we…
Rodot- 48c03be
fixed weighted packet source base clas
Rodot- fc5db3a
Use the rng for sampling rather than np.random as otherwise there are…
Rodot- 98a3fe0
Fixed issues with the rng and made the unit tests avoid legacy mode
Rodot- 40e14c2
ran black
Rodot- File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
61 changes: 61 additions & 0 deletions
61
tardis/transport/montecarlo/tests/test_weighted_packet_source.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import os | ||
|
||
from astropy import units as u | ||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
from numpy.testing import assert_allclose | ||
|
||
from tardis.transport.montecarlo.weighted_packet_source import ( | ||
BlackBodyWeightedSource, | ||
) | ||
from tardis.tests.fixtures.regression_data import RegressionData | ||
|
||
|
||
class TestBlackBodyWeightedSource: | ||
@pytest.fixture(scope="class") | ||
def blackbodyweightedource(self, request): | ||
""" | ||
Create blackbodyweightedsource instance. | ||
|
||
Yields | ||
------- | ||
tardis.transport.montecarlo.packet_source.blackbodyweightedsource | ||
""" | ||
cls = type(self) | ||
bb = BlackBodyWeightedSource( | ||
radius=123, | ||
temperature=10000 * u.K, | ||
base_seed=1963, | ||
legacy_second_seed=2508, | ||
legacy_mode_enabled=True, | ||
) | ||
yield bb | ||
|
||
def test_bb_nus(self, regression_data, blackbodyweightedsource): | ||
actual_nus = blackbodyweightedsource.create_packet_nus(100).value | ||
expected_nus = regression_data.sync_ndarray(actual_nus) | ||
assert_allclose(actual_nus, expected_nus) | ||
|
||
def test_bb_mus(self, regression_data, blackbodyweightedsource): | ||
actual_mus = blackbodyweightedsource.create_packet_mus(100) | ||
expected_mus = regression_data.sync_ndarray(actual_mus) | ||
assert_allclose(actual_mus, expected_mus) | ||
|
||
def test_bb_energies(self, regression_data, blackbodyweightedsource): | ||
actual_unif_energies = blackbodyweightedsource.create_packet_energies( | ||
100 | ||
).value | ||
expected_unif_energies = regression_data.sync_ndarray( | ||
actual_unif_energies | ||
) | ||
assert_allclose(actual_unif_energies, expected_unif_energies) | ||
|
||
def test_bb_attributes(self, regression_data, blackbodyweightedsource): | ||
actual_bb = blackbodyweightedsource | ||
expected_bb = regression_data.sync_hdf_store(actual_bb)[ | ||
"/black_body_weighted_source/scalars" | ||
] | ||
assert_allclose(expected_bb.base_seed, actual_bb.base_seed) | ||
assert_allclose(expected_bb.temperature, actual_bb.temperature.value) | ||
assert_allclose(expected_bb.radius, actual_bb.radius) | ||
66 changes: 66 additions & 0 deletions
66
tardis/transport/montecarlo/tests/test_weighted_packet_source_integration.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from copy import deepcopy | ||
|
||
import numpy.testing as npt | ||
import pandas as pd | ||
|
||
from tardis.simulation import Simulation | ||
|
||
|
||
def test_montecarlo_main_loop_weighted( | ||
montecarlo_main_loop_config, | ||
regression_data, | ||
atomic_dataset, | ||
simple_weighted_packet_source, | ||
): | ||
atomic_dataset = deepcopy(atomic_dataset) | ||
montecarlo_main_loop_simulation_weighted = Simulation.from_config( | ||
montecarlo_main_loop_config, | ||
atom_data=atomic_dataset, | ||
virtual_packet_logging=False, | ||
legacy_mode_enabled=True, | ||
) | ||
montecarlo_main_loop_simulation_weighted.packet_source = ( | ||
simple_weighted_packet_source | ||
) | ||
montecarlo_main_loop_simulation_weighted.run_convergence() | ||
montecarlo_main_loop_simulation_weighted.run_final() | ||
|
||
# Get the montecarlo simple regression data | ||
regression_data_dir = ( | ||
regression_data.absolute_regression_data_dir.absolute().parents[0] | ||
/ "test_montecarlo_main_loop/test_montecarlo_main_loop.h5" | ||
) | ||
expected_hdf_store = pd.HDFStore(regression_data_dir, mode="r") | ||
|
||
# Load compare data from refdata | ||
|
||
expected_nu = expected_hdf_store[ | ||
"/simulation/transport/transport_state/output_nu" | ||
] | ||
expected_energy = expected_hdf_store[ | ||
"/simulation/transport/transport_state/output_energy" | ||
] | ||
expected_nu_bar_estimator = expected_hdf_store[ | ||
"/simulation/transport/transport_state/nu_bar_estimator" | ||
] | ||
expected_j_estimator = expected_hdf_store[ | ||
"/simulation/transport/transport_state/j_estimator" | ||
] | ||
expected_hdf_store.close() | ||
transport_state = ( | ||
montecarlo_main_loop_simulation_weighted.transport.transport_state | ||
) | ||
actual_energy = transport_state.packet_collection.output_energies | ||
actual_nu = transport_state.packet_collection.output_nus | ||
actual_nu_bar_estimator = ( | ||
transport_state.radfield_mc_estimators.nu_bar_estimator | ||
) | ||
actual_j_estimator = transport_state.radfield_mc_estimators.j_estimator | ||
|
||
# Compare | ||
npt.assert_allclose( | ||
actual_nu_bar_estimator, expected_nu_bar_estimator, rtol=1e-2 | ||
) | ||
npt.assert_allclose(actual_j_estimator, expected_j_estimator, rtol=1e-2) | ||
npt.assert_allclose(actual_energy, expected_energy, rtol=1e-2) | ||
npt.assert_allclose(actual_nu, expected_nu, rtol=1e-2) | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
import numexpr as ne | ||
import numpy as np | ||
from astropy import constants as const | ||
from astropy import units as u | ||
|
||
from tardis.transport.montecarlo.packet_source import ( | ||
BasePacketSource, | ||
HDFWriterMixin, | ||
) | ||
from tardis.util.base import intensity_black_body | ||
|
||
|
||
class BlackBodyWeightedSource(BasePacketSource, HDFWriterMixin): | ||
""" | ||
Simple packet source that generates Blackbody packets for the Montecarlo | ||
part. | ||
|
||
Parameters | ||
---------- | ||
radius : astropy.units.Quantity | ||
Initial packet radius | ||
temperature : astropy.units.Quantity | ||
Absolute Temperature. | ||
base_seed : int | ||
Base Seed for random number generator | ||
legacy_secondary_seed : int | ||
Secondary seed for global numpy rng (Deprecated: Legacy reasons only) | ||
""" | ||
|
||
hdf_properties = ["radius", "temperature", "base_seed"] | ||
hdf_name = "black_body_weighted_source" | ||
|
||
@classmethod | ||
def from_simulation_state(cls, simulation_state, *args, **kwargs): | ||
return cls( | ||
simulation_state.r_inner[0], | ||
simulation_state.t_inner, | ||
*args, | ||
**kwargs, | ||
) | ||
|
||
def __init__(self, radius=None, temperature=None, **kwargs): | ||
self.radius = radius | ||
self.temperature = temperature | ||
super().__init__(**kwargs) | ||
|
||
def create_packets(self, no_of_packets, *args, **kwargs): | ||
if self.radius is None or self.temperature is None: | ||
raise ValueError("Black body Radius or Temperature isn't set") | ||
return super().create_packets(no_of_packets, *args, **kwargs) | ||
|
||
def create_packet_radii(self, no_of_packets): | ||
""" | ||
Create packet radii | ||
|
||
Parameters | ||
---------- | ||
no_of_packets : int | ||
number of packets to be created | ||
|
||
Returns | ||
------- | ||
Radii for packets | ||
numpy.ndarray | ||
""" | ||
return np.ones(no_of_packets) * self.radius.cgs | ||
|
||
def create_packet_nus(self, no_of_packets, l_samples=1000): | ||
Rodot- marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Create packet :math:`\\nu` distributed using the algorithm described in | ||
Bjorkman & Wood 2001 (page 4) which references | ||
Carter & Cashwell 1975: | ||
First, generate a uniform random number, :math:`\\xi_0 \\in [0, 1]` and | ||
determine the minimum value of | ||
:math:`l, l_{\\rm min}`, that satisfies the condition | ||
.. math:: | ||
\\sum_{i=1}^{l} i^{-4} \\ge {{\\pi^4}\\over{90}} m_0 \\;. | ||
Next obtain four additional uniform random numbers (in the range 0 | ||
to 1) :math:`\\xi_1, \\xi_2, \\xi_3, {\\rm and } \\xi_4`. | ||
Finally, the packet frequency is given by | ||
.. math:: | ||
x = -\\ln{(\\xi_1\\xi_2\\xi_3\\xi_4)}/l_{\\rm min}\\;. | ||
where :math:`x=h\\nu/kT` | ||
|
||
Parameters | ||
---------- | ||
no_of_packets : int | ||
l_samples : int | ||
number of l_samples needed in the algorithm | ||
|
||
Returns | ||
------- | ||
array of frequencies | ||
numpy.ndarray | ||
""" | ||
l_array = np.cumsum(np.arange(1, l_samples, dtype=np.float64) ** -4) | ||
l_coef = np.pi**4 / 90.0 | ||
|
||
# For testing purposes | ||
if self.legacy_mode_enabled: | ||
xis = np.random.random((5, no_of_packets)) | ||
else: | ||
xis = self.rng.random((5, no_of_packets)) | ||
|
||
l = l_array.searchsorted(xis[0] * l_coef) + 1.0 | ||
xis_prod = np.prod(xis[1:], 0) | ||
x = ne.evaluate("-log(xis_prod)/l") | ||
|
||
nus = (x * (const.k_B * self.temperature) / const.h).cgs | ||
|
||
nu_min = nus.min() | ||
nu_max = nus.max() | ||
|
||
self.nus = ( | ||
np.random.uniform(nu_min.cgs.value, nu_max.cgs.value, no_of_packets) | ||
* nus.unit | ||
) | ||
|
||
return self.nus | ||
|
||
def create_packet_mus(self, no_of_packets): | ||
""" | ||
Create zero-limb-darkening packet :math:`\\mu` distributed | ||
according to :math:`\\mu=\\sqrt{z}, z \\isin [0, 1]` | ||
|
||
Parameters | ||
---------- | ||
no_of_packets : int | ||
number of packets to be created | ||
|
||
Returns | ||
------- | ||
Directions for packets | ||
numpy.ndarray | ||
""" | ||
# For testing purposes | ||
if self.legacy_mode_enabled: | ||
return np.sqrt(np.random.random(no_of_packets)) | ||
else: | ||
return np.sqrt(self.rng.random(no_of_packets)) | ||
|
||
def create_packet_energies(self, no_of_packets): | ||
""" | ||
Rodot- marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Uniformly distribute energy in arbitrary units where the ensemble of | ||
packets has energy of 1. | ||
|
||
Parameters | ||
---------- | ||
no_of_packets : int | ||
number of packets | ||
|
||
Returns | ||
------- | ||
energies for packets | ||
numpy.ndarray | ||
""" | ||
try: | ||
self.nus | ||
except AttributeError: | ||
self.nus = self.create_packet_nus(no_of_packets) | ||
|
||
intensity = intensity_black_body(self.nus.cgs.value, self.temperature) | ||
return intensity / intensity.sum() * u.erg | ||
|
||
def set_temperature_from_luminosity(self, luminosity: u.Quantity): | ||
""" | ||
Rodot- marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Set blackbody packet source temperature from luminosity | ||
|
||
Parameters | ||
---------- | ||
luminosity : u.Quantity | ||
|
||
""" | ||
self.temperature = ( | ||
(luminosity / (4 * np.pi * self.radius**2 * const.sigma_sb)) | ||
** 0.25 | ||
).to("K") |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are you using rtol upto two decimal places?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because it is checking consistency with the default sampler rather than with itself. It is to make sure we're getting approximately the same solution
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am happy to approve if the tests passes.