Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT PR] Possible dust map implementation #180

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -149,5 +149,6 @@ _html/
# Project initialization script
.initialize_new_project.sh

# Default cached location for passband tables
# Default cached location for various data.
src/tdastro/astro_utils/passbands/*
/data_cache/*
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ dynamic = ["version"]
requires-python = ">=3.9"
dependencies = [
"astropy",
"dust_extinction",
"dustmaps",
"jax",
"numpy",
"pandas",
Expand Down Expand Up @@ -108,6 +110,7 @@ ignore = [
"E721", # Allow direct type comparison
"N803", # Allow arguments to start with a capital letter
"N806", # Allow variables to use non-lowercase letter
"SIM108" # Allow if-else blocks that could be ternary expressions
]

[tool.coverage.run]
Expand Down
1 change: 1 addition & 0 deletions src/tdastro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
_TDASTRO_BASE_DIR = Path(__file__).parent.parent.parent
_TDASTRO_TEST_DIR = _TDASTRO_BASE_DIR / "tests" / "tdastro"
_TDASTRO_TEST_DATA_DIR = _TDASTRO_TEST_DIR / "data"
_TDASTRO_CACHE_DATA_DIR = _TDASTRO_BASE_DIR / "data_cache"
168 changes: 168 additions & 0 deletions src/tdastro/astro_utils/dust_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
"""A wrapper for querying dust maps and then applying the corresponding
extinction functions.

This module is a wrapper for the following libraries:
* dustmaps:
Green 2018, JOSS, 3(26), 695.
https://github.com/gregreen/dustmaps
* dust_extinction:
Gordon 2024, JOSS, 9(100), 7023.
https://github.com/karllark/dust_extinction
"""

import importlib
from pkgutil import iter_modules

import astropy.units as u
import dust_extinction
from astropy.coordinates import SkyCoord
from dustmaps.config import config as dm_config

from tdastro import _TDASTRO_CACHE_DATA_DIR


class DustExtinctionEffect:
"""A general dust extinction model.

Attributes
----------
dust_map : dustmaps.DustMap or str
The dust map or its name. Since different dustmap's query function
may produce different outputs, you should include the corresponding
ebv_func to transform the result into ebv if needed.
extinction_model : function or str
The extinction object from the dust_extinction library or its name.
If a string is provided, the code will find a matching extinction
function in the dust_extinction package and use that.
ebv_func : function
A function to translate the result of the dustmap query into an ebv.
**kwargs : `dict`, optional
Any additional keyword arguments.
"""

def __init__(self, dust_map, extinction_model, ebv_func=None, **kwargs):
self.ebv_func = ebv_func

if isinstance(dust_map, str):
# Initially we only support loading the SFD dustmap by string.
# But we can expand this as needed.
dust_map = dust_map.lower()
if dust_map == "sfd":
self._load_sfd_dustmap()
else:
raise ValueError("Unsupported load from dustmap {dust_map}")
else:
# If given a dustmap, use that directly.
self.dust_map = dust_map

if isinstance(extinction_model, str):
extinction_model = DustExtinctionEffect.load_extinction_model(extinction_model, **kwargs)
self.extinction_model = extinction_model

def _load_sfd_dustmap(self):
"""Load the SFD dustmap, downloading it if needed.

Uses data from:
1. Schlegel, Finkbeiner, and Davis
The Astrophysical Journal, Volume 500, Issue 2, pp. 525-553.
https://ui.adsabs.harvard.edu/abs/1998ApJ...500..525S/abstract

2. Schlegel and Finkbeiner
The Astrophysical Journal, Volume 737, Issue 2, article id. 103, 13 pp. (2011).
https://ui.adsabs.harvard.edu/abs/2011ApJ...737..103S/abstract

Returns
-------
sfd_query : SFDQuery
The "query" object for the requested dustmap.
"""
import dustmaps.sfd

# Download the dustmap if needed.
dm_config["data_dir"] = str(_TDASTRO_CACHE_DATA_DIR / "dustmaps")
dustmaps.sfd.fetch()

# Load the dustmap.
from dustmaps.sfd import SFDQuery

self.dust_map = SFDQuery()

# Add the correction function.
def _sfd_scale_ebv(input, **kwargs):
"""Scale the result of the SFD query."""

self.ebv_func = _sfd_scale_ebv

@staticmethod
def load_extinction_model(name, **kwargs):
"""Load the extinction model.

Parameters
----------
name : str
The name of the extinction model to use.
**kwargs : dict
Any additional keyword arguments needed to create that argument.

Returns
-------
ext_obj
A extinction object.
"""
for submodule in iter_modules(dust_extinction.__path__):
ext_module = importlib.import_module(f"dust_extinction.{submodule.name}")
if ext_module is not None and name in dir(ext_module):
ext_class = getattr(ext_module, name)
return ext_class(**kwargs)
raise KeyError(f"Invalid dust extinction model '{name}'")

def apply(self, flux_density, wavelengths, ebv=None, ra=None, dec=None, dist=None, **kwargs):
"""Apply the effect to observations (flux_density values). The user can either
provide a ebv value directly or (RA, dec) and distance information that can
be used in the dustmap query.

Parameters
----------
flux_density : numpy.ndarray
An array of flux density values (in nJy).
wavelengths : numpy.ndarray, optional
An array of wavelengths (in angstroms).
ebv : float or np.array
A given ebv value or array of values. If present then this is used
instead of looking it up in the dust map.
ra : float, optional
The object's right ascension (in degrees).
dec : float, optional
The object's declination (in degrees).
dist : float, optional
The object's distance (in parsecs).
**kwargs : `dict`, optional
Any additional keyword arguments.

Returns
-------
flux_density : numpy.ndarray
The results (in nJy).
"""
if ebv is None:
if self.dust_map is None:
raise ValueError("If ebv=None then a dust map must be provided.")
if ra is None or dec is None:
raise ValueError("If ebv=None then ra, dec must be provided for a lookup.")

# Get the extinction value at the object's location.
if dist is not None:
coord = SkyCoord(ra, dec, dist, frame="icrs", unit="deg")
else:
coord = SkyCoord(ra, dec, frame="icrs", unit="deg")
dustmap_value = self.dust_map.query(coord)

# Perform any corrections needed for this dust map.
if self.ebv_func is not None:
ebv = self.ebv_func(dustmap_value, **kwargs)
else:
ebv = dustmap_value

print(f"Using ebv={ebv}")

return flux_density * self.extinction_model.extinguish(wavelengths * u.angstrom, Ebv=ebv)
jeremykubica marked this conversation as resolved.
Show resolved Hide resolved
111 changes: 111 additions & 0 deletions tests/tdastro/astro_utils/test_dust_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import numpy as np
from dust_extinction.parameter_averages import CCM89
from dustmaps.map_base import DustMap
from tdastro.astro_utils.dust_map import DustExtinctionEffect
from tdastro.sources.static_source import StaticSource


class ConstantDustMap(DustMap):
"""A DustMap with a constant value. Used for testing.

Attributes
----------
ebv : `float`
The DustMap's ebv value at all points in the sky.
"""

def __init__(self, ebv):
self.ebv = ebv

def query(self, coords):
"""Returns reddening at the requested coordinates.

Parameters
----------
coords : `astropy.coordinates.SkyCoord`
The coordinates of the query or queries.

Returns
-------
Reddening : `float` or `numpy.ndarray`
The result of the query.
"""
if coords.isscalar:
return self.ebv
return np.full((len(coords)), self.ebv)


class TestExtinction:
"""An extinction function that computes the scaling factor
as a constant times the ebv.

Attributes
----------
scale : `float`
The extinction function's multiplicative scaling.
"""

def __init__(self, scale):
self.scale = scale

def extinguish(self, wavelengths, Ebv=1.0):
"""The extinguish function

Parameters
----------
wavelengths : numpy.ndarray
The array of wavelengths
Ebv : float
The Ebv to use in the scaling.
"""
return Ebv * self.scale


def test_load_extinction_model():
"""Load an extinction model by string."""
g23_model = DustExtinctionEffect.load_extinction_model("G23", Rv=3.1)
assert g23_model is not None
assert hasattr(g23_model, "extinguish")

# Load through the DustExtinctionEffect constructor.
const_map = ConstantDustMap(0.5)
dust_effect = DustExtinctionEffect(const_map, "CCM89", Rv=3.1)
assert dust_effect.extinction_model is not None
assert hasattr(dust_effect.extinction_model, "extinguish")


def test_constant_dust_extinction():
"""Test that we can create and sample a DustExtinctionEffect object."""
times = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
wavelengths = np.array([7000.0, 5200.0, 4800.0]) # Red, green, blue

# Create a model without any dust.
model_clean = StaticSource(brightness=100.0, ra=0.0, dec=40.0)
state = model_clean.sample_parameters()
fluxes_clean = model_clean.evaluate(times, wavelengths, state)
assert fluxes_clean.shape == (5, 3)
assert np.all(fluxes_clean == 100.0)

# Create DustExtinctionEffect effects with constant ebvs and constant
# multiplicative extinction functions.
dust_effect = DustExtinctionEffect(ConstantDustMap(0.5), TestExtinction(0.1))
fluxes = dust_effect.apply(fluxes_clean, wavelengths, ra=0.0, dec=40.0, dist=100.0)
assert fluxes.shape == (5, 3)
assert np.all(fluxes == 5.0)

dust_effect = DustExtinctionEffect(ConstantDustMap(0.5), TestExtinction(0.3))
fluxes = dust_effect.apply(fluxes_clean, wavelengths, ra=0.0, dec=40.0, dist=100.0)
assert fluxes.shape == (5, 3)
assert np.all(fluxes == 15.0)

# Use a manual ebv, which overrides the dustmap.
dust_effect = DustExtinctionEffect(ConstantDustMap(0.5), TestExtinction(0.1))
fluxes = dust_effect.apply(fluxes_clean, wavelengths, ebv=1.0, ra=0.0, dec=40.0, dist=100.0)
assert fluxes.shape == (5, 3)
assert np.all(fluxes == 10.0)

# Create a model with ccm89 extinction at r_v = 3.1.
dust_effect = DustExtinctionEffect(ConstantDustMap(0.5), CCM89(Rv=3.1))
fluxes_ccm98 = dust_effect.apply(fluxes_clean, wavelengths, ra=0.0, dec=40.0, dist=1.0)
assert fluxes_ccm98.shape == (5, 3)
assert np.all(fluxes_ccm98 < 100.0)
Loading