Skip to content

Commit

Permalink
Move basis change to main compute function (#1027)
Browse files Browse the repository at this point in the history
All compute quantities are now returned in toroidal $(R,\phi,Z)$
coordinates.

Resolves #992
Resolves #1088
  • Loading branch information
f0uriest authored Jul 11, 2024
2 parents 56befac + c6792ff commit 0fbee18
Show file tree
Hide file tree
Showing 18 changed files with 433 additions and 681 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ Changelog

New Features

- All vector variables are now computed in toroidal (R,phi,Z) coordinates by default.
Cartesian (X,Y,Z) coordinates can be requested with the compute keyword ``basis='xyz'``.
- Add method ``from_values`` to ``FourierRZCurve`` to allow fitting of data points
to a ``FourierRZCurve`` object, and ``to_FourierRZCurve`` methods to ``Curve`` class.
- Adds the objective `CoilsetMinDistance`, which returns the minimum distance to another
Expand Down
13 changes: 9 additions & 4 deletions desc/coils.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,10 @@ def compute(
"""
if params is None:
params = [get_params(names, coil) for coil in self]
params = [
get_params(names, coil, basis=kwargs.get("basis", "rpz"))
for coil in self
]
if data is None:
data = [{}] * len(self)

Expand Down Expand Up @@ -940,9 +943,9 @@ def _compute_position(self, params=None, grid=None, **kwargs):
Coil positions, in [R,phi,Z] or [X,Y,Z] coordinates.
"""
if params is None:
params = [get_params("x", coil) for coil in self]
basis = kwargs.pop("basis", "xyz")
if params is None:
params = [get_params("x", coil, basis=basis) for coil in self]
data = self.compute("x", grid=grid, params=params, basis=basis, **kwargs)
data = tree_leaves(data, is_leaf=lambda x: isinstance(x, dict))
x = jnp.dstack([d["x"].T for d in data]).T # shape=(ncoils,num_nodes,3)
Expand Down Expand Up @@ -1009,7 +1012,9 @@ def compute_magnetic_field(
assert basis.lower() in ["rpz", "xyz"]
coords = jnp.atleast_2d(jnp.asarray(coords))
if params is None:
params = [get_params(["x_s", "x", "s", "ds"], coil) for coil in self]
params = [
get_params(["x_s", "x", "s", "ds"], coil, basis=basis) for coil in self
]
for par, coil in zip(params, self):
par["current"] = coil.current

Expand Down
14 changes: 7 additions & 7 deletions desc/compute/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ def _build_data_index():
for p in data_index:
for key in data_index[p]:
full = {
"data": get_data_deps(key, p, has_axis=False),
"transforms": get_derivs(key, p, has_axis=False),
"params": get_params(key, p, has_axis=False),
"profiles": get_profiles(key, p, has_axis=False),
"data": get_data_deps(key, p, has_axis=False, basis="rpz"),
"transforms": get_derivs(key, p, has_axis=False, basis="rpz"),
"params": get_params(key, p, has_axis=False, basis="rpz"),
"profiles": get_profiles(key, p, has_axis=False, basis="rpz"),
}
data_index[p][key]["full_dependencies"] = full

Expand All @@ -81,9 +81,9 @@ def _build_data_index():
else:
full_with_axis = {
"data": full_with_axis_data,
"transforms": get_derivs(key, p, has_axis=True),
"params": get_params(key, p, has_axis=True),
"profiles": get_profiles(key, p, has_axis=True),
"transforms": get_derivs(key, p, has_axis=True, basis="rpz"),
"params": get_params(key, p, has_axis=True, basis="rpz"),
"profiles": get_profiles(key, p, has_axis=True, basis="rpz"),
}
for _key, val in full_with_axis.items():
if full[_key] == val:
Expand Down
Loading

0 comments on commit 0fbee18

Please sign in to comment.