From 966181dbc687fa2cf1d72c25084fe20855df1801 Mon Sep 17 00:00:00 2001 From: Tom White Date: Mon, 12 Aug 2024 14:41:09 +0100 Subject: [PATCH] Allow `bool` in `sum` and `prod` --- cubed/array_api/statistical_functions.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/cubed/array_api/statistical_functions.py b/cubed/array_api/statistical_functions.py index 78ff6ae2a..7ee6525e1 100644 --- a/cubed/array_api/statistical_functions.py +++ b/cubed/array_api/statistical_functions.py @@ -1,6 +1,7 @@ import math from cubed.array_api.dtypes import ( + _boolean_dtypes, _numeric_dtypes, _real_floating_dtypes, _real_numeric_dtypes, @@ -124,10 +125,13 @@ def min(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None) def prod( x, /, *, axis=None, dtype=None, keepdims=False, use_new_impl=True, split_every=None ): - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in prod") + # boolean is allowed by numpy + if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes: + raise TypeError("Only numeric or boolean dtypes are allowed in prod") if dtype is None: - if x.dtype in _signed_integer_dtypes: + if x.dtype in _boolean_dtypes: + dtype = int64 + elif x.dtype in _signed_integer_dtypes: dtype = int64 elif x.dtype in _unsigned_integer_dtypes: dtype = uint64 @@ -153,10 +157,13 @@ def prod( def sum( x, /, *, axis=None, dtype=None, keepdims=False, use_new_impl=True, split_every=None ): - if x.dtype not in _numeric_dtypes: - raise TypeError("Only numeric dtypes are allowed in sum") + # boolean is allowed by numpy + if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes: + raise TypeError("Only numeric or boolean dtypes are allowed in sum") if dtype is None: - if x.dtype in _signed_integer_dtypes: + if x.dtype in _boolean_dtypes: + dtype = int64 + elif x.dtype in _signed_integer_dtypes: dtype = int64 elif x.dtype in _unsigned_integer_dtypes: dtype = uint64