From fae184c6242d22b36308854ccf091139f32018a1 Mon Sep 17 00:00:00 2001 From: nstarman Date: Mon, 23 Sep 2024 13:59:50 -0400 Subject: [PATCH] refactor: lax.broadcast_in_dim_p Signed-off-by: nstarman --- tests/myarray.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/tests/myarray.py b/tests/myarray.py index 651f65f..c3d2f46 100644 --- a/tests/myarray.py +++ b/tests/myarray.py @@ -219,16 +219,8 @@ def _bitcast_convert_type_p() -> MyArray: @register(lax.broadcast_in_dim_p) -def _broadcast_in_dim_p( - operand: MyArray, - *, - shape: Any, - broadcast_dimensions: Any, -) -> MyArray: - return replace( - operand, - array=lax.broadcast_in_dim(operand.array, shape, broadcast_dimensions), - ) +def _broadcast_in_dim_p(operand: MyArray, **kwargs: Any) -> MyArray: + return replace(operand, array=lax.broadcast_in_dim(operand.array, **kwargs)) # ==============================================================================