Skip to content

Commit

Permalink
format tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
adrn committed Dec 31, 2023
1 parent 3479dd1 commit 19ff02a
Show file tree
Hide file tree
Showing 3 changed files with 210 additions and 149 deletions.
126 changes: 69 additions & 57 deletions docs/tutorials/Arbitrary-density-SCF.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
import gala.potential as gp
from gala.potential.scf import compute_coeffs


# -

# ## SCF representation of an analytic density distribution
Expand All @@ -75,9 +74,10 @@
# coordinates (x, y, z) and returns the (scalar) value of the density at that
# location:


def density_func(x, y, z):
r = np.sqrt(x**2 + y**2 + z**2)
return 1 / (r**1.8 * (1 + r)**2.7)
return 1 / (r**1.8 * (1 + r) ** 2.7)


# Let's visualize this density function. For comparison, let's also over-plot
Expand All @@ -90,20 +90,20 @@ def density_func(x, y, z):

# +
x = np.logspace(-1, 1, 128)
plt.plot(x, density_func(x, 0, 0), marker='', label='custom density')
plt.plot(x, density_func(x, 0, 0), marker="", label="custom density")

# need a 3D grid for the potentials in Gala
xyz = np.zeros((3, len(x)))
xyz[0] = x
plt.plot(x, hern.density(xyz), marker='', label='Hernquist')
plt.plot(x, hern.density(xyz), marker="", label="Hernquist")

plt.xscale('log')
plt.yscale('log')
plt.xscale("log")
plt.yscale("log")

plt.xlabel('$r$')
plt.ylabel(r'$\rho(r)$')
plt.xlabel("$r$")
plt.ylabel(r"$\rho(r)$")

plt.legend(loc='best');
plt.legend(loc="best")
# -

# These functions are not *too* different, implying that we probably don't need
Expand All @@ -114,9 +114,9 @@ def density_func(x, y, z):
# m$ terms, so we set `lmax=0`. We can also neglect the sin() terms of the
# expansion ($T_{nlm}$):

(S, Serr), _ = compute_coeffs(density_func,
nmax=10, lmax=0,
M=1., r_s=1., S_only=True)
(S, Serr), _ = compute_coeffs(
density_func, nmax=10, lmax=0, M=1.0, r_s=1.0, S_only=True
)

# The above variable `S` will contain the expansion coefficients, and the
# variable `Serr` will contain an estimate of the error in this coefficient
Expand All @@ -125,27 +125,26 @@ def density_func(x, y, z):

S

pot = gp.SCFPotential(m=1., r_s=1,
Snlm=S, Tnlm=np.zeros_like(S))
pot = gp.SCFPotential(m=1.0, r_s=1, Snlm=S, Tnlm=np.zeros_like(S))

# Now let's visualize the SCF estimated density with the true density:

# +
x = np.logspace(-1, 1, 128)
plt.plot(x, density_func(x, 0, 0), marker='', label='custom density')
plt.plot(x, density_func(x, 0, 0), marker="", label="custom density")

# need a 3D grid for the potentials in Gala
xyz = np.zeros((3, len(x)))
xyz[0] = x
plt.plot(x, pot.density(xyz), marker='', label='SCF density')
plt.plot(x, pot.density(xyz), marker="", label="SCF density")

plt.xscale('log')
plt.yscale('log')
plt.xscale("log")
plt.yscale("log")

plt.xlabel('$r$')
plt.ylabel(r'$\rho(r)$')
plt.xlabel("$r$")
plt.ylabel(r"$\rho(r)$")

plt.legend(loc='best');
plt.legend(loc="best")


# -
Expand All @@ -172,9 +171,10 @@ def density_func(x, y, z):
# Cartesian coordinates (x, y, z) and returns the (scalar) value of the density
# at that location:


def density_func_flat(x, y, z, q):
r = np.sqrt(x**2 + y**2 + (z / q)**2)
return 1 / (r * (1 + r)**3) / (2*np.pi)
r = np.sqrt(x**2 + y**2 + (z / q) ** 2)
return 1 / (r * (1 + r) ** 3) / (2 * np.pi)


# Let's compute the density along a diagonal line for a few different
Expand All @@ -186,19 +186,23 @@ def density_func_flat(x, y, z, q):
xyz[0] = x
xyz[2] = x

for q in np.arange(0.6, 1+1e-3, 0.2):
plt.plot(x, density_func_flat(xyz[0], 0., xyz[2], q), marker='',
label=f'custom density: q={q}')
for q in np.arange(0.6, 1 + 1e-3, 0.2):
plt.plot(
x,
density_func_flat(xyz[0], 0.0, xyz[2], q),
marker="",
label=f"custom density: q={q}",
)

plt.plot(x, hern.density(xyz), marker='', ls='--', label='Hernquist')
plt.plot(x, hern.density(xyz), marker="", ls="--", label="Hernquist")

plt.xscale('log')
plt.yscale('log')
plt.xscale("log")
plt.yscale("log")

plt.xlabel('$r$')
plt.ylabel(r'$\rho(r)$')
plt.xlabel("$r$")
plt.ylabel(r"$\rho(r)$")

plt.legend(loc='best');
plt.legend(loc="best")
# -

# Because this is an axisymmetric density distribution, we need to also compute
Expand All @@ -208,32 +212,42 @@ def density_func_flat(x, y, z, q):
# pass `progress=True`, it will also display a progress bar:

q = 0.6
(S_flat, Serr_flat), _ = compute_coeffs(density_func_flat,
nmax=4, lmax=6, args=(q, ),
M=1., r_s=1., S_only=True,
skip_m=True, progress=True)

pot_flat = gp.SCFPotential(m=1., r_s=1,
Snlm=S_flat, Tnlm=np.zeros_like(S_flat))
(S_flat, Serr_flat), _ = compute_coeffs(
density_func_flat,
nmax=4,
lmax=6,
args=(q,),
M=1.0,
r_s=1.0,
S_only=True,
skip_m=True,
progress=True,
)

pot_flat = gp.SCFPotential(m=1.0, r_s=1, Snlm=S_flat, Tnlm=np.zeros_like(S_flat))

# +
x = np.logspace(-1, 1, 128)
xyz = np.zeros((3, len(x)))
xyz[0] = x
xyz[2] = x

plt.plot(x, density_func_flat(xyz[0], xyz[1], xyz[2], q), marker='',
label=f'true density q={q}')
plt.plot(
x,
density_func_flat(xyz[0], xyz[1], xyz[2], q),
marker="",
label=f"true density q={q}",
)

plt.plot(x, pot_flat.density(xyz), marker='', ls='--', label='SCF density')
plt.plot(x, pot_flat.density(xyz), marker="", ls="--", label="SCF density")

plt.xscale('log')
plt.yscale('log')
plt.xscale("log")
plt.yscale("log")

plt.xlabel('$r$')
plt.ylabel(r'$\rho(r)$')
plt.xlabel("$r$")
plt.ylabel(r"$\rho(r)$")

plt.legend(loc='best');
plt.legend(loc="best")
# -

# The SCF potential object acts like any other `gala.potential` object, meaning
Expand All @@ -242,25 +256,23 @@ def density_func_flat(x, y, z, q):
# +
grid = np.linspace(-8, 8, 128)

fig, axes = plt.subplots(1, 2, figsize=(10, 5),
sharex=True, sharey=True)
fig, axes = plt.subplots(1, 2, figsize=(10, 5), sharex=True, sharey=True)
_ = pot_flat.plot_contours((grid, grid, 0), ax=axes[0])
axes[0].set_xlabel('$x$')
axes[0].set_ylabel('$y$')
axes[0].set_xlabel("$x$")
axes[0].set_ylabel("$y$")

_ = pot_flat.plot_contours((grid, 0, grid), ax=axes[1])
axes[1].set_xlabel('$x$')
axes[1].set_ylabel('$z$')
axes[1].set_xlabel("$x$")
axes[1].set_ylabel("$z$")

for ax in axes:
ax.set_aspect('equal')
ax.set_aspect("equal")
# -

# And numerically integrate orbits by passing in initial conditions and
# integration parameters:

w0 = gd.PhaseSpacePosition(pos=[3.5, 0, 1],
vel=[0, 0.4, 0.05])
w0 = gd.PhaseSpacePosition(pos=[3.5, 0, 1], vel=[0, 0.4, 0.05])

orbit_flat = pot_flat.integrate_orbit(w0, dt=1., n_steps=5000)
orbit_flat = pot_flat.integrate_orbit(w0, dt=1.0, n_steps=5000)
_ = orbit_flat.plot()
Loading

0 comments on commit 19ff02a

Please sign in to comment.