diff --git a/.gitignore b/.gitignore index 1e20c19..5212a88 100644 --- a/.gitignore +++ b/.gitignore @@ -149,5 +149,6 @@ _html/ # Project initialization script .initialize_new_project.sh -# Default cached location for passband tables +# Default cached location for various data. src/tdastro/astro_utils/passbands/* +/data_cache/* diff --git a/pyproject.toml b/pyproject.toml index 4be8788..e5b86cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,8 @@ dynamic = ["version"] requires-python = ">=3.9" dependencies = [ "astropy", + "dust_extinction", + "dustmaps", "jax", "numpy", "pandas", @@ -108,6 +110,7 @@ ignore = [ "E721", # Allow direct type comparison "N803", # Allow arguments to start with a capital letter "N806", # Allow variables to use non-lowercase letter + "SIM108" # Allow if-else blocks that could be ternary expressions ] [tool.coverage.run] diff --git a/src/tdastro/__init__.py b/src/tdastro/__init__.py index 4f5f958..8d8fbbd 100644 --- a/src/tdastro/__init__.py +++ b/src/tdastro/__init__.py @@ -4,3 +4,4 @@ _TDASTRO_BASE_DIR = Path(__file__).parent.parent.parent _TDASTRO_TEST_DIR = _TDASTRO_BASE_DIR / "tests" / "tdastro" _TDASTRO_TEST_DATA_DIR = _TDASTRO_TEST_DIR / "data" +_TDASTRO_CACHE_DATA_DIR = _TDASTRO_BASE_DIR / "data_cache" diff --git a/src/tdastro/astro_utils/dust_map.py b/src/tdastro/astro_utils/dust_map.py new file mode 100644 index 0000000..1fea187 --- /dev/null +++ b/src/tdastro/astro_utils/dust_map.py @@ -0,0 +1,168 @@ +"""A wrapper for querying dust maps and then applying the corresponding +extinction functions. + +This module is a wrapper for the following libraries: + * dustmaps: + Green 2018, JOSS, 3(26), 695. + https://github.com/gregreen/dustmaps + * dust_extinction: + Gordon 2024, JOSS, 9(100), 7023. + https://github.com/karllark/dust_extinction +""" + +import importlib +from pkgutil import iter_modules + +import astropy.units as u +import dust_extinction +from astropy.coordinates import SkyCoord +from dustmaps.config import config as dm_config + +from tdastro import _TDASTRO_CACHE_DATA_DIR + + +class DustExtinctionEffect: + """A general dust extinction model. + + Attributes + ---------- + dust_map : dustmaps.DustMap or str + The dust map or its name. Since different dustmap's query function + may produce different outputs, you should include the corresponding + ebv_func to transform the result into ebv if needed. + extinction_model : function or str + The extinction object from the dust_extinction library or its name. + If a string is provided, the code will find a matching extinction + function in the dust_extinction package and use that. + ebv_func : function + A function to translate the result of the dustmap query into an ebv. + **kwargs : `dict`, optional + Any additional keyword arguments. + """ + + def __init__(self, dust_map, extinction_model, ebv_func=None, **kwargs): + self.ebv_func = ebv_func + + if isinstance(dust_map, str): + # Initially we only support loading the SFD dustmap by string. + # But we can expand this as needed. + dust_map = dust_map.lower() + if dust_map == "sfd": + self._load_sfd_dustmap() + else: + raise ValueError("Unsupported load from dustmap {dust_map}") + else: + # If given a dustmap, use that directly. + self.dust_map = dust_map + + if isinstance(extinction_model, str): + extinction_model = DustExtinctionEffect.load_extinction_model(extinction_model, **kwargs) + self.extinction_model = extinction_model + + def _load_sfd_dustmap(self): + """Load the SFD dustmap, downloading it if needed. + + Uses data from: + 1. Schlegel, Finkbeiner, and Davis + The Astrophysical Journal, Volume 500, Issue 2, pp. 525-553. + https://ui.adsabs.harvard.edu/abs/1998ApJ...500..525S/abstract + + 2. Schlegel and Finkbeiner + The Astrophysical Journal, Volume 737, Issue 2, article id. 103, 13 pp. (2011). + https://ui.adsabs.harvard.edu/abs/2011ApJ...737..103S/abstract + + Returns + ------- + sfd_query : SFDQuery + The "query" object for the requested dustmap. + """ + import dustmaps.sfd + + # Download the dustmap if needed. + dm_config["data_dir"] = str(_TDASTRO_CACHE_DATA_DIR / "dustmaps") + dustmaps.sfd.fetch() + + # Load the dustmap. + from dustmaps.sfd import SFDQuery + + self.dust_map = SFDQuery() + + # Add the correction function. + def _sfd_scale_ebv(input, **kwargs): + """Scale the result of the SFD query.""" + + self.ebv_func = _sfd_scale_ebv + + @staticmethod + def load_extinction_model(name, **kwargs): + """Load the extinction model. + + Parameters + ---------- + name : str + The name of the extinction model to use. + **kwargs : dict + Any additional keyword arguments needed to create that argument. + + Returns + ------- + ext_obj + A extinction object. + """ + for submodule in iter_modules(dust_extinction.__path__): + ext_module = importlib.import_module(f"dust_extinction.{submodule.name}") + if ext_module is not None and name in dir(ext_module): + ext_class = getattr(ext_module, name) + return ext_class(**kwargs) + raise KeyError(f"Invalid dust extinction model '{name}'") + + def apply(self, flux_density, wavelengths, ebv=None, ra=None, dec=None, dist=None, **kwargs): + """Apply the effect to observations (flux_density values). The user can either + provide a ebv value directly or (RA, dec) and distance information that can + be used in the dustmap query. + + Parameters + ---------- + flux_density : numpy.ndarray + An array of flux density values (in nJy). + wavelengths : numpy.ndarray, optional + An array of wavelengths (in angstroms). + ebv : float or np.array + A given ebv value or array of values. If present then this is used + instead of looking it up in the dust map. + ra : float, optional + The object's right ascension (in degrees). + dec : float, optional + The object's declination (in degrees). + dist : float, optional + The object's distance (in parsecs). + **kwargs : `dict`, optional + Any additional keyword arguments. + + Returns + ------- + flux_density : numpy.ndarray + The results (in nJy). + """ + if ebv is None: + if self.dust_map is None: + raise ValueError("If ebv=None then a dust map must be provided.") + if ra is None or dec is None: + raise ValueError("If ebv=None then ra, dec must be provided for a lookup.") + + # Get the extinction value at the object's location. + if dist is not None: + coord = SkyCoord(ra, dec, dist, frame="icrs", unit="deg") + else: + coord = SkyCoord(ra, dec, frame="icrs", unit="deg") + dustmap_value = self.dust_map.query(coord) + + # Perform any corrections needed for this dust map. + if self.ebv_func is not None: + ebv = self.ebv_func(dustmap_value, **kwargs) + else: + ebv = dustmap_value + + print(f"Using ebv={ebv}") + + return flux_density * self.extinction_model.extinguish(wavelengths * u.angstrom, Ebv=ebv) diff --git a/tests/tdastro/astro_utils/test_dust_map.py b/tests/tdastro/astro_utils/test_dust_map.py new file mode 100644 index 0000000..8b160c9 --- /dev/null +++ b/tests/tdastro/astro_utils/test_dust_map.py @@ -0,0 +1,111 @@ +import numpy as np +from dust_extinction.parameter_averages import CCM89 +from dustmaps.map_base import DustMap +from tdastro.astro_utils.dust_map import DustExtinctionEffect +from tdastro.sources.static_source import StaticSource + + +class ConstantDustMap(DustMap): + """A DustMap with a constant value. Used for testing. + + Attributes + ---------- + ebv : `float` + The DustMap's ebv value at all points in the sky. + """ + + def __init__(self, ebv): + self.ebv = ebv + + def query(self, coords): + """Returns reddening at the requested coordinates. + + Parameters + ---------- + coords : `astropy.coordinates.SkyCoord` + The coordinates of the query or queries. + + Returns + ------- + Reddening : `float` or `numpy.ndarray` + The result of the query. + """ + if coords.isscalar: + return self.ebv + return np.full((len(coords)), self.ebv) + + +class TestExtinction: + """An extinction function that computes the scaling factor + as a constant times the ebv. + + Attributes + ---------- + scale : `float` + The extinction function's multiplicative scaling. + """ + + def __init__(self, scale): + self.scale = scale + + def extinguish(self, wavelengths, Ebv=1.0): + """The extinguish function + + Parameters + ---------- + wavelengths : numpy.ndarray + The array of wavelengths + Ebv : float + The Ebv to use in the scaling. + """ + return Ebv * self.scale + + +def test_load_extinction_model(): + """Load an extinction model by string.""" + g23_model = DustExtinctionEffect.load_extinction_model("G23", Rv=3.1) + assert g23_model is not None + assert hasattr(g23_model, "extinguish") + + # Load through the DustExtinctionEffect constructor. + const_map = ConstantDustMap(0.5) + dust_effect = DustExtinctionEffect(const_map, "CCM89", Rv=3.1) + assert dust_effect.extinction_model is not None + assert hasattr(dust_effect.extinction_model, "extinguish") + + +def test_constant_dust_extinction(): + """Test that we can create and sample a DustExtinctionEffect object.""" + times = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + wavelengths = np.array([7000.0, 5200.0, 4800.0]) # Red, green, blue + + # Create a model without any dust. + model_clean = StaticSource(brightness=100.0, ra=0.0, dec=40.0) + state = model_clean.sample_parameters() + fluxes_clean = model_clean.evaluate(times, wavelengths, state) + assert fluxes_clean.shape == (5, 3) + assert np.all(fluxes_clean == 100.0) + + # Create DustExtinctionEffect effects with constant ebvs and constant + # multiplicative extinction functions. + dust_effect = DustExtinctionEffect(ConstantDustMap(0.5), TestExtinction(0.1)) + fluxes = dust_effect.apply(fluxes_clean, wavelengths, ra=0.0, dec=40.0, dist=100.0) + assert fluxes.shape == (5, 3) + assert np.all(fluxes == 5.0) + + dust_effect = DustExtinctionEffect(ConstantDustMap(0.5), TestExtinction(0.3)) + fluxes = dust_effect.apply(fluxes_clean, wavelengths, ra=0.0, dec=40.0, dist=100.0) + assert fluxes.shape == (5, 3) + assert np.all(fluxes == 15.0) + + # Use a manual ebv, which overrides the dustmap. + dust_effect = DustExtinctionEffect(ConstantDustMap(0.5), TestExtinction(0.1)) + fluxes = dust_effect.apply(fluxes_clean, wavelengths, ebv=1.0, ra=0.0, dec=40.0, dist=100.0) + assert fluxes.shape == (5, 3) + assert np.all(fluxes == 10.0) + + # Create a model with ccm89 extinction at r_v = 3.1. + dust_effect = DustExtinctionEffect(ConstantDustMap(0.5), CCM89(Rv=3.1)) + fluxes_ccm98 = dust_effect.apply(fluxes_clean, wavelengths, ra=0.0, dec=40.0, dist=1.0) + assert fluxes_ccm98.shape == (5, 3) + assert np.all(fluxes_ccm98 < 100.0)