Skip to content

Commit

Permalink
subtract + broadcast: Fixed tuple comparison for view shapes (#267)
Browse files Browse the repository at this point in the history
  • Loading branch information
HannanNaeem authored Mar 24, 2024
1 parent 2f662df commit f51eae6
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions pykokkos/lib/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1142,7 +1142,8 @@ def broadcast_view(val, viewB):
val = val[0] if len(val.shape) == 1 else val[0][0]

if is_view:
if not check_broadcastable_impl(val, viewB) or not val.shape < viewB.shape:
is_first_small = len(val.shape) < len(viewB.shape) or ((len(val.shape) == len(viewB.shape)) and val.shape < viewB.shape)
if not check_broadcastable_impl(val, viewB) or not is_first_small:
raise ValueError("Incompatible broadcast")
if not val.dtype == viewB.dtype:
raise ValueError("Broadcastable views must have same dtypes")
Expand Down Expand Up @@ -1235,9 +1236,9 @@ def subtract(viewA, valB):
raise ValueError("Views must be broadcastable")

# check if size is same otherwise broadcast and fix
if viewA.shape < valB.shape:
if len(viewA.shape) < len(valB.shape) or (len(viewA.shape) == len(valB.shape) and viewA.shape < valB.shape):
viewA = broadcast_view(viewA, valB)
elif valB.shape < viewA.shape:
elif len(valB.shape) < len(viewA.shape) or (len(viewA.shape) == len(valB.shape) and valB.shape < viewA.shape):
valB = broadcast_view(valB, viewA)

if viewA.dtype.__name__ == "float64" and valB.dtype.__name__ == "float64":
Expand Down

0 comments on commit f51eae6

Please sign in to comment.