Skip to content

Commit

Permalink
Fix bug in desc.backend.put with numpy arrays modified in place
Browse files Browse the repository at this point in the history
note that the change in grid.py is only made because the bugfix in
desc.backend now makes the copy parameter useless. the replace_at_axis
function was working fine before and still is.
  • Loading branch information
unalmis committed Oct 22, 2024
1 parent 84a051b commit d245b4b
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
7 changes: 5 additions & 2 deletions desc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,12 @@ def put(arr, inds, vals):
Returns
-------
arr : array-like
Input array with vals inserted at inds.
Copy of input array with vals inserted at inds.
In some cases JAX may decide a copy is not necessary.
"""
if isinstance(arr, np.ndarray):
arr = arr.copy()
arr[inds] = vals
return arr
return jnp.asarray(arr).at[inds].set(vals)
Expand Down Expand Up @@ -509,9 +511,10 @@ def put(arr, inds, vals):
Returns
-------
arr : array-like
Input array with vals inserted at inds.
Copy of input array with vals inserted at inds.
"""
arr = arr.copy()
arr[inds] = vals
return arr

Expand Down
9 changes: 3 additions & 6 deletions desc/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ def body(i, y):
y = fori_loop(0, other_grid.num_nodes, body, y)
return y

def replace_at_axis(self, x, y, copy=False, **kwargs):
def replace_at_axis(self, x, y, **kwargs):
"""Replace elements of ``x`` with elements of ``y`` at the axis of grid.
Parameters
Expand All @@ -575,10 +575,6 @@ def replace_at_axis(self, x, y, copy=False, **kwargs):
Replacement values. Should broadcast with arrays of size
``grid.num_nodes``. Can also be a function that returns such an
array. Additional keyword arguments are then input to ``y``.
copy : bool
If some value of ``x`` is to be replaced by some value in ``y``,
then setting ``copy`` to true ensures that ``x`` will not be
modified in-place.
Returns
-------
Expand All @@ -588,11 +584,12 @@ def replace_at_axis(self, x, y, copy=False, **kwargs):
others match ``x``.
"""
kwargs.pop("copy", None)
if self.axis.size:
if callable(y):
y = y(**kwargs)
x = put(
x.copy() if copy else x,
x,
self.axis,
y[self.axis] if jnp.ndim(y) else y,
)
Expand Down

0 comments on commit d245b4b

Please sign in to comment.