From d9cc154eaeb8f2289b64d500893de9b0eb3cdcd5 Mon Sep 17 00:00:00 2001 From: Hannan Naeem Date: Sun, 24 Mar 2024 01:11:14 -0500 Subject: [PATCH] subtract + broadcast: Fixed tuple comparison for view shapes --- pykokkos/lib/ufuncs.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pykokkos/lib/ufuncs.py b/pykokkos/lib/ufuncs.py index 5052cd6f..723e900a 100644 --- a/pykokkos/lib/ufuncs.py +++ b/pykokkos/lib/ufuncs.py @@ -1142,7 +1142,8 @@ def broadcast_view(val, viewB): val = val[0] if len(val.shape) == 1 else val[0][0] if is_view: - if not check_broadcastable_impl(val, viewB) or not val.shape < viewB.shape: + is_first_small = len(val.shape) < len(viewB.shape) or ((len(val.shape) == len(viewB.shape)) and val.shape < viewB.shape) + if not check_broadcastable_impl(val, viewB) or not is_first_small: raise ValueError("Incompatible broadcast") if not val.dtype == viewB.dtype: raise ValueError("Broadcastable views must have same dtypes") @@ -1235,9 +1236,9 @@ def subtract(viewA, valB): raise ValueError("Views must be broadcastable") # check if size is same otherwise broadcast and fix - if viewA.shape < valB.shape: + if len(viewA.shape) < len(valB.shape) or (len(viewA.shape) == len(valB.shape) and viewA.shape < valB.shape): viewA = broadcast_view(viewA, valB) - elif valB.shape < viewA.shape: + elif len(valB.shape) < len(viewA.shape) or (len(viewA.shape) == len(valB.shape) and valB.shape < viewA.shape): valB = broadcast_view(valB, viewA) if viewA.dtype.__name__ == "float64" and valB.dtype.__name__ == "float64":