Skip to content

Commit

Permalink
Add rewriter for multiple similar APPROX_PERCENTILEs in SELECT clause
Browse files Browse the repository at this point in the history
  • Loading branch information
jnmugerwa committed Mar 16, 2021
1 parent 14b2319 commit a1aaeea
Show file tree
Hide file tree
Showing 8 changed files with 593 additions and 0 deletions.
1 change: 1 addition & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
<modules>
<module>parser</module>
<module>linter</module>
<module>rewriter</module>
</modules>

<build>
Expand Down
180 changes: 180 additions & 0 deletions rewriter/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>

<parent>
<groupId>com.facebook.presto</groupId>
<artifactId>presto-coresql</artifactId>
<version>0.2-SNAPSHOT</version>
</parent>

<artifactId>presto-coresql-rewriter</artifactId>
<name>presto-coresql-rewriter</name>

<properties>
<air.main.basedir>${project.parent.basedir}</air.main.basedir>
<maven.compiler.source>1.6</maven.compiler.source>
<maven.compiler.target>1.6</maven.compiler.target>
</properties>

<dependencies>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.12</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.testng</groupId>
<artifactId>testng</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>com.facebook.presto</groupId>
<artifactId>presto-coresql-parser</artifactId>
<version>0.2-SNAPSHOT</version>
</dependency>

<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>build-helper-maven-plugin</artifactId>
<executions>
<execution>
<phase>generate-sources</phase>
<goals>
<goal>add-source</goal>
</goals>
<configuration>
<sources>
<source>${project.build.directory}/generated-sources</source>
</sources>
</configuration>
</execution>
</executions>
</plugin>

<plugin>
<groupId>com.facebook.presto</groupId>
<artifactId>presto-maven-plugin</artifactId>
<version>0.3</version>
<extensions>true</extensions>
</plugin>

<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
<version>3.1.1</version>
</plugin>

<plugin>
<groupId>org.skife.maven</groupId>
<artifactId>really-executable-jar-maven-plugin</artifactId>
<version>1.0.5</version>
</plugin>

<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-antrun-plugin</artifactId>
<version>1.8</version>
</plugin>

<plugin>
<groupId>io.airlift.maven.plugins</groupId>
<artifactId>sphinx-maven-plugin</artifactId>
<version>2.1</version>
</plugin>

<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-enforcer-plugin</artifactId>
<configuration>
<rules>
<requireUpperBoundDeps>
<excludes combine.children="append">
<!-- TODO: fix this in Airlift resolver -->
<exclude>org.alluxio:alluxio-shaded-client</exclude>
<exclude>org.codehaus.plexus:plexus-utils</exclude>
<exclude>com.google.guava:guava</exclude>
</excludes>
</requireUpperBoundDeps>
</rules>
</configuration>
</plugin>

<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-release-plugin</artifactId>
<configuration>
<preparationGoals>clean verify -DskipTests</preparationGoals>
</configuration>
</plugin>

<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration combine.children="append">
<fork>true</fork>
<compilerArgs>
<arg>-verbose</arg>
<arg>-J-Xss100M</arg>
</compilerArgs>
</configuration>
</plugin>

<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<configuration combine.children="append">
<includes>
<include>**/*.java</include>
<include>target/**/*.java</include>
<include>**/Benchmark*.java</include>
</includes>
<excludes>
<exclude>**/*jmhTest*.java</exclude>
<exclude>**/*jmhType*.java</exclude>
</excludes>
</configuration>
</plugin>

<!-- Always build a jar with the test classes -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<configuration>
<!-- do not build an empty jar if the project is
e.g. a pom project -->
<skipIfEmpty>true</skipIfEmpty>
<archive>
<manifest>
<addDefaultImplementationEntries>true</addDefaultImplementationEntries>
<addDefaultSpecificationEntries>true</addDefaultSpecificationEntries>
<addClasspath>false</addClasspath>
</manifest>
</archive>
</configuration>
</plugin>
</plugins>
<pluginManagement>
<plugins>
<plugin>
<groupId>org.gaul</groupId>
<artifactId>modernizer-maven-plugin</artifactId>
<version>2.1.0</version>
<configuration>
<javaVersion>1.8</javaVersion>
</configuration>
</plugin>
</plugins>
</pluginManagement>
</build>
</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.facebook.coresql.rewriter;

import com.facebook.coresql.parser.AstNode;
import com.facebook.coresql.parser.FunctionCall;
import com.facebook.coresql.parser.SqlParserDefaultVisitor;
import com.facebook.coresql.parser.Unparser;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.Multimap;

import java.util.ArrayList;
import java.util.Formatter;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static com.facebook.coresql.parser.ParserHelper.parseStatement;
import static com.facebook.coresql.parser.SqlParserTreeConstants.JJTARGUMENTLIST;
import static com.facebook.coresql.parser.SqlParserTreeConstants.JJTIDENTIFIER;
import static com.facebook.coresql.parser.SqlParserTreeConstants.JJTUNSIGNEDNUMERICLITERAL;
import static java.util.Collections.binarySearch;
import static java.util.Collections.sort;
import static java.util.Objects.requireNonNull;

public class ApproxPercentileRewriter
extends Rewriter
{
private final PatternMatcher<Multimap<String, AstNode>> matcher;
private static final String REPLACEMENT = " APPROX_PERCENTILE(%s, ARRAY%s)[%d]";
private Multimap<String, AstNode> firstArgMap; // A map of String to the APPROX_PERCENTILE nodes with that String as its first argument
private Map<String, ArrayList<Double>> percentiles;
private static final String REWRITE_NAME = "Multiple APPROX PERCENTILE with same first arg and literal second arg";

public ApproxPercentileRewriter()
{
this.matcher = new ApproxPercentilePatternMatcher();
this.firstArgMap = ArrayListMultimap.create();
this.percentiles = new HashMap<>();
}

@Override
public boolean rewritePatternIsPresent(String sql)
{
AstNode root = requireNonNull(parseStatement(sql));
firstArgMap = matcher.matchPattern(root);
return firstArgMap.keySet().stream().anyMatch(key -> firstArgMap.get(key).size() >= 2);
}

@Override
public RewriteResult rewrite(String sql)
{
AstNode root = requireNonNull(parseStatement(sql));
this.firstArgMap = matcher.matchPattern(root);
getPercentilesFromFirstArgMap();
String rewrittenSql = Unparser.unparse(root, this);
return new RewriteResult(REWRITE_NAME, sql, rewrittenSql);
}

@Override
public void visit(FunctionCall node, Void data)
{
if (canRewrite(node)) {
applyRewrite(node);
}
else {
defaultVisit(node, data);
}
}

/**
* Generates a rewritten version of the current subtree.
*
* @param node The function call node we're rewriting
*/
private void applyRewrite(AstNode node)
{
// First, unparse up to the node. This ensures we don't miss any special tokens
unparseUpto((AstNode) node.jjtGetChild(0));
// Then, add the rewritten version to the Unparser
String firstArg = getFirstArgAsString(node);
Double secondArg = getSecondArgAsDouble(node);

Formatter formatter = new Formatter(stringBuilder);
formatter.format(REPLACEMENT, firstArg, percentiles.get(firstArg), binarySearch(percentiles.get(firstArg), secondArg) + 1);
// Move to end of this node -- we've already put in a rewritten version of it, so we don't need to unparse it
moveToEndOfNode(node);
}

private String getFirstArgAsString(AstNode approxPercentile)
{
AstNode args = approxPercentile.GetFirstChildOfKind(JJTARGUMENTLIST);
AstNode firstArg = (AstNode) args.jjtGetChild(0);
return Unparser.unparse(firstArg);
}

private Double getSecondArgAsDouble(AstNode approxPercentile)
{
AstNode args = approxPercentile.GetFirstChildOfKind(JJTARGUMENTLIST);
AstNode secondArg = (AstNode) args.jjtGetChild(1);
return Double.parseDouble(Unparser.unparse(secondArg));
}

private boolean canRewrite(AstNode node)
{
String firstArg = getFirstArgAsString(node);
return firstArgMap.containsValue(node) && firstArgMap.get(firstArg).size() >= 2;
}

private void getPercentilesFromFirstArgMap()
{
// Map each first argument to a list of the percentiles of the APPROX_PERCENTILE nodes that have that first argument
for (Map.Entry<String, AstNode> entry : firstArgMap.entries()) {
String firstArg = entry.getKey();
AstNode approxPercentileNode = entry.getValue();
percentiles.putIfAbsent(firstArg, new ArrayList<>());
List<Double> percentilesWithThisFirstArg = percentiles.get(firstArg);
percentilesWithThisFirstArg.add(getSecondArgAsDouble(approxPercentileNode));
}
// Sort each percentile list. This will allow binary sort downstream
for (String key : percentiles.keySet()) {
sort(percentiles.get(key));
}
}

private static class ApproxPercentilePatternMatcher
extends SqlParserDefaultVisitor
implements PatternMatcher<Multimap<String, AstNode>>
{
private Multimap<String, AstNode> firstArgMap; // A map of String to the APPROX_PERCENTILE nodes with that String as its first argument

public ApproxPercentilePatternMatcher()
{ }

@Override
public Multimap<String, AstNode> matchPattern(AstNode root)
{
this.firstArgMap = ArrayListMultimap.create();
requireNonNull(root, "AST passed to pattern matcher was null");
root.jjtAccept(this, null);
return ImmutableListMultimap.copyOf(firstArgMap);
}

@Override
public void visit(FunctionCall node, Void data)
{
if (isApproxPercentile(node)) {
AstNode argList = node.GetFirstChildOfKind(JJTARGUMENTLIST);
AstNode secondArg = (AstNode) argList.jjtGetChild(1);
if (!isUnsignedLiteral(secondArg)) {
return;
}
AstNode firstArg = (AstNode) argList.jjtGetChild(0);
String firstArgAsString = Unparser.unparse(firstArg);
firstArgMap.put(firstArgAsString, node);
}
defaultVisit(node, data);
}

public static boolean isUnsignedLiteral(AstNode node)
{
return node.getId() == JJTUNSIGNEDNUMERICLITERAL;
}

private static boolean isApproxPercentile(AstNode node)
{
AstNode identifier = node.GetFirstChildOfKind(JJTIDENTIFIER);
if (identifier == null) {
return false;
}
String image = identifier.GetImage();
return image != null && image.equalsIgnoreCase("APPROX_PERCENTILE");
}
}
}
Loading

0 comments on commit a1aaeea

Please sign in to comment.