diff --git a/pykokkos/lib/ufuncs.py b/pykokkos/lib/ufuncs.py index 5052cd6f..723e900a 100644 --- a/pykokkos/lib/ufuncs.py +++ b/pykokkos/lib/ufuncs.py @@ -1142,7 +1142,8 @@ def broadcast_view(val, viewB): val = val[0] if len(val.shape) == 1 else val[0][0] if is_view: - if not check_broadcastable_impl(val, viewB) or not val.shape < viewB.shape: + is_first_small = len(val.shape) < len(viewB.shape) or ((len(val.shape) == len(viewB.shape)) and val.shape < viewB.shape) + if not check_broadcastable_impl(val, viewB) or not is_first_small: raise ValueError("Incompatible broadcast") if not val.dtype == viewB.dtype: raise ValueError("Broadcastable views must have same dtypes") @@ -1235,9 +1236,9 @@ def subtract(viewA, valB): raise ValueError("Views must be broadcastable") # check if size is same otherwise broadcast and fix - if viewA.shape < valB.shape: + if len(viewA.shape) < len(valB.shape) or (len(viewA.shape) == len(valB.shape) and viewA.shape < valB.shape): viewA = broadcast_view(viewA, valB) - elif valB.shape < viewA.shape: + elif len(valB.shape) < len(viewA.shape) or (len(viewA.shape) == len(valB.shape) and valB.shape < viewA.shape): valB = broadcast_view(valB, viewA) if viewA.dtype.__name__ == "float64" and valB.dtype.__name__ == "float64":