Skip to content

Commit

Permalink
Merge branch 'master' into basis_vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
f0uriest committed Aug 10, 2023
2 parents 11ce811 + 6adcb4c commit b8b53ba
Show file tree
Hide file tree
Showing 17 changed files with 401 additions and 402 deletions.
4 changes: 2 additions & 2 deletions desc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
bincount = jnp.bincount
from jax.experimental.ode import odeint
from jax.scipy.linalg import block_diag, cho_factor, cho_solve, qr, solve_triangular
from jax.scipy.special import gammaln
from jax.scipy.special import gammaln, logsumexp
from jax.tree_util import register_pytree_node

def put(arr, inds, vals):
Expand Down Expand Up @@ -127,7 +127,7 @@ def sign(x):
qr,
solve_triangular,
)
from scipy.special import gammaln # noqa: F401
from scipy.special import gammaln, logsumexp # noqa: F401

def register_pytree_node(foo, *args):
"""Dummy decorator for non-jax pytrees."""
Expand Down
58 changes: 26 additions & 32 deletions desc/compute/_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,39 +278,33 @@ def j_dot_B_Redl(
J_dot_B = dnds_term + dTeds_term + dTids_term

# Store all results in the J_dot_B_data dictionary:
nu_e_star = nu_e # noqa: F841
nu_i_star = nu_i # noqa: F841
variables = [
"rho",
"ne",
"ni",
"Zeff",
"Te",
"Ti",
"d_ne_d_s",
"d_Te_d_s",
"d_Ti_d_s",
"ln_Lambda_e",
"ln_Lambda_ii",
"nu_e_star",
"nu_i_star",
"X31",
"X32e",
"X32ei",
"F32ee",
"F32ei",
"L31",
"L32",
"L34",
"alpha0",
"alpha",
"dnds_term",
"dTeds_term",
"dTids_term",
]
J_dot_B_data = geom_data.copy()
for v in variables:
J_dot_B_data[v] = eval(v)
J_dot_B_data["rho"] = rho
J_dot_B_data["ne"] = ne
J_dot_B_data["ni"] = ni
J_dot_B_data["Zeff"] = Zeff
J_dot_B_data["Te"] = Te
J_dot_B_data["Ti"] = Ti
J_dot_B_data["d_ne_d_s"] = d_ne_d_s
J_dot_B_data["d_Te_d_s"] = d_Te_d_s
J_dot_B_data["d_Ti_d_s"] = d_Ti_d_s
J_dot_B_data["ln_Lambda_e"] = ln_Lambda_e
J_dot_B_data["ln_Lambda_ii"] = ln_Lambda_ii
J_dot_B_data["nu_e_star"] = nu_e
J_dot_B_data["nu_i_star"] = nu_i
J_dot_B_data["X31"] = X31
J_dot_B_data["X32e"] = X32e
J_dot_B_data["X32ei"] = X32ei
J_dot_B_data["F32ee"] = F32ee
J_dot_B_data["F32ei"] = F32ei
J_dot_B_data["L31"] = L31
J_dot_B_data["L32"] = L32
J_dot_B_data["L34"] = L34
J_dot_B_data["alpha0"] = alpha0
J_dot_B_data["alpha"] = alpha
J_dot_B_data["dnds_term"] = dnds_term
J_dot_B_data["dTeds_term"] = dTeds_term
J_dot_B_data["dTids_term"] = dTids_term
J_dot_B_data["<J*B>"] = J_dot_B
return J_dot_B_data

Expand Down
22 changes: 7 additions & 15 deletions desc/compute/_curve.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from desc.backend import jnp

from .data_index import register_compute_fun
from .geom_utils import rpz2xyz, rpz2xyz_vec, xyz2rpz, xyz2rpz_vec
from .geom_utils import (
_rotation_matrix_from_normal,
rpz2xyz,
rpz2xyz_vec,
xyz2rpz,
xyz2rpz_vec,
)
from .utils import cross, dot


Expand All @@ -24,20 +30,6 @@ def _s(params, transforms, profiles, data, **kwargs):
return data


def _rotation_matrix_from_normal(normal):
nx, ny, nz = normal
nxny = jnp.sqrt(nx**2 + ny**2)
R = jnp.array(
[
[ny / nxny, -nx / nxny, 0],
[nx * nx / nxny, ny * nz / nxny, -nxny],
[nx, ny, nz],
]
).T
R = jnp.where(nxny == 0, jnp.eye(3), R)
return R


@register_compute_fun(
name="x",
label="\\mathbf{r}",
Expand Down
14 changes: 14 additions & 0 deletions desc/compute/geom_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,17 @@ def rpz2xyz_vec(vec, x=None, y=None, phi=None):
rot = rot.T
cart = jnp.matmul(rot, vec.reshape((-1, 3, 1)))
return cart.reshape((-1, 3))


def _rotation_matrix_from_normal(normal):
nx, ny, nz = normal
nxny = jnp.sqrt(nx**2 + ny**2)
R = jnp.array(
[
[ny / nxny, -nx / nxny, 0],
[nx * nx / nxny, ny * nz / nxny, -nxny],
[nx, ny, nz],
]
).T
R = jnp.where(nxny == 0, jnp.eye(3), R)
return R
13 changes: 7 additions & 6 deletions desc/objectives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@
QuasisymmetryTwoTerm,
)
from ._stability import MagneticWell, MercierStability
from .getters import (
get_equilibrium_objective,
get_fixed_axis_constraints,
get_fixed_boundary_constraints,
get_NAE_constraints,
maybe_add_self_consistency,
)
from .linear_objectives import (
AxisRSelfConsistency,
AxisZSelfConsistency,
Expand All @@ -55,9 +62,3 @@
FixThetaSFL,
)
from .objective_funs import ObjectiveFunction
from .utils import (
get_equilibrium_objective,
get_fixed_axis_constraints,
get_fixed_boundary_constraints,
get_NAE_constraints,
)
4 changes: 2 additions & 2 deletions desc/objectives/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .normalization import compute_scaling_factors
from .objective_funs import _Objective
from .utils import jax_softmin
from .utils import softmin


class AspectRatio(_Objective):
Expand Down Expand Up @@ -613,7 +613,7 @@ def compute(self, *args, **kwargs):
plasma_coords[:, None, :] - constants["surface_coords"][None, :, :], axis=-1
)
if self._use_softmin: # do softmin
return jnp.apply_along_axis(jax_softmin, 0, d, self._alpha)
return jnp.apply_along_axis(softmin, 0, d, self._alpha)
else: # do hardmin
return d.min(axis=0)

Expand Down
Loading

0 comments on commit b8b53ba

Please sign in to comment.