-
Notifications
You must be signed in to change notification settings - Fork 980
Aggregate UDFs
Drill allows two types of UDFs: simple (per-row) and aggregate (per-group.) Thus far we've discussed simple functions. Let's now turn our attention to aggregates.
Let's start with a simple average function. (Drill provides one, but we'll create one anyway since we'll later expand it to be a weighted average.)
First, declare the function similar to what we've seen before:
public class WeightedAverageImpl {
@FunctionTemplate(
name = "myAvg",
scope = FunctionScope.POINT_AGGREGATE,
nulls = NullHandling.INTERNAL)
public static class MyAvgFunc implements DrillAggFunc {
@Param NullableFloat8Holder input;
@Output Float8Holder output;
@Override
public void setup() { }
@Override
public void reset() { }
@Override
public void add() { }
@Override
public void output() { }
}
The declaration is a bit different than a simple function. Aggregate functions:
- Are defined by the setting
scope = FunctionScope.POINT_AGGREGATE
. - Implement
DrillAggFunc
. -
Must do their own null handling, so we must use
nulls = NullHandling.INTERNAL
. - Presumably must use nullable input parameters.
As we can see from the above, Drill requires that we provide four methods:
-
setup()
is similar to a simple function: it is a one-time operation called at the start of each batch (or schema.) -
reset()
is called before the start of each new group. It is where we reset our per-group variables. -
add()
is the equivalent of the simple functioneval()
: it is called once per row with the values of the input parameters filled in. -
output()
is called at the end of the group. Here we set the output value into our@Output
value.
In simple UDFs, we use the @Workspace
annotation for fields that are neither parameters or output. Drill allows variables of any type in simple functions. But (for reasons explained below), workspace variables in aggregate functions must be holders. For an average, we need a count and a sum:
@Workspace Float8Holder sum;
@Workspace IntHolder count;
Some items to note:
- Workspace variables must be holders.
- Drill's code generation will create the holder objects and assign them to our variables.
- It is impossible to use working variables other than holders. (Though, it is possible to use an
ObjectHolder
which can reference a Java object to work around this restriction. See below.)
We can now fill in the straightforward implementation of our average function:
@Override
public void setup() { }
@Override
public void reset() {
sum.value = 0;
count.value = 0;
}
@Override
public void add() {
if (input.isSet == 1) {
count.value++;
sum.value += input.value;
}
}
@Override
public void output() {
if (count.value == 0) {
output.value = 0;
} else {
output.value = sum.value / count.value;
}
}
Some things to note:
- No setup is needed.
- We never create any of the holders (even for workspace variables): Drill does that for us.
- We reset our working variables to 0 in
reset()
at the start of each group. - We use the
value
field of holders in place of simple Java primitive variables. - We implement the common convention that null values are ignored for averages.
- Since we ignore nulls, we may have a zero count. In this case, we define the average to be 0. (We could define a nullable output and set the average to null if we preferred.)
Drill provides a holder called the ObjectHolder
. While all the other holders contain only primitive Java types, the ObjectHolder
can contain a Java object. This seems like an escape hatch to allow us to use Java objects as working values for an aggregate UDF. Indeed, examples exist to show this usage, and Drill's own Decimal AVG, MIN and MAX functions use this approach.
Unfortunately, the ObjectHolder
should no longer be used. It works only when the aggregate values remain in memory. It fails when Drill tries to spill working values to disk. Since Drill 1.11 added spilling to the Hash Aggregate operator, Drill does, in fact, now spill intermediate values.
Thus, if you use the ObjectHolder
, your code will work for small data sets. It will also work for sorted queries that use the Streaming Aggregate operator. But, your query will fail with an UnsupportedOperationException
if your data is large enough that the Hash Aggregate must spill.
A possible work-around is to disable the Hash Aggregate:
ALTER SESSION SET `planner.enable_hashagg` = false
There is a performance impact from disabling Hash Aggregate: Drill must sort data in order to use the Streaming Aggregate. If your data is large enough to require spilling, then the added cost of the sort may be substantial.
Although these workarounds have not been tested, possible ideas to implement complex aggregates include:
- Volunteer to start a project to modify Drill's aggregate mechanism to provide the desired functionality.
- Use thread-local storage to hold aggregate data. (Careful, however, that you cannot tell if you have two or more aggregates in the same thread. Even if you do, they won't run concurrently, so no synchronization is necessary. Knowing when to release data may be a challenge.) In this model, store an index key in a holder, then use that key to reference your on-heap data.
- Use other tools for complex aggregates.
- Other ideas?
Caveat: See the above note! Consider this section experimental and not for use in production code.
We discussed earlier the value of strong unit testing. Clearly, we can't test the above aggregate function without running queries, and we certainly can't step through the code. The example was chosen for its simplicity. But, testing and debugging becomes increasingly important as your functions grow more complex (and, the complex functions are the most valuable.)
An alternative is to leverage two ideas:
- The implementation-and-wrapper approach described in Simplified UDF Framework.
- That Drill provide a hack that allows aggregate functions to hold onto arbitrary Java objects.
For our example, we'll implement a simple weighted average function: FLOAT8-REQUIRED wtAvg(FLOAT8-NULLABLE value, FLOAT8-NULLABLE weight)
.
First, let's define an outer class to hold our work:
package org.apache.drill.exec.expr.contrib.udfExample;
public class WeightedAverageImpl {
}
Then, let's define a non-UDF class that implements the algorithm:
public static class WeightedAvgImpl {
double sum;
double totalWeights;
public void reset() {
sum = 0;
totalWeights = 0;
}
public void add(double input, double weight) {
totalWeights += weight;
sum += input * weight;
}
public double output() {
if (totalWeights == 0) {
return 0;
} else {
return sum / totalWeights;
}
}
}
The above is plain Java, so we can use JUnit to test it. Once we are convinced it works, we can create a UDF wrapper function using the ObjectHolder
hack:
@FunctionTemplate(
name = "wtAvg2",
scope = FunctionScope.POINT_AGGREGATE,
nulls = NullHandling.INTERNAL)
public static class WeightedAvgWrapper implements DrillAggFunc {
@Param NullableFloat8Holder input;
@Param NullableFloat8Holder weight;
@Workspace ObjectHolder impl;
@Output Float8Holder output;
@Override
public void setup() {
impl.obj = new org.apache.drill.exec.expr.contrib.udfExample
.WeightedAverageImpl.WeightedAvgImpl();
}
@Override
public void reset() {
((org.apache.drill.exec.expr.contrib.udfExample
.WeightedAverageImpl.WeightedAvgImpl) impl.obj).reset();
}
@Override
public void add() {
if (input.isSet == 1 && weight.isSet == 1) {
((org.apache.drill.exec.expr.contrib.udfExample
.WeightedAverageImpl.WeightedAvgImpl) impl.obj).add(
input.value, weight.value);
}
}
@Override
public void output() {
output.value = ((org.apache.drill.exec.expr.contrib.udfExample
.WeightedAverageImpl.WeightedAvgImpl) impl.obj).output();
}
}
The general theme is:
- Use the
ObjectHolder
to hold our actual implementation. - Since the
ObjectHolder
is not parameterized, we must do an explicit cast each time we use it. - This is a UDF, so we must use the fully-qualified class name when we do the cast.
- The methods in the implementation follow those in the UDF.
- The UDF handles null checking so we don't have to pass the null flags into the implementation. (If we wanted more complex null handling, then we could pass in the
isSet
flags.)
As before, the cost of this approach is one extra Java method call per row. But, for complex functions, that cost is well worth the time we save when debugging.
Also, the only unit testing we need to do with all of Drill is to ensure the wrapper works; we've already tested the actual algorithm.
Drill provides two kinds of aggregates: streaming and hash. In a streaming aggregate, Drill receives data sorted by the group key, and so Drill works on the record for the first group, then the records for the second group, and so on. In a hash aggregate, by contrast, rows arrive in random order and Drill builds all the groups in parallel.
What this means for your code is that, even if a single fragment is running, Drill may create multiple "instances" of your aggregate function. For the hash aggregate, Drill will immediate create 64K instances of your aggregate, regardless of the actual number of groups in your data. Since Drill inlines the code, how does Drill create these instances?
The answer is that Drill creates internal value vectors to hold your intermediate results. This is why your workspace variables must be holders: for each row Drill copies the intermediate values out of a vector into your workspace variables, calls the add()
method, then copies the workspace variables back into the corresponding vector.
Drill added spilling to the hash aggregate in Drill 1.12. The implication is that your temporary values may be written to disk, then read back. This, in turn, places a constraint on what you can do with an ObjectHolder
. ((TODO: Write up how Drill serializes Java objects, and what the UDF writer must do.)) ((TODO: Discuss how to force spilling to allow testing.))
Because of each @Workspace
field gives rise to a value vector of 64K values in the hash aggregate, be conservative when creating such fields. The values of these fields compete for memory with data rows, and can cause excessive spilling if the workspace values are large in quantity or values. This would be a particular concern if, for example, you created large internal lists of strings, such as if you were trying to do machine learning on word vectors.
One common complaint voiced by UDF authors is that it is too hard to test aggregate UDFs. This is true if we build the UDF, deploy to Drill, and try to test in the Drill server. However, with a little bit of code, we can test our aggregate UDF using JUnit without the us of the Drill server.
As an example, let's test the weighted average UDF. As we did for simple UDFs, let's create some wrapper methods in our test class. First up is a function to create an instance of our function class and populate it with the required holders, then call the setup()
method:
public class TestWeightedAverage {
private WeightedAvgFunc newAfgFn() {
WeightedAvgFunc fn = new WeightedAvgFunc();
fn.input = new NullableFloat8Holder();
fn.weight = new NullableFloat8Holder();
fn.output = new Float8Holder();
fn.sum = new Float8Holder();
fn.totalWeights = new Float8Holder();
fn.setup();
return fn;
}
}
Then, we need a method to call the add method. Since our method takes nullable Float8 values, let's handle nulls. Since this is a test method, we can favor convenience over speed: we'll pass in values as Double
objects which can be null
:
private void callAdd(WeightedAvgFunc fn, Double value, Double weight) {
setHolder(fn.input, value);
setHolder(fn.weight, weight);
fn.add();
}
private void setHolder(NullableFloat8Holder holder, Double value) {
if (value == null) {
holder.isSet = 0;
holder.value = Double.MAX_VALUE;
} else {
holder.isSet = 1;
holder.value = value;
}
}
Note that we deliberately set the holder to a bogus value if the value is null. This should catch code that does not honor the isSet = 0
setting.
Then, we need a way to get the output:
private double callOutput(WeightedAvgFunc fn) {
// Catch unset values.
fn.output.value = Double.MAX_VALUE;
fn.output();
return fn.output.value;
}
Again, we use a trick to catch code paths that don't set the output value.
Finally, we can write our actual tests:
@Test
public void testAggFn() throws Exception {
// Create an instance of the aggregate function
WeightedAvgFunc fn = newAfgFn();
// Test an empty group
fn.reset();
assertEquals(0D, callOutput(fn), 0.0001);
// Test only nulls
fn.reset();
callAdd(fn, null, 1.0);
callAdd(fn, 100.0, null);
callAdd(fn, null, null);
assertEquals(0D, callOutput(fn), 0.0001);
// Actual values
fn.reset();
callAdd(fn, null, 1.0);
callAdd(fn, 100.0, 10.0);
callAdd(fn, 1000.0, 0.0);
callAdd(fn, 200.0, 5.0);
double expected = ((100.0 * 10) + (200 * 5)) / (10 + 5);
assertEquals(expected, callOutput(fn), 0.0001);
}
You can now run this test as a JUnit test in your IDE. No Drill needed. Also, since we are executing our actual aggregate function class (not the code generated from the class), we can easily set breakpoints, step through the code, and look at variables.