Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: numba apply supports positional arguments passed as **kwargs #58995

Merged
merged 31 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
3a5fc90
add *args for raw numba apply
auderson May 18, 2024
3165efe
add whatsnew
auderson May 18, 2024
de89574
fix test_case
auderson May 18, 2024
3f13b30
fix pre-commit
auderson May 18, 2024
c026845
fix test case
auderson May 18, 2024
96581a3
add *args for raw=False as well; merge tests together
auderson May 19, 2024
2aae933
add prepare_function_arguments
auderson May 21, 2024
1a6f1ae
fix mypy
auderson May 21, 2024
8925b3a
update get_jit_arguments
auderson May 26, 2024
085ae73
add nopython test in `test_apply_args`
auderson May 26, 2024
c75e0b7
fix test
auderson May 26, 2024
ceb8178
fix pre-commit
auderson May 26, 2024
e191be9
Merge branch 'refs/heads/main' into enh_numba_apply_support_kwargs
auderson Jun 13, 2024
aa91722
modify prepare_function_arguments
auderson Jun 13, 2024
0de3224
add tests
auderson Jun 13, 2024
82252be
add tests
auderson Jun 13, 2024
da6dbc7
add whatsnew
auderson Jun 13, 2024
e72bfb2
compat for python 3.12
auderson Jun 13, 2024
8168d9b
pre-commit
auderson Jun 13, 2024
c211119
compat for python 3.12
auderson Jun 13, 2024
98499a1
Merge remote-tracking branch 'origin/enh_numba_apply_support_kwargs' …
auderson Oct 22, 2024
54fdb4a
Merge branch 'main' into enh_numba_apply_support_kwargs
auderson Oct 22, 2024
f7936f2
update doc; use kw-only
auderson Oct 23, 2024
f00fdd7
Merge branch 'main' into enh_numba_apply_support_kwargs
auderson Oct 27, 2024
33be300
Merge branch 'main' into enh_numba_apply_support_kwargs
auderson Oct 30, 2024
2400d28
add more tests
auderson Oct 30, 2024
8d10211
update whatsnew
auderson Oct 30, 2024
f672f9b
pre-commit
auderson Oct 30, 2024
1eba10b
move the tests to test_numba.py
auderson Oct 30, 2024
93925ba
Update doc/source/whatsnew/v3.0.0.rst
auderson Oct 31, 2024
09bdae0
Update doc/source/whatsnew/v3.0.0.rst
auderson Oct 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/source/whatsnew/v3.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Other enhancements
- :meth:`Series.cummin` and :meth:`Series.cummax` now supports :class:`CategoricalDtype` (:issue:`52335`)
- :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`)
- :meth:`DataFrame.plot.scatter` argument ``c`` now accepts a column of strings, where rows with the same string are colored identically (:issue:`16827` and :issue:`16485`)
- :meth:`DataFrameGroupBy.transform`, :meth:`SeriesGroupBy.transform`, :meth:`DataFrameGroupBy.agg`, :meth:`SeriesGroupBy.agg`, :meth:`RollingGroupby.apply`, :meth:`ExpandingGroupby.apply`, :meth:`Rolling.apply`, :meth:`Expanding.apply`, :meth:`DataFrame.apply`: when using numba engine in these apply methods, positional arguments now can be passed as kwargs (:issue:`58995`)
- :meth:`Series.map` can now accept kwargs to pass on to func (:issue:`59814`)
- :meth:`pandas.concat` will raise a ``ValueError`` when ``ignore_index=True`` and ``keys`` is not ``None`` (:issue:`59274`)
- :meth:`str.get_dummies` now accepts a ``dtype`` parameter to specify the dtype of the resulting DataFrame (:issue:`47872`)
Expand All @@ -63,6 +64,7 @@ Other enhancements
- Support reading Stata 102-format (Stata 1) dta files (:issue:`58978`)
- Support reading Stata 110-format (Stata 7) dta files (:issue:`47176`)


.. ---------------------------------------------------------------------------
.. _whatsnew_300.notable_bug_fixes:

Expand Down
15 changes: 10 additions & 5 deletions pandas/core/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,14 +994,15 @@ def wrapper(*args, **kwargs):
self.func, # type: ignore[arg-type]
self.args,
self.kwargs,
num_required_args=1,
)
# error: Argument 1 to "__call__" of "_lru_cache_wrapper" has
# incompatible type "Callable[..., Any] | str | list[Callable
# [..., Any] | str] | dict[Hashable,Callable[..., Any] | str |
# list[Callable[..., Any] | str]]"; expected "Hashable"
nb_looper = generate_apply_looper(
self.func, # type: ignore[arg-type]
**get_jit_arguments(engine_kwargs, kwargs),
**get_jit_arguments(engine_kwargs),
)
result = nb_looper(self.values, self.axis, *args)
# If we made the result 2-D, squeeze it back to 1-D
Expand Down Expand Up @@ -1158,9 +1159,11 @@ def numba_func(values, col_names, df_index, *args):

def apply_with_numba(self) -> dict[int, Any]:
func = cast(Callable, self.func)
args, kwargs = prepare_function_arguments(func, self.args, self.kwargs)
args, kwargs = prepare_function_arguments(
func, self.args, self.kwargs, num_required_args=1
)
nb_func = self.generate_numba_apply_func(
func, **get_jit_arguments(self.engine_kwargs, kwargs)
func, **get_jit_arguments(self.engine_kwargs)
)
from pandas.core._numba.extensions import set_numba_data

Expand Down Expand Up @@ -1298,9 +1301,11 @@ def numba_func(values, col_names_index, index, *args):

def apply_with_numba(self) -> dict[int, Any]:
func = cast(Callable, self.func)
args, kwargs = prepare_function_arguments(func, self.args, self.kwargs)
args, kwargs = prepare_function_arguments(
func, self.args, self.kwargs, num_required_args=1
)
nb_func = self.generate_numba_apply_func(
func, **get_jit_arguments(self.engine_kwargs, kwargs)
func, **get_jit_arguments(self.engine_kwargs)
)

from pandas.core._numba.extensions import set_numba_data
Expand Down
11 changes: 9 additions & 2 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ class providing the base-class of operations.
from pandas.core.util.numba_ import (
get_jit_arguments,
maybe_use_numba,
prepare_function_arguments,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -1289,8 +1290,11 @@ def _transform_with_numba(self, func, *args, engine_kwargs=None, **kwargs):

starts, ends, sorted_index, sorted_data = self._numba_prep(df)
numba_.validate_udf(func)
args, kwargs = prepare_function_arguments(
func, args, kwargs, num_required_args=2
)
numba_transform_func = numba_.generate_numba_transform_func(
func, **get_jit_arguments(engine_kwargs, kwargs)
func, **get_jit_arguments(engine_kwargs)
)
result = numba_transform_func(
sorted_data,
Expand Down Expand Up @@ -1325,8 +1329,11 @@ def _aggregate_with_numba(self, func, *args, engine_kwargs=None, **kwargs):

starts, ends, sorted_index, sorted_data = self._numba_prep(df)
numba_.validate_udf(func)
args, kwargs = prepare_function_arguments(
func, args, kwargs, num_required_args=2
)
numba_agg_func = numba_.generate_numba_agg_func(
func, **get_jit_arguments(engine_kwargs, kwargs)
func, **get_jit_arguments(engine_kwargs)
)
result = numba_agg_func(
sorted_data,
Expand Down
47 changes: 24 additions & 23 deletions pandas/core/util/numba_.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,14 @@ def set_use_numba(enable: bool = False) -> None:
GLOBAL_USE_NUMBA = enable


def get_jit_arguments(
engine_kwargs: dict[str, bool] | None = None, kwargs: dict | None = None
) -> dict[str, bool]:
def get_jit_arguments(engine_kwargs: dict[str, bool] | None = None) -> dict[str, bool]:
"""
Return arguments to pass to numba.JIT, falling back on pandas default JIT settings.

Parameters
----------
engine_kwargs : dict, default None
user passed keyword arguments for numba.JIT
kwargs : dict, default None
user passed keyword arguments to pass into the JITed function

Returns
-------
Expand All @@ -55,16 +51,6 @@ def get_jit_arguments(
engine_kwargs = {}

nopython = engine_kwargs.get("nopython", True)
if kwargs:
# Note: in case numba supports keyword-only arguments in
# a future version, we should remove this check. But this
# seems unlikely to happen soon.

raise NumbaUtilError(
"numba does not support keyword-only arguments"
"https://github.com/numba/numba/issues/2916, "
"https://github.com/numba/numba/issues/6846"
)
nogil = engine_kwargs.get("nogil", False)
parallel = engine_kwargs.get("parallel", False)
return {"nopython": nopython, "nogil": nogil, "parallel": parallel}
Expand Down Expand Up @@ -109,7 +95,7 @@ def jit_user_function(func: Callable) -> Callable:


def prepare_function_arguments(
func: Callable, args: tuple, kwargs: dict
func: Callable, args: tuple, kwargs: dict, *, num_required_args: int
) -> tuple[tuple, dict]:
"""
Prepare arguments for jitted function. As numba functions do not support kwargs,
Expand All @@ -118,11 +104,17 @@ def prepare_function_arguments(
Parameters
----------
func : function
user defined function
User defined function
args : tuple
user input positional arguments
User input positional arguments
kwargs : dict
user input keyword arguments
User input keyword arguments
num_required_args : int
The number of leading positional arguments we will pass to udf.
These are not supplied by the user.
e.g. for groupby we require "values", "index" as the first two arguments:
`numba_func(group, group_index, *args)`, in this case num_required_args=2.
See :func:`pandas.core.groupby.numba_.generate_numba_agg_func`

Returns
-------
Expand All @@ -133,17 +125,26 @@ def prepare_function_arguments(
if not kwargs:
return args, kwargs

# the udf should have this pattern: def udf(value, *args, **kwargs):...
# the udf should have this pattern: def udf(arg1, arg2, ..., *args, **kwargs):...
signature = inspect.signature(func)
arguments = signature.bind(_sentinel, *args, **kwargs)
arguments = signature.bind(*[_sentinel] * num_required_args, *args, **kwargs)
arguments.apply_defaults()
# Ref: https://peps.python.org/pep-0362/
# Arguments which could be passed as part of either *args or **kwargs
# will be included only in the BoundArguments.args attribute.
args = arguments.args
kwargs = arguments.kwargs

assert args[0] is _sentinel
args = args[1:]
if kwargs:
# Note: in case numba supports keyword-only arguments in
# a future version, we should remove this check. But this
# seems unlikely to happen soon.

raise NumbaUtilError(
"numba does not support keyword-only arguments"
"https://github.com/numba/numba/issues/2916, "
"https://github.com/numba/numba/issues/6846"
)

args = args[num_required_args:]
return args, kwargs
9 changes: 6 additions & 3 deletions pandas/core/window/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from pandas.core.util.numba_ import (
get_jit_arguments,
maybe_use_numba,
prepare_function_arguments,
)
from pandas.core.window.common import (
flex_binary_moment,
Expand Down Expand Up @@ -1472,14 +1473,16 @@ def apply(
if maybe_use_numba(engine):
if raw is False:
raise ValueError("raw must be `True` when using the numba engine")
numba_args = args
numba_args, kwargs = prepare_function_arguments(
func, args, kwargs, num_required_args=1
)
if self.method == "single":
apply_func = generate_numba_apply_func(
func, **get_jit_arguments(engine_kwargs, kwargs)
func, **get_jit_arguments(engine_kwargs)
)
else:
apply_func = generate_numba_table_func(
func, **get_jit_arguments(engine_kwargs, kwargs)
func, **get_jit_arguments(engine_kwargs)
)
elif engine in ("cython", None):
if engine_kwargs is not None:
Expand Down
10 changes: 10 additions & 0 deletions pandas/tests/apply/test_frame_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,16 @@ def test_apply_args(float_frame, axis, raw, engine, nopython):
tm.assert_frame_equal(result, expected)

if engine == "numba":
# py signature binding
with pytest.raises(TypeError, match="missing a required argument: 'a'"):
float_frame.apply(
lambda x, a: x + a,
b=2,
raw=raw,
engine=engine,
engine_kwargs=engine_kwargs,
)

# keyword-only arguments are not supported in numba
with pytest.raises(
pd.errors.NumbaUtilError,
Expand Down
29 changes: 27 additions & 2 deletions pandas/tests/groupby/aggregate/test_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,43 @@ def incorrect_function(x):
def test_check_nopython_kwargs():
pytest.importorskip("numba")

def incorrect_function(values, index):
return sum(values) * 2.7
def incorrect_function(values, index, *, a):
return sum(values) * 2.7 + a

def correct_function(values, index, a):
return sum(values) * 2.7 + a

data = DataFrame(
{"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]},
columns=["key", "data"],
)
expected = data.groupby("key").sum() * 2.7

# py signature binding
with pytest.raises(
TypeError, match="missing a required (keyword-only argument|argument): 'a'"
):
data.groupby("key").agg(incorrect_function, engine="numba", b=1)
with pytest.raises(TypeError, match="missing a required argument: 'a'"):
data.groupby("key").agg(correct_function, engine="numba", b=1)

with pytest.raises(
TypeError, match="missing a required (keyword-only argument|argument): 'a'"
):
data.groupby("key")["data"].agg(incorrect_function, engine="numba", b=1)
with pytest.raises(TypeError, match="missing a required argument: 'a'"):
data.groupby("key")["data"].agg(correct_function, engine="numba", b=1)

# numba signature check after binding
with pytest.raises(NumbaUtilError, match="numba does not support"):
data.groupby("key").agg(incorrect_function, engine="numba", a=1)
actual = data.groupby("key").agg(correct_function, engine="numba", a=1)
tm.assert_frame_equal(expected + 1, actual)

with pytest.raises(NumbaUtilError, match="numba does not support"):
data.groupby("key")["data"].agg(incorrect_function, engine="numba", a=1)
actual = data.groupby("key")["data"].agg(correct_function, engine="numba", a=1)
tm.assert_series_equal(expected["data"] + 1, actual)


@pytest.mark.filterwarnings("ignore")
Expand Down
29 changes: 27 additions & 2 deletions pandas/tests/groupby/transform/test_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,43 @@ def incorrect_function(x):
def test_check_nopython_kwargs():
pytest.importorskip("numba")

def incorrect_function(values, index):
return values + 1
def incorrect_function(values, index, *, a):
return values + a

def correct_function(values, index, a):
return values + a

data = DataFrame(
{"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]},
columns=["key", "data"],
)
# py signature binding
with pytest.raises(
TypeError, match="missing a required (keyword-only argument|argument): 'a'"
):
data.groupby("key").transform(incorrect_function, engine="numba", b=1)
with pytest.raises(TypeError, match="missing a required argument: 'a'"):
data.groupby("key").transform(correct_function, engine="numba", b=1)

with pytest.raises(
TypeError, match="missing a required (keyword-only argument|argument): 'a'"
):
data.groupby("key")["data"].transform(incorrect_function, engine="numba", b=1)
with pytest.raises(TypeError, match="missing a required argument: 'a'"):
data.groupby("key")["data"].transform(correct_function, engine="numba", b=1)

# numba signature check after binding
with pytest.raises(NumbaUtilError, match="numba does not support"):
data.groupby("key").transform(incorrect_function, engine="numba", a=1)
actual = data.groupby("key").transform(correct_function, engine="numba", a=1)
tm.assert_frame_equal(data[["data"]] + 1, actual)

with pytest.raises(NumbaUtilError, match="numba does not support"):
data.groupby("key")["data"].transform(incorrect_function, engine="numba", a=1)
actual = data.groupby("key")["data"].transform(
correct_function, engine="numba", a=1
)
tm.assert_series_equal(data["data"] + 1, actual)


@pytest.mark.filterwarnings("ignore")
Expand Down
Loading
Loading