-
Notifications
You must be signed in to change notification settings - Fork 19
/
spark-sql-agg-example-and-weird-percent.scala
164 lines (128 loc) · 6.57 KB
/
spark-sql-agg-example-and-weird-percent.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def sumDoubleVectors: UserDefinedAggregateFunction = new UserDefinedAggregateFunction {
def inputSchema: StructType = StructType(StructField("doubleArray", ArrayType(DoubleType, false)) :: Nil)
def bufferSchema: StructType = StructType(StructField("doubleArray", ArrayType(DoubleType, false)) :: Nil)
def dataType: DataType = ArrayType(DoubleType, false)
def deterministic: Boolean = true
def initialize(buffer: MutableAggregationBuffer): Unit = buffer(0) = Array.empty[Double]
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = update(buffer1, buffer2)
def evaluate(buffer: Row): Any = buffer.getAs[mutable.WrappedArray[Double]](0)
// TODO Might want to mutate rather than copy (i.e. map)
def update(buffer: MutableAggregationBuffer, input: Row): Unit =
buffer(0) = {
val inputArray = input.getAs[mutable.WrappedArray[Double]](0)
buffer.getAs[mutable.WrappedArray[Double]](0) match {
case bufferOld if bufferOld.isEmpty => inputArray
case bufferOld if inputArray.isEmpty => bufferOld
case bufferOld => bufferOld.zip(inputArray).map {
case (l, r) => l + r
}
}
}
}
def sumNestedDoubleVectors: UserDefinedAggregateFunction = new UserDefinedAggregateFunction {
def inputSchema: StructType = StructType(StructField("doubleArrays", ArrayType(ArrayType(DoubleType, false), false)) :: Nil)
def bufferSchema: StructType = StructType(StructField("doubleArrays", ArrayType(ArrayType(DoubleType, false), false)) :: Nil)
def dataType: DataType = ArrayType(ArrayType(DoubleType, false), false)
def deterministic: Boolean = true
def initialize(buffer: MutableAggregationBuffer): Unit = buffer(0) = Array.empty[Array[Double]]
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = update(buffer1, buffer2)
def evaluate(buffer: Row): Any = buffer.getAs[mutable.WrappedArray[mutable.WrappedArray[Double]]](0)
// TODO Might want to mutate rather than copy (i.e. map)
def update(buffer: MutableAggregationBuffer, input: Row): Unit =
buffer(0) = {
val inputArray = input.getAs[mutable.WrappedArray[mutable.WrappedArray[Double]]](0)
buffer.getAs[mutable.WrappedArray[mutable.WrappedArray[Double]]](0) match {
case bufferOld if bufferOld.isEmpty => inputArray
case bufferOld if inputArray.isEmpty => bufferOld
case bufferOld => bufferOld.zip(inputArray).map {
case (l, r) => l.zip(r).map {
case (l, r) => l + r
}
}
}
}
}
.....
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{DoubleType, ArrayType, StructField, StructType}
import org.scalatest.FunSuite
import scala.collection.mutable
case class DoubleArray(doubleArray: Array[Double])
case class DoubleNestedArray(doubleArrays: Array[Array[Double]])
class IMAUDFsTest extends FunSuite {
import sql.implicits._
test("IMAUDFs.apply creates a sumDoubleVectors UDAF") {
val data = Seq(
DoubleArray(Array(1.0, 1.0, 1.0)),
DoubleArray(Array(1.0, 5.0, 2.0))
)
sql.createDataset(data).toDF().registerTempTable("double_array")
IMAUDFs(sql)
assert(
sql.sql("select sumDoubleVectors(doubleArray) from double_array")
.collect().toList.map(_ (0).asInstanceOf[mutable.WrappedArray[Double]]) ==
List(Seq(2.0, 6.0, 3.0))
)
}
test("IMAUDFs.apply creates a sumNestedDoubleVectors UDAF") {
val data = Seq(
DoubleNestedArray(Array(Array(1.0, 1.0, 1.0), Array(1.0, 1.0, 1.0))),
DoubleNestedArray(Array(Array(1.0, 5.0, 2.0), Array(3.0, 3.0, 3.0)))
)
sql.createDataset(data).toDF().registerTempTable("double_arrays")
IMAUDFs(sql)
assert(
sql.sql("select sumNestedDoubleVectors(doubleArrays) from double_arrays")
.collect().toList
.map(_ (0).asInstanceOf[mutable.WrappedArray[mutable.WrappedArray[Double]]]) ==
List(Seq(Seq(2.0, 6.0, 3.0), Seq(4.0, 4.0, 4.0)))
)
}
}
..........
// Uses while loop so we can early exit (so has better average time than foldLeft). We could use tailrec here
// but the stackframe copy overhead
// (Warning: has very little meaning, cannot be understood, invented by statisticians).
def percentileLinInterpFirstVariant(values: Seq[Double], p: Double): Double = {
val count = values.size
require(count > 1, "Need at least 2 values")
var ((value, index) :: tail, prevPercentRank, prevValue, result) =
(values.sorted.zipWithIndex, Option.empty[Double], 0.0, Option.empty[Double])
while (result.isEmpty) {
(prevPercentRank, (100.0 / count) * (index + 1 - 0.5), tail) match {
case (None, percentRank, _) if p < percentRank =>
result = Some(value)
case (_, percentRank, _) if p == percentRank =>
result = Some(value)
case (Some(prevPercentRank), percentRank, _) if prevPercentRank < p && p < percentRank =>
result = Some(count * (p - prevPercentRank) * (value - prevValue) / 100 + prevValue)
case (_, percentRank, Nil) =>
result = Some(value)
case (_, percentRank, (nextValue, nextIndex) :: nextTail) =>
prevValue = value
value = nextValue
index = nextIndex
tail = nextTail
prevPercentRank = Some(percentRank)
}
}
result.get
}
test("IMA.percentileLinInterpFirstVariant handles small sequences as per worked example on wikipedia: " +
"https://en.wikipedia.org/wiki/Percentile#Worked_Example_of_the_First_Variant") {
assert(IMA.percentileLinInterpFirstVariant(List(15, 20, 35, 40, 50), 5) == 15.0)
assert(IMA.percentileLinInterpFirstVariant(List(15, 20, 35, 40, 50), 30) == 20.0)
assert(IMA.percentileLinInterpFirstVariant(List(15, 20, 35, 40, 50), 40) == 27.5)
assert(IMA.percentileLinInterpFirstVariant(List(15, 20, 35, 40, 50), 95) == 50)
}
test("IMA.percentileLinInterpFirstVariant handles weird cases") {
assert(IMA.percentileLinInterpFirstVariant(List(0, 100), 50) == 50.0)
assert(IMA.percentileLinInterpFirstVariant(List(0, 200), 50) == 100.0)
assert(IMA.percentileLinInterpFirstVariant(List(0, 100), 50 + 12.5) == 75.0)
assert(IMA.percentileLinInterpFirstVariant(List(0, 100, 200), 50) == 100.0)
assert(IMA.percentileLinInterpFirstVariant(List(0, 100, 200, 300), 75) == 250.0)
}
test("IMA.percentileLinInterpFirstVariant handles ordinary cases") {
assert(IMA.percentileLinInterpFirstVariant((1 to 99).map(_.toDouble).toList, 50) == 50.0)
assert(IMA.percentileLinInterpFirstVariant((101 to 199).map(_.toDouble).toList, 50) == 150.0)
}