-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #12 from lincc-frameworks/more_complex_model
Add a spline-based model
- Loading branch information
Showing
3 changed files
with
133 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ dynamic = ["version"] | |
requires-python = ">=3.9" | ||
dependencies = [ | ||
"numpy", | ||
"scipy", | ||
] | ||
|
||
[project.urls] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
"""The SplineModel represents SED functions as a two dimensional grid | ||
of (time, wavelength) -> flux value that is interpolated using a 2D spline. | ||
It is adapted from sncosmo's TimeSeriesSource model: | ||
https://github.com/sncosmo/sncosmo/blob/v2.10.1/sncosmo/models.py | ||
""" | ||
|
||
from scipy.interpolate import RectBivariateSpline | ||
|
||
from tdastro.base_models import PhysicalModel | ||
|
||
|
||
class SplineModel(PhysicalModel): | ||
"""A time series model defined by sample points where the intermediate | ||
points are fit by a spline. Based on sncosmo's TimeSeriesSource: | ||
https://github.com/sncosmo/sncosmo/blob/v2.10.1/sncosmo/models.py | ||
Attributes | ||
---------- | ||
_times : `numpy.ndarray` | ||
A length T array containing the times at which the data was sampled. | ||
_wavelengths : `numpy.ndarray` | ||
A length W array containing the wavelengths at which the data was sampled. | ||
_spline : `RectBivariateSpline` | ||
The spline object for predicting the flux from a given (time, wavelength). | ||
name : `str` | ||
The name of the model being used. | ||
amplitude : `float` | ||
A unitless scaling parameter for the flux density values. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
times, | ||
wavelengths, | ||
flux, | ||
amplitude=1.0, | ||
time_degree=3, | ||
wave_degree=3, | ||
name=None, | ||
**kwargs, | ||
): | ||
"""Create the SplineModel from a grid of (timestep, wavelength, flux) points. | ||
Parameters | ||
---------- | ||
times : `numpy.ndarray` | ||
A length T array containing the times at which the data was sampled. | ||
wavelengths : `numpy.ndarray` | ||
A length W array containing the wavelengths at which the data was sampled. | ||
flux : `numpy.ndarray` | ||
A shape (T, W) matrix with flux values for each pair of time and wavelength. | ||
Fluxes provided in erg / s / cm^2 / Angstrom. | ||
amplitude : `float` | ||
A unitless scaling parameter for the flux density values. Default = 1.0 | ||
time_degree : `int` | ||
The polynomial degree to use in the time dimension. | ||
wave_degree : `int` | ||
The polynomial degree to use in the wavelength dimension. | ||
name : `str`, optional | ||
The name of the model. | ||
**kwargs : `dict`, optional | ||
Any additional keyword arguments. | ||
""" | ||
super().__init__(**kwargs) | ||
|
||
self.name = name | ||
self.amplitude = amplitude | ||
self._times = times | ||
self._wavelengths = wavelengths | ||
self._spline = RectBivariateSpline(times, wavelengths, flux, kx=time_degree, ky=wave_degree) | ||
|
||
def _evaluate(self, times, wavelengths, **kwargs): | ||
"""Draw effect-free observations for this object. | ||
Parameters | ||
---------- | ||
times : `numpy.ndarray` | ||
A length T array of timestamps. | ||
wavelengths : `numpy.ndarray`, optional | ||
A length N array of wavelengths. | ||
**kwargs : `dict`, optional | ||
Any additional keyword arguments. | ||
Returns | ||
------- | ||
flux_density : `numpy.ndarray` | ||
A length T x N matrix of SED values. | ||
""" | ||
return self.amplitude * self._spline(times, wavelengths, grid=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import numpy as np | ||
from tdastro.sources.spline_model import SplineModel | ||
|
||
|
||
def test_spline_model_flat() -> None: | ||
"""Test that we can sample and create a flat SplineModel object.""" | ||
times = np.linspace(1.0, 5.0, 20) | ||
wavelengths = np.linspace(100.0, 500.0, 25) | ||
fluxes = np.full((len(times), len(wavelengths)), 1.0) | ||
model = SplineModel(times, wavelengths, fluxes) | ||
|
||
test_times = np.array([0.0, 1.0, 2.0, 3.0, 10.0]) | ||
test_waves = np.array([0.0, 100.0, 200.0, 1000.0]) | ||
|
||
values = model.evaluate(test_times, test_waves) | ||
assert values.shape == (5, 4) | ||
expected = np.full_like(values, 1.0) | ||
np.testing.assert_array_almost_equal(values, expected) | ||
|
||
model2 = SplineModel(times, wavelengths, fluxes, amplitude=5.0) | ||
values2 = model2.evaluate(test_times, test_waves) | ||
assert values2.shape == (5, 4) | ||
expected2 = np.full_like(values2, 5.0) | ||
np.testing.assert_array_almost_equal(values2, expected2) | ||
|
||
|
||
def test_spline_model_interesting() -> None: | ||
"""Test that we can sample and create a flat SplineModel object.""" | ||
times = np.array([1.0, 2.0, 3.0]) | ||
wavelengths = np.array([10.0, 20.0, 30.0]) | ||
fluxes = np.array([[1.0, 5.0, 1.0], [5.0, 10.0, 5.0], [1.0, 5.0, 3.0]]) | ||
model = SplineModel(times, wavelengths, fluxes, time_degree=1, wave_degree=1) | ||
|
||
test_times = np.array([1.0, 1.5, 2.0, 3.0]) | ||
test_waves = np.array([10.0, 15.0, 20.0, 30.0]) | ||
values = model.evaluate(test_times, test_waves) | ||
assert values.shape == (4, 4) | ||
|
||
expected = np.array( | ||
[[1.0, 3.0, 5.0, 1.0], [3.0, 5.25, 7.5, 3.0], [5.0, 7.5, 10.0, 5.0], [1.0, 3.0, 5.0, 3.0]] | ||
) | ||
np.testing.assert_array_almost_equal(values, expected) |