Skip to content

Commit

Permalink
refactor: lax.broadcast_in_dim_p
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman committed Sep 23, 2024
1 parent c8f711f commit fae184c
Showing 1 changed file with 2 additions and 10 deletions.
12 changes: 2 additions & 10 deletions tests/myarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,16 +219,8 @@ def _bitcast_convert_type_p() -> MyArray:


@register(lax.broadcast_in_dim_p)
def _broadcast_in_dim_p(
operand: MyArray,
*,
shape: Any,
broadcast_dimensions: Any,
) -> MyArray:
return replace(
operand,
array=lax.broadcast_in_dim(operand.array, shape, broadcast_dimensions),
)
def _broadcast_in_dim_p(operand: MyArray, **kwargs: Any) -> MyArray:
return replace(operand, array=lax.broadcast_in_dim(operand.array, **kwargs))


# ==============================================================================
Expand Down

0 comments on commit fae184c

Please sign in to comment.