Skip to content

Commit

Permalink
feat: add scalar support to where
Browse files Browse the repository at this point in the history
PR-URL: #860
Ref: #807
Co-authored-by: Athan Reines <[email protected]>
Reviewed-by: Athan Reines <[email protected]> 
Reviewed-by: Evgeni Burovski
Reviewed-by: Lucas Colley <[email protected]>
  • Loading branch information
betatim and kgryte authored Jan 9, 2025
1 parent fd6f507 commit d12a5e3
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
3 changes: 3 additions & 0 deletions spec/draft/API_specification/type_promotion.rst
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ Notes
.. note::
Mixed integer and floating-point type promotion rules are not specified because behavior varies between implementations.


.. _mixing-scalars-and-arrays:

Mixing arrays with Python scalars
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
20 changes: 17 additions & 3 deletions src/array_api_stubs/_draft/searching_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,21 +168,35 @@ def searchsorted(
"""


def where(condition: array, x1: array, x2: array, /) -> array:
def where(
condition: array,
x1: Union[array, int, float, complex, bool],
x2: Union[array, int, float, complex, bool],
/,
) -> array:
"""
Returns elements chosen from ``x1`` or ``x2`` depending on ``condition``.
Parameters
----------
condition: array
when ``True``, yield ``x1_i``; otherwise, yield ``x2_i``. Should have a boolean data type. Must be compatible with ``x1`` and ``x2`` (see :ref:`broadcasting`).
x1: array
x1: Union[array, int, float, complex, bool]
first input array. Must be compatible with ``condition`` and ``x2`` (see :ref:`broadcasting`).
x2: array
x2: Union[array, int, float, complex, bool]
second input array. Must be compatible with ``condition`` and ``x1`` (see :ref:`broadcasting`).
Returns
-------
out: array
an array with elements from ``x1`` where ``condition`` is ``True``, and elements from ``x2`` elsewhere. The returned array must have a data type determined by :ref:`type-promotion` rules with the arrays ``x1`` and ``x2``.
Notes
-----
- At least one of ``x1`` and ``x2`` must be an array.
- If either ``x1`` or ``x2`` is a scalar value, the returned array must have a data type determined according to :ref:`mixing-scalars-and-arrays`.
.. versionchanged:: 2024.12
Added support for scalar arguments.
"""

0 comments on commit d12a5e3

Please sign in to comment.