Skip to content
This repository was archived by the owner on Jan 9, 2020. It is now read-only.

Commit bfc7e1f

Browse files
icexellossHyukjinKwon
authored andcommitted
[SPARK-20396][SQL][PYSPARK] groupby().apply() with pandas udf
## What changes were proposed in this pull request? This PR adds an apply() function on df.groupby(). apply() takes a pandas udf that is a transformation on `pandas.DataFrame` -> `pandas.DataFrame`. Static schema ------------------- ``` schema = df.schema pandas_udf(schema) def normalize(df): df = df.assign(v1 = (df.v1 - df.v1.mean()) / df.v1.std() return df df.groupBy('id').apply(normalize) ``` Dynamic schema ----------------------- **This use case is removed from the PR and we will discuss this as a follow up. See discussion apache#18732 (review) Another example to use pd.DataFrame dtypes as output schema of the udf: ``` sample_df = df.filter(df.id == 1).toPandas() def foo(df): ret = # Some transformation on the input pd.DataFrame return ret foo_udf = pandas_udf(foo, foo(sample_df).dtypes) df.groupBy('id').apply(foo_udf) ``` In interactive use case, user usually have a sample pd.DataFrame to test function `foo` in their notebook. Having been able to use `foo(sample_df).dtypes` frees user from specifying the output schema of `foo`. Design doc: https://github.com/icexelloss/spark/blob/pandas-udf-doc/docs/pyspark-pandas-udf.md ## How was this patch tested? * Added GroupbyApplyTest Author: Li Jin <[email protected]> Author: Takuya UESHIN <[email protected]> Author: Bryan Cutler <[email protected]> Closes apache#18732 from icexelloss/groupby-apply-SPARK-20396.
1 parent 2028e5a commit bfc7e1f

File tree

14 files changed

+561
-69
lines changed

14 files changed

+561
-69
lines changed

python/pyspark/sql/dataframe.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1227,7 +1227,7 @@ def groupBy(self, *cols):
12271227
"""
12281228
jgd = self._jdf.groupBy(self._jcols(*cols))
12291229
from pyspark.sql.group import GroupedData
1230-
return GroupedData(jgd, self.sql_ctx)
1230+
return GroupedData(jgd, self)
12311231

12321232
@since(1.4)
12331233
def rollup(self, *cols):
@@ -1248,7 +1248,7 @@ def rollup(self, *cols):
12481248
"""
12491249
jgd = self._jdf.rollup(self._jcols(*cols))
12501250
from pyspark.sql.group import GroupedData
1251-
return GroupedData(jgd, self.sql_ctx)
1251+
return GroupedData(jgd, self)
12521252

12531253
@since(1.4)
12541254
def cube(self, *cols):
@@ -1271,7 +1271,7 @@ def cube(self, *cols):
12711271
"""
12721272
jgd = self._jdf.cube(self._jcols(*cols))
12731273
from pyspark.sql.group import GroupedData
1274-
return GroupedData(jgd, self.sql_ctx)
1274+
return GroupedData(jgd, self)
12751275

12761276
@since(1.3)
12771277
def agg(self, *exprs):

python/pyspark/sql/functions.py

+72-26
Original file line numberDiff line numberDiff line change
@@ -2058,7 +2058,7 @@ def __init__(self, func, returnType, name=None, vectorized=False):
20582058
self._name = name or (
20592059
func.__name__ if hasattr(func, '__name__')
20602060
else func.__class__.__name__)
2061-
self._vectorized = vectorized
2061+
self.vectorized = vectorized
20622062

20632063
@property
20642064
def returnType(self):
@@ -2090,7 +2090,7 @@ def _create_judf(self):
20902090
wrapped_func = _wrap_function(sc, self.func, self.returnType)
20912091
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
20922092
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
2093-
self._name, wrapped_func, jdt, self._vectorized)
2093+
self._name, wrapped_func, jdt, self.vectorized)
20942094
return judf
20952095

20962096
def __call__(self, *cols):
@@ -2118,8 +2118,10 @@ def wrapper(*args):
21182118
wrapper.__name__ = self._name
21192119
wrapper.__module__ = (self.func.__module__ if hasattr(self.func, '__module__')
21202120
else self.func.__class__.__module__)
2121+
21212122
wrapper.func = self.func
21222123
wrapper.returnType = self.returnType
2124+
wrapper.vectorized = self.vectorized
21232125

21242126
return wrapper
21252127

@@ -2129,8 +2131,12 @@ def _create_udf(f, returnType, vectorized):
21292131
def _udf(f, returnType=StringType(), vectorized=vectorized):
21302132
if vectorized:
21312133
import inspect
2132-
if len(inspect.getargspec(f).args) == 0:
2133-
raise NotImplementedError("0-parameter pandas_udfs are not currently supported")
2134+
argspec = inspect.getargspec(f)
2135+
if len(argspec.args) == 0 and argspec.varargs is None:
2136+
raise ValueError(
2137+
"0-arg pandas_udfs are not supported. "
2138+
"Instead, create a 1-arg pandas_udf and ignore the arg in your function."
2139+
)
21342140
udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized)
21352141
return udf_obj._wrapped()
21362142

@@ -2146,7 +2152,7 @@ def _udf(f, returnType=StringType(), vectorized=vectorized):
21462152

21472153
@since(1.3)
21482154
def udf(f=None, returnType=StringType()):
2149-
"""Creates a :class:`Column` expression representing a user defined function (UDF).
2155+
"""Creates a user defined function (UDF).
21502156
21512157
.. note:: The user-defined functions must be deterministic. Due to optimization,
21522158
duplicate invocations may be eliminated or the function may even be invoked more times than
@@ -2181,30 +2187,70 @@ def udf(f=None, returnType=StringType()):
21812187
@since(2.3)
21822188
def pandas_udf(f=None, returnType=StringType()):
21832189
"""
2184-
Creates a :class:`Column` expression representing a user defined function (UDF) that accepts
2185-
`Pandas.Series` as input arguments and outputs a `Pandas.Series` of the same length.
2190+
Creates a vectorized user defined function (UDF).
21862191
2187-
:param f: python function if used as a standalone function
2192+
:param f: user-defined function. A python function if used as a standalone function
21882193
:param returnType: a :class:`pyspark.sql.types.DataType` object
21892194
2190-
>>> from pyspark.sql.types import IntegerType, StringType
2191-
>>> slen = pandas_udf(lambda s: s.str.len(), IntegerType())
2192-
>>> @pandas_udf(returnType=StringType())
2193-
... def to_upper(s):
2194-
... return s.str.upper()
2195-
...
2196-
>>> @pandas_udf(returnType="integer")
2197-
... def add_one(x):
2198-
... return x + 1
2199-
...
2200-
>>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age"))
2201-
>>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\
2202-
... .show() # doctest: +SKIP
2203-
+----------+--------------+------------+
2204-
|slen(name)|to_upper(name)|add_one(age)|
2205-
+----------+--------------+------------+
2206-
| 8| JOHN DOE| 22|
2207-
+----------+--------------+------------+
2195+
The user-defined function can define one of the following transformations:
2196+
2197+
1. One or more `pandas.Series` -> A `pandas.Series`
2198+
2199+
This udf is used with :meth:`pyspark.sql.DataFrame.withColumn` and
2200+
:meth:`pyspark.sql.DataFrame.select`.
2201+
The returnType should be a primitive data type, e.g., `DoubleType()`.
2202+
The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`.
2203+
2204+
>>> from pyspark.sql.types import IntegerType, StringType
2205+
>>> slen = pandas_udf(lambda s: s.str.len(), IntegerType())
2206+
>>> @pandas_udf(returnType=StringType())
2207+
... def to_upper(s):
2208+
... return s.str.upper()
2209+
...
2210+
>>> @pandas_udf(returnType="integer")
2211+
... def add_one(x):
2212+
... return x + 1
2213+
...
2214+
>>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age"))
2215+
>>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\
2216+
... .show() # doctest: +SKIP
2217+
+----------+--------------+------------+
2218+
|slen(name)|to_upper(name)|add_one(age)|
2219+
+----------+--------------+------------+
2220+
| 8| JOHN DOE| 22|
2221+
+----------+--------------+------------+
2222+
2223+
2. A `pandas.DataFrame` -> A `pandas.DataFrame`
2224+
2225+
This udf is only used with :meth:`pyspark.sql.GroupedData.apply`.
2226+
The returnType should be a :class:`StructType` describing the schema of the returned
2227+
`pandas.DataFrame`.
2228+
2229+
>>> df = spark.createDataFrame(
2230+
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
2231+
... ("id", "v"))
2232+
>>> @pandas_udf(returnType=df.schema)
2233+
... def normalize(pdf):
2234+
... v = pdf.v
2235+
... return pdf.assign(v=(v - v.mean()) / v.std())
2236+
>>> df.groupby('id').apply(normalize).show() # doctest: +SKIP
2237+
+---+-------------------+
2238+
| id| v|
2239+
+---+-------------------+
2240+
| 1|-0.7071067811865475|
2241+
| 1| 0.7071067811865475|
2242+
| 2|-0.8320502943378437|
2243+
| 2|-0.2773500981126146|
2244+
| 2| 1.1094003924504583|
2245+
+---+-------------------+
2246+
2247+
.. note:: This type of udf cannot be used with functions such as `withColumn` or `select`
2248+
because it defines a `DataFrame` transformation rather than a `Column`
2249+
transformation.
2250+
2251+
.. seealso:: :meth:`pyspark.sql.GroupedData.apply`
2252+
2253+
.. note:: The user-defined function must be deterministic.
22082254
"""
22092255
return _create_udf(f, returnType=returnType, vectorized=True)
22102256

python/pyspark/sql/group.py

+84-4
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,10 @@ class GroupedData(object):
5454
.. versionadded:: 1.3
5555
"""
5656

57-
def __init__(self, jgd, sql_ctx):
57+
def __init__(self, jgd, df):
5858
self._jgd = jgd
59-
self.sql_ctx = sql_ctx
59+
self._df = df
60+
self.sql_ctx = df.sql_ctx
6061

6162
@ignore_unicode_prefix
6263
@since(1.3)
@@ -170,7 +171,7 @@ def sum(self, *cols):
170171
@since(1.6)
171172
def pivot(self, pivot_col, values=None):
172173
"""
173-
Pivots a column of the current [[DataFrame]] and perform the specified aggregation.
174+
Pivots a column of the current :class:`DataFrame` and perform the specified aggregation.
174175
There are two versions of pivot function: one that requires the caller to specify the list
175176
of distinct values to pivot on, and one that does not. The latter is more concise but less
176177
efficient, because Spark needs to first compute the list of distinct values internally.
@@ -192,7 +193,85 @@ def pivot(self, pivot_col, values=None):
192193
jgd = self._jgd.pivot(pivot_col)
193194
else:
194195
jgd = self._jgd.pivot(pivot_col, values)
195-
return GroupedData(jgd, self.sql_ctx)
196+
return GroupedData(jgd, self._df)
197+
198+
@since(2.3)
199+
def apply(self, udf):
200+
"""
201+
Maps each group of the current :class:`DataFrame` using a pandas udf and returns the result
202+
as a `DataFrame`.
203+
204+
The user-defined function should take a `pandas.DataFrame` and return another
205+
`pandas.DataFrame`. For each group, all columns are passed together as a `pandas.DataFrame`
206+
to the user-function and the returned `pandas.DataFrame`s are combined as a
207+
:class:`DataFrame`.
208+
The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the
209+
returnType of the pandas udf.
210+
211+
This function does not support partial aggregation, and requires shuffling all the data in
212+
the :class:`DataFrame`.
213+
214+
:param udf: A function object returned by :meth:`pyspark.sql.functions.pandas_udf`
215+
216+
>>> from pyspark.sql.functions import pandas_udf
217+
>>> df = spark.createDataFrame(
218+
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
219+
... ("id", "v"))
220+
>>> @pandas_udf(returnType=df.schema)
221+
... def normalize(pdf):
222+
... v = pdf.v
223+
... return pdf.assign(v=(v - v.mean()) / v.std())
224+
>>> df.groupby('id').apply(normalize).show() # doctest: +SKIP
225+
+---+-------------------+
226+
| id| v|
227+
+---+-------------------+
228+
| 1|-0.7071067811865475|
229+
| 1| 0.7071067811865475|
230+
| 2|-0.8320502943378437|
231+
| 2|-0.2773500981126146|
232+
| 2| 1.1094003924504583|
233+
+---+-------------------+
234+
235+
.. seealso:: :meth:`pyspark.sql.functions.pandas_udf`
236+
237+
"""
238+
from pyspark.sql.functions import pandas_udf
239+
240+
# Columns are special because hasattr always return True
241+
if isinstance(udf, Column) or not hasattr(udf, 'func') or not udf.vectorized:
242+
raise ValueError("The argument to apply must be a pandas_udf")
243+
if not isinstance(udf.returnType, StructType):
244+
raise ValueError("The returnType of the pandas_udf must be a StructType")
245+
246+
df = self._df
247+
func = udf.func
248+
returnType = udf.returnType
249+
250+
# The python executors expects the function to use pd.Series as input and output
251+
# So we to create a wrapper function that turns that to a pd.DataFrame before passing
252+
# down to the user function, then turn the result pd.DataFrame back into pd.Series
253+
columns = df.columns
254+
255+
def wrapped(*cols):
256+
from pyspark.sql.types import to_arrow_type
257+
import pandas as pd
258+
result = func(pd.concat(cols, axis=1, keys=columns))
259+
if not isinstance(result, pd.DataFrame):
260+
raise TypeError("Return type of the user-defined function should be "
261+
"Pandas.DataFrame, but is {}".format(type(result)))
262+
if not len(result.columns) == len(returnType):
263+
raise RuntimeError(
264+
"Number of columns of the returned Pandas.DataFrame "
265+
"doesn't match specified schema. "
266+
"Expected: {} Actual: {}".format(len(returnType), len(result.columns)))
267+
arrow_return_types = (to_arrow_type(field.dataType) for field in returnType)
268+
return [(result[result.columns[i]], arrow_type)
269+
for i, arrow_type in enumerate(arrow_return_types)]
270+
271+
wrapped_udf_obj = pandas_udf(wrapped, returnType)
272+
udf_column = wrapped_udf_obj(*[df[col] for col in df.columns])
273+
jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
274+
return DataFrame(jdf, self.sql_ctx)
196275

197276

198277
def _test():
@@ -206,6 +285,7 @@ def _test():
206285
.getOrCreate()
207286
sc = spark.sparkContext
208287
globs['sc'] = sc
288+
globs['spark'] = spark
209289
globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \
210290
.toDF(StructType([StructField('age', IntegerType()),
211291
StructField('name', StringType())]))

0 commit comments

Comments
 (0)