Skip to content

Commit

Permalink
Merge pull request #12 from lincc-frameworks/more_complex_model
Browse files Browse the repository at this point in the history
Add a spline-based model
  • Loading branch information
jeremykubica authored Jun 27, 2024
2 parents a046276 + 06d5586 commit 5b0a26f
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 0 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dynamic = ["version"]
requires-python = ">=3.9"
dependencies = [
"numpy",
"scipy",
]

[project.urls]
Expand Down
90 changes: 90 additions & 0 deletions src/tdastro/sources/spline_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""The SplineModel represents SED functions as a two dimensional grid
of (time, wavelength) -> flux value that is interpolated using a 2D spline.
It is adapted from sncosmo's TimeSeriesSource model:
https://github.com/sncosmo/sncosmo/blob/v2.10.1/sncosmo/models.py
"""

from scipy.interpolate import RectBivariateSpline

from tdastro.base_models import PhysicalModel


class SplineModel(PhysicalModel):
"""A time series model defined by sample points where the intermediate
points are fit by a spline. Based on sncosmo's TimeSeriesSource:
https://github.com/sncosmo/sncosmo/blob/v2.10.1/sncosmo/models.py
Attributes
----------
_times : `numpy.ndarray`
A length T array containing the times at which the data was sampled.
_wavelengths : `numpy.ndarray`
A length W array containing the wavelengths at which the data was sampled.
_spline : `RectBivariateSpline`
The spline object for predicting the flux from a given (time, wavelength).
name : `str`
The name of the model being used.
amplitude : `float`
A unitless scaling parameter for the flux density values.
"""

def __init__(
self,
times,
wavelengths,
flux,
amplitude=1.0,
time_degree=3,
wave_degree=3,
name=None,
**kwargs,
):
"""Create the SplineModel from a grid of (timestep, wavelength, flux) points.
Parameters
----------
times : `numpy.ndarray`
A length T array containing the times at which the data was sampled.
wavelengths : `numpy.ndarray`
A length W array containing the wavelengths at which the data was sampled.
flux : `numpy.ndarray`
A shape (T, W) matrix with flux values for each pair of time and wavelength.
Fluxes provided in erg / s / cm^2 / Angstrom.
amplitude : `float`
A unitless scaling parameter for the flux density values. Default = 1.0
time_degree : `int`
The polynomial degree to use in the time dimension.
wave_degree : `int`
The polynomial degree to use in the wavelength dimension.
name : `str`, optional
The name of the model.
**kwargs : `dict`, optional
Any additional keyword arguments.
"""
super().__init__(**kwargs)

self.name = name
self.amplitude = amplitude
self._times = times
self._wavelengths = wavelengths
self._spline = RectBivariateSpline(times, wavelengths, flux, kx=time_degree, ky=wave_degree)

def _evaluate(self, times, wavelengths, **kwargs):
"""Draw effect-free observations for this object.
Parameters
----------
times : `numpy.ndarray`
A length T array of timestamps.
wavelengths : `numpy.ndarray`, optional
A length N array of wavelengths.
**kwargs : `dict`, optional
Any additional keyword arguments.
Returns
-------
flux_density : `numpy.ndarray`
A length T x N matrix of SED values.
"""
return self.amplitude * self._spline(times, wavelengths, grid=True)
42 changes: 42 additions & 0 deletions tests/tdastro/sources/test_spline_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import numpy as np
from tdastro.sources.spline_model import SplineModel


def test_spline_model_flat() -> None:
"""Test that we can sample and create a flat SplineModel object."""
times = np.linspace(1.0, 5.0, 20)
wavelengths = np.linspace(100.0, 500.0, 25)
fluxes = np.full((len(times), len(wavelengths)), 1.0)
model = SplineModel(times, wavelengths, fluxes)

test_times = np.array([0.0, 1.0, 2.0, 3.0, 10.0])
test_waves = np.array([0.0, 100.0, 200.0, 1000.0])

values = model.evaluate(test_times, test_waves)
assert values.shape == (5, 4)
expected = np.full_like(values, 1.0)
np.testing.assert_array_almost_equal(values, expected)

model2 = SplineModel(times, wavelengths, fluxes, amplitude=5.0)
values2 = model2.evaluate(test_times, test_waves)
assert values2.shape == (5, 4)
expected2 = np.full_like(values2, 5.0)
np.testing.assert_array_almost_equal(values2, expected2)


def test_spline_model_interesting() -> None:
"""Test that we can sample and create a flat SplineModel object."""
times = np.array([1.0, 2.0, 3.0])
wavelengths = np.array([10.0, 20.0, 30.0])
fluxes = np.array([[1.0, 5.0, 1.0], [5.0, 10.0, 5.0], [1.0, 5.0, 3.0]])
model = SplineModel(times, wavelengths, fluxes, time_degree=1, wave_degree=1)

test_times = np.array([1.0, 1.5, 2.0, 3.0])
test_waves = np.array([10.0, 15.0, 20.0, 30.0])
values = model.evaluate(test_times, test_waves)
assert values.shape == (4, 4)

expected = np.array(
[[1.0, 3.0, 5.0, 1.0], [3.0, 5.25, 7.5, 3.0], [5.0, 7.5, 10.0, 5.0], [1.0, 3.0, 5.0, 3.0]]
)
np.testing.assert_array_almost_equal(values, expected)

0 comments on commit 5b0a26f

Please sign in to comment.