Skip to content

Commit

Permalink
Merge pull request #84 from ddudt/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
dpanici authored May 30, 2021
2 parents 16d2f19 + e04b557 commit c68fa66
Show file tree
Hide file tree
Showing 17 changed files with 1,606 additions and 147 deletions.
89 changes: 81 additions & 8 deletions desc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
desc.__version__, jax.__version__, jaxlib.__version__, y.dtype
)
)
del x,y
del x, y
except:
jnp = np
x = jnp.linspace(0, 5)
Expand All @@ -60,6 +60,9 @@
if use_jax:
jit = jax.jit
fori_loop = jax.lax.fori_loop
cond = jax.lax.cond
switch = jax.lax.switch
while_loop = jax.lax.while_loop
from jax.scipy.linalg import cho_factor, cho_solve, qr, solve_triangular

def put(arr, inds, vals):
Expand Down Expand Up @@ -118,13 +121,6 @@ def fori_loop(lower, upper, body_fun, init_val):
"""Loop from lower to upper, applying body_fun to init_val
This version is for the numpy backend, for jax backend see jax.lax.fori_loop
The semantics of ``fori_loop`` are given by this Python implementation::
def fori_loop(lower, upper, body_fun, init_val):
val = init_val
for i in range(lower, upper):
val = body_fun(i, val)
return val
Parameters
----------
Expand All @@ -147,3 +143,80 @@ def fori_loop(lower, upper, body_fun, init_val):
for i in np.arange(lower, upper):
val = body_fun(i, val)
return val

def cond(pred, true_fun, false_fun, operand):
"""Conditionally apply true_fun or false_fun.
This version is for the numpy backend, for jax backend see jax.lax.cond
Parameters
----------
pred: bool
which branch function to apply.
true_fun: callable
Function (A -> B), to be applied if pred is True.
false_fun: callable
Function (A -> B), to be applied if pred is False.
operand: any
input to either branch depending on pred. The type can be a scalar, array,
or any pytree (nested Python tuple/list/dict) thereof.
Returns
-------
value: any
value of either true_fun(operand) or false_fun(operand), depending on the
value of pred. The type can be a scalar, array, or any pytree (nested
Python tuple/list/dict) thereof.
"""
if pred:
return true_fun(operand)
else:
return false_fun(operand)

def switch(index, branches, operand):
"""Apply exactly one of branches given by index.
If index is out of bounds, it is clamped to within bounds.
Parameters
----------
index: int
which branch function to apply.
branches: Sequence[Callable]
sequence of functions (A -> B) to be applied based on index.
operand: any
input to whichever branch is applied.
Returns
-------
value: any
output of branches[index](operand)
"""
index = np.clip(index, 0, len(branches) - 1)
return branches[index](operand)

def while_loop(cond_fun, body_fun, init_val):
"""Call body_fun repeatedly in a loop while cond_fun is True.
Parameters
----------
cond_fun: callable
function of type a -> bool.
body_fun: callable
function of type a -> a.
init_val: any
value of type a, a type that can be a scalar, array, or any pytree (nested
Python tuple/list/dict) thereof, representing the initial loop carry value.
Returns
-------
value: any
The output from the final iteration of body_fun, of type a.
"""
val = init_val
while cond_fun(val):
val = body_fun(val)
return val
77 changes: 77 additions & 0 deletions desc/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from abc import ABC
from shapely.geometry import LineString, MultiLineString

from desc.backend import jnp, put
from desc.io import IOAble
from desc.utils import unpack_state, copy_coeffs
from desc.grid import Grid, LinearGrid, ConcentricGrid, QuadratureGrid
Expand Down Expand Up @@ -1231,6 +1232,82 @@ def compute_dW(self, grid=None):
dW = obj.hess_x(x, self.Rb_lmn, self.Zb_lmn, self.p_l, self.i_l, self.Psi)
return dW

def compute_flux_coords(self, real_coords, tol=1e-6, maxiter=20, rhomin=1e-6):
"""Finds the flux coordinates (rho, theta, zeta) that correspond to a set of
real space coordinates (R,phi,Z)
Parameters
----------
real_coords : ndarray, shape(k,3)
2d array of real space coordinates [R,phi,Z]. Each row is a different coordinate.
tol : float
Stopping tolerance. Iterations stop when sqrt((R-Ri)**2 + (Z-Zi)**2) < tol
maxiter : int > 0
maximum number of Newton iterations
rhomin : float
minimum allowable value of rho (to avoid singularity at rho=0)
Returns
-------
flux_coords : ndarray, shape(k,3)
flux coordinates [rho,theta,zeta]. If Newton method doesn't converge for
a given coordinate (often because it is outside the plasma boundary),
nan will be returned for those values
"""

R = real_coords[:, 0]
phi = real_coords[:, 1]
Z = real_coords[:, 2]
if maxiter <= 0:
raise ValueError(f"maxiter must be a positive integer, got{maxiter}")
if jnp.any(R) <= 0:
raise ValueError("R values must be positive")

R0, Z0 = self.compute_axis_location(zeta=phi)
theta = jnp.arctan2(Z - Z0, R - R0)
rho = 0.5 * jnp.ones_like(theta) # TODO: better initial guess
grid = Grid(jnp.vstack([rho, theta, phi]).T, sort=False)

R_transform = Transform(grid, self.R_basis, derivs=1, method="direct1")
Z_transform = Transform(grid, self.Z_basis, derivs=1, method="direct1")

Rk = R_transform.transform(self.R_lmn)
Zk = Z_transform.transform(self.Z_lmn)
eR = R - Rk
eZ = Z - Zk

k = 0
while jnp.any(jnp.sqrt((eR) ** 2 + (eZ) ** 2) > tol) and k < maxiter:
Rr = R_transform.transform(self.R_lmn, 1, 0, 0)
Rt = R_transform.transform(self.R_lmn, 0, 1, 0)
Zr = Z_transform.transform(self.Z_lmn, 1, 0, 0)
Zt = Z_transform.transform(self.Z_lmn, 0, 1, 0)

tau = Rt * Zr - Rr * Zt
theta += (Zr * eR - Rr * eZ) / tau
rho += (Rt * eZ - Zt * eR) / tau
# negative rho -> rotate theta instead
theta = jnp.where(rho < 0, -theta % (2 * np.pi), theta % (2 * np.pi))
rho = jnp.clip(rho, rhomin, 1)

grid = Grid(jnp.vstack([rho, theta, phi]).T, sort=False)
R_transform = Transform(grid, self.R_basis, derivs=1, method="direct1")
Z_transform = Transform(grid, self.Z_basis, derivs=1, method="direct1")

Rk = R_transform.transform(self.R_lmn)
Zk = Z_transform.transform(self.Z_lmn)
eR = R - Rk
eZ = Z - Zk
k += 1

if k >= maxiter: # didn't converge for all, mark those as nan
i = np.where(jnp.sqrt((eR) ** 2 + (eZ) ** 2) > tol)
rho = put(rho, i, np.nan)
theta = put(theta, i, np.nan)
phi = put(phi, i, np.nan)

return jnp.vstack([rho, theta, phi]).T

def compute_axis_location(self, zeta=0):
"""Find the axis location on specified zeta plane(s).
Expand Down
7 changes: 5 additions & 2 deletions desc/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class Grid(IOAble):
----------
nodes : ndarray of float, size(num_nodes,3)
node coordinates, in (rho,theta,zeta)
sort : bool
whether to sort the nodes for use with FFT method.
"""

Expand All @@ -34,7 +36,7 @@ class Grid(IOAble):
"_node_pattern",
]

def __init__(self, nodes):
def __init__(self, nodes, sort=True):

self._L = np.unique(nodes[:, 0]).size
self._M = np.unique(nodes[:, 1]).size
Expand All @@ -46,7 +48,8 @@ def __init__(self, nodes):
self._nodes, self._weights = self._create_nodes(nodes)

self._enforce_symmetry()
self._sort_nodes()
if sort:
self._sort_nodes()
self._find_axis()

def __eq__(self, other):
Expand Down
Loading

0 comments on commit c68fa66

Please sign in to comment.