From f3387148dadb3b5e820d14a03036dc873a0b6e96 Mon Sep 17 00:00:00 2001 From: Mahmoud Hanafy Date: Mon, 23 May 2016 07:53:58 +0200 Subject: [PATCH 1/2] Port UDFs to Java --- .../examples/dataframe/JavaUDFs.java | 74 +++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 src/main/java/com/highperformancespark/examples/dataframe/JavaUDFs.java diff --git a/src/main/java/com/highperformancespark/examples/dataframe/JavaUDFs.java b/src/main/java/com/highperformancespark/examples/dataframe/JavaUDFs.java new file mode 100644 index 0000000..467b8b7 --- /dev/null +++ b/src/main/java/com/highperformancespark/examples/dataframe/JavaUDFs.java @@ -0,0 +1,74 @@ +package com.highperformancespark.examples.dataframe; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.expressions.MutableAggregationBuffer; +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; +import org.apache.spark.sql.types.*; + +public class JavaUDFs { + + public static void setupUDFs(SQLContext sqlContext) { + sqlContext.udf().register("strlen", (String s) -> s.length(), DataTypes.StringType); + } + + public static void setupUDAFs(SQLContext sqlContext) { + + class Avg extends UserDefinedAggregateFunction { + + @Override + public StructType inputSchema() { + StructType inputSchema = + new StructType(new StructField[]{new StructField("value", DataTypes.DoubleType, true, Metadata.empty())}); + return inputSchema; + } + + @Override + public StructType bufferSchema() { + StructType bufferSchema = + new StructType(new StructField[]{ + new StructField("count", DataTypes.LongType, true, Metadata.empty()), + new StructField("sum", DataTypes.DoubleType, true, Metadata.empty()) + }); + + return bufferSchema; + } + + @Override + public DataType dataType() { + return DataTypes.DoubleType; + } + + @Override + public boolean deterministic() { + return true; + } + + @Override + public void initialize(MutableAggregationBuffer buffer) { + buffer.update(0, 0L); + buffer.update(1, 0.0); + } + + @Override + public void update(MutableAggregationBuffer buffer, Row input) { + buffer.update(0, buffer.getLong(0) + 1); + buffer.update(1, buffer.getDouble(1) + input.getDouble(0)); + } + + @Override + public void merge(MutableAggregationBuffer buffer1, Row buffer2) { + buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0)); + buffer1.update(1, buffer1.getDouble(1) + buffer2.getDouble(1)); + } + + @Override + public Object evaluate(Row buffer) { + return Math.pow(buffer.getDouble(1), 1.0 / buffer.getLong(0)); + } + } + + Avg average = new Avg(); + sqlContext.udf().register("ourAvg", average); + } +} From 6e4ea731d557dcc0bd986122ec9124883b7da34f Mon Sep 17 00:00:00 2001 From: Mahmoud Hanafy Date: Mon, 23 May 2016 10:14:17 +0200 Subject: [PATCH 2/2] Fix evaluate average UDF --- .../com/highperformancespark/examples/dataframe/JavaUDFs.java | 2 +- .../com/high-performance-spark-examples/dataframe/UDFs.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/java/com/highperformancespark/examples/dataframe/JavaUDFs.java b/src/main/java/com/highperformancespark/examples/dataframe/JavaUDFs.java index 467b8b7..6866e4b 100644 --- a/src/main/java/com/highperformancespark/examples/dataframe/JavaUDFs.java +++ b/src/main/java/com/highperformancespark/examples/dataframe/JavaUDFs.java @@ -64,7 +64,7 @@ public void merge(MutableAggregationBuffer buffer1, Row buffer2) { @Override public Object evaluate(Row buffer) { - return Math.pow(buffer.getDouble(1), 1.0 / buffer.getLong(0)); + return buffer.getDouble(1) / buffer.getLong(0); } } diff --git a/src/main/scala/com/high-performance-spark-examples/dataframe/UDFs.scala b/src/main/scala/com/high-performance-spark-examples/dataframe/UDFs.scala index 2274781..56d4beb 100644 --- a/src/main/scala/com/high-performance-spark-examples/dataframe/UDFs.scala +++ b/src/main/scala/com/high-performance-spark-examples/dataframe/UDFs.scala @@ -47,7 +47,7 @@ object UDFs { } def evaluate(buffer: Row): Any = { - math.pow(buffer.getDouble(1), 1.toDouble / buffer.getLong(0)) + buffer.getDouble(1) / buffer.getLong(0) } } // Optionally register