diff --git a/CHANGES.rst b/CHANGES.rst index 2a14d2298..c6285215f 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -7,6 +7,8 @@ New Features - Added a ``.guiding_center()`` method to ``PhaseSpacePosition`` and ``Orbit`` to compute the guiding center radius. +- Added a way to convert Gala potential instances to Agama potential instances. + Bug fixes --------- @@ -24,6 +26,10 @@ Bug fixes API changes ----------- +- Changed the way potential interoperability is done with other Galactic dynamics + packages (Agama, galpy, etc.). It is now handled by the ``Potential.as_interop()`` + method on all potential class instances. + 1.7.1 (2023-08-05) ================== diff --git a/gala/potential/potential/core.py b/gala/potential/potential/core.py index 551d05861..7d0c48465 100644 --- a/gala/potential/potential/core.py +++ b/gala/potential/potential/core.py @@ -1,15 +1,17 @@ # Standard library import abc -from collections import OrderedDict import copy as pycopy -import warnings import uuid +import warnings +from collections import OrderedDict + +import astropy.units as u # Third-party import numpy as np from astropy.constants import G -import astropy.units as u from astropy.utils import isiterable +from astropy.utils.decorators import deprecated try: from scipy.spatial.transform import Rotation @@ -21,9 +23,10 @@ # Project from gala.util import GalaDeprecationWarning -from ..common import CommonBase -from ...util import ImmutableDict, atleast_2d + from ...units import DimensionlessUnitSystem +from ...util import ImmutableDict, atleast_2d +from ..common import CommonBase __all__ = ["PotentialBase", "CompositePotential"] @@ -897,6 +900,11 @@ def value(self, *args, **kwargs): ########################################################################### # Interoperability with other packages # + @deprecated( + since="v1.8", + message="This has been replaced by a more general interoperability framework.", + alternative="interop", + ) def to_galpy_potential(self, ro=None, vo=None): """Convert a Gala potential to a Galpy potential instance @@ -905,9 +913,36 @@ def to_galpy_potential(self, ro=None, vo=None): ro : quantity-like (optional) vo : quantity-like (optional) """ - from .interop import gala_to_galpy_potential + return self.as_interop("galpy", ro=ro, vo=vo) + + def as_interop(self, package, **kwargs): + """Interoperability with other Galactic dynamics packages + + Parameters + ---------- + package : str + The package to export the potential to. Currently supported packages are + ``"galpy"`` and ``"agama"``. + kwargs + Any additional keyword arguments are passed to the interop function. + """ + if package == "galpy": + from .interop import gala_to_galpy_potential + + kwargs.setdefault("ro", None) + kwargs.setdefault("vo", None) + return gala_to_galpy_potential(self, **kwargs) + elif package == "agama": + import agama - return gala_to_galpy_potential(self, ro=ro, vo=vo) + from .interop import gala_to_agama_potential + + agama_pot = gala_to_agama_potential(self, **kwargs) + if not isinstance(agama_pot, agama.Potential): + agama_pot = agama.Potential(*agama_pot) + return agama_pot + else: + raise ValueError(f"Unsupported package: {package}") class CompositePotential(PotentialBase, OrderedDict): diff --git a/gala/potential/potential/interop.py b/gala/potential/potential/interop.py index 748501b4b..585a138d2 100644 --- a/gala/potential/potential/interop.py +++ b/gala/potential/potential/interop.py @@ -3,208 +3,253 @@ import inspect import warnings -from astropy.constants import G import astropy.units as u import numpy as np +from astropy.constants import G import gala.potential.potential.builtin as gp from gala.potential.potential.ccompositepotential import CCompositePotential from gala.potential.potential.core import CompositePotential +from gala.tests.optional_deps import HAS_AGAMA, HAS_GALPY from gala.units import galactic -from gala.tests.optional_deps import HAS_GALPY -__all__ = ['gala_to_galpy_potential', 'galpy_to_gala_potential'] +__all__ = [ + "gala_to_galpy_potential", + "galpy_to_gala_potential", + "gala_to_agama_potential", +] ############################################################################### # Galpy interoperability # if HAS_GALPY: - from scipy.special import gamma import galpy.potential as galpy_gp + from scipy.special import gamma def _powerlaw_amp_to_galpy(pars, ro, vo): # I don't really remember why this is like this, but it might be related # to the difference between GSL gamma and scipy gamma?? - fac = ((1/(2*np.pi) * pars['r_c'].to_value(ro)**(pars['alpha'] - 3) / - (gamma(3/2 - pars['alpha']/2)))) - amp = fac * (G * pars['m']).to_value(vo**2 * ro) + fac = ( + 1 + / (2 * np.pi) + * pars["r_c"].to_value(ro) ** (pars["alpha"] - 3) + / (gamma(3 / 2 - pars["alpha"] / 2)) + ) + amp = fac * (G * pars["m"]).to_value(vo**2 * ro) return amp def _powerlaw_m_from_galpy(pars, ro, vo): # See note above! - fac = ((1/(2*np.pi) * pars['rc']**(pars['alpha'] - 3) / - (gamma(3/2 - pars['alpha']/2)))) - amp = pars['amp'] * vo**2 * ro + fac = ( + 1 + / (2 * np.pi) + * pars["rc"] ** (pars["alpha"] - 3) + / (gamma(3 / 2 - pars["alpha"] / 2)) + ) + amp = pars["amp"] * vo**2 * ro m = amp / G / fac return m def _mn3_amp_to_galpy(pars, ro, vo): - num = (G * pars['m']).to_value(ro * vo**2) - den = (4*np.pi * pars['h_R'].to_value(ro)**2 * pars['h_z'].to_value(ro)) + num = (G * pars["m"]).to_value(ro * vo**2) + den = 4 * np.pi * pars["h_R"].to_value(ro) ** 2 * pars["h_z"].to_value(ro) return num / den # TODO: some potential conversions drop parameters. Might want to add an # option for a custom validator function or something to raise warnings? _gala_to_galpy = { gp.HernquistPotential: ( - galpy_gp.HernquistPotential, { - 'a': 'c', - 'amp': lambda pars, ro, vo: (G*2*pars['m']).to_value(ro*vo**2) - } - ), - gp.IsochronePotential: ( - galpy_gp.IsochronePotential, { - 'b': 'b' - } - ), - gp.JaffePotential: ( - galpy_gp.JaffePotential, { - 'a': 'c' - } + galpy_gp.HernquistPotential, + { + "a": "c", + "amp": lambda pars, ro, vo: (G * 2 * pars["m"]).to_value(ro * vo**2), + }, ), + gp.IsochronePotential: (galpy_gp.IsochronePotential, {"b": "b"}), + gp.JaffePotential: (galpy_gp.JaffePotential, {"a": "c"}), gp.KeplerPotential: (galpy_gp.KeplerPotential, {}), gp.KuzminPotential: ( - galpy_gp.KuzminDiskPotential, { - 'a': 'a', - } + galpy_gp.KuzminDiskPotential, + { + "a": "a", + }, ), gp.LogarithmicPotential: ( - galpy_gp.LogarithmicHaloPotential, { - 'amp': lambda pars, ro, vo: pars['v_c'].to_value(vo)**2, - 'core': 'r_h', - 'q': 'q3' - } + galpy_gp.LogarithmicHaloPotential, + { + "amp": lambda pars, ro, vo: pars["v_c"].to_value(vo) ** 2, + "core": "r_h", + "q": "q3", + }, ), gp.LongMuraliBarPotential: ( - galpy_gp.SoftenedNeedleBarPotential, { - 'a': 'a', - 'b': 'b', - 'c': 'c', - 'pa': 'alpha' - } + galpy_gp.SoftenedNeedleBarPotential, + {"a": "a", "b": "b", "c": "c", "pa": "alpha"}, ), gp.MiyamotoNagaiPotential: ( - galpy_gp.MiyamotoNagaiPotential, { - 'a': 'a', - 'b': 'b' - } + galpy_gp.MiyamotoNagaiPotential, + {"a": "a", "b": "b"}, ), gp.MN3ExponentialDiskPotential: ( - galpy_gp.MN3ExponentialDiskPotential, { - 'amp': _mn3_amp_to_galpy, - 'hr': 'h_R', - 'hz': 'h_z', - 'posdens': 'positive_density', - 'sech': 'sech2_z' - } + galpy_gp.MN3ExponentialDiskPotential, + { + "amp": _mn3_amp_to_galpy, + "hr": "h_R", + "hz": "h_z", + "posdens": "positive_density", + "sech": "sech2_z", + }, ), gp.NFWPotential: ( - galpy_gp.TriaxialNFWPotential, { - 'a': 'r_s', - 'b': lambda pars, *_: pars['b'] / pars['a'], - 'c': lambda pars, *_: pars['c'] / pars['a'], - } - ), - gp.PlummerPotential: ( - galpy_gp.PlummerPotential, { - 'b': 'b' - } + galpy_gp.TriaxialNFWPotential, + { + "a": "r_s", + "b": lambda pars, *_: pars["b"] / pars["a"], + "c": lambda pars, *_: pars["c"] / pars["a"], + }, ), + gp.PlummerPotential: (galpy_gp.PlummerPotential, {"b": "b"}), gp.PowerLawCutoffPotential: ( - galpy_gp.PowerSphericalPotentialwCutoff, { - 'amp': _powerlaw_amp_to_galpy, - 'rc': 'r_c', - 'alpha': 'alpha' - } + galpy_gp.PowerSphericalPotentialwCutoff, + {"amp": _powerlaw_amp_to_galpy, "rc": "r_c", "alpha": "alpha"}, ), } _galpy_to_gala = {} for gala_cls, (galpy_cls, pars) in _gala_to_galpy.items(): - galpy_pars = {v: k for k, v in pars.items() - if isinstance(v, (str, int, float, np.ndarray))} + galpy_pars = { + v: k + for k, v in pars.items() + if isinstance(v, (str, int, float, np.ndarray)) + } _galpy_to_gala[galpy_cls] = (gala_cls, galpy_pars) # Special cases: - _galpy_to_gala[galpy_gp.HernquistPotential][1]['m'] = \ - lambda pars, ro, vo: (pars['amp'] * ro * vo**2 / G / 2) + _galpy_to_gala[galpy_gp.HernquistPotential][1]["m"] = lambda pars, ro, vo: ( + pars["amp"] * ro * vo**2 / G / 2 + ) - _galpy_to_gala[galpy_gp.LogarithmicHaloPotential][1]['v_c'] = \ - lambda pars, ro, vo: np.sqrt(pars['amp'] * vo**2) + _galpy_to_gala[galpy_gp.LogarithmicHaloPotential][1][ + "v_c" + ] = lambda pars, ro, vo: np.sqrt(pars["amp"] * vo**2) - _galpy_to_gala[galpy_gp.TriaxialNFWPotential][1]['m'] = \ - lambda pars, ro, vo: ( - pars['amp'] * ro * vo**2 / G * 4*np.pi*pars['a']**3) - _galpy_to_gala[galpy_gp.TriaxialNFWPotential][1]['a'] = 1. - _galpy_to_gala[galpy_gp.TriaxialNFWPotential][1]['b'] = 'b' - _galpy_to_gala[galpy_gp.TriaxialNFWPotential][1]['c'] = 'c' + _galpy_to_gala[galpy_gp.TriaxialNFWPotential][1]["m"] = lambda pars, ro, vo: ( + pars["amp"] * ro * vo**2 / G * 4 * np.pi * pars["a"] ** 3 + ) + _galpy_to_gala[galpy_gp.TriaxialNFWPotential][1]["a"] = 1.0 + _galpy_to_gala[galpy_gp.TriaxialNFWPotential][1]["b"] = "b" + _galpy_to_gala[galpy_gp.TriaxialNFWPotential][1]["c"] = "c" - _galpy_to_gala[galpy_gp.PowerSphericalPotentialwCutoff][1]['m'] = \ - _powerlaw_m_from_galpy + _galpy_to_gala[galpy_gp.PowerSphericalPotentialwCutoff][1][ + "m" + ] = _powerlaw_m_from_galpy _galpy_to_gala[galpy_gp.NFWPotential] = ( - gp.NFWPotential, { - 'r_s': 'a', - } + gp.NFWPotential, + { + "r_s": "a", + }, ) +if HAS_AGAMA: + # TODO: some potential conversions drop parameters. Might want to add an + # option for a custom validator function or something to raise warnings? + _gala_to_agama = { + gp.HernquistPotential: { + "type": "dehnen", + "mass": "m", + "scaleradius": "c", + "gamma": 1.0, + }, + gp.IsochronePotential: {"type": "isochrone", "mass": "m", "scaleradius": "b"}, + gp.JaffePotential: { + "type": "dehnen", + "mass": "m", + "scaleradius": "c", + "gamma": 2.0, + }, + # gp.KeplerPotential: {}, + # gp.KuzminPotential: {}, + gp.LogarithmicPotential: { + "type": "logarithmic", + "v0": "v_c", + "scaleradius": "r_h", + "axisRatioY": "q2", + "axisRatioZ": "q3", + }, + # gp.LongMuraliBarPotential: {}, + gp.MiyamotoNagaiPotential: { + "type": "miyamotonagai", + "mass": "m", + "scaleradius": "a", + "scaleheight": "b", + }, + # gp.MN3ExponentialDiskPotential: {}, # Special cased below + gp.NFWPotential: {"type": "nfw", "mass": "m", "scaleradius": "r_s"}, + gp.PlummerPotential: {"type": "plummer", "mass": "m", "scaleradius": "b"}, + # gp.PowerLawCutoffPotential: {} + } + def _get_ro_vo(ro, vo): # If not specified, get the default ro, vo from Galpy if ro is None or vo is None: from galpy.potential import Force + f = Force() if ro is None: ro = f._ro * u.kpc if vo is None: - vo = f._vo * u.km/u.s + vo = f._vo * u.km / u.s return u.Quantity(ro), u.Quantity(vo) def gala_to_galpy_potential(potential, ro=None, vo=None): - if not HAS_GALPY: raise ImportError( - "Failed to import galpy.potential: Converting a potential to a " - "galpy potential requires galpy to be installed.") + "Failed to import galpy.potential: Converting a potential to a galpy " + "potential requires galpy to be installed." + ) ro, vo = _get_ro_vo(ro, vo) if isinstance(potential, CompositePotential): pot = [] for k in potential.keys(): - pot.append( - gala_to_galpy_potential(potential[k], ro, vo)) + pot.append(gala_to_galpy_potential(potential[k], ro, vo)) else: if potential.__class__ not in _gala_to_galpy: raise TypeError( f"Converting potential class {potential.__class__.__name__} " - "to galpy is currently not supported") + "to galpy is currently not supported" + ) galpy_cls, converters = _gala_to_galpy[potential.__class__] gala_pars = potential.parameters.copy() galpy_pars = {} - if 'amp' not in converters and 'm' not in gala_pars: - raise ValueError("Gala potential has no mass parameter, so " - "converting to a Galpy potential is currently " - "not supported.") + if "amp" not in converters and "m" not in gala_pars: + raise ValueError( + "Gala potential has no mass parameter, so converting to a Galpy " + "potential is currently not supported." + ) if isinstance(potential, gp.MN3ExponentialDiskPotential): - gala_pars['positive_density'] = potential.positive_density - gala_pars['sech2_z'] = potential.sech2_z + gala_pars["positive_density"] = potential.positive_density + gala_pars["sech2_z"] = potential.sech2_z converters.setdefault( - 'amp', lambda pars, ro, vo: (G * pars['m']).to_value(ro * vo**2)) + "amp", lambda pars, ro, vo: (G * pars["m"]).to_value(ro * vo**2) + ) for galpy_par_name, conv in converters.items(): if isinstance(conv, str): galpy_pars[galpy_par_name] = gala_pars[conv] - elif hasattr(conv, '__call__'): + elif hasattr(conv, "__call__"): galpy_pars[galpy_par_name] = conv(gala_pars, ro, vo) elif isinstance(conv, (int, float, u.Quantity, np.ndarray)): galpy_pars[galpy_par_name] = conv @@ -213,21 +258,21 @@ def gala_to_galpy_potential(potential, ro=None, vo=None): print(f"FAIL: {galpy_par_name}, {conv}") par = galpy_pars[galpy_par_name] - if hasattr(par, 'unit'): - if par.unit.physical_type == 'length': + if hasattr(par, "unit"): + if par.unit.physical_type == "length": galpy_pars[galpy_par_name] = par.to_value(ro) - elif par.unit.physical_type == 'speed': + elif par.unit.physical_type == "speed": galpy_pars[galpy_par_name] = par.to_value(vo) - elif par.unit.physical_type == 'dimensionless': + elif par.unit.physical_type == "dimensionless": galpy_pars[galpy_par_name] = par.value - elif par.unit.physical_type == 'angle': + elif par.unit.physical_type == "angle": galpy_pars[galpy_par_name] = par.to_value(u.rad) else: warnings.warn( f"Unknown unit physical type '{par.unit.physical_type}'" " - this should have a custom unit converter. Please " "make a GitHub issue!", - RuntimeWarning + RuntimeWarning, ) galpy_pars[galpy_par_name] = par.value @@ -237,18 +282,18 @@ def gala_to_galpy_potential(potential, ro=None, vo=None): def galpy_to_gala_potential(potential, ro=None, vo=None, units=galactic): - if not HAS_GALPY: raise ImportError( "Failed to import galpy.potential: Converting a potential to a " - "gala potential requires galpy to be installed.") + "gala potential requires galpy to be installed." + ) ro, vo = _get_ro_vo(ro, vo) if potential._roSet: ro = potential._ro * u.kpc if potential._voSet: - vo = potential._vo * u.km/u.s + vo = potential._vo * u.km / u.s if isinstance(potential, list): pot = CCompositePotential() @@ -259,45 +304,46 @@ def galpy_to_gala_potential(potential, ro=None, vo=None, units=galactic): if potential.__class__ not in _galpy_to_gala: raise TypeError( f"Converting galpy potential {potential.__class__.__name__} " - "to gala is currently not supported") + "to gala is currently not supported" + ) elif isinstance(potential, galpy_gp.MN3ExponentialDiskPotential): warnings.warn( "For the MN3ExponentialDiskPotential, galpy does not store " "information to fully reconstruct the potential, so the " "default gala choices will be adopted for the " "'positive_density' and 'sech2_z' potential arguments", - RuntimeWarning + RuntimeWarning, ) gala_cls, converters = _galpy_to_gala[potential.__class__] - exclude = ['self', 'normalize', 'ro', 'vo'] + exclude = ["self", "normalize", "ro", "vo"] spec = inspect.getfullargspec(potential.__class__) par_names = [arg for arg in spec.args if arg not in exclude] # UGH! galpy_pars = {} for name in par_names: - galpy_pars[name] = getattr(potential, - '_' + name, - getattr(potential, name, None)) + galpy_pars[name] = getattr( + potential, "_" + name, getattr(potential, name, None) + ) if isinstance(potential, galpy_gp.LogarithmicHaloPotential): - galpy_pars['core'] = np.sqrt(potential._core2) + galpy_pars["core"] = np.sqrt(potential._core2) elif isinstance(potential, galpy_gp.SoftenedNeedleBarPotential): - galpy_pars['c'] = np.sqrt(potential._c2) + galpy_pars["c"] = np.sqrt(potential._c2) - if 'm' in inspect.getfullargspec(gala_cls).args: + if "m" in inspect.getfullargspec(gala_cls).args: converters.setdefault( - 'm', lambda pars, ro, vo: pars['amp'] * ro * vo**2 / G + "m", lambda pars, ro, vo: pars["amp"] * ro * vo**2 / G ) gala_pars = {} for gala_par_name, conv in converters.items(): if isinstance(conv, str): gala_pars[gala_par_name] = galpy_pars[conv] - elif hasattr(conv, '__call__'): + elif hasattr(conv, "__call__"): gala_pars[gala_par_name] = conv(galpy_pars, ro, vo) elif isinstance(conv, (int, float, u.Quantity, np.ndarray)): gala_pars[gala_par_name] = conv @@ -305,22 +351,22 @@ def galpy_to_gala_potential(potential, ro=None, vo=None, units=galactic): # TODO: invalid parameter?? print(f"FAIL: {gala_par_name}, {conv}") - if hasattr(gala_pars[gala_par_name], 'unit'): + if hasattr(gala_pars[gala_par_name], "unit"): continue if gala_par_name not in gala_cls._parameters: continue gala_par = gala_cls._parameters[gala_par_name] - if gala_par.physical_type == 'mass': + if gala_par.physical_type == "mass": gala_pars[gala_par_name] = gala_pars[gala_par_name] * u.Msun - elif gala_par.physical_type == 'length': + elif gala_par.physical_type == "length": gala_pars[gala_par_name] = gala_pars[gala_par_name] * ro - elif gala_par.physical_type == 'speed': + elif gala_par.physical_type == "speed": gala_pars[gala_par_name] = gala_pars[gala_par_name] * vo - elif gala_par.physical_type == 'angle': + elif gala_par.physical_type == "angle": gala_pars[gala_par_name] = gala_pars[gala_par_name] * u.radian - elif gala_par.physical_type == 'dimensionless': + elif gala_par.physical_type == "dimensionless": pass else: print("TODO") @@ -328,3 +374,61 @@ def galpy_to_gala_potential(potential, ro=None, vo=None, units=galactic): pot = gala_cls(**gala_pars, units=units) return pot + + +def gala_to_agama_potential(potential): + if not HAS_AGAMA: + raise ImportError( + "Failed to import agama: Converting a potential to an Agama potential " + "requires Agama to be installed." + ) + + import agama + + agama.setUnits(**{k: potential.units[k] for k in ["length", "mass", "time"]}) + + if isinstance(potential, CompositePotential): + pot = [] + for k in potential.keys(): + agama_pot = gala_to_agama_potential(potential[k]) + if isinstance(agama_pot, list): + pot.extend(agama_pot) + else: + pot.append(agama_pot) + + elif isinstance(potential, gp.MN3ExponentialDiskPotential): + pot = [] + for disk in potential.get_three_potentials().values(): + pot.append(gala_to_agama_potential(disk)) + + else: + if potential.__class__ not in _gala_to_agama: + raise TypeError( + f"Converting potential class {potential.__class__.__name__} " + "to agama is currently not supported" + ) + + agama_spec = _gala_to_agama[potential.__class__] + gala_pars = potential.parameters.copy() + + agama_pars = {"type": agama_spec["type"]} + for agama_par_name, conv in agama_spec.items(): + if agama_par_name == "type": + continue + elif isinstance(conv, str): + agama_pars[agama_par_name] = gala_pars[conv] + # elif hasattr(conv, "__call__"): + # agama_pars[agama_par_name] = conv(gala_pars) + elif isinstance(conv, (int, float, u.Quantity, np.ndarray)): + agama_pars[agama_par_name] = conv + else: + # TODO: invalid parameter?? + print(f"FAIL: {agama_par_name}, {conv}") + + for k, v in agama_pars.items(): + if hasattr(v, "unit"): + agama_pars[k] = v.decompose(potential.units).value + + pot = agama.Potential(**agama_pars) + + return pot diff --git a/gala/potential/potential/tests/test_interop_agama.py b/gala/potential/potential/tests/test_interop_agama.py new file mode 100644 index 000000000..6728fb140 --- /dev/null +++ b/gala/potential/potential/tests/test_interop_agama.py @@ -0,0 +1,108 @@ +""" +Test converting the builtin Potential classes to Agama +""" + +# Third-party +import astropy.units as u +import numpy as np +import pytest + +# This project +from gala.potential import JaffePotential, LogarithmicPotential, MiyamotoNagaiPotential +from gala.tests.optional_deps import HAS_AGAMA +from gala.units import galactic + +if HAS_AGAMA: + from gala.potential.potential.interop import _gala_to_agama + + +def pytest_generate_tests(metafunc): + # Some magic, semi-random numbers below! + gala_pots = [] + other_pots = [] + + if not HAS_AGAMA: + return + + # Test the Gala -> Agama direction + for Potential in _gala_to_agama.keys(): + init = {} + len_scale = 1.0 + for k, par in Potential._parameters.items(): + if k == "m": + val = 1.43e10 * u.Msun + elif par.physical_type == "length": + val = 5.12 * u.kpc * len_scale + len_scale *= 0.5 + elif par.physical_type == "dimensionless": + val = 1.0 + elif par.physical_type == "speed": + val = 201.41 * u.km / u.s + else: + continue + + init[k] = val + + pot = Potential(**init, units=galactic) + other_pot = pot.as_interop("agama") + + gala_pots.append(pot) + other_pots.append(other_pot) + + # Make a composite potential too: + gala_pots.append(gala_pots[0] + gala_pots[1]) + other_pots.append(gala_pots[-1].as_interop("agama")) + + test_names = [ + f"{g1.__class__.__name__}:{g2.__class__.__name__}" + for g1, g2 in zip(gala_pots, other_pots) + ] + + metafunc.parametrize( + ["gala_pot", "other_pot"], list(zip(gala_pots, other_pots)), ids=test_names + ) + + +@pytest.mark.skipif( + not HAS_AGAMA, reason="must have agama installed to run these tests" +) +class TestAgamaInterop: + def setup_method(self): + # Test points: + rng = np.random.default_rng(42) + ntest = 4 + + xyz = rng.uniform(-25, 25, size=(3, ntest)) * u.kpc + self.xyz = xyz.copy() + + def test_density(self, gala_pot, other_pot): + gala_val = gala_pot.density(self.xyz).decompose(gala_pot.units).value + other_val = other_pot.density(self.xyz.decompose(gala_pot.units).value.T) + assert np.allclose(gala_val, other_val) + + def test_energy(self, gala_pot, other_pot): + if isinstance(gala_pot, LogarithmicPotential): + # TODO: Agama has an inconsistency with Gala's log potential energy + pytest.skip() + gala_val = gala_pot.energy(self.xyz).decompose(gala_pot.units).value + other_val = other_pot.potential(self.xyz.decompose(gala_pot.units).value.T) + assert np.allclose(gala_val, other_val) + + def test_acc(self, gala_pot, other_pot): + gala_val = gala_pot.acceleration(self.xyz).decompose(gala_pot.units).value + other_val = other_pot.force(self.xyz.decompose(gala_pot.units).value.T).T + assert np.allclose(gala_val, other_val) + + def test_Menc(self, gala_pot, other_pot): + if isinstance( + gala_pot, (LogarithmicPotential, JaffePotential, MiyamotoNagaiPotential) + ): + # TODO: Agama has an inconsistency with Gala's log potential energy + pytest.skip() + + grid = np.zeros((3, 128)) + grid[0] = np.geomspace(1e-3, 100.0, 128) + + gala_val = gala_pot.mass_enclosed(grid).value + agama_val = other_pot.enclosedMass(grid[0]) + assert np.allclose(gala_val, agama_val) diff --git a/gala/potential/potential/tests/test_interop_galpy.py b/gala/potential/potential/tests/test_interop_galpy.py index 2e85c3971..abe31f953 100644 --- a/gala/potential/potential/tests/test_interop_galpy.py +++ b/gala/potential/potential/tests/test_interop_galpy.py @@ -3,17 +3,17 @@ """ # Third-party -from astropy.coordinates import CylindricalRepresentation -from astropy.tests.helper import catch_warnings import astropy.units as u import numpy as np import pytest +from astropy.coordinates import CylindricalRepresentation +from astropy.tests.helper import catch_warnings # This project import gala.potential as gp -from gala.units import galactic -from gala.tests.optional_deps import HAS_GALPY from gala.potential.potential.interop import galpy_to_gala_potential +from gala.tests.optional_deps import HAS_GALPY +from gala.units import galactic # Set these globally! ro = 8.122 * u.kpc @@ -53,7 +53,7 @@ def pytest_generate_tests(metafunc): init[k] = val pot = Potential(**init, units=galactic) - galpy_pot = pot.to_galpy_potential(ro=ro, vo=vo) + galpy_pot = pot.as_interop("galpy", ro=ro, vo=vo) gala_pots.append(pot) galpy_pots.append(galpy_pot) diff --git a/gala/tests/optional_deps.py b/gala/tests/optional_deps.py index 2be401576..dea3c3dc4 100644 --- a/gala/tests/optional_deps.py +++ b/gala/tests/optional_deps.py @@ -2,18 +2,20 @@ `PEP 562 `_. """ import importlib +import io from collections.abc import Sequence +from contextlib import redirect_stdout # First, the top-level packages: # TODO: This list is a duplicate of the dependencies in setup.cfg "all", but # some of the package names are different from the pip-install name (e.g., # beautifulsoup4 -> bs4). -_optional_deps = ['h5py', 'sympy', 'tqdm', 'twobody'] +_optional_deps = ["h5py", "sympy", "tqdm", "twobody", "agama"] _deps = {k.upper(): k for k in _optional_deps} # Any subpackages that have different import behavior: -_deps['MATPLOTLIB'] = ('matplotlib', 'matplotlib.pyplot') -_deps['GALPY'] = ('galpy', 'galpy.orbit', 'galpy.potential') +_deps["MATPLOTLIB"] = ("matplotlib", "matplotlib.pyplot") +_deps["GALPY"] = ("galpy", "galpy.orbit", "galpy.potential") __all__ = [f"HAS_{pkg}" for pkg in _deps] @@ -28,7 +30,8 @@ def __getattr__(name): for module in modules: try: - importlib.import_module(module) + with redirect_stdout(io.StringIO()): + importlib.import_module(module) except (ImportError, ModuleNotFoundError): return False return True