Skip to content

Commit

Permalink
Fixes to work with xarray
Browse files Browse the repository at this point in the history
  • Loading branch information
yairchu committed Nov 25, 2024
1 parent 24d1f01 commit 250e841
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
3 changes: 3 additions & 0 deletions autograd/numpy/numpy_vjps.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ def grad_transpose(ans, x, axes=None):


defvjp(anp.transpose, grad_transpose)
defvjp(anp.permute_dims, grad_transpose)


def repeat_to_match_shape(g, shape, dtype, axis, keepdims):
Expand All @@ -428,6 +429,8 @@ def repeat_to_match_shape(g, shape, dtype, axis, keepdims):


def grad_broadcast_to(ans, x, new_shape):
while len(x.shape) < len(new_shape):
x = x[None]
old_shape = anp.shape(x)
assert anp.shape(ans) == new_shape
assert len(old_shape) == len(new_shape), "Can't handle extra leading dims"
Expand Down
1 change: 1 addition & 0 deletions autograd/numpy/numpy_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def concatenate_args(axis, *args):


concatenate = lambda arr_list, axis=0: concatenate_args(axis, *arr_list)
concat = concatenate
vstack = row_stack = lambda tup: concatenate([atleast_2d(_m) for _m in tup], axis=0)


Expand Down

0 comments on commit 250e841

Please sign in to comment.