From 4a7de9c9bc7586a0a54af3ffe7498489e3702e8b Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Fri, 6 Sep 2024 15:33:35 -0400 Subject: [PATCH 1/6] initial stuff does not work --- pyproject.toml | 2 + src/tdastro/astro_utils/dust_map.py | 182 ++++++++++++++++++++++++++++ 2 files changed, 184 insertions(+) create mode 100644 src/tdastro/astro_utils/dust_map.py diff --git a/pyproject.toml b/pyproject.toml index 0ef4981e..b0895c0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,8 @@ dynamic = ["version"] requires-python = ">=3.9" dependencies = [ "astropy", + "dust_extinction", + "dustmaps", "jax", "numpy", "pandas", diff --git a/src/tdastro/astro_utils/dust_map.py b/src/tdastro/astro_utils/dust_map.py new file mode 100644 index 00000000..d7c3b952 --- /dev/null +++ b/src/tdastro/astro_utils/dust_map.py @@ -0,0 +1,182 @@ +"""A wrapper for querying dust maps and then applying the corresponding +extinction functions. + +This module is a wrapper for the following libraries: + * dustmaps - https://github.com/gregreen/dustmaps + * dust_extinction - https://github.com/karllark/dust_extinction +""" + + +from astropy.coordinates import SkyCoord +from dustmaps.config import config +from pathlib import Path + +import dustmaps + +class DustExtinctionEffect(): + f"""A general dust extinction model. + + Attributes + ---------- + dust_map : ``dustmaps.DustMap`` + The dust map. + extinction_func : `function` + The extinction function to use. + r_v : `float`, optional + The ratio of total extinction to selective extinction to pass to the + extinction function. See: https://extinction.readthedocs.io/ + Set to ``None`` if the parameter is not used. + + Parameters + ---------- + dust_map_name : `str` + The name of the dustmap to use. Valid options include + + """ + + def __init__(self, dust_map, extinction_type, r_v=None, **kwargs): + """Create a dust extinction model. + + Parameters + ---------- + dust_map : `dustmaps.DustMap` + The dust map. + extinction_type : `str` + The extinction function to use. Must be one of: + "ccm89", "odonnell94", "calzetti00", "fitzpatrick99", or "fm07" + r_v : `float`, optional + The ratio of total extinction to selective extinction to pass to the + extinction function. See: https://extinction.readthedocs.io/ + Can be set to ``None`` when using "fitzpatrick99" or "fm07" + **kwargs : `dict`, optional + Any additional keyword arguments. + """ + super().__init__(**kwargs) + self.dust_map = dust_map + self.r_v = r_v + + if extinction_type == "ccm89": + self.extinction_func = extinction.ccm89 + if r_v is None: + raise ValueError("r_v must be set for ccm89") + elif extinction_type == "odonnell94": + self.extinction_func = extinction.odonnell94 + if r_v is None: + raise ValueError("r_v must be set for odonnell94") + elif extinction_type == "calzetti00": + self.extinction_func = extinction.calzetti00 + if r_v is None: + raise ValueError("r_v must be set for calzetti00") + elif extinction_type == "fitzpatrick99": + self.extinction_func = extinction.fitzpatrick99 + elif extinction_type == "fm07": + self.extinction_func = extinction.fm07 + + + @staticmethod + def download_map(map_name, cache_path="../data_cache/dust_maps"): + """Download a dust map given it's name and using the dustmaps package. + + Parameters + ---------- + map_name : `str` + The name of the dust map. Should be one of: std, csfd, planck + planck_GNILC, bayestar, iphas, marshall, chen2014, lenz2017, + pg2010, leike_ensslin_2019, + + """ + cache_dir = Path(cache_path) + cache_dir.mkdir(exist_ok=True, parents=True) + + if map_name == "sfd": + import dustmaps.sfd + dustmaps.sfd.fetch() + elif map_name + +import dustmaps.csfd +dustmaps.csfd.fetch() + +import dustmaps.planck +dustmaps.planck.fetch() + +import dustmaps.planck +dustmaps.planck.fetch(which='GNILC') + +import dustmaps.bayestar +dustmaps.bayestar.fetch() + +import dustmaps.iphas +dustmaps.iphas.fetch() + +import dustmaps.marshall +dustmaps.marshall.fetch() + +import dustmaps.chen2014 +dustmaps.chen2014.fetch() + +import dustmaps.lenz2017 +dustmaps.lenz2017.fetch() + +import dustmaps.pg2010 +dustmaps.pg2010.fetch() + +import dustmaps.leike_ensslin_2019 +dustmaps.leike_ensslin_2019.fetch() + +import dustmaps.leike2020 +dustmaps.leike2020.fetch() + +import dustmaps.edenhofer2023 +dustmaps.edenhofer2023.fetch() + +import dustmaps.gaia_tge +dustmaps.gaia_tge.fetch() + +Path(__file__).parent / "passbands" / self.survey + + + + +from dustmaps.config import config +config['data_dir'] = '/path/to/store/maps/in' + + + + def apply(self, flux_density, wavelengths=None, graph_state=None, **kwargs): + """Apply the effect to observations (flux_density values) + + Parameters + ---------- + flux_density : `numpy.ndarray` + An array of flux density values. + wavelengths : `numpy.ndarray`, optional + An array of wavelengths. + graph_state : `GraphState` + An object mapping graph parameters to their values. + **kwargs : `dict`, optional + Any additional keyword arguments. + + Returns + ------- + flux_density : `numpy.ndarray` + The results. + """ + # Get the extinction value at the object's location. + if physical_model is None: + raise ValueError("physical_model cannot be None") + if physical_model.distance is None: + dist = 1.0 + else: + dist = physical_model.distance + coord = SkyCoord(physical_model.ra, physical_model.dec, dist, frame="icrs", unit="deg") + ebv = self.dust_map.query(coord) + + # Apply the extinction. + if wavelengths is None: + raise ValueError("wavelengths cannot be None") + + if self.r_v is None: + ext = self.extinction_func(wavelengths, ebv) + else: + ext = self.extinction_func(wavelengths, ebv, self.r_v) + return extinction.apply(ext, flux_density) From d2aa42e1424947068c2071e74089a597ff8a933c Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Mon, 9 Sep 2024 09:58:55 -0400 Subject: [PATCH 2/6] More partial work --- src/tdastro/astro_utils/dust_map.py | 169 +++------------------ tests/tdastro/astro_utils/test_dust_map.py | 55 +++++++ 2 files changed, 77 insertions(+), 147 deletions(-) create mode 100644 tests/tdastro/astro_utils/test_dust_map.py diff --git a/src/tdastro/astro_utils/dust_map.py b/src/tdastro/astro_utils/dust_map.py index d7c3b952..56b15f75 100644 --- a/src/tdastro/astro_utils/dust_map.py +++ b/src/tdastro/astro_utils/dust_map.py @@ -2,147 +2,34 @@ extinction functions. This module is a wrapper for the following libraries: - * dustmaps - https://github.com/gregreen/dustmaps - * dust_extinction - https://github.com/karllark/dust_extinction -""" + * dustmaps: + Green 2018, JOSS, 3(26), 695. + https://github.com/gregreen/dustmaps + * dust_extinction: + Gordon 2024, JOSS, 9(100), 7023. + https://github.com/karllark/dust_extinction +""" from astropy.coordinates import SkyCoord -from dustmaps.config import config -from pathlib import Path -import dustmaps class DustExtinctionEffect(): - f"""A general dust extinction model. + """A general dust extinction model. Attributes ---------- - dust_map : ``dustmaps.DustMap`` + dust_map : `dustmaps.DustMap` The dust map. - extinction_func : `function` + extinction_model : `function` The extinction function to use. - r_v : `float`, optional - The ratio of total extinction to selective extinction to pass to the - extinction function. See: https://extinction.readthedocs.io/ - Set to ``None`` if the parameter is not used. - - Parameters - ---------- - dust_map_name : `str` - The name of the dustmap to use. Valid options include - """ - def __init__(self, dust_map, extinction_type, r_v=None, **kwargs): - """Create a dust extinction model. - - Parameters - ---------- - dust_map : `dustmaps.DustMap` - The dust map. - extinction_type : `str` - The extinction function to use. Must be one of: - "ccm89", "odonnell94", "calzetti00", "fitzpatrick99", or "fm07" - r_v : `float`, optional - The ratio of total extinction to selective extinction to pass to the - extinction function. See: https://extinction.readthedocs.io/ - Can be set to ``None`` when using "fitzpatrick99" or "fm07" - **kwargs : `dict`, optional - Any additional keyword arguments. - """ - super().__init__(**kwargs) + def __init__(self, dust_map, extinction_model): self.dust_map = dust_map - self.r_v = r_v - - if extinction_type == "ccm89": - self.extinction_func = extinction.ccm89 - if r_v is None: - raise ValueError("r_v must be set for ccm89") - elif extinction_type == "odonnell94": - self.extinction_func = extinction.odonnell94 - if r_v is None: - raise ValueError("r_v must be set for odonnell94") - elif extinction_type == "calzetti00": - self.extinction_func = extinction.calzetti00 - if r_v is None: - raise ValueError("r_v must be set for calzetti00") - elif extinction_type == "fitzpatrick99": - self.extinction_func = extinction.fitzpatrick99 - elif extinction_type == "fm07": - self.extinction_func = extinction.fm07 - - - @staticmethod - def download_map(map_name, cache_path="../data_cache/dust_maps"): - """Download a dust map given it's name and using the dustmaps package. - - Parameters - ---------- - map_name : `str` - The name of the dust map. Should be one of: std, csfd, planck - planck_GNILC, bayestar, iphas, marshall, chen2014, lenz2017, - pg2010, leike_ensslin_2019, - - """ - cache_dir = Path(cache_path) - cache_dir.mkdir(exist_ok=True, parents=True) - - if map_name == "sfd": - import dustmaps.sfd - dustmaps.sfd.fetch() - elif map_name - -import dustmaps.csfd -dustmaps.csfd.fetch() - -import dustmaps.planck -dustmaps.planck.fetch() + self.extinction_model = extinction_model -import dustmaps.planck -dustmaps.planck.fetch(which='GNILC') - -import dustmaps.bayestar -dustmaps.bayestar.fetch() - -import dustmaps.iphas -dustmaps.iphas.fetch() - -import dustmaps.marshall -dustmaps.marshall.fetch() - -import dustmaps.chen2014 -dustmaps.chen2014.fetch() - -import dustmaps.lenz2017 -dustmaps.lenz2017.fetch() - -import dustmaps.pg2010 -dustmaps.pg2010.fetch() - -import dustmaps.leike_ensslin_2019 -dustmaps.leike_ensslin_2019.fetch() - -import dustmaps.leike2020 -dustmaps.leike2020.fetch() - -import dustmaps.edenhofer2023 -dustmaps.edenhofer2023.fetch() - -import dustmaps.gaia_tge -dustmaps.gaia_tge.fetch() - -Path(__file__).parent / "passbands" / self.survey - - - - -from dustmaps.config import config -config['data_dir'] = '/path/to/store/maps/in' - - - - def apply(self, flux_density, wavelengths=None, graph_state=None, **kwargs): + def apply(self, flux_density, wavelengths, ra, dec, dist=1.0): """Apply the effect to observations (flux_density values) Parameters @@ -151,10 +38,13 @@ def apply(self, flux_density, wavelengths=None, graph_state=None, **kwargs): An array of flux density values. wavelengths : `numpy.ndarray`, optional An array of wavelengths. - graph_state : `GraphState` - An object mapping graph parameters to their values. - **kwargs : `dict`, optional - Any additional keyword arguments. + ra : `float` + The object's right ascension (in degrees). + dec : `float` + The object's declination (in degrees). + dist : `float` + The object's distance (in ?). + Default = 1.0 Returns ------- @@ -162,21 +52,6 @@ def apply(self, flux_density, wavelengths=None, graph_state=None, **kwargs): The results. """ # Get the extinction value at the object's location. - if physical_model is None: - raise ValueError("physical_model cannot be None") - if physical_model.distance is None: - dist = 1.0 - else: - dist = physical_model.distance - coord = SkyCoord(physical_model.ra, physical_model.dec, dist, frame="icrs", unit="deg") + coord = SkyCoord(ra, dec, dist, frame="icrs", unit="deg") ebv = self.dust_map.query(coord) - - # Apply the extinction. - if wavelengths is None: - raise ValueError("wavelengths cannot be None") - - if self.r_v is None: - ext = self.extinction_func(wavelengths, ebv) - else: - ext = self.extinction_func(wavelengths, ebv, self.r_v) - return extinction.apply(ext, flux_density) + return flux_density * self.extinction_model.extinguish(wavelengths, Ebv=ebv) diff --git a/tests/tdastro/astro_utils/test_dust_map.py b/tests/tdastro/astro_utils/test_dust_map.py new file mode 100644 index 00000000..d42a2ffb --- /dev/null +++ b/tests/tdastro/astro_utils/test_dust_map.py @@ -0,0 +1,55 @@ +import numpy as np + +from dust_extinction.parameter_averages import CCM89 +from dustmaps.map_base import DustMap + +from tdastro.astro_utils.dust_map import DustExtinctionEffect +from tdastro.sources.static_source import StaticSource + + +class ConstantDustMap(DustMap): + """A DustMap with a constant value. Used for testing. + + Attributes + ---------- + ebv : `float` + The DustMap's ebv value at all + """ + def __init__(self, ebv): + self.ebv = ebv + + def query(self, coords): + """Returns reddening at the requested coordinates. + + Parameters + ---------- + coords : `astropy.coordinates.SkyCoord` + The coordinates of the query or queries. + + Returns + ------- + Reddening : `float` or `numpy.ndarray` + The result of the query. + """ + if coords.isscalar: + return self.ebv + return np.full((len(coords)), self.ebv) + + +def test_constant_dust_extinction(): + """Test that we can create and sample a DustExtinctionEffect object.""" + times = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + wavelengths = np.array([800.0, 900.0, 1000.0, 1000.0, 900.0]) + + # Create a model without any dust. + model_clean = StaticSource(brightness=100.0, ra=0.0, dec=40.0) + state = model_clean.sample_parameters() + fluxes_clean = model_clean.evaluate(times, wavelengths, state) + assert len(fluxes_clean) == 5 + assert np.all(fluxes_clean == 100.0) + + # Create a model with ccm89 extinction at r_v = 3.1. + dust_effect = DustExtinctionEffect(ConstantDustMap(0.5), CCM89(Rv=3.1)) + fluxes_ccm98 = dust_effect.apply(fluxes_clean, wavelengths, 0.0, 40.0, 1.0) + assert len(fluxes_ccm98) == 5 + assert np.all(fluxes_ccm98 < 100.0) From b060a7a0360e291fa3ebd593f959de0eea9567b3 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Mon, 9 Sep 2024 11:26:26 -0400 Subject: [PATCH 3/6] Working v1 --- src/tdastro/astro_utils/dust_map.py | 15 ++++++++++----- tests/tdastro/astro_utils/test_dust_map.py | 15 +++++++-------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/tdastro/astro_utils/dust_map.py b/src/tdastro/astro_utils/dust_map.py index 56b15f75..c4a8ae32 100644 --- a/src/tdastro/astro_utils/dust_map.py +++ b/src/tdastro/astro_utils/dust_map.py @@ -11,10 +11,11 @@ """ +import astropy.units as u from astropy.coordinates import SkyCoord -class DustExtinctionEffect(): +class DustExtinctionEffect: """A general dust extinction model. Attributes @@ -35,9 +36,9 @@ def apply(self, flux_density, wavelengths, ra, dec, dist=1.0): Parameters ---------- flux_density : `numpy.ndarray` - An array of flux density values. + An array of flux density values (in nJy). wavelengths : `numpy.ndarray`, optional - An array of wavelengths. + An array of wavelengths (in angstroms). ra : `float` The object's right ascension (in degrees). dec : `float` @@ -49,9 +50,13 @@ def apply(self, flux_density, wavelengths, ra, dec, dist=1.0): Returns ------- flux_density : `numpy.ndarray` - The results. + The results (in nJy). """ # Get the extinction value at the object's location. coord = SkyCoord(ra, dec, dist, frame="icrs", unit="deg") ebv = self.dust_map.query(coord) - return flux_density * self.extinction_model.extinguish(wavelengths, Ebv=ebv) + + # Do we need to convert ebv by a factor from this table: + # https://iopscience.iop.org/article/10.1088/0004-637X/737/2/103#apj398709t6 + + return flux_density * self.extinction_model.extinguish(wavelengths * u.angstrom, Ebv=ebv) diff --git a/tests/tdastro/astro_utils/test_dust_map.py b/tests/tdastro/astro_utils/test_dust_map.py index d42a2ffb..c74d5a69 100644 --- a/tests/tdastro/astro_utils/test_dust_map.py +++ b/tests/tdastro/astro_utils/test_dust_map.py @@ -1,26 +1,25 @@ import numpy as np - from dust_extinction.parameter_averages import CCM89 from dustmaps.map_base import DustMap - from tdastro.astro_utils.dust_map import DustExtinctionEffect from tdastro.sources.static_source import StaticSource class ConstantDustMap(DustMap): """A DustMap with a constant value. Used for testing. - + Attributes ---------- ebv : `float` - The DustMap's ebv value at all + The DustMap's ebv value at all points in the sky. """ + def __init__(self, ebv): self.ebv = ebv def query(self, coords): """Returns reddening at the requested coordinates. - + Parameters ---------- coords : `astropy.coordinates.SkyCoord` @@ -39,17 +38,17 @@ def query(self, coords): def test_constant_dust_extinction(): """Test that we can create and sample a DustExtinctionEffect object.""" times = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) - wavelengths = np.array([800.0, 900.0, 1000.0, 1000.0, 900.0]) + wavelengths = np.array([7000.0, 5200.0, 4800.0]) # Red, green, blue # Create a model without any dust. model_clean = StaticSource(brightness=100.0, ra=0.0, dec=40.0) state = model_clean.sample_parameters() fluxes_clean = model_clean.evaluate(times, wavelengths, state) - assert len(fluxes_clean) == 5 + assert fluxes_clean.shape == (5, 3) assert np.all(fluxes_clean == 100.0) # Create a model with ccm89 extinction at r_v = 3.1. dust_effect = DustExtinctionEffect(ConstantDustMap(0.5), CCM89(Rv=3.1)) fluxes_ccm98 = dust_effect.apply(fluxes_clean, wavelengths, 0.0, 40.0, 1.0) - assert len(fluxes_ccm98) == 5 + assert fluxes_ccm98.shape == (5, 3) assert np.all(fluxes_ccm98 < 100.0) From ad9f7182184284aaa7b8d807c49e3d9233010b06 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Mon, 23 Sep 2024 10:36:49 -0400 Subject: [PATCH 4/6] Update .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 1e20c191..e2d948b9 100644 --- a/.gitignore +++ b/.gitignore @@ -151,3 +151,4 @@ _html/ # Default cached location for passband tables src/tdastro/astro_utils/passbands/* +/data/sfd From 58a2401a55518f2a7c70c5422a159ee4fdd9d9cb Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Tue, 29 Oct 2024 13:33:30 -0400 Subject: [PATCH 5/6] Allow loading dust_maps and models by name --- .gitignore | 4 +- src/tdastro/__init__.py | 1 + src/tdastro/astro_utils/dust_map.py | 91 +++++++++++++++++++++++++++-- 3 files changed, 88 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index e2d948b9..5212a886 100644 --- a/.gitignore +++ b/.gitignore @@ -149,6 +149,6 @@ _html/ # Project initialization script .initialize_new_project.sh -# Default cached location for passband tables +# Default cached location for various data. src/tdastro/astro_utils/passbands/* -/data/sfd +/data_cache/* diff --git a/src/tdastro/__init__.py b/src/tdastro/__init__.py index 4f5f9584..8d8fbbd5 100644 --- a/src/tdastro/__init__.py +++ b/src/tdastro/__init__.py @@ -4,3 +4,4 @@ _TDASTRO_BASE_DIR = Path(__file__).parent.parent.parent _TDASTRO_TEST_DIR = _TDASTRO_BASE_DIR / "tests" / "tdastro" _TDASTRO_TEST_DATA_DIR = _TDASTRO_TEST_DIR / "data" +_TDASTRO_CACHE_DATA_DIR = _TDASTRO_BASE_DIR / "data_cache" diff --git a/src/tdastro/astro_utils/dust_map.py b/src/tdastro/astro_utils/dust_map.py index c4a8ae32..b9f55858 100644 --- a/src/tdastro/astro_utils/dust_map.py +++ b/src/tdastro/astro_utils/dust_map.py @@ -11,8 +11,16 @@ """ +import importlib +from pkgutil import iter_modules + import astropy.units as u +import dust_extinction +import dustmaps from astropy.coordinates import SkyCoord +from dustmaps.config import config as dm_config + +from tdastro import _TDASTRO_CACHE_DATA_DIR class DustExtinctionEffect: @@ -20,15 +28,86 @@ class DustExtinctionEffect: Attributes ---------- - dust_map : `dustmaps.DustMap` - The dust map. - extinction_model : `function` - The extinction function to use. + dust_map : `dustmaps.DustMap` or `str` + The dust map or its name. + ext_model : `function` or `str` + The extinction function to use or its name. """ - def __init__(self, dust_map, extinction_model): + def __init__(self, dust_map, ext_model, **kwargs): + if isinstance(dust_map, str): + dust_map = DustExtinctionEffect.load_dustmap(dust_map) self.dust_map = dust_map - self.extinction_model = extinction_model + + if isinstance(ext_model, str): + ext_model = DustExtinctionEffect.load_extinction_model(ext_model, **kwargs) + self.extinction_model = ext_model + + @staticmethod + def load_dustmap(name): + """Load a dustmap from files, downloading it if needed. + + Parameters + ---------- + name : str + The name of the dustmap. + Must be one of: bayestar, chen2014, csfd, edenhofer2023, iphas, + leike_ensslin_2019, leike2020, lenz2017, marshall, pg2010, planck, + or sfd. + + Returns + ------- + dust_map : `dustmaps.DustMap` + A "query" object for the requested dustmap. + """ + # Find the correct submodule within dustmaps and load it. + dm_module = None + for submodule in iter_modules(dustmaps.__path__): + if name == submodule.name: + dm_module = importlib.import_module(f"dustmaps.{name}") + if dm_module is None: + raise KeyError(f"Invalid dustmap '{name}'") + + # Fetch the data to TDAstro's cache directory. + dm_config["data_dir"] = str(_TDASTRO_CACHE_DATA_DIR / "dustmaps") + dm_module.fetch() + + # Get the query object by searching for a class using the {Module}Query + # naming convention. + target_name = f"{name}query" + query_class_name = None + for attr in dir(dm_module): + if attr.lower() == target_name: + query_class_name = attr + if query_class_name is None: + raise ValueError(f"Unable to find query class within module dustmaps.{name}") + + # Get the class, create a query object, and return that object. + dm_class = getattr(dm_module, query_class_name) + return dm_class() + + @staticmethod + def load_extinction_model(name, **kwargs): + """Load the extinction model. + + Parameters + ---------- + name : str + The name of the extinction model to use. + **kwargs : dict + Any additional keyword arguments needed to create that argument. + + Returns + ------- + ext_obj + A extinction object. + """ + for submodule in iter_modules(dust_extinction.__path__): + ext_module = importlib.import_module(f"dust_extinction.{submodule.name}") + if ext_module is not None and name in dir(ext_module): + ext_class = getattr(ext_module, name) + return ext_class(**kwargs) + raise KeyError(f"Invalid dust extinction model '{name}'") def apply(self, flux_density, wavelengths, ra, dec, dist=1.0): """Apply the effect to observations (flux_density values) From a0894f5f944acf2430236ec3d66f16fa23f8abe0 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Wed, 6 Nov 2024 10:04:08 -0500 Subject: [PATCH 6/6] Address comments in the PR --- pyproject.toml | 1 + src/tdastro/astro_utils/dust_map.py | 161 ++++++++++++--------- tests/tdastro/astro_utils/test_dust_map.py | 59 +++++++- 3 files changed, 153 insertions(+), 68 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6df13565..e5b86cc9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -110,6 +110,7 @@ ignore = [ "E721", # Allow direct type comparison "N803", # Allow arguments to start with a capital letter "N806", # Allow variables to use non-lowercase letter + "SIM108" # Allow if-else blocks that could be ternary expressions ] [tool.coverage.run] diff --git a/src/tdastro/astro_utils/dust_map.py b/src/tdastro/astro_utils/dust_map.py index b9f55858..1fea1873 100644 --- a/src/tdastro/astro_utils/dust_map.py +++ b/src/tdastro/astro_utils/dust_map.py @@ -8,7 +8,6 @@ * dust_extinction: Gordon 2024, JOSS, 9(100), 7023. https://github.com/karllark/dust_extinction - """ import importlib @@ -16,7 +15,6 @@ import astropy.units as u import dust_extinction -import dustmaps from astropy.coordinates import SkyCoord from dustmaps.config import config as dm_config @@ -28,63 +26,72 @@ class DustExtinctionEffect: Attributes ---------- - dust_map : `dustmaps.DustMap` or `str` - The dust map or its name. - ext_model : `function` or `str` - The extinction function to use or its name. + dust_map : dustmaps.DustMap or str + The dust map or its name. Since different dustmap's query function + may produce different outputs, you should include the corresponding + ebv_func to transform the result into ebv if needed. + extinction_model : function or str + The extinction object from the dust_extinction library or its name. + If a string is provided, the code will find a matching extinction + function in the dust_extinction package and use that. + ebv_func : function + A function to translate the result of the dustmap query into an ebv. + **kwargs : `dict`, optional + Any additional keyword arguments. """ - def __init__(self, dust_map, ext_model, **kwargs): - if isinstance(dust_map, str): - dust_map = DustExtinctionEffect.load_dustmap(dust_map) - self.dust_map = dust_map - - if isinstance(ext_model, str): - ext_model = DustExtinctionEffect.load_extinction_model(ext_model, **kwargs) - self.extinction_model = ext_model - - @staticmethod - def load_dustmap(name): - """Load a dustmap from files, downloading it if needed. + def __init__(self, dust_map, extinction_model, ebv_func=None, **kwargs): + self.ebv_func = ebv_func - Parameters - ---------- - name : str - The name of the dustmap. - Must be one of: bayestar, chen2014, csfd, edenhofer2023, iphas, - leike_ensslin_2019, leike2020, lenz2017, marshall, pg2010, planck, - or sfd. + if isinstance(dust_map, str): + # Initially we only support loading the SFD dustmap by string. + # But we can expand this as needed. + dust_map = dust_map.lower() + if dust_map == "sfd": + self._load_sfd_dustmap() + else: + raise ValueError("Unsupported load from dustmap {dust_map}") + else: + # If given a dustmap, use that directly. + self.dust_map = dust_map + + if isinstance(extinction_model, str): + extinction_model = DustExtinctionEffect.load_extinction_model(extinction_model, **kwargs) + self.extinction_model = extinction_model + + def _load_sfd_dustmap(self): + """Load the SFD dustmap, downloading it if needed. + + Uses data from: + 1. Schlegel, Finkbeiner, and Davis + The Astrophysical Journal, Volume 500, Issue 2, pp. 525-553. + https://ui.adsabs.harvard.edu/abs/1998ApJ...500..525S/abstract + + 2. Schlegel and Finkbeiner + The Astrophysical Journal, Volume 737, Issue 2, article id. 103, 13 pp. (2011). + https://ui.adsabs.harvard.edu/abs/2011ApJ...737..103S/abstract Returns ------- - dust_map : `dustmaps.DustMap` - A "query" object for the requested dustmap. + sfd_query : SFDQuery + The "query" object for the requested dustmap. """ - # Find the correct submodule within dustmaps and load it. - dm_module = None - for submodule in iter_modules(dustmaps.__path__): - if name == submodule.name: - dm_module = importlib.import_module(f"dustmaps.{name}") - if dm_module is None: - raise KeyError(f"Invalid dustmap '{name}'") - - # Fetch the data to TDAstro's cache directory. + import dustmaps.sfd + + # Download the dustmap if needed. dm_config["data_dir"] = str(_TDASTRO_CACHE_DATA_DIR / "dustmaps") - dm_module.fetch() - - # Get the query object by searching for a class using the {Module}Query - # naming convention. - target_name = f"{name}query" - query_class_name = None - for attr in dir(dm_module): - if attr.lower() == target_name: - query_class_name = attr - if query_class_name is None: - raise ValueError(f"Unable to find query class within module dustmaps.{name}") - - # Get the class, create a query object, and return that object. - dm_class = getattr(dm_module, query_class_name) - return dm_class() + dustmaps.sfd.fetch() + + # Load the dustmap. + from dustmaps.sfd import SFDQuery + + self.dust_map = SFDQuery() + + # Add the correction function. + def _sfd_scale_ebv(input, **kwargs): + """Scale the result of the SFD query.""" + + self.ebv_func = _sfd_scale_ebv @staticmethod def load_extinction_model(name, **kwargs): @@ -109,33 +116,53 @@ def load_extinction_model(name, **kwargs): return ext_class(**kwargs) raise KeyError(f"Invalid dust extinction model '{name}'") - def apply(self, flux_density, wavelengths, ra, dec, dist=1.0): - """Apply the effect to observations (flux_density values) + def apply(self, flux_density, wavelengths, ebv=None, ra=None, dec=None, dist=None, **kwargs): + """Apply the effect to observations (flux_density values). The user can either + provide a ebv value directly or (RA, dec) and distance information that can + be used in the dustmap query. Parameters ---------- - flux_density : `numpy.ndarray` + flux_density : numpy.ndarray An array of flux density values (in nJy). - wavelengths : `numpy.ndarray`, optional + wavelengths : numpy.ndarray, optional An array of wavelengths (in angstroms). - ra : `float` + ebv : float or np.array + A given ebv value or array of values. If present then this is used + instead of looking it up in the dust map. + ra : float, optional The object's right ascension (in degrees). - dec : `float` + dec : float, optional The object's declination (in degrees). - dist : `float` - The object's distance (in ?). - Default = 1.0 + dist : float, optional + The object's distance (in parsecs). + **kwargs : `dict`, optional + Any additional keyword arguments. Returns ------- - flux_density : `numpy.ndarray` + flux_density : numpy.ndarray The results (in nJy). """ - # Get the extinction value at the object's location. - coord = SkyCoord(ra, dec, dist, frame="icrs", unit="deg") - ebv = self.dust_map.query(coord) - - # Do we need to convert ebv by a factor from this table: - # https://iopscience.iop.org/article/10.1088/0004-637X/737/2/103#apj398709t6 + if ebv is None: + if self.dust_map is None: + raise ValueError("If ebv=None then a dust map must be provided.") + if ra is None or dec is None: + raise ValueError("If ebv=None then ra, dec must be provided for a lookup.") + + # Get the extinction value at the object's location. + if dist is not None: + coord = SkyCoord(ra, dec, dist, frame="icrs", unit="deg") + else: + coord = SkyCoord(ra, dec, frame="icrs", unit="deg") + dustmap_value = self.dust_map.query(coord) + + # Perform any corrections needed for this dust map. + if self.ebv_func is not None: + ebv = self.ebv_func(dustmap_value, **kwargs) + else: + ebv = dustmap_value + + print(f"Using ebv={ebv}") return flux_density * self.extinction_model.extinguish(wavelengths * u.angstrom, Ebv=ebv) diff --git a/tests/tdastro/astro_utils/test_dust_map.py b/tests/tdastro/astro_utils/test_dust_map.py index c74d5a69..8b160c9a 100644 --- a/tests/tdastro/astro_utils/test_dust_map.py +++ b/tests/tdastro/astro_utils/test_dust_map.py @@ -35,6 +35,45 @@ def query(self, coords): return np.full((len(coords)), self.ebv) +class TestExtinction: + """An extinction function that computes the scaling factor + as a constant times the ebv. + + Attributes + ---------- + scale : `float` + The extinction function's multiplicative scaling. + """ + + def __init__(self, scale): + self.scale = scale + + def extinguish(self, wavelengths, Ebv=1.0): + """The extinguish function + + Parameters + ---------- + wavelengths : numpy.ndarray + The array of wavelengths + Ebv : float + The Ebv to use in the scaling. + """ + return Ebv * self.scale + + +def test_load_extinction_model(): + """Load an extinction model by string.""" + g23_model = DustExtinctionEffect.load_extinction_model("G23", Rv=3.1) + assert g23_model is not None + assert hasattr(g23_model, "extinguish") + + # Load through the DustExtinctionEffect constructor. + const_map = ConstantDustMap(0.5) + dust_effect = DustExtinctionEffect(const_map, "CCM89", Rv=3.1) + assert dust_effect.extinction_model is not None + assert hasattr(dust_effect.extinction_model, "extinguish") + + def test_constant_dust_extinction(): """Test that we can create and sample a DustExtinctionEffect object.""" times = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) @@ -47,8 +86,26 @@ def test_constant_dust_extinction(): assert fluxes_clean.shape == (5, 3) assert np.all(fluxes_clean == 100.0) + # Create DustExtinctionEffect effects with constant ebvs and constant + # multiplicative extinction functions. + dust_effect = DustExtinctionEffect(ConstantDustMap(0.5), TestExtinction(0.1)) + fluxes = dust_effect.apply(fluxes_clean, wavelengths, ra=0.0, dec=40.0, dist=100.0) + assert fluxes.shape == (5, 3) + assert np.all(fluxes == 5.0) + + dust_effect = DustExtinctionEffect(ConstantDustMap(0.5), TestExtinction(0.3)) + fluxes = dust_effect.apply(fluxes_clean, wavelengths, ra=0.0, dec=40.0, dist=100.0) + assert fluxes.shape == (5, 3) + assert np.all(fluxes == 15.0) + + # Use a manual ebv, which overrides the dustmap. + dust_effect = DustExtinctionEffect(ConstantDustMap(0.5), TestExtinction(0.1)) + fluxes = dust_effect.apply(fluxes_clean, wavelengths, ebv=1.0, ra=0.0, dec=40.0, dist=100.0) + assert fluxes.shape == (5, 3) + assert np.all(fluxes == 10.0) + # Create a model with ccm89 extinction at r_v = 3.1. dust_effect = DustExtinctionEffect(ConstantDustMap(0.5), CCM89(Rv=3.1)) - fluxes_ccm98 = dust_effect.apply(fluxes_clean, wavelengths, 0.0, 40.0, 1.0) + fluxes_ccm98 = dust_effect.apply(fluxes_clean, wavelengths, ra=0.0, dec=40.0, dist=1.0) assert fluxes_ccm98.shape == (5, 3) assert np.all(fluxes_ccm98 < 100.0)