Skip to content

Commit

Permalink
Add missing array_api functions (#576)
Browse files Browse the repository at this point in the history
* add array_api.var and array_api.std

* add arange/asarray

* fix test

* nonzero, argmin, argmax, unique_values

* flake8

* add tests for statistical functions and fix bug where the array result type was wrong

* another fix

* another fix

* flake8

* flake8

* astype, result_type, iinfo, finfo, can_cast

* doctest

* any/all

* concat

* flip

* where

* isort

* isort

* remove flip from the list of unimplemented functions

* increase timeout

* debug CI failure

* undo changes to rbc_test.yml
  • Loading branch information
guilhermeleobas authored Aug 1, 2023
1 parent 8fdca78 commit fc2ce64
Show file tree
Hide file tree
Showing 22 changed files with 1,463 additions and 103 deletions.
20 changes: 19 additions & 1 deletion rbc/heavydb/array.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Implement HeavyDB Array type support
"""

__all__ = ['ArrayPointer', 'Array', 'HeavyDBArrayType']
__all__ = ['ArrayPointer', 'Array', 'HeavyDBArrayType', 'type_can_asarray']

from typing import Optional, Union

Expand Down Expand Up @@ -224,6 +224,11 @@ def typer(lst):
return typer


def type_can_asarray(typ):
supported = (ArrayPointer, nb_types.List)
return isinstance(typ, supported)


@extending.overload_attribute(ArrayPointer, 'ndim')
def get_ndim(arr):
def impl(arr):
Expand All @@ -247,3 +252,16 @@ def impl(arr):
lst.append(arr[i])
return lst
return impl


@extending.overload_method(ArrayPointer, 'drop_null')
def ol_drop_null(arr):
from rbc.stdlib import array_api

def impl(arr):
lst = []
for i in range(len(arr)):
if not arr.is_null(i):
lst.append(arr[i])
return array_api.asarray(lst)
return impl
36 changes: 22 additions & 14 deletions rbc/stdlib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import textwrap
import functools
import textwrap
from enum import Enum

import numpy as np
from numba.core import extending
from numba.np.numpy_support import as_dtype, from_dtype

from rbc import errors, typesystem
from rbc.heavydb import Array, ArrayPointer
from rbc import typesystem, errors

ARRAY_API_ADDRESS = ("https://data-apis.org/array-api/latest/API_specification/"
"generated/array_api.{0}.html#array_api.{0}")
Expand All @@ -16,13 +20,6 @@ class API(Enum):
ARRAY_API = 1


def determine_dtype(a, dtype):
if isinstance(a, ArrayPointer):
return a.eltype if dtype is None else dtype
else:
return a if dtype is None else dtype


def determine_input_type(argty):
if isinstance(argty, ArrayPointer):
return determine_input_type(argty.eltype)
Expand All @@ -33,6 +30,14 @@ def determine_input_type(argty):
return argty


def result_type(*args, dtype=None):
if dtype:
return dtype

typs = [np.bool if arg == typesystem.boolean8 else as_dtype(arg) for arg in args]
return from_dtype(np.result_type(*typs))


class Expose:
def __init__(self, globals, module_name):
self._globals = globals
Expand Down Expand Up @@ -124,29 +129,29 @@ def broadcast(e, sz, dtype):
return b

if isinstance(a, ArrayPointer) and isinstance(b, ArrayPointer):
nb_dtype = determine_dtype(a, dtype)
nb_dtype = result_type(a.eltype, b.eltype, dtype=dtype)

def impl(a, b):
return binary_impl(a, b, nb_dtype)
return impl
elif isinstance(a, ArrayPointer):
nb_dtype = determine_dtype(a, dtype)
nb_dtype = result_type(a.eltype, b, dtype=dtype)
other_dtype = b

def impl(a, b):
b = broadcast(b, len(a), other_dtype)
return binary_impl(a, b, nb_dtype)
return impl
elif isinstance(b, ArrayPointer):
nb_dtype = determine_dtype(b, dtype)
nb_dtype = result_type(a, b.eltype, dtype=dtype)
other_dtype = a

def impl(a, b):
a = broadcast(a, len(b), other_dtype)
return binary_impl(a, b, nb_dtype)
return impl
else:
nb_dtype = determine_dtype(a, dtype)
nb_dtype = result_type(a, b, dtype=dtype)

def impl(a, b):
cast_a = typA(a)
Expand Down Expand Up @@ -176,10 +181,11 @@ def implements(self, ufunc, func_name=None, dtype=None, api=API.ARRAY_API):
func_name = ufunc.__name__

def unary_ufunc_impl(a):
nb_dtype = determine_dtype(a, dtype)
typ = determine_input_type(a)

if isinstance(a, ArrayPointer):
nb_dtype = result_type(a.eltype, dtype=dtype)

def impl(a):
sz = len(a)
x = Array(sz, nb_dtype)
Expand All @@ -190,6 +196,8 @@ def impl(a):
return x
return impl
else:
nb_dtype = result_type(a, dtype=dtype)

def impl(a):
# Convert the value to type typ
cast = typ(a)
Expand Down
1 change: 0 additions & 1 deletion rbc/stdlib/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
"""
import numpy as np


__all__ = [
'e', 'inf', 'nan', 'pi'
]
Expand Down
134 changes: 125 additions & 9 deletions rbc/stdlib/creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
https://data-apis.org/array-api/latest/API_specification/creation_functions.html
"""

from rbc import typesystem, errors
from rbc.heavydb import Array, ArrayPointer
from rbc.stdlib import Expose, API
from numba.core import extending, types
from numba import TypingError
from numba.core import cgutils, extending, types
from numba.core.typing import asnumbatype

from rbc import errors, typesystem
from rbc.heavydb import Array, ArrayPointer
from rbc.stdlib import API, Expose

__all__ = [
'full', 'full_like', 'empty_like', 'empty', 'zeros', 'zeros_like',
'ones', 'ones_like', 'cumsum', 'arange', 'asarray',
Expand Down Expand Up @@ -41,20 +42,135 @@ def _determine_dtype(dtype, fill_value):
return typesystem.Type.fromobject(dtype).tonumba()


@expose.not_implemented('arange')
@expose.implements('arange')
def _array_api_arange(start, stop=None, step=1, dtype=None, device=None):
"""
Return evenly spaced values within a given interval.
"""
pass
Examples
--------
IGNORE:
>>> from rbc.heavydb import global_heavydb_singleton
>>> heavydb = next(global_heavydb_singleton)
>>> heavydb.unregister()
IGNORE
>>> from rbc.stdlib import array_api
>>> @heavydb('int64[](int64)')
... def rbc_arange(start):
... return array_api.arange(start)
>>> rbc_arange(5).execute()
array([0, 1, 2, 3, 4])
>>> @heavydb('double[](int64, float, float)')
... def rbc_arange3(start, stop, step):
... return array_api.arange(start, stop, step)
>>> rbc_arange3(0, 1.0, 0.2).execute()
array([0. , 0.2, 0.4, 0.6, 0.8], dtype=float32)
>>> @heavydb('int8[](int64, int64)')
... def rbc_arange3(start, stop):
... return array_api.arange(start, stop, dtype=array_api.int8)
>>> rbc_arange3(2, 5).execute()
array([2, 3, 4], dtype=int8)
"""

# Code based on Numba implementation of "np.arange(...)"

def _arange_dtype(*args):
"""
If dtype is None, the output array data type must be inferred from
start, stop and step. If those are all integers, the output array dtype
must be the default integer dtype; if one or more have type float, then
the output array dtype must be the default real-valued floating-point
data type.
"""
if any(isinstance(a, types.Float) for a in args):
return types.float64
else:
return types.int64

if not isinstance(start, types.Number):
raise errors.NumbaTypeError('"start" must be a number')

if not cgutils.is_nonelike(stop) and not isinstance(stop, types.Number):
raise errors.NumbaTypeError('"stop" must be a number')

if cgutils.is_nonelike(dtype):
true_dtype = _arange_dtype(start, stop, step)
elif isinstance(dtype, types.DTypeSpec):
true_dtype = dtype.dtype
elif isinstance(dtype, types.StringLiteral):
true_dtype = dtype.literal_value
else:
msg = f'If specified, "dtype" must be a DTypeSpec. Got {dtype}'
raise errors.NumbaTypeError(msg)

from rbc.stdlib import array_api

def impl(start, stop=None, step=1, dtype=None, device=None):
if stop is None:
_start, _stop = 0, start
else:
_start, _stop = start, stop

_step = step

@expose.not_implemented('asarray')
if _step == 0:
raise ValueError("Maximum allowed size exceeded")

nitems_c = (_stop - _start) / _step
nitems_r = int(array_api.ceil(nitems_c))

nitems = max(nitems_r, 0)
arr = Array(nitems, true_dtype)
if nitems == 0:
arr.set_null()
return arr
val = _start
for i in range(nitems):
arr[i] = val + (i * _step)
return arr
return impl


@expose.implements('asarray')
def _array_api_asarray(obj, dtype=None, device=None, copy=None):
"""
Convert the input to an array.
Examples
--------
IGNORE:
>>> from rbc.heavydb import global_heavydb_singleton
>>> heavydb = next(global_heavydb_singleton)
>>> heavydb.unregister()
IGNORE
>>> from rbc.stdlib import array_api
>>> @heavydb('float[](int64[])')
... def rbc_asarray(arr):
... return array_api.asarray(arr, dtype=array_api.float32)
>>> rbc_asarray([1, 2, 3]).execute()
array([1., 2., 3.], dtype=float32)
"""
pass
if isinstance(obj, (ArrayPointer, types.List)):
if isinstance(obj, ArrayPointer):
nb_dtype = obj.eltype if dtype is None else dtype.dtype
else:
nb_dtype = obj.dtype if dtype is None else dtype.dtype

def impl(obj, dtype=None, device=None, copy=None):
sz = len(obj)
arr = Array(sz, nb_dtype)
for i in range(sz):
arr[i] = nb_dtype(obj[i])
return arr
return impl


@expose.not_implemented('eye')
Expand Down
Loading

0 comments on commit fc2ce64

Please sign in to comment.