diff --git a/parser/grammar/sql-spec.txt b/parser/grammar/sql-spec.txt index cbff421..25902cd 100644 --- a/parser/grammar/sql-spec.txt +++ b/parser/grammar/sql-spec.txt @@ -1123,7 +1123,7 @@ void numeric_primary(): | character_value_expression() } -void numeric_value_function() #FunctionCall: +void numeric_value_function() #BuiltinFunctionCall(1): {} { position_expression() @@ -1371,7 +1371,7 @@ void string_value_function(): | binary_value_function() } -void character_value_function() #FunctionCall: +void character_value_function() #BuiltinFunctionCall: {} { character_substring_function() @@ -1595,7 +1595,7 @@ void time_zone_specifier(): | "TIME" "ZONE" interval_primary() } -void datetime_value_function() #FunctionCall: +void datetime_value_function() #BuiltinFunctionCall: {} { current_date_value_function() @@ -1668,7 +1668,7 @@ void interval_primary(): } -void interval_value_function() #FunctionCall: +void interval_value_function() #BuiltinFunctionCall: {} { interval_absolute_value_function() diff --git a/parser/src/main/java/com/facebook/coresql/parser/AstNode.java b/parser/src/main/java/com/facebook/coresql/parser/AstNode.java index e5f6df8..ce9de52 100644 --- a/parser/src/main/java/com/facebook/coresql/parser/AstNode.java +++ b/parser/src/main/java/com/facebook/coresql/parser/AstNode.java @@ -91,4 +91,9 @@ public String toString(String prefix) return super.toString(prefix) + " (" + getLocation().toString() + ")" + (NumChildren() == 0 ? " (" + beginToken.image + ")" : ""); } + + public String GetSqlString() + { + return Unparser.unparseClean(this); + } } diff --git a/parser/src/main/java/com/facebook/coresql/parser/Unparser.java b/parser/src/main/java/com/facebook/coresql/parser/Unparser.java index 690b00c..c101451 100644 --- a/parser/src/main/java/com/facebook/coresql/parser/Unparser.java +++ b/parser/src/main/java/com/facebook/coresql/parser/Unparser.java @@ -14,6 +14,7 @@ package com.facebook.coresql.parser; import static com.facebook.coresql.parser.SqlParserConstants.EOF; +import static com.facebook.coresql.parser.SqlParserConstants.tokenImage; public class Unparser extends com.facebook.coresql.parser.SqlParserDefaultVisitor @@ -21,7 +22,7 @@ public class Unparser protected StringBuilder stringBuilder = new StringBuilder(); private Token lastToken = new Token(); - public static String unparse(AstNode node, Unparser unparser) + public static String unparseClean(AstNode node, Unparser unparser) { unparser.stringBuilder.setLength(0); unparser.lastToken.next = node.beginToken; @@ -30,9 +31,9 @@ public static String unparse(AstNode node, Unparser unparser) return unparser.stringBuilder.toString(); } - public static String unparse(AstNode node) + public static String unparseClean(AstNode node) { - return unparse(node, new Unparser()); + return unparseClean(node, new Unparser()); } private void printSpecialTokens(Token t) @@ -62,6 +63,11 @@ public final void printToken(String s) stringBuilder.append(" " + s + " "); } + public final void printKeyword(int keyword) + { + printToken(tokenImage[keyword].substring(1, tokenImage[keyword].length() - 1)); + } + private void printToken(Token t) { while (lastToken != t) { diff --git a/parser/src/test/java/com/facebook/coresql/parser/TestSqlParser.java b/parser/src/test/java/com/facebook/coresql/parser/TestSqlParser.java index eafe051..a2461ea 100644 --- a/parser/src/test/java/com/facebook/coresql/parser/TestSqlParser.java +++ b/parser/src/test/java/com/facebook/coresql/parser/TestSqlParser.java @@ -19,7 +19,7 @@ import java.io.IOException; import static com.facebook.coresql.parser.ParserHelper.parseStatement; -import static com.facebook.coresql.parser.Unparser.unparse; +import static com.facebook.coresql.parser.Unparser.unparseClean; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertNotNull; @@ -66,7 +66,7 @@ public void parseUnparseTest() for (String sql : TEST_SQL_TESTSTRINGS) { AstNode ast = parse(sql); assertNotNull(ast); - assertEquals(sql.trim(), unparse(ast).trim()); + assertEquals(sql.trim(), unparseClean(ast).trim()); } } diff --git a/parser/src/test/java/com/facebook/coresql/parser/sqllogictest/java/SqlLogicTest.java b/parser/src/test/java/com/facebook/coresql/parser/sqllogictest/java/SqlLogicTest.java index c8cdfe9..99bad71 100644 --- a/parser/src/test/java/com/facebook/coresql/parser/sqllogictest/java/SqlLogicTest.java +++ b/parser/src/test/java/com/facebook/coresql/parser/sqllogictest/java/SqlLogicTest.java @@ -34,7 +34,7 @@ import java.util.Optional; import static com.facebook.coresql.parser.ParserHelper.parseStatement; -import static com.facebook.coresql.parser.Unparser.unparse; +import static com.facebook.coresql.parser.Unparser.unparseClean; import static com.facebook.coresql.parser.sqllogictest.java.SqlLogicTest.CoreSqlParsingError.PARSING_ERROR; import static com.facebook.coresql.parser.sqllogictest.java.SqlLogicTest.CoreSqlParsingError.UNPARSED_DOES_NOT_MATCH_ORIGINAL_ERROR; import static com.facebook.coresql.parser.sqllogictest.java.SqlLogicTest.CoreSqlParsingError.UNPARSING_ERROR; @@ -132,7 +132,7 @@ private Optional checkStatementForParserError(String statem } String unparsed; try { - unparsed = unparse(ast.get()); + unparsed = unparseClean(ast.get()); } catch (Exception e) { return Optional.of(UNPARSING_ERROR); diff --git a/rewriter/src/main/java/com/facebook/coresql/rewriter/OrderByRewriter.java b/rewriter/src/main/java/com/facebook/coresql/rewriter/OrderByRewriter.java index ff12f66..fb33d0c 100644 --- a/rewriter/src/main/java/com/facebook/coresql/rewriter/OrderByRewriter.java +++ b/rewriter/src/main/java/com/facebook/coresql/rewriter/OrderByRewriter.java @@ -51,7 +51,7 @@ public Optional rewrite() if (patternMatchedNodes.isEmpty()) { return Optional.empty(); } - String rewrittenSql = Unparser.unparse(root, this); + String rewrittenSql = Unparser.unparseClean(root, this); return Optional.of(new RewriteResult(REWRITE_NAME, rewrittenSql)); } diff --git a/rewriter/src/main/java/com/facebook/coresql/rewriter/UnionAllToDisjunction.java b/rewriter/src/main/java/com/facebook/coresql/rewriter/UnionAllToDisjunction.java new file mode 100644 index 0000000..a779d77 --- /dev/null +++ b/rewriter/src/main/java/com/facebook/coresql/rewriter/UnionAllToDisjunction.java @@ -0,0 +1,333 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.coresql.rewriter; + +import com.facebook.coresql.parser.AstNode; +import com.facebook.coresql.parser.SetOperation; +import com.facebook.coresql.parser.SqlParserDefaultVisitor; +import com.facebook.coresql.parser.Unparser; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.facebook.coresql.parser.SqlParserConstants.ALL; +import static com.facebook.coresql.parser.SqlParserConstants.AND; +import static com.facebook.coresql.parser.SqlParserConstants.CASE; +import static com.facebook.coresql.parser.SqlParserConstants.END; +import static com.facebook.coresql.parser.SqlParserConstants.FROM; +import static com.facebook.coresql.parser.SqlParserConstants.SELECT; +import static com.facebook.coresql.parser.SqlParserConstants.THEN; +import static com.facebook.coresql.parser.SqlParserConstants.TRUE; +import static com.facebook.coresql.parser.SqlParserConstants.UNION; +import static com.facebook.coresql.parser.SqlParserConstants.WHEN; +import static com.facebook.coresql.parser.SqlParserConstants.WHERE; +import static com.facebook.coresql.parser.SqlParserTreeConstants.JJTALIAS; +import static com.facebook.coresql.parser.SqlParserTreeConstants.JJTFROMCLAUSE; +import static com.facebook.coresql.parser.SqlParserTreeConstants.JJTGROUPBYCLAUSE; +import static com.facebook.coresql.parser.SqlParserTreeConstants.JJTHAVINGCLAUSE; +import static com.facebook.coresql.parser.SqlParserTreeConstants.JJTSELECT; +import static com.facebook.coresql.parser.SqlParserTreeConstants.JJTSELECTLIST; +import static com.facebook.coresql.parser.SqlParserTreeConstants.JJTTABLEEXPRESSION; +import static com.facebook.coresql.parser.SqlParserTreeConstants.JJTTABLENAME; +import static com.facebook.coresql.parser.SqlParserTreeConstants.JJTWHERECLAUSE; +import static java.util.Objects.requireNonNull; + +public class UnionAllToDisjunction + extends Rewriter +{ + private static final String REWRITE_NAME = "UNION ALL of same dataset to ORT"; + private static final String UNNEST = "CROSS JOIN UNNEST(SEQUENCE(%s, %s)) AS UNNEST__(index__)"; + + private final AstNode root; + private final Set matchedNodes; + private final PatternMatcher patternMatcher; + + public UnionAllToDisjunction(AstNode root) + { + this.root = requireNonNull(root, "AST passed to rewriter was null"); + this.patternMatcher = new PatternMatcher(root); + this.matchedNodes = patternMatcher.matchPattern(); + } + + @Override + public Optional rewrite() + { + if (matchedNodes.isEmpty()) { + return Optional.empty(); + } + + String rewrittenSql = Unparser.unparseClean(root, this); + return Optional.of(new RewriteResult(REWRITE_NAME, rewrittenSql)); + } + + private void mergeUionAllChildren(String from, List selectClauses) + { + printKeyword(SELECT); + AstNode firstSelect = selectClauses.get(0); + int selectItemCount = firstSelect.GetFirstChildOfKind(JJTSELECTLIST).NumChildren(); + + for (int i = 0; i < selectItemCount; i++) { + if (i > 0) { + printToken(","); + } + + int selectIndex = 1; + printKeyword(CASE); + printToken("index__"); + + for (AstNode select : selectClauses) { + AstNode selectList = select.GetFirstChildOfKind(JJTSELECTLIST); + printKeyword(WHEN); + printToken(String.valueOf(selectIndex)); + printKeyword(THEN); + AstNode expression = selectList.GetChild(i).GetChild(0); + // Just unparse the select expression excluing the alias + printToken(unparseClean(expression)); + selectIndex++; + } + + printKeyword(END); + } + + if (from != null) { + // Just generate the source here + printToken(unparseClean(firstSelect.GetFirstChildOfKind(JJTTABLEEXPRESSION).GetFirstChildOfKind(JJTFROMCLAUSE))); + } + else { + // Generate a dummy select 1 + printToken("("); + printKeyword(SELECT); + printToken("1"); + printToken(")"); + } + + // Generate a lateral join + printToken(String.format(UNNEST, 1, selectClauses.size())); + + // Now generate the where clauses with an OR + boolean first = true; + printKeyword(WHERE); + printToken("("); + for (AstNode select : selectClauses) { + AstNode tableExpression = select.GetFirstChildOfKind(JJTTABLEEXPRESSION); + if (tableExpression.NumChildren() == 2) { + AstNode whereClause = tableExpression.GetFirstChildOfKind(JJTWHERECLAUSE); + if (!first) { + printToken(" OR "); + } + else { + first = false; + } + + printToken("("); + printToken(unparseClean(whereClause.GetChild(0))); + printToken(")"); + } + } + + printToken(")"); + + if (first) { + printKeyword(TRUE); + } + + // Now we generate the specific predicate for each of the branches + printKeyword(AND); + printToken("("); + printKeyword(CASE); + printToken("index__"); + int caseIndex = 1; + for (AstNode select : selectClauses) { + AstNode tableExpression = select.GetFirstChildOfKind(JJTTABLEEXPRESSION); + printKeyword(WHEN); + printToken(String.valueOf(caseIndex++)); + printKeyword(THEN); + if (tableExpression.NumChildren() == 2) { + AstNode whereClause = tableExpression.GetFirstChildOfKind(JJTWHERECLAUSE); + printToken("("); + printToken(unparseClean(whereClause.GetChild(0))); + printToken(")"); + } + else { + // Missing where just true + printKeyword(TRUE); + } + } + printKeyword(END); + printToken(")"); + } + + private void rewriteUnionBranches(AstNode union) + { + Map> repeatedSources = patternMatcher.getRepeatedSources(union); + + // Collect the select aliases from the first branch + ImmutableList.Builder finalSelects = ImmutableList.builder(); + AstNode firstBranch = union.GetChild(0); + AstNode selectList = firstBranch.GetFirstChildOfKind(JJTSELECTLIST); + for (int i = 0; i < selectList.NumChildren(); i++) { + AstNode item = selectList.GetChild(i); + AstNode alias = item.GetFirstChildOfKind(JJTALIAS); + if (alias == null) { + // TODO(kaikalur): Use a method to make it uniq + finalSelects.add("\"" + item.GetChild(0).GetSqlString() .trim() + "\""); + } + else { + // Simple hack for simple ids + finalSelects.add(alias.GetChild(0).GetSqlString()); + } + } + + // We generate an outer select for ease of aliasing: + printKeyword(SELECT); + printToken("*"); + + printKeyword(FROM); + printToken("("); + + boolean first = true; + // Now the real thing + for (Map.Entry> entry : repeatedSources.entrySet()) { + if (!first) { + printKeyword(UNION); + printKeyword(ALL); + } + else { + first = false; + } + + // The real rewrite + mergeUionAllChildren(entry.getKey(), entry.getValue()); + } + + // TODO(sreeni): Generate unique table alias + printToken(") AS t__ ("); + first = true; + for (String s : finalSelects.build()) { + if (!first) { + printToken(","); + } + else { + first = false; + } + printToken(s); + } + printToken(")"); + } + + @Override + public void visit(SetOperation node, Void data) + { + if (!matchedNodes.contains(node)) { + defaultVisit(node, data); + return; + } + + // Unparse upto the first token + unparseUpto(node); + + // Rewrite to avoid repeated scans + rewriteUnionBranches(node); + + // Move to the end of the node + moveToEndOfNode(node); + } + + private static class PatternMatcher + extends SqlParserDefaultVisitor + { + private final AstNode root; + private final ImmutableSet.Builder builder = ImmutableSet.builder(); + private final ImmutableMap.Builder>> repeatedSourcesBuilder = ImmutableMap.builder(); + private static final int MINIMUM_SUBQUERY_DEPTH = 2; // Past this depth, all queries we encounter are subqueries + + public PatternMatcher(AstNode root) + { + this.root = requireNonNull(root, "AST passed to rewriter was null"); + } + + public Set matchPattern() + { + root.jjtAccept(this, null); + repeatedSourcesBuilder.build(); + return builder.build(); + } + + private String getSourceNode(AstNode select) + { + // TODO(kaikalur): mvoe this to ASTUtils + + if (select.Kind() == JJTSELECT && select.GetChild(1).Kind() == JJTTABLEEXPRESSION) { + AstNode tableExpression = select.GetFirstChildOfKind(JJTTABLEEXPRESSION); + if (tableExpression.NumChildren() == 1 || (tableExpression.NumChildren() == 2 && tableExpression.GetChild(1).Kind() == JJTWHERECLAUSE)) { + AstNode from = tableExpression.GetFirstChildOfKind(JJTFROMCLAUSE); + // See that it is a simple select/filter + if (from.GetChild(0).Kind() == JJTTABLENAME && + !(select.hasChildOfKind(JJTGROUPBYCLAUSE) || select.hasChildOfKind(JJTHAVINGCLAUSE))) { + // like FROM T + return from.GetChild(0).GetSqlString(); + } + } + } + + return null; + } + + @Override + public void visit(SetOperation setOperation, Void data) + { + if (true) { // (setOperation.beginToken.kind == UNION && setOperation.beginToken.next.kind == ALL) { + boolean repeated = true; + Map> selectsBySource = new HashMap<>(); + // Union all. + for (int i = 0; i < setOperation.jjtGetNumChildren(); i++) { + AstNode select = setOperation.GetChild(i); + String source = getSourceNode(select); + List selects = selectsBySource.get(source); + if (selects == null) { + selects = new ArrayList<>(); + selectsBySource.put(source, selects); + } + else { + // Has some repeated sources + repeated = true; + } + selects.add(select); + } + + // For now we support all branches having the same table. + if (repeated) { + builder.add(setOperation); + repeatedSourcesBuilder.put(setOperation, selectsBySource); + } + } + + defaultVisit(setOperation, data); + } + + private Map> getRepeatedSources(AstNode setOperation) + { + Map>> repeatedSources = repeatedSourcesBuilder.build(); + return repeatedSources.get(setOperation); + } + } +} diff --git a/rewriter/src/test/java/com/facebook/coresql/rewriter/TestOrderByRewriter.java b/rewriter/src/test/java/com/facebook/coresql/rewriter/TestOrderByRewriter.java index 774037c..af9c3fa 100644 --- a/rewriter/src/test/java/com/facebook/coresql/rewriter/TestOrderByRewriter.java +++ b/rewriter/src/test/java/com/facebook/coresql/rewriter/TestOrderByRewriter.java @@ -103,7 +103,7 @@ private Optional getRewriteResult(String originalStatement) { AstNode ast = parseStatement(originalStatement); assertNotNull(ast); - return new OrderByRewriter(ast).rewrite(); + return new UnionAllToDisjunction(ast).rewrite(); } @Test diff --git a/rewriter/src/test/java/com/facebook/coresql/rewriter/TestUnionAllToDisjunction.java b/rewriter/src/test/java/com/facebook/coresql/rewriter/TestUnionAllToDisjunction.java new file mode 100644 index 0000000..8d0af45 --- /dev/null +++ b/rewriter/src/test/java/com/facebook/coresql/rewriter/TestUnionAllToDisjunction.java @@ -0,0 +1,39 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.coresql.rewriter; + +import com.facebook.coresql.parser.AstNode; +import org.testng.annotations.Test; + +import java.io.File; +import java.nio.file.Files; +import java.util.Optional; + +import static com.facebook.coresql.parser.ParserHelper.parseStatement; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; + +public class TestUnionAllToDisjunction +{ + @Test + public void rewriteTest() throws Throwable + { + String sql = new String(Files.readAllBytes(new File("/tmp/uniona.sql").toPath())); + AstNode ast = parseStatement(sql); + assertNotNull(ast); + Optional rewriteResult = new UnionAllToDisjunction(ast).rewrite(); + assertEquals(rewriteResult.get().getRewrittenSql(), ""); + } +}