From 4a46b4c64e3c48b9f3b090bc405796d380f944c2 Mon Sep 17 00:00:00 2001 From: Nathan Mugerwa Date: Wed, 10 Mar 2021 22:24:19 -0500 Subject: [PATCH] Add rewriter for multiple similar APPROX_PERCENTILEs in SELECT clause --- pom.xml | 1 + rewriter/pom.xml | 179 +++++++++++++++++ .../rewriter/ApproxPercentileRewriter.java | 188 ++++++++++++++++++ .../coresql/rewriter/PatternMatcher.java | 28 +++ .../coresql/rewriter/RewriteResult.java | 46 +++++ .../facebook/coresql/rewriter/Rewriter.java | 37 ++++ .../TestApproxPercentileRewriter.java | 112 +++++++++++ rewriter/~$pom.xml | Bin 0 -> 162 bytes 8 files changed, 591 insertions(+) create mode 100644 rewriter/pom.xml create mode 100644 rewriter/src/main/java/com/facebook/coresql/rewriter/ApproxPercentileRewriter.java create mode 100644 rewriter/src/main/java/com/facebook/coresql/rewriter/PatternMatcher.java create mode 100644 rewriter/src/main/java/com/facebook/coresql/rewriter/RewriteResult.java create mode 100644 rewriter/src/main/java/com/facebook/coresql/rewriter/Rewriter.java create mode 100644 rewriter/src/test/java/com/facebook/coresql/rewriter/TestApproxPercentileRewriter.java create mode 100644 rewriter/~$pom.xml diff --git a/pom.xml b/pom.xml index 3c227f4..fad61bc 100644 --- a/pom.xml +++ b/pom.xml @@ -24,6 +24,7 @@ parser linter + rewriter diff --git a/rewriter/pom.xml b/rewriter/pom.xml new file mode 100644 index 0000000..424c7a2 --- /dev/null +++ b/rewriter/pom.xml @@ -0,0 +1,179 @@ + + + 4.0.0 + + + com.facebook.presto + presto-coresql + 0.2-SNAPSHOT + + + presto-coresql-rewriter + presto-coresql-rewriter + + + ${project.parent.basedir} + 1.6 + 1.6 + + + + + junit + junit + 4.12 + test + + + + org.testng + testng + test + + + + com.facebook.presto + presto-coresql-parser + 0.2-SNAPSHOT + + + com.google.guava + guava + + + + + + org.codehaus.mojo + build-helper-maven-plugin + + + generate-sources + + add-source + + + + ${project.build.directory}/generated-sources + + + + + + + + com.facebook.presto + presto-maven-plugin + 0.3 + true + + + + org.apache.maven.plugins + maven-shade-plugin + 3.1.1 + + + + org.skife.maven + really-executable-jar-maven-plugin + 1.0.5 + + + + org.apache.maven.plugins + maven-antrun-plugin + 1.8 + + + + io.airlift.maven.plugins + sphinx-maven-plugin + 2.1 + + + + org.apache.maven.plugins + maven-enforcer-plugin + + + + + + org.alluxio:alluxio-shaded-client + org.codehaus.plexus:plexus-utils + com.google.guava:guava + + + + + + + + org.apache.maven.plugins + maven-release-plugin + + clean verify -DskipTests + + + + + org.apache.maven.plugins + maven-compiler-plugin + + true + + -verbose + -J-Xss100M + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + **/*.java + target/**/*.java + **/Benchmark*.java + + + **/*jmhTest*.java + **/*jmhType*.java + + + + + + + org.apache.maven.plugins + maven-jar-plugin + + + true + + + true + true + false + + + + + + + + + org.gaul + modernizer-maven-plugin + 2.1.0 + + 1.8 + + + + + + diff --git a/rewriter/src/main/java/com/facebook/coresql/rewriter/ApproxPercentileRewriter.java b/rewriter/src/main/java/com/facebook/coresql/rewriter/ApproxPercentileRewriter.java new file mode 100644 index 0000000..9bfe8dc --- /dev/null +++ b/rewriter/src/main/java/com/facebook/coresql/rewriter/ApproxPercentileRewriter.java @@ -0,0 +1,188 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.coresql.rewriter; + +import com.facebook.coresql.parser.AstNode; +import com.facebook.coresql.parser.FunctionCall; +import com.facebook.coresql.parser.SqlParserDefaultVisitor; +import com.facebook.coresql.parser.Unparser; +import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.Multimap; + +import java.util.ArrayList; +import java.util.Formatter; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static com.facebook.coresql.parser.ParserHelper.parseStatement; +import static com.facebook.coresql.parser.SqlParserTreeConstants.JJTARGUMENTLIST; +import static com.facebook.coresql.parser.SqlParserTreeConstants.JJTIDENTIFIER; +import static com.facebook.coresql.parser.SqlParserTreeConstants.JJTUNSIGNEDNUMERICLITERAL; +import static java.util.Collections.binarySearch; +import static java.util.Collections.sort; +import static java.util.Objects.requireNonNull; + +public class ApproxPercentileRewriter + extends Rewriter +{ + private final PatternMatcher> matcher; + private static final String REPLACEMENT = " APPROX_PERCENTILE(%s, ARRAY%s)[%d]"; + private Multimap firstArgMap; // A map of String to the APPROX_PERCENTILE nodes with that String as its first argument + private Map> percentiles; + private static final String REWRITE_NAME = "Multiple APPROX PERCENTILE with same first arg and literal second arg"; + + public ApproxPercentileRewriter() + { + this.matcher = new ApproxPercentilePatternMatcher(); + this.firstArgMap = ArrayListMultimap.create(); + this.percentiles = new HashMap<>(); + } + + @Override + public boolean rewritePatternIsPresent(String sql) + { + AstNode root = requireNonNull(parseStatement(sql)); + firstArgMap = matcher.matchPattern(root); + return firstArgMap.keySet().stream().anyMatch(key -> firstArgMap.get(key).size() >= 2); + } + + @Override + public RewriteResult rewrite(String sql) + { + AstNode root = requireNonNull(parseStatement(sql)); + this.firstArgMap = matcher.matchPattern(root); + getPercentilesFromFirstArgMap(); + String rewrittenSql = Unparser.unparse(root, this); + return new RewriteResult(REWRITE_NAME, sql, rewrittenSql); + } + + @Override + public void visit(FunctionCall node, Void data) + { + if (canRewrite(node)) { + applyRewrite(node); + } + else { + defaultVisit(node, data); + } + } + + /** + * Generates a rewritten version of the SELECT clause given, places that version in the Unparser, then skips the original SELECT. + * + * @param node The APPROX_PERCENTILE node we're rewriting + */ + private void applyRewrite(AstNode node) + { + // First, unparse up to the node. This ensures we don't miss any special tokens + unparseUpto((AstNode) node.jjtGetChild(0)); + // Then, add the rewritten version to the Unparser + String firstArg = getFirstArgAsString(node); + Double secondArg = getSecondArgAsDouble(node); + + Formatter formatter = new Formatter(stringBuilder); + formatter.format(REPLACEMENT, firstArg, percentiles.get(firstArg), binarySearch(percentiles.get(firstArg), secondArg) + 1); + // Move to end of this node -- we've already put in a rewritten version of it, so we don't need to unparse it + moveToEndOfNode(node); + } + + private String getFirstArgAsString(AstNode approxPercentile) + { + AstNode args = approxPercentile.GetFirstChildOfKind(JJTARGUMENTLIST); + AstNode firstArg = (AstNode) args.jjtGetChild(0); + return Unparser.unparse(firstArg); + } + + private Double getSecondArgAsDouble(AstNode approxPercentile) + { + AstNode args = approxPercentile.GetFirstChildOfKind(JJTARGUMENTLIST); + AstNode secondArg = (AstNode) args.jjtGetChild(1); + return Double.parseDouble(Unparser.unparse(secondArg)); + } + + private boolean canRewrite(AstNode node) + { + String firstArg = getFirstArgAsString(node); + return firstArgMap.containsValue(node) && firstArgMap.get(firstArg).size() >= 2; + } + + private void getPercentilesFromFirstArgMap() + { + // Map each first argument to a list of the percentiles of the APPROX_PERCENTILE nodes that have that first argument + for (Map.Entry entry : firstArgMap.entries()) { + String firstArg = entry.getKey(); + AstNode approxPercentileNode = entry.getValue(); + percentiles.putIfAbsent(firstArg, new ArrayList<>()); + List percentilesWithThisFirstArg = percentiles.get(firstArg); + percentilesWithThisFirstArg.add(getSecondArgAsDouble(approxPercentileNode)); + } + // Sort each percentile list. This will allow binary sort downstream + for (String key : percentiles.keySet()) { + sort(percentiles.get(key)); + } + } + + private static class ApproxPercentilePatternMatcher + extends SqlParserDefaultVisitor + implements PatternMatcher> + { + private Multimap firstArgMap; // A map of String to the APPROX_PERCENTILE nodes with that String as its first argument + + public ApproxPercentilePatternMatcher() + { } + + @Override + public Multimap matchPattern(AstNode root) + { + this.firstArgMap = ArrayListMultimap.create(); + requireNonNull(root, "AST passed to pattern matcher was null"); + root.jjtAccept(this, null); + return ImmutableListMultimap.copyOf(firstArgMap); + } + + @Override + public void visit(FunctionCall node, Void data) + { + if (isApproxPercentile(node)) { + AstNode argList = node.GetFirstChildOfKind(JJTARGUMENTLIST); + AstNode secondArg = (AstNode) argList.jjtGetChild(1); + if (!isUnsignedLiteral(secondArg)) { + return; + } + AstNode firstArg = (AstNode) argList.jjtGetChild(0); + String firstArgAsString = Unparser.unparse(firstArg); + firstArgMap.put(firstArgAsString, node); + } + defaultVisit(node, data); + } + + public static boolean isUnsignedLiteral(AstNode node) + { + return node.getId() == JJTUNSIGNEDNUMERICLITERAL; + } + + private static boolean isApproxPercentile(AstNode node) + { + AstNode identifier = node.GetFirstChildOfKind(JJTIDENTIFIER); + if (identifier == null) { + return false; + } + String image = identifier.GetImage(); + return image != null && image.equalsIgnoreCase("APPROX_PERCENTILE"); + } + } +} diff --git a/rewriter/src/main/java/com/facebook/coresql/rewriter/PatternMatcher.java b/rewriter/src/main/java/com/facebook/coresql/rewriter/PatternMatcher.java new file mode 100644 index 0000000..bf7e8d5 --- /dev/null +++ b/rewriter/src/main/java/com/facebook/coresql/rewriter/PatternMatcher.java @@ -0,0 +1,28 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.coresql.rewriter; + +import com.facebook.coresql.parser.AstNode; + +public interface PatternMatcher +{ + /** + * Traverses an AST, finding the set of nodes that match this pattern matcher's pattern + * + * @param root root of the AST + * @return The set of nodes that match this pattern matcher's pattern + */ + T matchPattern(AstNode root); +} diff --git a/rewriter/src/main/java/com/facebook/coresql/rewriter/RewriteResult.java b/rewriter/src/main/java/com/facebook/coresql/rewriter/RewriteResult.java new file mode 100644 index 0000000..7bef822 --- /dev/null +++ b/rewriter/src/main/java/com/facebook/coresql/rewriter/RewriteResult.java @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.coresql.rewriter; + +import static java.util.Objects.requireNonNull; + +public class RewriteResult +{ + private String nameOfRewrite; + private String originalSql; + private String rewrittenSql; + + public RewriteResult(String nameOfRewrite, String originalSql, String rewrittenSql) + { + this.nameOfRewrite = requireNonNull(nameOfRewrite, "name of rewrite is null"); + this.originalSql = requireNonNull(originalSql, "original sql statement is null"); + this.rewrittenSql = requireNonNull(rewrittenSql, "rewritten sql statement is null"); + } + + public String getNameOfRewrite() + { + return nameOfRewrite; + } + + public String getOriginalSql() + { + return originalSql; + } + + public String getRewrittenSql() + { + return rewrittenSql; + } +} diff --git a/rewriter/src/main/java/com/facebook/coresql/rewriter/Rewriter.java b/rewriter/src/main/java/com/facebook/coresql/rewriter/Rewriter.java new file mode 100644 index 0000000..39ba6a1 --- /dev/null +++ b/rewriter/src/main/java/com/facebook/coresql/rewriter/Rewriter.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.coresql.rewriter; + +import com.facebook.coresql.parser.Unparser; + +public abstract class Rewriter + extends Unparser +{ + /** + * Attempts to rewrite a SQL statement, storing the name of the rewriter, original string, and rewritten string in a RewriteResult object. + * + * @param sql The statement that will be rewritten + * @return A RewriteResult object containing the name of the rewriter, original string, and rewritten string + */ + public abstract RewriteResult rewrite(String sql); + + /** + * Checks if the pattern we're trying to rewrite is present within a SQL statement. + * + * @param sql The statement we're checking + * @return true if the rewrite pattern is present else false + */ + public abstract boolean rewritePatternIsPresent(String sql); +} diff --git a/rewriter/src/test/java/com/facebook/coresql/rewriter/TestApproxPercentileRewriter.java b/rewriter/src/test/java/com/facebook/coresql/rewriter/TestApproxPercentileRewriter.java new file mode 100644 index 0000000..f2e4dc4 --- /dev/null +++ b/rewriter/src/test/java/com/facebook/coresql/rewriter/TestApproxPercentileRewriter.java @@ -0,0 +1,112 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.coresql.rewriter; + +import org.testng.annotations.Test; + +import static java.lang.String.format; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +public class TestApproxPercentileRewriter +{ + private static final String[] statementsThatDontNeedAnyRewrite = new String[] { + // False Positive + "CREATE TABLE blah AS SELECT * FROM (SELECT * FROM (SELECT foo FROM T ORDER BY x LIMIT 10) ORDER BY y LIMIT 10) ORDER BY z LIMIT 10;", + "SELECT dealer_id, sales OVER (PARTITION BY dealer_id ORDER BY sales);", + "INSERT INTO blah SELECT * FROM (SELECT t.date, t.code, t.qty FROM sales AS t ORDER BY t.date LIMIT 100);", + "SELECT (true or false) and false;", + // True Negative + "SELECT * FROM T ORDER BY y;", + "SELECT * FROM T ORDER BY y LIMIT 10;", + "use a.b;", + " SELECT 1;", + "SELECT a FROM T;", + "SELECT a FROM T WHERE p1 > p2;", + "SELECT a, b, c FROM T WHERE c1 < c2 and c3 < c4;", + "SELECT CASE a WHEN IN ( 1 ) THEN b ELSE c END AS x, b, c FROM T WHERE c1 < c2 and c3 < c4;", + "SELECT T.* FROM T JOIN W ON T.x = W.x;", + "SELECT NULL;", + "SELECT ARRAY[x] FROM T;", + "SELECT TRANSFORM(ARRAY[x], x -> x + 2) AS arra FROM T;", + "CREATE TABLE T AS SELECT TRANSFORM(ARRAY[x], x -> x + 2) AS arra FROM T;", + "INSERT INTO T SELECT TRANSFORM(ARRAY[x], x -> x + 2) AS arra FROM T;", + "SELECT ROW_NUMBER() OVER(PARTITION BY x) FROM T;", + "SELECT x, SUM(y) OVER (PARTITION BY y ORDER BY 1) AS min\n" + + "FROM (values ('b',10), ('a', 10)) AS T(x, y)\n;", + "SELECT\n" + + " CAST(MAP() AS map>) AS \"bool_tensor_features\";", + "SELECT f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f(f())))))))))))))))))))))))))))));", + "SELECT abs, 2 as abs;", + }; + + private static final String[] statementsThatNeedApproxPercentileRewrite = new String[] { + // True Positive + "SELECT APPROX_PERCENTILE(x, 0.1) AS percentile_10, APPROX_PERCENTILE(x, 0.2) AS percentile_20, APPROX_PERCENTILE(x, 0.3) AS percentile_30 FROM T;", + "SELECT APPROX_PERCENTILE(x, 0.1), APPROX_PERCENTILE(x, 0.2) AS percentile_20, APPROX_PERCENTILE(x, 0.3) FROM T;", + "SELECT approx_percentile(y, 0.2), x + 1, approx_percentile(y, 0.1) from T group by 2;", + // False Negative + "SELECT APPROX_PERCENTILE(x, 0.1) AS percentile_10, APPROX_PERCENTILE(x, 0.2) AS percentile_20, APPROX_PERCENTILE(z, 0.3) AS percentile_30 FROM T;", + "SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3]), APPROX_PERCENTILE(x, 0.4), APPROX_PERCENTILE(x, 0.5) FROM T;", + "SELECT APPROX_PERCENTILE(x, 0.1) FROM (SELECT APPROX_PERCENTILE(x, 0.2) FROM T);", + // Subquery + "SELECT x FROM (SELECT APPROX_PERCENTILE(x, 0.1) AS percentile_10, APPROX_PERCENTILE(x, 0.2) AS percentile_20, APPROX_PERCENTILE(x, 0.3) AS percentile_30 FROM T);", + "CREATE TABLE blah AS SELECT * FROM (SELECT * FROM (SELECT approx_percentile(y, 0.2), x + 1, approx_percentile(y, 0.1) from T group by 2));", + "SELECT APPROX_PERCENTILE(x, 0.1) FROM (SELECT * FROM (SELECT approx_percentile(x, 0.1), approx_percentile(y, 0.1) from T group by 2));" + }; + + private static final String[] statementsThatDontNeedApproxPercentileRewrite = new String[] { + // False Positive + "SELECT APPROX_PERCENTILE(x, 0.1) AS percentile_10, APPROX_PERCENTILE(y, 0.2) AS percentile_20, APPROX_PERCENTILE(z, 0.3) AS percentile_30 FROM T;", + "SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3]), APPROX_PERCENTILE(x, 0.4) FROM T;", + "SELECT APPROX_PERCENTILE(x, ARRAY[0.1, 0.2, 0.3]) FROM (SELECT APPROX_PERCENTILE(x, 0.1) from T);" + }; + + private void assertApproxPercentilePatternIsMatchedOrUnmatched(String sql, boolean isMatched) + { + Rewriter rewriter = new ApproxPercentileRewriter(); + if (isMatched) { + assertTrue(rewriter.rewritePatternIsPresent(sql)); + } + else { + assertFalse(rewriter.rewritePatternIsPresent(sql)); + } + } + + private void rewriteThenPrint(String sql) + { + String rewritten = new ApproxPercentileRewriter().rewrite(sql).getRewrittenSql(); + System.out.println(format("Before --> %s", sql)); + System.out.println(format("AFTER --> %s", rewritten)); + System.out.println(); + } + + @Test + public void approxPercentilePatternDetectionTest() + { + for (String sql : statementsThatNeedApproxPercentileRewrite) { + assertApproxPercentilePatternIsMatchedOrUnmatched(sql, true); + rewriteThenPrint(sql); + } + + for (String sql : statementsThatDontNeedApproxPercentileRewrite) { + assertApproxPercentilePatternIsMatchedOrUnmatched(sql, false); + } + + for (String sql : statementsThatDontNeedAnyRewrite) { + assertApproxPercentilePatternIsMatchedOrUnmatched(sql, false); + } + } +} diff --git a/rewriter/~$pom.xml b/rewriter/~$pom.xml new file mode 100644 index 0000000000000000000000000000000000000000..0499b883fe28a22d5c66c4abccba984dbf061fb8 GIT binary patch literal 162 zcmaiqF%E-35Ci8E^jA=aPozauJRmyVN+&=|ZjngE+i~)NtmL(9PdC=wJlu{nnblm< z9HsH+Iz}p)h7*!;keKB5>x