diff --git a/autograd/numpy/numpy_vjps.py b/autograd/numpy/numpy_vjps.py index d0e73b64..99875ce0 100644 --- a/autograd/numpy/numpy_vjps.py +++ b/autograd/numpy/numpy_vjps.py @@ -411,6 +411,7 @@ def grad_transpose(ans, x, axes=None): defvjp(anp.transpose, grad_transpose) +defvjp(anp.permute_dims, grad_transpose) def repeat_to_match_shape(g, shape, dtype, axis, keepdims): @@ -428,6 +429,8 @@ def repeat_to_match_shape(g, shape, dtype, axis, keepdims): def grad_broadcast_to(ans, x, new_shape): + while len(x.shape) < len(new_shape): + x = x[None] old_shape = anp.shape(x) assert anp.shape(ans) == new_shape assert len(old_shape) == len(new_shape), "Can't handle extra leading dims" diff --git a/autograd/numpy/numpy_wrapper.py b/autograd/numpy/numpy_wrapper.py index baa0aed3..26ffb9b3 100644 --- a/autograd/numpy/numpy_wrapper.py +++ b/autograd/numpy/numpy_wrapper.py @@ -45,6 +45,7 @@ def concatenate_args(axis, *args): concatenate = lambda arr_list, axis=0: concatenate_args(axis, *arr_list) +concat = concatenate vstack = row_stack = lambda tup: concatenate([atleast_2d(_m) for _m in tup], axis=0)