Skip to content

Commit

Permalink
Merge pull request #355 from nstarman/deprecate-integrator-run
Browse files Browse the repository at this point in the history
deprecate Integrator.run for integrator.call
  • Loading branch information
adrn authored Aug 22, 2024
2 parents 4df4a57 + ac1edf0 commit 66443cf
Show file tree
Hide file tree
Showing 11 changed files with 57 additions and 27 deletions.
4 changes: 4 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ Bug fixes
API changes
-----------

- Deprecated ``gala.integrate.Integrator.run`` for
``gala.integrate.Integrator.__call__``. The old method will raise a warning
and will be removed in a future release.


1.8.1 (2023-12-31)
==================
Expand Down
2 changes: 1 addition & 1 deletion gala/dynamics/nonlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def lyapunov_max(w0, integrator, dt, n_steps, d0=1e-5, n_steps_per_pullback=10,
for i in range(1, niter+1):
ii = i * n_steps_per_pullback

orbit = integrator.run(all_w0, dt=dt, n_steps=n_steps_per_pullback, t1=time)
orbit = integrator(all_w0, dt=dt, n_steps=n_steps_per_pullback, t1=time)
tt = orbit.t.value
ww = orbit.w(units)
time += dt*n_steps_per_pullback
Expand Down
2 changes: 1 addition & 1 deletion gala/dynamics/tests/test_nonlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def setup_method(self):

def test_integrate_orbit(self, tmpdir):
integrator = DOPRI853Integrator(self.F_max, func_args=self.par)
orbit = integrator.run(self.w0, dt=self.dt, n_steps=self.n_steps)
orbit = integrator(self.w0, dt=self.dt, n_steps=self.n_steps)

def test_lyapunov_max(self, tmpdir):
n_steps_per_pullback = 10
Expand Down
17 changes: 15 additions & 2 deletions gala/integrate/core.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
""" Base class for integrators. """

from abc import ABCMeta, abstractmethod

# Third-party
import numpy as np
from astropy.utils.decorators import deprecated

# This project
from gala.units import UnitSystem, DimensionlessUnitSystem

__all__ = ["Integrator"]


class Integrator(object):
class Integrator(metaclass=ABCMeta):
def __init__(
self,
func,
Expand Down Expand Up @@ -111,7 +114,17 @@ def _handle_output(self, w0, t, w):
)
return orbit

def run(self):
@deprecated("1.9", alternative="Integrator call method")
def run(self, w0, mmap=None, **time_spec):
"""Run the integrator starting from the specified phase-space position.
.. deprecated:: 1.9
Use the ``__call__`` method instead.
"""
return self(w0, mmap=mmap, **time_spec)

@abstractmethod
def __call__(self, w0, mmap=None, **time_spec):
"""
Run the integrator starting from the specified phase-space position.
The initial conditions ``w0`` should be a
Expand Down
3 changes: 1 addition & 2 deletions gala/integrate/pyintegrators/dopri853.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ def __init__(
)
self._ode_kwargs = kwargs

def run(self, w0, mmap=None, **time_spec):

def __call__(self, w0, mmap=None, **time_spec):
# generate the array of times
times = parse_time_specification(self._func_units, **time_spec)
n_steps = len(times) - 1
Expand Down
5 changes: 2 additions & 3 deletions gala/integrate/pyintegrators/leapfrog.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def F(t, w):
of initial conditions::
integrator = LeapfrogIntegrator(acceleration)
times, ws = integrator.run(w0=[1., 0.], dt=0.1, n_steps=1000)
times, ws = integrator(w0=[1., 0.], dt=0.1, n_steps=1000)
.. note::
Expand Down Expand Up @@ -132,8 +132,7 @@ def _init_v(self, t, w0, dt):

return v_1_2

def run(self, w0, mmap=None, **time_spec):

def __call__(self, w0, mmap=None, **time_spec):
# generate the array of times
times = parse_time_specification(self._func_units, **time_spec)
n_steps = len(times) - 1
Expand Down
3 changes: 1 addition & 2 deletions gala/integrate/pyintegrators/rk5.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ def step(self, t, w, dt):

return w + dw

def run(self, w0, mmap=None, **time_spec):

def __call__(self, w0, mmap=None, **time_spec):
# generate the array of times
times = parse_time_specification(self._func_units, **time_spec)
n_steps = len(times)-1
Expand Down
5 changes: 2 additions & 3 deletions gala/integrate/pyintegrators/ruth4.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def F(t, w):
of initial conditions::
integrator = Ruth4Integrator(acceleration)
times, ws = integrator.run(w0=[1., 0.], dt=0.1, n_steps=1000)
times, ws = integrator(w0=[1., 0.], dt=0.1, n_steps=1000)
.. note::
Expand Down Expand Up @@ -122,8 +122,7 @@ def step(self, t, w, dt):

return w_i

def run(self, w0, mmap=None, **time_spec):

def __call__(self, w0, mmap=None, **time_spec):
# generate the array of times
times = parse_time_specification(self._func_units, **time_spec)
n_steps = len(times) - 1
Expand Down
4 changes: 2 additions & 2 deletions gala/integrate/tests/test_cyintegrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def F(t, w):
cy_w = np.rollaxis(cy_w, -1)

integrator = Integrator(F)
orbit = integrator.run(py_w0, dt=dt, n_steps=n_steps)
orbit = integrator(py_w0, dt=dt, n_steps=n_steps)

py_t = orbit.t.value
py_w = orbit.w()
Expand Down Expand Up @@ -135,7 +135,7 @@ def F(t, w):
# time the Python integration
t0 = time.time()
integrator = Integrator(F)
orbit = integrator.run(py_w0, dt=dt, n_steps=n_steps)
orbit = integrator(py_w0, dt=dt, n_steps=n_steps)
py_times.append(time.time() - t0)

# pl.loglog(x, cy_times, linestyle='-', lw=2., c=c, marker='',
Expand Down
37 changes: 27 additions & 10 deletions gala/integrate/tests/test_pyintegrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# Third-party
import pytest
import numpy as np
from astropy.utils.exceptions import AstropyDeprecationWarning

# Project
from .. import (
Expand Down Expand Up @@ -70,19 +71,35 @@ def test_sho_forward_backward(Integrator):
dt = 1e-4
n_steps = 10_000

forw = integrator.run([0.0, 1.0], dt=dt, n_steps=n_steps)
back = integrator.run([0.0, 1.0], dt=-dt, n_steps=n_steps)
forw = integrator([0.0, 1.0], dt=dt, n_steps=n_steps)
back = integrator([0.0, 1.0], dt=-dt, n_steps=n_steps)

assert np.allclose(forw.w()[:, -1], back.w()[:, -1], atol=1e-6)


@pytest.mark.parametrize("Integrator", integrator_list)
def test_deprecated_run_method(Integrator):
"""Test the deprecated run method."""
integrator = Integrator(sho_F, func_args=(1.0,))

dt = 1e-4
n_steps = 10_000

with pytest.warns(AstropyDeprecationWarning):
run = integrator.run([0.0, 1.0], dt=dt, n_steps=n_steps)

call = integrator([0.0, 1.0], dt=dt, n_steps=n_steps)

assert np.allclose(run.w()[:, -1], call.w()[:, -1], atol=1e-6)


@pytest.mark.parametrize("Integrator", integrator_list)
def test_point_mass(Integrator):
q0 = np.array([1.0, 0.0])
p0 = np.array([0.0, 1.0])

integrator = Integrator(ptmass_F)
orbit = integrator.run(np.append(q0, p0), t1=0.0, t2=2 * np.pi, n_steps=1e4)
orbit = integrator(np.append(q0, p0), t1=0.0, t2=2 * np.pi, n_steps=1e4)

assert np.allclose(orbit.w()[:, 0], orbit.w()[:, -1], atol=1e-6)

Expand All @@ -94,7 +111,7 @@ def test_progress(Integrator):
p0 = np.array([0.0, 1.0])

integrator = Integrator(ptmass_F, progress=True)
_ = integrator.run(np.append(q0, p0), t1=0.0, t2=2 * np.pi, n_steps=1e2)
_ = integrator(np.append(q0, p0), t1=0.0, t2=2 * np.pi, n_steps=1e2)


@pytest.mark.parametrize("Integrator", integrator_list)
Expand All @@ -104,21 +121,21 @@ def test_point_mass_multiple(Integrator):
).T

integrator = Integrator(ptmass_F)
_ = integrator.run(w0, dt=1e-3, n_steps=1e4)
_ = integrator(w0, dt=1e-3, n_steps=1e4)


@pytest.mark.parametrize("Integrator", integrator_list)
def test_driven_pendulum(Integrator):
integrator = Integrator(forced_sho_F, func_args=(0.07, 0.75))
_ = integrator.run([3.0, 0.0], dt=1e-2, n_steps=1e4)
_ = integrator([3.0, 0.0], dt=1e-2, n_steps=1e4)


@pytest.mark.parametrize("Integrator", integrator_list)
def test_lorenz(Integrator):
sigma, rho, beta = 10.0, 28.0, 8 / 3.0
integrator = Integrator(lorenz_F, func_args=(sigma, rho, beta))

_ = integrator.run([0.5, 0.5, 0.5, 0, 0, 0], dt=1e-2, n_steps=1e4)
_ = integrator([0.5, 0.5, 0.5, 0, 0, 0], dt=1e-2, n_steps=1e4)


@pytest.mark.parametrize("Integrator", integrator_list)
Expand All @@ -134,7 +151,7 @@ def test_memmap(tmpdir, Integrator):

integrator = Integrator(sho_F, func_args=(1.0,))

_ = integrator.run(w0, dt=dt, n_steps=n_steps, mmap=mmap)
_ = integrator(w0, dt=dt, n_steps=n_steps, mmap=mmap)


@pytest.mark.parametrize("Integrator", integrator_list)
Expand All @@ -145,7 +162,7 @@ def test_py_store_all(Integrator):
dt = 1e-4
n_steps = 10_000

out_all = integrator_all.run([0.0, 1.0], dt=dt, n_steps=n_steps)
out_final = integrator_final.run([0.0, 1.0], dt=dt, n_steps=n_steps)
out_all = integrator_all([0.0, 1.0], dt=dt, n_steps=n_steps)
out_final = integrator_final([0.0, 1.0], dt=dt, n_steps=n_steps)

assert np.allclose(out_all.w()[:, -1], out_final.w()[:, 0])
2 changes: 1 addition & 1 deletion gala/potential/potential/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ def plot_rotation_curve(self, R_grid, t=0.0, ax=None, labels=None, **plot_kwargs
def integrate_orbit(self, *args, **kwargs):
"""
Integrate an orbit in the current potential using the integrator class
provided. Uses same time specification as `Integrator.run()` -- see
provided. Uses same time specification as `Integrator()` -- see
the documentation for `gala.integrate` for more information.
Parameters
Expand Down

0 comments on commit 66443cf

Please sign in to comment.