Skip to content

Commit

Permalink
Account for Pynbody 2.0 API change for how to use the direct gravity …
Browse files Browse the repository at this point in the history
…calculation
  • Loading branch information
jobovy committed Sep 14, 2024
1 parent 6c9bb85 commit 2d78942
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 6 deletions.
16 changes: 10 additions & 6 deletions galpy/potential/SnapshotRZPotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy
from scipy import interpolate

from ..util._optional_deps import _PYNBODY_LOADED
from ..util._optional_deps import _PYNBODY_GE_20, _PYNBODY_LOADED
from .interpRZPotential import (
calc_2dsplinecoeffs_c,
interpRZPotential,
Expand All @@ -16,7 +16,11 @@

if _PYNBODY_LOADED:
import pynbody
from pynbody import gravity

if _PYNBODY_GE_20: # pragma: no cover
from pynbody import gravity as pynbody_gravity_calc
else:
from pynbody.gravity import calc as pynbody_gravity_calc
from pynbody.units import NoUnit


Expand Down Expand Up @@ -111,7 +115,7 @@ def _setup_potential(self, R, z, use_pkdgrav=False):
).T

points_new = points.reshape(points.size // 3, 3)
pot, acc = gravity.calc.direct(
pot, acc = pynbody_gravity_calc.direct(
self._s, points_new, num_threads=self._num_threads
)

Expand Down Expand Up @@ -393,7 +397,7 @@ def _setup_potential(self, R, z, use_pkdgrav=False, dr=0.0001):

else:
if self._interpPot:
pot, acc = gravity.calc.direct(
pot, acc = pynbody_gravity_calc.direct(
self._s, points_new, num_threads=self._numcores
)

Expand Down Expand Up @@ -426,7 +430,7 @@ def _setup_potential(self, R, z, use_pkdgrav=False, dr=0.0001):

# first get the accelerations
if self._interpverticalfreq:
zgrad_pot, zgrad_acc = gravity.calc.direct(
zgrad_pot, zgrad_acc = pynbody_gravity_calc.direct(
self._s, zgrad_points, num_threads=self._numcores
)
# each point from the points used above for pot and acc is straddled by
Expand All @@ -452,7 +456,7 @@ def _setup_potential(self, R, z, use_pkdgrav=False, dr=0.0001):

# do the same for the radial component
if self._interpepifreq:
rgrad_pot, rgrad_acc = gravity.calc.direct(
rgrad_pot, rgrad_acc = pynbody_gravity_calc.direct(
self._s, rgrad_points, num_threads=self._numcores
)
rgrad = numpy.zeros(len(points_new))
Expand Down
6 changes: 6 additions & 0 deletions galpy/util/_optional_deps.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Central place to process optional dependencies
from packaging.version import Version
from packaging.version import parse as parse_version

# astropy
Expand Down Expand Up @@ -62,7 +63,12 @@

# pynbody
_PYNBODY_LOADED = True
_PYNBODY_GE_20 = None
try:
import pynbody
except ImportError: # pragma: no cover
_PYNBODY_LOADED = False
else:
_PYNBODY_GE_20 = Version(
parse_version(pynbody.__version__).base_version
) >= Version("2.0.0")
1 change: 1 addition & 0 deletions tests/test_snapshotpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ def test_interpsnapshotKeplerPotential_normalize_units():
s["mass"] = 4.0
s["eps"] = 0.0
s["pos"].units = "kpc"
s["eps"].units = "kpc"
s["vel"].units = "km s**-1"
sp = potential.InterpSnapshotRZPotential(
s,
Expand Down

0 comments on commit 2d78942

Please sign in to comment.