Skip to content

Commit

Permalink
Update tests and improve syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
ybillchen committed Aug 20, 2024
1 parent 15daccf commit 968a974
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 59 deletions.
1 change: 0 additions & 1 deletion galpy/df/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@
osipkovmerrittdf = osipkovmerrittdf.osipkovmerrittdf
osipkovmerrittNFWdf = osipkovmerrittNFWdf.osipkovmerrittNFWdf
constantbetadf = constantbetadf.constantbetadf
basestreamspraydf = streamspraydf.basestreamspraydf
chen24spraydf = streamspraydf.chen24spraydf
fardal15spraydf = streamspraydf.fardal15spraydf
streamspraydf = streamspraydf.streamspraydf
30 changes: 18 additions & 12 deletions galpy/df/streamspraydf.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(
ro=self._ro,
vo=self._vo,
)
self._pot = flatten_potential([self._pot, progtrajpot])
self._pot = self._pot + progtrajpot

return None

Expand Down Expand Up @@ -295,7 +295,7 @@ def _setup_rot(self, dt):
rot_inv = numpy.einsum("ijk,ikl->ijl", z_rot_inv, pa_rot_inv)
return (rot, rot_inv)

def _calc_rtide_vcs(self, Rpt, phipt, Zpt, dt):
def _calc_rtide(self, Rpt, phipt, Zpt, dt):
try:
rtides = rtide(
self._rtpot,
Expand All @@ -306,12 +306,6 @@ def _calc_rtide_vcs(self, Rpt, phipt, Zpt, dt):
M=self._progenitor_mass,
use_physical=False,
)
vcs = numpy.sqrt(
-Rpt
* evaluateRforces(
self._rtpot, Rpt, Zpt, phi=phipt, t=-dt, use_physical=False
)
)
except (ValueError, TypeError):
rtides = numpy.array(
[
Expand All @@ -327,6 +321,17 @@ def _calc_rtide_vcs(self, Rpt, phipt, Zpt, dt):
for ii in range(len(Rpt))
]
)
return rtides

def _calc_vc(self, Rpt, phipt, Zpt, dt):
try:
vcs = numpy.sqrt(
-Rpt
* evaluateRforces(
self._rtpot, Rpt, Zpt, phi=phipt, t=-dt, use_physical=False
)
)
except (ValueError, TypeError):
vcs = numpy.array(
[
numpy.sqrt(
Expand All @@ -343,7 +348,7 @@ def _calc_rtide_vcs(self, Rpt, phipt, Zpt, dt):
for ii in range(len(Rpt))
]
)
return rtides, vcs
return vcs

def spray_df(self, xyzpt, vxyzpt, dt):
"""
Expand All @@ -366,7 +371,7 @@ def spray_df(self, xyzpt, vxyzpt, dt):
vxst, vyst, vzst : array, shape (N,)
Velocities of points on the stream in the progenitor coordinates.
"""
warnings.warn("Not implemented!", RuntimeWarning, stacklevel=1)
warnings.warn("Not implemented!", NotImplementedError, stacklevel=1)
pass

Check warning on line 375 in galpy/df/streamspraydf.py

View check run for this annotation

Codecov / codecov/patch

galpy/df/streamspraydf.py#L374-L375

Added lines #L374 - L375 were not covered by tests


Expand Down Expand Up @@ -477,7 +482,7 @@ def spray_df(self, xyzpt, vxyzpt, dt):
Velocities of points on the stream in the progenitor coordinates.
"""
Rpt, phipt, Zpt = coords.rect_to_cyl(xyzpt[:, 0], xyzpt[:, 1], xyzpt[:, 2])
rtides, vcs = self._calc_rtide_vcs(Rpt, phipt, Zpt, dt)
rtides = self._calc_rtide(Rpt, phipt, Zpt, dt)

# Sample positions and velocities in the instantaneous frame
posvel = numpy.random.multivariate_normal(self._mean, self._cov, size=len(dt))
Expand Down Expand Up @@ -603,7 +608,8 @@ def spray_df(self, xyzpt, vxyzpt, dt):
Velocities of points on the stream in the progenitor coordinates.
"""
Rpt, phipt, Zpt = coords.rect_to_cyl(xyzpt[:, 0], xyzpt[:, 1], xyzpt[:, 2])
rtides, vcs = self._calc_rtide_vcs(Rpt, phipt, Zpt, dt)
rtides = self._calc_rtide(Rpt, phipt, Zpt, dt)
vcs = self._calc_vc(Rpt, phipt, Zpt, dt)
rtides_as_frac = rtides / Rpt

vRpt, vTpt, vZpt = coords.rect_to_cyl_vec(
Expand Down
81 changes: 35 additions & 46 deletions tests/test_streamspraydf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,7 @@
import pytest

from galpy.actionAngle import actionAngleIsochroneApprox
from galpy.df import (
basestreamspraydf,
chen24spraydf,
fardal15spraydf,
streamdf,
streamspraydf,
)
from galpy.df import chen24spraydf, fardal15spraydf, streamdf, streamspraydf
from galpy.orbit import Orbit
from galpy.potential import (
ChandrasekharDynamicalFrictionForce,
Expand All @@ -27,23 +21,6 @@
################################ Tests against streamdf ######################


def test_basestreamspraydf_abstract_method():
# Check if the abstract method raises the correct warning
lp = LogarithmicHaloPotential(normalize=1.0, q=0.9)
obs = Orbit(
[1.56148083, 0.35081535, -1.15481504, 0.88719443, -0.47713334, 0.12019596]
)
ro, vo = 8.0, 220.0
spdf = basestreamspraydf(
2 * 10.0**4.0 / conversion.mass_in_msol(vo, ro),
progenitor=obs,
pot=lp,
tdisrupt=4.5 / conversion.time_in_Gyr(vo, ro),
)
with pytest.warns(RuntimeWarning):
spdf.spray_df(None, None, None)


def test_streamspraydf_deprecation():
# Check if the deprecating class raises the correct warning
lp = LogarithmicHaloPotential(normalize=1.0, q=0.9)
Expand Down Expand Up @@ -468,29 +445,41 @@ def test_sample_orbit_rovoetc():
return None


def test_integrate_with_prog(setup_testStreamsprayAgainstStreamdf):
def test_integrate_with_prog():
# Test integrating orbits with the progenitor's potential
# input progenitor
_, spdfs_bovy14 = setup_testStreamsprayAgainstStreamdf
for spdf_bovy14 in spdfs_bovy14:
# Without the progenitor's potential
numpy.random.seed(4)
RvR, dt = spdf_bovy14.sample(
n=100, return_orbit=False, returndt=True, integrate=True
)
# With the progenitor's potential, but set to zero-mass
numpy.random.seed(4)
pot_prog = PlummerPotential(0, 0)
RvR_withprog, dt_withprog = spdf_bovy14.sample(
n=100, return_orbit=False, returndt=True, integrate=True, pot_prog=pot_prog
)
# Should agree
assert (
numpy.amax(numpy.fabs(dt - dt_withprog)) < 1e-10
), "Times not the same when sampling with and without prognitor's potential"
assert (
numpy.amax(numpy.fabs(RvR - RvR_withprog)) < 1e-7
), "Phase-space points not the same when sampling with and without prognitor's potential"
lp = LogarithmicHaloPotential(normalize=1.0, q=0.9)
obs = Orbit(
[1.56148083, 0.35081535, -1.15481504, 0.88719443, -0.47713334, 0.12019596]
)
ro, vo = 8.0, 220.0
# Without the progenitor's potential
spdf = chen24spraydf(
2 * 10.0**4.0 / conversion.mass_in_msol(vo, ro),
progenitor=obs,
pot=lp,
tdisrupt=4.5 / conversion.time_in_Gyr(vo, ro),
)
numpy.random.seed(4)
RvR, dt = spdf.sample(n=100, return_orbit=False, returndt=True, integrate=True)
# With the progenitor's potential, but set to zero-mass
spdf = chen24spraydf(
2 * 10.0**4.0 / conversion.mass_in_msol(vo, ro),
progenitor=obs,
pot=lp,
tdisrupt=4.5 / conversion.time_in_Gyr(vo, ro),
progpot=PlummerPotential(0, 0),
)
numpy.random.seed(4)
RvR_withprog, dt_withprog = spdf.sample(
n=100, return_orbit=False, returndt=True, integrate=True
)
# Should agree
assert (
numpy.amax(numpy.fabs(dt - dt_withprog)) < 1e-10
), "Times not the same when sampling with and without prognitor's potential"
assert (
numpy.amax(numpy.fabs(RvR - RvR_withprog)) < 1e-7
), "Phase-space points not the same when sampling with and without prognitor's potential"
return None


Expand Down

0 comments on commit 968a974

Please sign in to comment.