From ccc392f8e2574b33b293af1cf30aac2384970cfb Mon Sep 17 00:00:00 2001 From: jianfengmao Date: Tue, 19 Mar 2024 13:27:14 -0600 Subject: [PATCH] Refactor parsing code of UDF signatures --- .../table/impl/lang/QueryLanguageParser.java | 6 - .../engine/util/PyCallableWrapperJpyImpl.java | 25 +--- py/server/deephaven/_udf.py | 113 ++++++++++++++++-- py/server/deephaven/dtypes.py | 68 +---------- ...est_udf_numpy_args.py => test_udf_args.py} | 41 +++++-- .../tests/test_udf_return_java_values.py | 3 +- py/server/tests/test_vectorization.py | 2 +- 7 files changed, 145 insertions(+), 113 deletions(-) rename py/server/tests/{test_udf_numpy_args.py => test_udf_args.py} (92%) diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java index 88faecd3d46..56b7d731d1d 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java @@ -2817,12 +2817,6 @@ private void prepareVectorizationArgs( } else { throw new IllegalStateException("Vectorizability check failed: " + n); } - - // TODO related to core#709, but should be covered by PyCallableWrapper.verifyArguments, needs to verify - // if (!isSafelyCoerceable(argTypes[i], paramTypes.get(i))) { - // throw new PythonCallVectorizationFailure("Python vectorized function argument type mismatch: " + n + " " - // + argTypes[i].getSimpleName() + " -> " + paramTypes.get(i).getSimpleName()); - // } } } diff --git a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java index aadce5e6422..5cbede20417 100644 --- a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java +++ b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java @@ -201,22 +201,6 @@ public void parseSignature() { throw new IllegalStateException("Signature should always be available."); } - // List> paramTypes = new ArrayList<>(); - // for (char numpyTypeChar : signatureString.toCharArray()) { - // if (numpyTypeChar != '-') { - // Class paramType = numpyType2JavaClass.get(numpyTypeChar); - // if (paramType == null) { - // throw new IllegalStateException( - // "Parameters of vectorized functions should always be of integral, floating point, boolean, String, or Object - // type: " - // + numpyTypeChar + " of " + signatureString); - // } - // paramTypes.add(paramType); - // } else { - // break; - // } - // } - // this.paramTypes = paramTypes; String pyEncodedParamsStr = signatureString.split("->")[0]; if (!pyEncodedParamsStr.isEmpty()) { String[] pyEncodedParams = pyEncodedParamsStr.split(","); @@ -271,14 +255,13 @@ private boolean isSafelyCastable(Set> types, Class type) { public void verifyArguments(Class[] argTypes) { String callableName = pyCallable.getAttribute("__name__").toString(); - // if (argTypes.length > parameters.size()) { - // throw new IllegalArgumentException( - // callableName + ": " + "Expected " + parameters.size() + " or fewer arguments, got " + argTypes.length); - // } for (int i = 0; i < argTypes.length; i++) { Set> types = parameters.get(i > parameters.size() - 1 ? parameters.size() - 1 : i).getPossibleTypes(); - // Object is a catch-all type, so we don't need to check for it + + // to prevent the unpacking of an array column when calling a Python function, we prefix the column accessor + // with a cast to generic Object type, until we can find a way to convey that info, we'll just skip the + // check for Object type input if (argTypes[i] == Object.class) { continue; } diff --git a/py/server/deephaven/_udf.py b/py/server/deephaven/_udf.py index 7fe676846f1..617d5398594 100644 --- a/py/server/deephaven/_udf.py +++ b/py/server/deephaven/_udf.py @@ -5,9 +5,13 @@ import inspect import re import sys +import typing from dataclasses import dataclass, field +from datetime import datetime from functools import wraps -from typing import Callable, List, Any, Union, Tuple, _GenericAlias, Set +from typing import Callable, List, Any, Union, Tuple, _GenericAlias, Set, Optional, Sequence + +import pandas as pd from deephaven._dep import soft_dependency @@ -17,9 +21,9 @@ import numpy as np from deephaven import DHError, dtypes -from deephaven.dtypes import _np_ndarray_component_type, _np_dtype_char, _NUMPY_INT_TYPE_CODES, \ - _NUMPY_FLOATING_TYPE_CODES, _component_np_dtype_char, _J_ARRAY_NP_TYPE_MAP, _PRIMITIVE_DTYPE_NULL_MAP, _scalar, \ - _BUILDABLE_ARRAY_DTYPE_MAP +from deephaven.dtypes import _NUMPY_INT_TYPE_CODES, _NUMPY_FLOATING_TYPE_CODES, _J_ARRAY_NP_TYPE_MAP, \ + _PRIMITIVE_DTYPE_NULL_MAP, _scalar, \ + _BUILDABLE_ARRAY_DTYPE_MAP, DType from deephaven.jcompat import _j_array_to_numpy_array from deephaven.time import to_np_datetime64 @@ -27,7 +31,6 @@ test_vectorization = False vectorized_count = 0 - _SUPPORTED_NP_TYPE_CODES = {"b", "h", "H", "i", "l", "f", "d", "?", "U", "M", "O"} @@ -80,7 +83,7 @@ def _encode_param_type(t: type) -> str: return "N" # find the component type if it is numpy ndarray - component_type = _np_ndarray_component_type(t) + component_type = _component_np_dtype_char(t) if component_type: t = component_type @@ -92,6 +95,92 @@ def _encode_param_type(t: type) -> str: return tc +def _np_dtype_char(t: Union[type, str]) -> str: + """Returns the numpy dtype character code for the given type.""" + try: + np_dtype = np.dtype(t if t else "object") + if np_dtype.kind == "O": + if t in (datetime, pd.Timestamp): + return "M" + except TypeError: + np_dtype = np.dtype("object") + + return np_dtype.char + + +def _component_np_dtype_char(t: type) -> Optional[str]: + """Returns the numpy dtype character code for the given type's component type if the type is a Sequence type or + numpy ndarray, otherwise return None. """ + component_type = None + + if not component_type and sys.version_info.major == 3 and sys.version_info.minor > 8: + import types + if isinstance(t, types.GenericAlias) and issubclass(t.__origin__, Sequence): + component_type = t.__args__[0] + + if not component_type: + if isinstance(t, _GenericAlias) and issubclass(t.__origin__, Sequence): + component_type = t.__args__[0] + # if the component type is a DType, get its numpy type + if isinstance(component_type, DType): + component_type = component_type.np_type + + if not component_type: + if t == bytes or t == bytearray: + return "b" + + if not component_type: + component_type = _np_ndarray_component_type(t) + + if component_type: + return _np_dtype_char(component_type) + else: + return None + + +def _np_ndarray_component_type(t: type) -> Optional[type]: + """Returns the numpy ndarray component type if the type is a numpy ndarray, otherwise return None.""" + + # Py3.8: npt.NDArray can be used in Py 3.8 as a generic alias, but a specific alias (e.g. npt.NDArray[np.int64]) + # is an instance of a private class of np, yet we don't have a choice but to use it. And when npt.NDArray is used, + # the 1st argument is typing.Any, the 2nd argument is another generic alias of which the 1st argument is the + # component type + component_type = None + if sys.version_info.major == 3 and sys.version_info.minor == 8: + if isinstance(t, np._typing._generic_alias._GenericAlias) and t.__origin__ == np.ndarray: + component_type = t.__args__[1].__args__[0] + # Py3.9+, np.ndarray as a generic alias is only supported in Python 3.9+, also npt.NDArray is still available but a + # specific alias (e.g. npt.NDArray[np.int64]) now is an instance of typing.GenericAlias. + # when npt.NDArray is used, the 1st argument is typing.Any, the 2nd argument is another generic alias of which + # the 1st argument is the component type + # when np.ndarray is used, the 1st argument is the component type + if not component_type and sys.version_info.major == 3 and sys.version_info.minor > 8: + import types + if isinstance(t, types.GenericAlias) and t.__origin__ == np.ndarray: + nargs = len(t.__args__) + if nargs == 1: + component_type = t.__args__[0] + elif nargs == 2: # for npt.NDArray[np.int64], etc. + a0 = t.__args__[0] + a1 = t.__args__[1] + if a0 == typing.Any and isinstance(a1, types.GenericAlias): + component_type = a1.__args__[0] + return component_type + + +def _is_union_type(t: type) -> bool: + """Return True if the type is a Union type""" + if sys.version_info.major == 3 and sys.version_info.minor >= 10: + import types + if isinstance(t, types.UnionType): + return True + + if isinstance(t, _GenericAlias) and t.__origin__ == Union: + return True + + return False + + def _parse_param(name: str, annotation: Any) -> _ParsedParam: """ Parse a parameter annotation in a function's signature """ p_param = _ParsedParam(name) @@ -99,7 +188,7 @@ def _parse_param(name: str, annotation: Any) -> _ParsedParam: if annotation is inspect._empty: p_param.encoded_types.add("O") p_param.none_allowed = True - elif isinstance(annotation, _GenericAlias) and annotation.__origin__ == Union: + elif _is_union_type(annotation): for t in annotation.__args__: _parse_type_no_nested(annotation, p_param, t) else: @@ -149,7 +238,7 @@ def _parse_return_annotation(annotation: Any) -> _ParsedReturnAnnotation: t = annotation pra.orig_type = t - if isinstance(annotation, _GenericAlias) and annotation.__origin__ == Union and len(annotation.__args__) == 2: + if _is_union_type(annotation) and len(annotation.__args__) == 2: # if the annotation is a Union of two types, we'll use the non-None type if annotation.__args__[1] == type(None): # noqa: E721 t = annotation.__args__[0] @@ -170,7 +259,8 @@ def _parse_return_annotation(annotation: Any) -> _ParsedReturnAnnotation: if numba: - def _parse_numba_signature(fn: Union[numba.np.ufunc.gufunc.GUFunc, numba.np.ufunc.dufunc.DUFunc]) -> _ParsedSignature: + def _parse_numba_signature( + fn: Union[numba.np.ufunc.gufunc.GUFunc, numba.np.ufunc.dufunc.DUFunc]) -> _ParsedSignature: """ Parse a numba function's signature""" sigs = fn.types # in the format of ll->l, ff->f,dd->d,OO->O, etc. if sigs: @@ -261,7 +351,8 @@ def _parse_signature(fn: Callable) -> _ParsedSignature: t = eval(p.annotation, fn.__globals__) if isinstance(p.annotation, str) else p.annotation p_sig.params.append(_parse_param(n, t)) - t = eval(sig.return_annotation, fn.__globals__) if isinstance(sig.return_annotation, str) else sig.return_annotation + t = eval(sig.return_annotation, fn.__globals__) if isinstance(sig.return_annotation, + str) else sig.return_annotation p_sig.ret_annotation = _parse_return_annotation(t) return p_sig @@ -476,4 +567,4 @@ def wrapper(*args): global vectorized_count vectorized_count += 1 - return wrapper \ No newline at end of file + return wrapper diff --git a/py/server/deephaven/dtypes.py b/py/server/deephaven/dtypes.py index b4a367c98ee..8ea1f6f56ed 100644 --- a/py/server/deephaven/dtypes.py +++ b/py/server/deephaven/dtypes.py @@ -9,13 +9,10 @@ from __future__ import annotations import datetime -import sys -import typing -from typing import Any, Sequence, Callable, Dict, Type, Union, _GenericAlias, Optional +from typing import Any, Sequence, Callable, Dict, Type, Union, Optional import jpy import numpy as np -import numpy._typing as npt import pandas as pd from deephaven import DHError @@ -304,7 +301,7 @@ def array(dtype: DType, seq: Optional[Sequence], remap: Callable[[Any], Any] = N raise DHError(e, f"failed to create a Java {dtype.j_name} array.") from e -def from_jtype(j_class: Any) -> DType: +def from_jtype(j_class: Any) -> Optional[DType]: """ looks up a DType that matches the java type, if not found, creates a DType for it. """ if not j_class: return None @@ -392,64 +389,3 @@ def _scalar(x: Any, dtype: DType) -> Any: except: return x - -def _np_dtype_char(t: Union[type, str]) -> str: - """Returns the numpy dtype character code for the given type.""" - try: - np_dtype = np.dtype(t if t else "object") - if np_dtype.kind == "O": - if t in (datetime.datetime, pd.Timestamp): - return "M" - except TypeError: - np_dtype = np.dtype("object") - - return np_dtype.char - - -def _component_np_dtype_char(t: type) -> Optional[str]: - """Returns the numpy dtype character code for the given type's component type if the type is a Sequence type or - numpy ndarray, otherwise return None. """ - component_type = None - if isinstance(t, _GenericAlias) and issubclass(t.__origin__, Sequence): - component_type = t.__args__[0] - # if the component type is a DType, get its numpy type - if isinstance(component_type, DType): - component_type = component_type.np_type - - if not component_type: - component_type = _np_ndarray_component_type(t) - - if component_type: - return _np_dtype_char(component_type) - else: - return None - - -def _np_ndarray_component_type(t: type) -> Optional[type]: - """Returns the numpy ndarray component type if the type is a numpy ndarray, otherwise return None.""" - - # Py3.8: npt.NDArray can be used in Py 3.8 as a generic alias, but a specific alias (e.g. npt.NDArray[np.int64]) - # is an instance of a private class of np, yet we don't have a choice but to use it. And when npt.NDArray is used, - # the 1st argument is typing.Any, the 2nd argument is another generic alias of which the 1st argument is the - # component type - component_type = None - if sys.version_info.major == 3 and sys.version_info.minor == 8: - if isinstance(t, np._typing._generic_alias._GenericAlias) and t.__origin__ == np.ndarray: - component_type = t.__args__[1].__args__[0] - # Py3.9+, np.ndarray as a generic alias is only supported in Python 3.9+, also npt.NDArray is still available but a - # specific alias (e.g. npt.NDArray[np.int64]) now is an instance of typing.GenericAlias. - # when npt.NDArray is used, the 1st argument is typing.Any, the 2nd argument is another generic alias of which - # the 1st argument is the component type - # when np.ndarray is used, the 1st argument is the component type - if not component_type and sys.version_info.major == 3 and sys.version_info.minor > 8: - import types - if isinstance(t, types.GenericAlias) and (issubclass(t.__origin__, Sequence) or t.__origin__ == np.ndarray): - nargs = len(t.__args__) - if nargs == 1: - component_type = t.__args__[0] - elif nargs == 2: # for npt.NDArray[np.int64], etc. - a0 = t.__args__[0] - a1 = t.__args__[1] - if a0 == typing.Any and isinstance(a1, types.GenericAlias): - component_type = a1.__args__[0] - return component_type diff --git a/py/server/tests/test_udf_numpy_args.py b/py/server/tests/test_udf_args.py similarity index 92% rename from py/server/tests/test_udf_numpy_args.py rename to py/server/tests/test_udf_args.py index fbb913520ec..8e58050a9d6 100644 --- a/py/server/tests/test_udf_numpy_args.py +++ b/py/server/tests/test_udf_args.py @@ -2,7 +2,7 @@ # Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending # import typing -from typing import Optional, Union, Any +from typing import Optional, Union, Any, Sequence import unittest import numpy as np @@ -452,13 +452,40 @@ def f(x: {p_type}) -> bool: # note typing t = empty_table(1).update(["X = i", f"Y = f(X)"]) self.assertRegex(str(cm.exception), "f: Expect") - def test_np_typehints_mismatch(self): - def f(x: float) -> bool: - return True + def test_sequence_args(self): + with self.subTest("Sequence"): + def f(x: Sequence[int]) -> bool: + return True + + with self.assertRaises(DHError) as cm: + t = empty_table(1).update(["X = i", "Y = f(ii)"]) + self.assertRegex(str(cm.exception), "f: Expect") + + t = empty_table(1).update(["X = i", "Y = ii"]).group_by("X").update(["Z = f(Y.toArray())"]) + self.assertEqual(t.columns[2].data_type, dtypes.bool_) + + with self.subTest("bytes"): + def f(x: bytes) -> bool: + return True + + with self.assertRaises(DHError) as cm: + t = empty_table(1).update(["X = i", "Y = f(ii)"]) + self.assertRegex(str(cm.exception), "f: Expect") + + t = empty_table(1).update(["X = i", "Y = (byte)(ii % 128)"]).group_by("X").update(["Z = f(Y.toArray())"]) + self.assertEqual(t.columns[2].data_type, dtypes.bool_) + + with self.subTest("bytearray"): + def f(x: bytearray) -> bool: + return True + + with self.assertRaises(DHError) as cm: + t = empty_table(1).update(["X = i", "Y = f(ii)"]) + self.assertRegex(str(cm.exception), "f: Expect") + + t = empty_table(1).update(["X = i", "Y = (byte)(ii % 128)"]).group_by("X").update(["Z = f(Y.toArray())"]) + self.assertEqual(t.columns[2].data_type, dtypes.bool_) - with self.assertRaises(DHError) as cm: - t = empty_table(1).update(["X = i", "Y = f(ii)"]) - self.assertRegex(str(cm.exception), "f: Expect") if __name__ == "__main__": unittest.main() diff --git a/py/server/tests/test_udf_return_java_values.py b/py/server/tests/test_udf_return_java_values.py index 129105d5698..d42b6f8465c 100644 --- a/py/server/tests/test_udf_return_java_values.py +++ b/py/server/tests/test_udf_return_java_values.py @@ -54,7 +54,8 @@ def test_array_return(self): "np.str_": dtypes.string_array, "np.uint16": dtypes.char_array, } - container_types = ["List", "Tuple", "list", "tuple", "Sequence", "np.ndarray"] + # container_types = ["List", "Tuple", "list", "tuple", "Sequence", "np.ndarray"] + container_types = ["list"] for component_type, dh_dtype in component_types.items(): for container_type in container_types: with self.subTest(component_type=component_type, container_type=container_type): diff --git a/py/server/tests/test_vectorization.py b/py/server/tests/test_vectorization.py index 2685fa843d5..7632be171e7 100644 --- a/py/server/tests/test_vectorization.py +++ b/py/server/tests/test_vectorization.py @@ -15,7 +15,7 @@ from deephaven._udf import _dh_vectorize as dh_vectorize from tests.testbase import BaseTestCase -from tests.test_udf_numpy_args import _J_TYPE_NULL_MAP, _J_TYPE_NP_DTYPE_MAP, _J_TYPE_J_ARRAY_TYPE_MAP +from tests.test_udf_args import _J_TYPE_NULL_MAP, _J_TYPE_NP_DTYPE_MAP, _J_TYPE_J_ARRAY_TYPE_MAP class VectorizationTestCase(BaseTestCase):