From 6d905e56dca69d8cb9320a705654aa9fd599e9d0 Mon Sep 17 00:00:00 2001 From: Rebecca Schlussel Date: Thu, 29 Jun 2017 14:03:36 -0400 Subject: [PATCH 01/15] Replace distributed_joins with join_distribution_type session property The join_distribution_type property has three option: "repartitioned", "replicated", and "automatic". This property replaces the "distributed_joins" property. "Repartitioned" has the same behavior as distribtued_joins=true. "Replicated" is equivalent to distributed_joins=false. "Automatic" will use stats to evaluate whether partitioned or replicated is better, or if no information is available, it will choose partitioned. The default value for join_distribution_type is "repartitioned" . --- .../presto/cli/TestClientOptions.java | 8 ++-- .../src/main/sphinx/admin/properties.rst | 36 +++++++++------ presto-main/etc/config.properties | 1 - .../presto/SystemSessionProperties.java | 32 ++++++++----- .../presto/sql/analyzer/FeaturesConfig.java | 46 +++++++++++++------ .../DetermineJoinDistributionType.java | 11 +++-- .../server/TestHttpRequestSessionFactory.java | 6 +-- .../facebook/presto/server/TestServer.java | 6 +-- .../sql/analyzer/TestFeaturesConfig.java | 10 ++-- .../TestUnionWithReplicatedJoin.java | 2 +- .../conf/presto/etc/config.properties | 1 - .../facebook/presto/tests/jdbc/JdbcTests.java | 16 +++---- .../presto/tests/DistributedQueryRunner.java | 3 +- 13 files changed, 110 insertions(+), 68 deletions(-) diff --git a/presto-cli/src/test/java/com/facebook/presto/cli/TestClientOptions.java b/presto-cli/src/test/java/com/facebook/presto/cli/TestClientOptions.java index 1b4ebb9a0d72..a15ac6f661ea 100644 --- a/presto-cli/src/test/java/com/facebook/presto/cli/TestClientOptions.java +++ b/presto-cli/src/test/java/com/facebook/presto/cli/TestClientOptions.java @@ -143,27 +143,27 @@ public void testUpdateSessionParameters() ClientSession session = options.toClientSession(); SqlParser sqlParser = new SqlParser(); - ImmutableMap existingProperties = ImmutableMap.of("query_max_memory", "10GB", "distributed_join", "true"); + ImmutableMap existingProperties = ImmutableMap.of("query_max_memory", "10GB", "join_distribution_type", "repartitioned"); ImmutableMap preparedStatements = ImmutableMap.of("my_query", "select * from foo"); session = Console.processSessionParameterChange(sqlParser.createStatement("USE test_catalog.test_schema"), session, existingProperties, preparedStatements); assertEquals(session.getCatalog(), "test_catalog"); assertEquals(session.getSchema(), "test_schema"); assertEquals(session.getProperties().get("query_max_memory"), "10GB"); - assertEquals(session.getProperties().get("distributed_join"), "true"); + assertEquals(session.getProperties().get("join_distribution_type"), "repartitioned"); assertEquals(session.getPreparedStatements().get("my_query"), "select * from foo"); session = Console.processSessionParameterChange(sqlParser.createStatement("USE test_schema_b"), session, existingProperties, preparedStatements); assertEquals(session.getCatalog(), "test_catalog"); assertEquals(session.getSchema(), "test_schema_b"); assertEquals(session.getProperties().get("query_max_memory"), "10GB"); - assertEquals(session.getProperties().get("distributed_join"), "true"); + assertEquals(session.getProperties().get("join_distribution_type"), "repartitioned"); assertEquals(session.getPreparedStatements().get("my_query"), "select * from foo"); session = Console.processSessionParameterChange(sqlParser.createStatement("USE test_catalog_2.test_schema"), session, existingProperties, preparedStatements); assertEquals(session.getCatalog(), "test_catalog_2"); assertEquals(session.getSchema(), "test_schema"); assertEquals(session.getProperties().get("query_max_memory"), "10GB"); - assertEquals(session.getProperties().get("distributed_join"), "true"); + assertEquals(session.getProperties().get("join_distribution_type"), "repartitioned"); assertEquals(session.getPreparedStatements().get("my_query"), "select * from foo"); } } diff --git a/presto-docs/src/main/sphinx/admin/properties.rst b/presto-docs/src/main/sphinx/admin/properties.rst index 6284c52f07bf..116f99286dbd 100644 --- a/presto-docs/src/main/sphinx/admin/properties.rst +++ b/presto-docs/src/main/sphinx/admin/properties.rst @@ -6,26 +6,34 @@ This section describes the most important config properties that may be used to tune Presto or alter its behavior when required. .. contents:: - :local: +:local: :backlinks: none - :depth: 1 + :depth: 1 General Properties ------------------ -``distributed-joins-enabled`` -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +``join-distribution-type`` +^^^^^^^^^^^^^^^^^^^^^^^^^^ - * **Type:** ``boolean`` - * **Default value:** ``true`` - - Use hash distributed joins instead of broadcast joins. Distributed joins - require redistributing both tables using a hash of the join key. This can - be slower (sometimes substantially) than broadcast joins, but allows much - larger joins. Broadcast joins require that the tables on the right side of - the join after filtering fit in memory on each node, whereas distributed joins - only need to fit in distributed memory across all nodes. This can also be - specified on a per-query basis using the ``distributed_join`` session property. + * **Type:** ``string`` + * **Allowed values:** ``AUTOMATIC``, ``REPARTITIONED``, ``REPLICATED`` + * **Default value:** ``REPARTITIONED`` + + The type of distributed join to use. When set to ``REPARTITIONED``, presto will + use hash distributed joins. When set to ``REPLICATED``, it will broadcast the + right table to all nodes in the cluster that have data from the left table. + Repartitioned joins require redistributing both tables using a hash of the join key. + This can be slower (sometimes substantially) than broadcast joins, but allows much + larger joins. In particular broadcast joins will be faster if the right table is + much smaller than the left. However, broadcast joins require that the tables on the right + side of the join after filtering fit in memory on each node, whereas distributed joins + only need to fit in distributed memory across all nodes. When set to ``AUTOMATIC``, + Presto will make a cost based decision as to which distribution type is optimal. + It will also consider switching the left and right inputs to the join. In ``AUTOMATIC`` + mode, Presto will default to replicated joins if no cost could be computed, such as if + the tables do not have statistics. This can also be specified on a per-query basis using + the ``join_distribution_type`` session property. ``redistribute-writes`` ^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/presto-main/etc/config.properties b/presto-main/etc/config.properties index 44d7c148de4e..30c32d6f1fcb 100644 --- a/presto-main/etc/config.properties +++ b/presto-main/etc/config.properties @@ -41,5 +41,4 @@ plugin.bundles=\ ../presto-postgresql/pom.xml presto.version=testversion -distributed-joins-enabled=true node-scheduler.include-coordinator=true diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index 3ed749a98a31..3074484d4581 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -20,6 +20,7 @@ import com.facebook.presto.spi.StandardErrorCode; import com.facebook.presto.spi.session.PropertyMetadata; import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType; import com.google.common.collect.ImmutableList; import io.airlift.units.DataSize; import io.airlift.units.Duration; @@ -27,6 +28,7 @@ import javax.inject.Inject; import java.util.List; +import java.util.stream.Stream; import static com.facebook.presto.spi.session.PropertyMetadata.booleanSessionProperty; import static com.facebook.presto.spi.session.PropertyMetadata.integerSessionProperty; @@ -37,11 +39,12 @@ import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.units.DataSize.succinctBytes; import static java.lang.String.format; +import static java.util.stream.Collectors.joining; public final class SystemSessionProperties { public static final String OPTIMIZE_HASH_GENERATION = "optimize_hash_generation"; - public static final String DISTRIBUTED_JOIN = "distributed_join"; + public static final String JOIN_DISTRIBUTION_TYPE = "join_distribution_type"; public static final String DISTRIBUTED_INDEX_JOIN = "distributed_index_join"; public static final String HASH_PARTITION_COUNT = "hash_partition_count"; public static final String PREFER_STREAMING_OPERATORS = "prefer_streaming_operators"; @@ -101,11 +104,18 @@ public SystemSessionProperties( "Compute hash codes for distribution, joins, and aggregations early in query plan", featuresConfig.isOptimizeHashGeneration(), false), - booleanSessionProperty( - DISTRIBUTED_JOIN, - "Use a distributed join instead of a broadcast join", - featuresConfig.isDistributedJoinsEnabled(), - false), + new PropertyMetadata<>( + JOIN_DISTRIBUTION_TYPE, + format("The join method to use. Options are %s", + Stream.of(JoinDistributionType.values()) + .map(FeaturesConfig.JoinDistributionType::name) + .collect(joining(","))), + VARCHAR, + JoinDistributionType.class, + featuresConfig.getJoinDistributionType(), + false, + value -> JoinDistributionType.valueOf(((String) value).toUpperCase()), + JoinDistributionType::name), booleanSessionProperty( DISTRIBUTED_INDEX_JOIN, "Distribute index joins on join keys instead of executing inline", @@ -347,11 +357,6 @@ public static boolean isOptimizeHashGenerationEnabled(Session session) return session.getSystemProperty(OPTIMIZE_HASH_GENERATION, Boolean.class); } - public static boolean isDistributedJoinEnabled(Session session) - { - return session.getSystemProperty(DISTRIBUTED_JOIN, Boolean.class); - } - public static boolean isDistributedIndexJoinEnabled(Session session) { return session.getSystemProperty(DISTRIBUTED_INDEX_JOIN, Boolean.class); @@ -515,4 +520,9 @@ public static boolean isUseNewStatsCalculator(Session session) { return session.getSystemProperty(USE_NEW_STATS_CALCULATOR, Boolean.class); } + + public static JoinDistributionType getJoinDistributionType(Session session) + { + return session.getSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.class); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index 4c7b3a0d9a99..c2aa7ddc91c2 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -27,8 +27,10 @@ import java.nio.file.Paths; import java.util.List; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType.REPARTITIONED; import static com.facebook.presto.sql.analyzer.RegexLibrary.JONI; import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MINUTES; @DefunctConfig({ @@ -42,8 +44,8 @@ public class FeaturesConfig private double cpuCostWeight = 0.75; private double memoryCostWeight = 0; private double networkCostWeight = 0.25; + private boolean distributedIndexJoinsEnabled; - private boolean distributedJoinsEnabled = true; private boolean colocatedJoinsEnabled; private boolean fastInequalityJoins = true; private boolean reorderJoins = true; @@ -59,6 +61,7 @@ public class FeaturesConfig private boolean legacyMapSubscript; private boolean newMapBlock = true; private boolean optimizeMixedDistinctAggregations; + private JoinDistributionType joinDistributionType = REPARTITIONED; private boolean dictionaryAggregation; private boolean resourceGroups; @@ -77,6 +80,23 @@ public class FeaturesConfig private Duration iterativeOptimizerTimeout = new Duration(3, MINUTES); // by default let optimizer wait a long time in case it retrieves some data from ConnectorMetadata + public enum JoinDistributionType + { + AUTOMATIC, + REPLICATED, + REPARTITIONED; + + public boolean canRepartition() + { + return this == REPARTITIONED || this == AUTOMATIC; + } + + public boolean canReplicate() + { + return this == REPLICATED || this == AUTOMATIC; + } + } + public double getCpuCostWeight() { return cpuCostWeight; @@ -137,11 +157,6 @@ public FeaturesConfig setDistributedIndexJoinsEnabled(boolean distributedIndexJo return this; } - public boolean isDistributedJoinsEnabled() - { - return distributedJoinsEnabled; - } - @Config("deprecated.legacy-array-agg") public FeaturesConfig setLegacyArrayAgg(boolean legacyArrayAgg) { @@ -190,13 +205,6 @@ public boolean isNewMapBlock() return newMapBlock; } - @Config("distributed-joins-enabled") - public FeaturesConfig setDistributedJoinsEnabled(boolean distributedJoinsEnabled) - { - this.distributedJoinsEnabled = distributedJoinsEnabled; - return this; - } - public boolean isColocatedJoinsEnabled() { return colocatedJoinsEnabled; @@ -490,4 +498,16 @@ public boolean isUseNewStatsCalculator() { return useNewStatsCalculator; } + + @Config("join-distribution-type") + public FeaturesConfig setJoinDistributionType(JoinDistributionType joinDistributionType) + { + this.joinDistributionType = requireNonNull(joinDistributionType, "joinDistributionType is null"); + return this; + } + + public JoinDistributionType getJoinDistributionType() + { + return joinDistributionType; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/DetermineJoinDistributionType.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/DetermineJoinDistributionType.java index cb912fbbb092..6a9f6991f298 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/DetermineJoinDistributionType.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/DetermineJoinDistributionType.java @@ -27,7 +27,7 @@ import java.util.Map; import java.util.Optional; -import static com.facebook.presto.SystemSessionProperties.isDistributedJoinEnabled; +import static com.facebook.presto.SystemSessionProperties.getJoinDistributionType; import static com.facebook.presto.sql.planner.optimizations.ScalarQueryUtil.isScalar; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.FULL; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; @@ -113,7 +113,7 @@ private JoinNode.DistributionType getTargetJoinDistributionType(JoinNode node) { // The implementation of full outer join only works if the data is hash partitioned. See LookupJoinOperators#buildSideOuterJoinUnvisitedPositions JoinNode.Type type = node.getType(); - if (type == RIGHT || type == FULL || (isDistributedJoinEnabled(session) && !mustBroadcastJoin(node))) { + if (type == RIGHT || type == FULL || (isRepartitionedJoinEnabled(session) && !mustBroadcastJoin(node))) { return JoinNode.DistributionType.PARTITIONED; } @@ -132,11 +132,16 @@ private static boolean isCrossJoin(JoinNode node) private SemiJoinNode.DistributionType getTargetSemiJoinDistributionType(boolean isDeleteQuery) { - if (isDistributedJoinEnabled(session) && !isDeleteQuery) { + if (isRepartitionedJoinEnabled(session) && !isDeleteQuery) { return SemiJoinNode.DistributionType.PARTITIONED; } return SemiJoinNode.DistributionType.REPLICATED; } + + private static boolean isRepartitionedJoinEnabled(Session session) + { + return getJoinDistributionType(session).canRepartition(); + } } } diff --git a/presto-main/src/test/java/com/facebook/presto/server/TestHttpRequestSessionFactory.java b/presto-main/src/test/java/com/facebook/presto/server/TestHttpRequestSessionFactory.java index 7f8b8e0a6c32..f516618b1563 100644 --- a/presto-main/src/test/java/com/facebook/presto/server/TestHttpRequestSessionFactory.java +++ b/presto-main/src/test/java/com/facebook/presto/server/TestHttpRequestSessionFactory.java @@ -27,8 +27,8 @@ import java.util.Locale; -import static com.facebook.presto.SystemSessionProperties.DISTRIBUTED_JOIN; import static com.facebook.presto.SystemSessionProperties.HASH_PARTITION_COUNT; +import static com.facebook.presto.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE; import static com.facebook.presto.SystemSessionProperties.QUERY_MAX_MEMORY; import static com.facebook.presto.client.PrestoHeaders.PRESTO_CATALOG; import static com.facebook.presto.client.PrestoHeaders.PRESTO_CLIENT_INFO; @@ -59,7 +59,7 @@ public void testCreateSession() .put(PRESTO_TIME_ZONE, "Asia/Taipei") .put(PRESTO_CLIENT_INFO, "client-info") .put(PRESTO_SESSION, QUERY_MAX_MEMORY + "=1GB") - .put(PRESTO_SESSION, DISTRIBUTED_JOIN + "=true," + HASH_PARTITION_COUNT + " = 43") + .put(PRESTO_SESSION, JOIN_DISTRIBUTION_TYPE + "=repartitioned," + HASH_PARTITION_COUNT + " = 43") .put(PRESTO_PREPARED_STATEMENT, "query1=select * from foo,query2=select * from bar") .build(), "testRemote"); @@ -82,7 +82,7 @@ public void testCreateSession() assertEquals(session.getClientInfo().get(), "client-info"); assertEquals(session.getSystemProperties(), ImmutableMap.builder() .put(QUERY_MAX_MEMORY, "1GB") - .put(DISTRIBUTED_JOIN, "true") + .put(JOIN_DISTRIBUTION_TYPE, "repartitioned") .put(HASH_PARTITION_COUNT, "43") .build()); assertEquals(session.getPreparedStatements(), ImmutableMap.builder() diff --git a/presto-main/src/test/java/com/facebook/presto/server/TestServer.java b/presto-main/src/test/java/com/facebook/presto/server/TestServer.java index 3fad76a15ea3..a2ddec819a83 100644 --- a/presto-main/src/test/java/com/facebook/presto/server/TestServer.java +++ b/presto-main/src/test/java/com/facebook/presto/server/TestServer.java @@ -35,8 +35,8 @@ import java.net.URI; import java.util.List; -import static com.facebook.presto.SystemSessionProperties.DISTRIBUTED_JOIN; import static com.facebook.presto.SystemSessionProperties.HASH_PARTITION_COUNT; +import static com.facebook.presto.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE; import static com.facebook.presto.SystemSessionProperties.QUERY_MAX_MEMORY; import static com.facebook.presto.client.PrestoHeaders.PRESTO_CATALOG; import static com.facebook.presto.client.PrestoHeaders.PRESTO_CLIENT_INFO; @@ -134,7 +134,7 @@ public void testQuery() .setHeader(PRESTO_SCHEMA, "schema") .setHeader(PRESTO_CLIENT_INFO, "{\"clientVersion\":\"testVersion\"}") .addHeader(PRESTO_SESSION, QUERY_MAX_MEMORY + "=1GB") - .addHeader(PRESTO_SESSION, DISTRIBUTED_JOIN + "=true," + HASH_PARTITION_COUNT + " = 43") + .addHeader(PRESTO_SESSION, JOIN_DISTRIBUTION_TYPE + "=repartitioned," + HASH_PARTITION_COUNT + " = 43") .addHeader(PRESTO_PREPARED_STATEMENT, "foo=select * from bar") .build(); @@ -146,7 +146,7 @@ public void testQuery() // verify session properties assertEquals(queryInfo.getSession().getSystemProperties(), ImmutableMap.builder() .put(QUERY_MAX_MEMORY, "1GB") - .put(DISTRIBUTED_JOIN, "true") + .put(JOIN_DISTRIBUTION_TYPE, "repartitioned") .put(HASH_PARTITION_COUNT, "43") .build()); diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index 780c48507c26..ceac4b721233 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -21,6 +21,8 @@ import java.util.Map; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType.REPARTITIONED; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType.REPLICATED; import static com.facebook.presto.sql.analyzer.RegexLibrary.JONI; import static com.facebook.presto.sql.analyzer.RegexLibrary.RE2J; import static io.airlift.configuration.testing.ConfigAssertions.assertDeprecatedEquivalence; @@ -40,7 +42,7 @@ public void testDefaults() .setNetworkCostWeight(0.25) .setResourceGroupsEnabled(false) .setDistributedIndexJoinsEnabled(false) - .setDistributedJoinsEnabled(true) + .setJoinDistributionType(REPARTITIONED) .setFastInequalityJoins(true) .setColocatedJoinsEnabled(false) .setJoinReorderingEnabled(true) @@ -86,7 +88,7 @@ public void testExplicitPropertyMappings() .put("deprecated.legacy-map-subscript", "true") .put("deprecated.new-map-block", "false") .put("distributed-index-joins-enabled", "true") - .put("distributed-joins-enabled", "false") + .put("join-distribution-type", "REPLICATED") .put("fast-inequality-joins", "false") .put("colocated-joins-enabled", "true") .put("reorder-joins", "false") @@ -122,7 +124,7 @@ public void testExplicitPropertyMappings() .put("deprecated.legacy-map-subscript", "true") .put("deprecated.new-map-block", "false") .put("distributed-index-joins-enabled", "true") - .put("distributed-joins-enabled", "false") + .put("join-distribution-type", "REPLICATED") .put("fast-inequality-joins", "false") .put("colocated-joins-enabled", "true") .put("reorder-joins", "false") @@ -155,7 +157,7 @@ public void testExplicitPropertyMappings() .setIterativeOptimizerEnabled(false) .setIterativeOptimizerTimeout(new Duration(10, SECONDS)) .setDistributedIndexJoinsEnabled(true) - .setDistributedJoinsEnabled(false) + .setJoinDistributionType(REPLICATED) .setFastInequalityJoins(false) .setColocatedJoinsEnabled(true) .setJoinReorderingEnabled(false) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestUnionWithReplicatedJoin.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestUnionWithReplicatedJoin.java index 0235bba4f0cb..f29253a8c247 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestUnionWithReplicatedJoin.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestUnionWithReplicatedJoin.java @@ -21,6 +21,6 @@ public class TestUnionWithReplicatedJoin { public TestUnionWithReplicatedJoin() { - super(ImmutableMap.of(SystemSessionProperties.DISTRIBUTED_JOIN, "false")); + super(ImmutableMap.of(SystemSessionProperties.JOIN_DISTRIBUTION_TYPE, "replicated")); } } diff --git a/presto-product-tests/conf/presto/etc/config.properties b/presto-product-tests/conf/presto/etc/config.properties index 44aef5efc782..c15d874f879a 100644 --- a/presto-product-tests/conf/presto/etc/config.properties +++ b/presto-product-tests/conf/presto/etc/config.properties @@ -38,7 +38,6 @@ plugin.bundles=\ ../../../presto-sqlserver/pom.xml presto.version=testversion -distributed-joins-enabled=true query.max-memory-per-node=1GB query.max-memory=1GB redistribute-writes=false diff --git a/presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/JdbcTests.java b/presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/JdbcTests.java index 5be87ee48944..025394cd81b5 100644 --- a/presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/JdbcTests.java +++ b/presto-product-tests/src/main/java/com/facebook/presto/tests/jdbc/JdbcTests.java @@ -14,6 +14,7 @@ package com.facebook.presto.tests.jdbc; import com.facebook.presto.jdbc.PrestoConnection; +import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.teradata.tempto.BeforeTestWithContext; import com.teradata.tempto.ProductTest; import com.teradata.tempto.Requirement; @@ -52,8 +53,6 @@ import static com.teradata.tempto.internal.convention.SqlResultDescriptor.sqlResultDescriptorForResource; import static com.teradata.tempto.query.QueryExecutor.defaultQueryExecutor; import static com.teradata.tempto.query.QueryExecutor.query; -import static java.lang.Boolean.FALSE; -import static java.lang.Boolean.TRUE; import static java.util.Locale.CHINESE; import static org.assertj.core.api.Assertions.assertThat; @@ -247,13 +246,14 @@ public void testSqlEscapeFunctions() public void testSessionProperties() throws SQLException { - final String distributedJoin = "distributed_join"; + final String joinDistributionType = "join_distribution_type"; + final String defaultValue = new FeaturesConfig().getJoinDistributionType().name(); - assertThat(getSessionProperty(connection, distributedJoin)).isEqualTo(TRUE.toString()); - setSessionProperty(connection, distributedJoin, FALSE.toString()); - assertThat(getSessionProperty(connection, distributedJoin)).isEqualTo(FALSE.toString()); - resetSessionProperty(connection, distributedJoin); - assertThat(getSessionProperty(connection, distributedJoin)).isEqualTo(TRUE.toString()); + assertThat(getSessionProperty(connection, joinDistributionType)).isEqualTo(defaultValue); + setSessionProperty(connection, joinDistributionType, "REPLICATED"); + assertThat(getSessionProperty(connection, joinDistributionType)).isEqualTo("REPLICATED"); + resetSessionProperty(connection, joinDistributionType); + assertThat(getSessionProperty(connection, joinDistributionType)).isEqualTo(defaultValue); } private QueryResult queryResult(Statement statement, String query) diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/DistributedQueryRunner.java b/presto-tests/src/main/java/com/facebook/presto/tests/DistributedQueryRunner.java index 6f683b3ef757..7a767c2264b2 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/DistributedQueryRunner.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/DistributedQueryRunner.java @@ -171,11 +171,10 @@ private static TestingPrestoServer createTestingPrestoServer(URI discoveryUri, b .put("compiler.interpreter-enabled", "false") .put("task.max-index-memory", "16kB") // causes index joins to fault load .put("datasources", "system") - .put("distributed-index-joins-enabled", "true") .put("optimizer.optimize-mixed-distinct-aggregations", "true"); if (coordinator) { propertiesBuilder.put("node-scheduler.include-coordinator", "true"); - propertiesBuilder.put("distributed-joins-enabled", "true"); + propertiesBuilder.put("join-distribution-type", "REPARTITIONED"); } HashMap properties = new HashMap<>(propertiesBuilder.build()); properties.putAll(extraProperties); From 50370b4c3806ddd8aa64c80cae5981f1885a6114 Mon Sep 17 00:00:00 2001 From: Rebecca Schlussel Date: Thu, 29 Jun 2017 14:03:38 -0400 Subject: [PATCH 02/15] Replace reorder_joins with join_reordering_strategy Create a new property called join_reordering_strategy with the options ELIMINATE_CROSS_JOINS, COST_BASED and NONE. ELIMINATE_CROSS_JOINS is equivalent to the previous reorder_joins=true. COST_BASED will used join enumeration to make a cost-based decision of join order. NONE will maintain the syntactic join order, and is equivalent to the previous reorder_joins=false. --- .../resources/benchmarks/presto/tpcds.yaml | 12 +++++----- .../resources/benchmarks/presto/tpch.yaml | 14 +++++------ .../session_set_join_reordering_strategy.sql | 1 + .../sql/presto/session_set_reorder_joins.sql | 1 - .../src/main/sphinx/admin/properties.rst | 16 +++++++++++++ .../presto/SystemSessionProperties.java | 24 ++++++++++++------- .../presto/sql/analyzer/FeaturesConfig.java | 22 +++++++++++------ .../iterative/rule/EliminateCrossJoins.java | 7 ++++-- .../optimizations/EliminateCrossJoins.java | 6 +++-- .../sql/analyzer/TestFeaturesConfig.java | 10 ++++---- .../rule/TestEliminateCrossJoins.java | 8 +++---- .../optimizations/TestReorderJoins.java | 2 +- .../conf/presto/etc/config.properties | 1 - .../presto/etc/multinode-master.properties | 1 - .../conf/presto/etc/singlenode.properties | 1 - 15 files changed, 81 insertions(+), 45 deletions(-) create mode 100644 presto-benchto-benchmarks/src/main/resources/sql/presto/session_set_join_reordering_strategy.sql delete mode 100644 presto-benchto-benchmarks/src/main/resources/sql/presto/session_set_reorder_joins.sql diff --git a/presto-benchto-benchmarks/src/main/resources/benchmarks/presto/tpcds.yaml b/presto-benchto-benchmarks/src/main/resources/benchmarks/presto/tpcds.yaml index aff146e980a2..be416992230d 100644 --- a/presto-benchto-benchmarks/src/main/resources/benchmarks/presto/tpcds.yaml +++ b/presto-benchto-benchmarks/src/main/resources/benchmarks/presto/tpcds.yaml @@ -2,7 +2,7 @@ datasource: presto query-names: presto/tpcds/${query}.sql runs: 6 prewarm-runs: 2 -before-execution: sleep-4s, presto/session_set_reorder_joins.sql +before-execution: sleep-4s, presto/session_set_join_reordering_strategy.sql frequency: 7 database: hive tpcds_small: tpcds_10gb_orc @@ -12,22 +12,22 @@ variables: 1: query: q01,q06,q14_1,q39_1,q39_2,q47,q57,q67,q81 schema: ${tpcds_small} - reorder_joins: true, false + join_reordering_strategy: ELIMINATE_CROSS_JOINS, NONE 2: query: q02,q03,q04,q05,q07,q09,q10,q11,q13,q14_2,q16,q17,q19,q22,q23_1,q23_2,q24_1,q24_2,q25,q28,q29,q30,q31,q32,q33,q35,q37,q38,q42,q43,q44,q46,q48,q49,q50,q51,q52,q53,q54,q55,q56,q58,q59,q60,q61,q63,q65,q66,q68,q69,q70,q71,q72,q74,q75,q77,q78,q80,q82,q88,q89,q94,q95 schema: ${tpcds_medium} - reorder_joins: true, false + join_reordering_strategy: ELIMINATE_CROSS_JOINS, NONE 3: # query not passing quick enough without reordering query: q18,q64 schema: ${tpcds_medium} - reorder_joins: true + join_reordering_strategy: ELIMINATE_CROSS_JOINS 4: query: q08,q12,q15,q20,q21,q26,q27,q34,q36,q40,q41,q45,q62,q73,q76,q79,q83,q84,q85,q86,q87,q90,q91,q92,q93,q96,q97,q98,q99 schema: ${tpcds_large} - reorder_joins: true, false + join_reordering_strategy: ELIMINATE_CROSS_JOINS, NONE 5: # extra runs with reordering on 1tb schema (too slow without reordering on 1tb). For 100g we keep both runs, with and without reordering query: q03,q37,q42,q43,q52,q53 schema: ${tpcds_large} - reorder_joins: true + join_reordering_strategy: ELIMINATE_CROSS_JOIN diff --git a/presto-benchto-benchmarks/src/main/resources/benchmarks/presto/tpch.yaml b/presto-benchto-benchmarks/src/main/resources/benchmarks/presto/tpch.yaml index 3a19db13125a..2463762311a1 100644 --- a/presto-benchto-benchmarks/src/main/resources/benchmarks/presto/tpch.yaml +++ b/presto-benchto-benchmarks/src/main/resources/benchmarks/presto/tpch.yaml @@ -2,7 +2,7 @@ datasource: presto query-names: presto/tpch/${query}.sql runs: 6 prewarm-runs: 2 -before-execution: sleep-4s, presto/session_set_reorder_joins.sql +before-execution: sleep-4s, presto/session_set_join_reordering_strategy.sql frequency: 7 database: hive tpch_small: tpch_10gb_orc @@ -14,28 +14,28 @@ variables: # queries too slow to run on 100gb without reordering query: q2, q8, q9 schema: ${tpch_small} - reorder_joins: true, false + join_reordering_strategy: ELIMINATE_CROSS_JOINS, NONE 2: # queries too slow to run on 100gb without reordering query: q8, q9 schema: ${tpch_medium} - reorder_joins: true + join_reordering_strategy: ELIMINATE_CROSS_JOINS 3: # queries too slow to run on 100gb without reordering query: q2 schema: ${tpch_large} - reorder_joins: true + join_reordering_strategy: ELIMINATE_CROSS_JOINS 4: # queries too slow to run on 1tb query: q3, q4, q5, q7, q17, q18, q21 schema: ${tpch_medium} - reorder_joins: true, false + join_reordering_strategy: ELIMINATE_CROSS_JOINS, NONE 5: query: q10, q11, q12, q13, q14, q15, q16, q19, q20, q22 schema: ${tpch_large} - reorder_joins: true, false + join_reordering_strategy: ELIMINATE_CROSS_JOINS, NONE 6: # queries without joins query: q1, q6 schema: ${tpch_large} - reorder_joins: false + join_reordering_strategy: NONE diff --git a/presto-benchto-benchmarks/src/main/resources/sql/presto/session_set_join_reordering_strategy.sql b/presto-benchto-benchmarks/src/main/resources/sql/presto/session_set_join_reordering_strategy.sql new file mode 100644 index 000000000000..a832822eda53 --- /dev/null +++ b/presto-benchto-benchmarks/src/main/resources/sql/presto/session_set_join_reordering_strategy.sql @@ -0,0 +1 @@ +SET SESSION join_reordering_strategy='${join_reordering_strategy}' diff --git a/presto-benchto-benchmarks/src/main/resources/sql/presto/session_set_reorder_joins.sql b/presto-benchto-benchmarks/src/main/resources/sql/presto/session_set_reorder_joins.sql deleted file mode 100644 index 43d4a7faa470..000000000000 --- a/presto-benchto-benchmarks/src/main/resources/sql/presto/session_set_reorder_joins.sql +++ /dev/null @@ -1 +0,0 @@ -SET SESSION reorder_joins='${reorder_joins}' diff --git a/presto-docs/src/main/sphinx/admin/properties.rst b/presto-docs/src/main/sphinx/admin/properties.rst index 116f99286dbd..2b79bcc5562a 100644 --- a/presto-docs/src/main/sphinx/admin/properties.rst +++ b/presto-docs/src/main/sphinx/admin/properties.rst @@ -376,6 +376,22 @@ Optimizer Properties using the ``push_table_write_through_union`` session property. +``optimizer.join-reordering-strategy`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + * **Type:** ``string`` + * **Allowed values:** ``COST_BASED``, ``ELIMINATE_CROSS_JOINS``, ``NONE`` + * **Default value:** ``ELIMINATE_CROSS_JOINS`` + + The join reordering strategy to use. ``NONE`` maintains the order the tables are listed in the + query. ``ELIMINATE_CROSS_JOINS`` reorders joins to eliminate cross joins where possible and + otherwise maintains the original query order. When reordering joins it also strives to maintain the + original table order as much as possible. ``COST_BASED`` enumerates possible orders and uses + statistics-based cost estimation to determine the least cost order. If stats are not available or if + for any reason a cost could not be computed, the ``ELIMINATE_CROSS_JOINS`` strategy is used. This can + also be specified on a per-query basis using the ``join_reordering_strategy`` session property. + + Regular Expression Function Properties -------------------------------------- diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index 3074484d4581..02d717269eef 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -21,6 +21,7 @@ import com.facebook.presto.spi.session.PropertyMetadata; import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType; +import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy; import com.google.common.collect.ImmutableList; import io.airlift.units.DataSize; import io.airlift.units.Duration; @@ -61,7 +62,7 @@ public final class SystemSessionProperties public static final String DICTIONARY_AGGREGATION = "dictionary_aggregation"; public static final String PLAN_WITH_TABLE_NODE_PARTITIONING = "plan_with_table_node_partitioning"; public static final String COLOCATED_JOIN = "colocated_join"; - public static final String REORDER_JOINS = "reorder_joins"; + public static final String JOIN_REORDERING_STRATEGY = "join_reordering_strategy"; public static final String INITIAL_SPLITS_PER_NODE = "initial_splits_per_node"; public static final String SPLIT_CONCURRENCY_ADJUSTMENT_INTERVAL = "split_concurrency_adjustment_interval"; public static final String OPTIMIZE_METADATA_QUERIES = "optimize_metadata_queries"; @@ -250,11 +251,18 @@ public SystemSessionProperties( "Experimental: Adapt plan to pre-partitioned tables", true, false), - booleanSessionProperty( - REORDER_JOINS, - "Experimental: Reorder joins to optimize plan", - featuresConfig.isJoinReorderingEnabled(), - false), + new PropertyMetadata<>( + JOIN_REORDERING_STRATEGY, + format("The join reordering strategy to use. Options are %s", + Stream.of(JoinReorderingStrategy.values()) + .map(FeaturesConfig.JoinReorderingStrategy::name) + .collect(joining(","))), + VARCHAR, + JoinReorderingStrategy.class, + featuresConfig.getJoinReorderingStrategy(), + false, + value -> JoinReorderingStrategy.valueOf(((String) value).toUpperCase()), + JoinReorderingStrategy::name), booleanSessionProperty( FAST_INEQUALITY_JOINS, "Use faster handling of inequality join if it is possible", @@ -432,9 +440,9 @@ public static boolean isFastInequalityJoin(Session session) return session.getSystemProperty(FAST_INEQUALITY_JOINS, Boolean.class); } - public static boolean isJoinReorderingEnabled(Session session) + public static JoinReorderingStrategy getJoinReorderingStrategy(Session session) { - return session.getSystemProperty(REORDER_JOINS, Boolean.class); + return session.getSystemProperty(JOIN_REORDERING_STRATEGY, JoinReorderingStrategy.class); } public static boolean isColocatedJoinEnabled(Session session) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index c2aa7ddc91c2..b5b3205dd068 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -28,6 +28,7 @@ import java.util.List; import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType.REPARTITIONED; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.ELIMINATE_CROSS_JOINS; import static com.facebook.presto.sql.analyzer.RegexLibrary.JONI; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; @@ -48,7 +49,7 @@ public class FeaturesConfig private boolean distributedIndexJoinsEnabled; private boolean colocatedJoinsEnabled; private boolean fastInequalityJoins = true; - private boolean reorderJoins = true; + private JoinReorderingStrategy joinReorderingStrategy = ELIMINATE_CROSS_JOINS; private boolean redistributeWrites = true; private boolean optimizeMetadataQueries; private boolean optimizeHashGeneration = true; @@ -80,6 +81,13 @@ public class FeaturesConfig private Duration iterativeOptimizerTimeout = new Duration(3, MINUTES); // by default let optimizer wait a long time in case it retrieves some data from ConnectorMetadata + public enum JoinReorderingStrategy + { + ELIMINATE_CROSS_JOINS, + COST_BASED, + NONE + } + public enum JoinDistributionType { AUTOMATIC, @@ -231,16 +239,16 @@ public boolean isFastInequalityJoins() return fastInequalityJoins; } - public boolean isJoinReorderingEnabled() + public JoinReorderingStrategy getJoinReorderingStrategy() { - return reorderJoins; + return joinReorderingStrategy; } - @Config("reorder-joins") - @ConfigDescription("Experimental: Reorder joins to optimize plan") - public FeaturesConfig setJoinReorderingEnabled(boolean reorderJoins) + @Config("optimizer.join-reordering-strategy") + @ConfigDescription("The strategy to use for reordering joins") + public FeaturesConfig setJoinReorderingStrategy(JoinReorderingStrategy joinReorderingStrategy) { - this.reorderJoins = reorderJoins; + this.joinReorderingStrategy = joinReorderingStrategy; return this; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java index 2452f75ae7a4..a4e7907a8fff 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/EliminateCrossJoins.java @@ -14,7 +14,6 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; -import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.matching.Pattern; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; @@ -41,6 +40,9 @@ import java.util.PriorityQueue; import java.util.Set; +import static com.facebook.presto.SystemSessionProperties.getJoinReorderingStrategy; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.COST_BASED; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.ELIMINATE_CROSS_JOINS; import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictOutputs; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -65,7 +67,8 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato return Optional.empty(); } - if (!SystemSessionProperties.isJoinReorderingEnabled(session)) { + // we run this for cost_based reordering also for cases when some of the tables do not have statistics + if (getJoinReorderingStrategy(session) != ELIMINATE_CROSS_JOINS && getJoinReorderingStrategy(session) != COST_BASED) { return Optional.empty(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/EliminateCrossJoins.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/EliminateCrossJoins.java index 89a4eccfc674..57f29e4dad79 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/EliminateCrossJoins.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/EliminateCrossJoins.java @@ -14,7 +14,6 @@ package com.facebook.presto.sql.planner.optimizations; import com.facebook.presto.Session; -import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.spi.type.Type; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; import com.facebook.presto.sql.planner.Symbol; @@ -28,6 +27,9 @@ import java.util.Map; import java.util.Objects; +import static com.facebook.presto.SystemSessionProperties.getJoinReorderingStrategy; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.COST_BASED; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.ELIMINATE_CROSS_JOINS; import static com.facebook.presto.sql.planner.iterative.rule.EliminateCrossJoins.buildJoinTree; import static com.facebook.presto.sql.planner.iterative.rule.EliminateCrossJoins.getJoinOrder; import static com.facebook.presto.sql.planner.iterative.rule.EliminateCrossJoins.isOriginalOrder; @@ -47,7 +49,7 @@ public PlanNode optimize( SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) { - if (!SystemSessionProperties.isJoinReorderingEnabled(session)) { + if (getJoinReorderingStrategy(session) != ELIMINATE_CROSS_JOINS && getJoinReorderingStrategy(session) != COST_BASED) { return plan; } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index ceac4b721233..bb4db2ba43b4 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -23,6 +23,8 @@ import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType.REPARTITIONED; import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType.REPLICATED; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.ELIMINATE_CROSS_JOINS; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.NONE; import static com.facebook.presto.sql.analyzer.RegexLibrary.JONI; import static com.facebook.presto.sql.analyzer.RegexLibrary.RE2J; import static io.airlift.configuration.testing.ConfigAssertions.assertDeprecatedEquivalence; @@ -45,7 +47,7 @@ public void testDefaults() .setJoinDistributionType(REPARTITIONED) .setFastInequalityJoins(true) .setColocatedJoinsEnabled(false) - .setJoinReorderingEnabled(true) + .setJoinReorderingStrategy(ELIMINATE_CROSS_JOINS) .setRedistributeWrites(true) .setOptimizeMetadataQueries(false) .setOptimizeHashGeneration(true) @@ -91,7 +93,7 @@ public void testExplicitPropertyMappings() .put("join-distribution-type", "REPLICATED") .put("fast-inequality-joins", "false") .put("colocated-joins-enabled", "true") - .put("reorder-joins", "false") + .put("optimizer.join-reordering-strategy", "NONE") .put("redistribute-writes", "false") .put("optimizer.optimize-metadata-queries", "true") .put("optimizer.optimize-hash-generation", "false") @@ -127,7 +129,7 @@ public void testExplicitPropertyMappings() .put("join-distribution-type", "REPLICATED") .put("fast-inequality-joins", "false") .put("colocated-joins-enabled", "true") - .put("reorder-joins", "false") + .put("optimizer.join-reordering-strategy", "NONE") .put("redistribute-writes", "false") .put("optimizer.optimize-metadata-queries", "true") .put("optimizer.optimize-hash-generation", "false") @@ -160,7 +162,7 @@ public void testExplicitPropertyMappings() .setJoinDistributionType(REPLICATED) .setFastInequalityJoins(false) .setColocatedJoinsEnabled(true) - .setJoinReorderingEnabled(false) + .setJoinReorderingStrategy(NONE) .setRedistributeWrites(false) .setOptimizeMetadataQueries(true) .setOptimizeHashGeneration(false) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEliminateCrossJoins.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEliminateCrossJoins.java index e46826bfb567..3609523de8c5 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEliminateCrossJoins.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestEliminateCrossJoins.java @@ -35,7 +35,7 @@ import java.util.Optional; import java.util.function.Function; -import static com.facebook.presto.SystemSessionProperties.REORDER_JOINS; +import static com.facebook.presto.SystemSessionProperties.JOIN_REORDERING_STRATEGY; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.any; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; @@ -60,7 +60,7 @@ public class TestEliminateCrossJoins public void testEliminateCrossJoin() { tester().assertThat(new EliminateCrossJoins()) - .setSystemProperty(REORDER_JOINS, "true") + .setSystemProperty(JOIN_REORDERING_STRATEGY, "ELIMINATE_CROSS_JOINS") .on(crossJoinAndJoin(INNER)) .matches( join(INNER, @@ -79,7 +79,7 @@ public void testEliminateCrossJoin() public void testRetainOutgoingGroupReferences() { tester().assertThat(new EliminateCrossJoins()) - .setSystemProperty(REORDER_JOINS, "true") + .setSystemProperty(JOIN_REORDERING_STRATEGY, "ELIMINATE_CROSS_JOINS") .on(crossJoinAndJoin(INNER)) .matches( node(JoinNode.class, @@ -96,7 +96,7 @@ public void testRetainOutgoingGroupReferences() public void testDoNotReorderOuterJoin() { tester().assertThat(new EliminateCrossJoins()) - .setSystemProperty(REORDER_JOINS, "true") + .setSystemProperty(JOIN_REORDERING_STRATEGY, "ELIMINATE_CROSS_JOINS") .on(crossJoinAndJoin(JoinNode.Type.LEFT)) .doesNotFire(); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderJoins.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderJoins.java index 40a6321583d9..f1589980d0a0 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderJoins.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestReorderJoins.java @@ -59,7 +59,7 @@ public class TestReorderJoins public TestReorderJoins() { - super(ImmutableMap.of(SystemSessionProperties.REORDER_JOINS, "true")); + super(ImmutableMap.of(SystemSessionProperties.JOIN_REORDERING_STRATEGY, "ELIMINATE_CROSS_JOINS")); } @Test diff --git a/presto-product-tests/conf/presto/etc/config.properties b/presto-product-tests/conf/presto/etc/config.properties index c15d874f879a..a7e918b66a97 100644 --- a/presto-product-tests/conf/presto/etc/config.properties +++ b/presto-product-tests/conf/presto/etc/config.properties @@ -41,4 +41,3 @@ presto.version=testversion query.max-memory-per-node=1GB query.max-memory=1GB redistribute-writes=false -reorder-joins=true diff --git a/presto-product-tests/conf/presto/etc/multinode-master.properties b/presto-product-tests/conf/presto/etc/multinode-master.properties index 5be52ed2cb2d..98d6602e0844 100644 --- a/presto-product-tests/conf/presto/etc/multinode-master.properties +++ b/presto-product-tests/conf/presto/etc/multinode-master.properties @@ -15,4 +15,3 @@ query.max-memory=1GB query.max-memory-per-node=512MB discovery-server.enabled=true discovery.uri=http://presto-master:8080 -reorder-joins=true diff --git a/presto-product-tests/conf/presto/etc/singlenode.properties b/presto-product-tests/conf/presto/etc/singlenode.properties index d0e2cda196b1..a2e66146a4eb 100644 --- a/presto-product-tests/conf/presto/etc/singlenode.properties +++ b/presto-product-tests/conf/presto/etc/singlenode.properties @@ -15,4 +15,3 @@ query.max-memory=2GB query.max-memory-per-node=1GB discovery-server.enabled=true discovery.uri=http://presto-master:8080 -reorder-joins=true From 74c80b2fec8caddae2cf9295274358943343d816 Mon Sep 17 00:00:00 2001 From: Rebecca Schlussel Date: Thu, 29 Jun 2017 14:03:39 -0400 Subject: [PATCH 03/15] Support matching join distribution type in tests --- .../presto/sql/planner/assertions/JoinMatcher.java | 8 +++++++- .../presto/sql/planner/assertions/PlanMatchPattern.java | 8 +++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/JoinMatcher.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/JoinMatcher.java index a78c85624707..6d002c104dd7 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/JoinMatcher.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/JoinMatcher.java @@ -37,12 +37,14 @@ final class JoinMatcher private final JoinNode.Type joinType; private final List> equiCriteria; private final Optional filter; + private final Optional distributionType; - JoinMatcher(JoinNode.Type joinType, List> equiCriteria, Optional filter) + JoinMatcher(JoinNode.Type joinType, List> equiCriteria, Optional filter, Optional distributionType) { this.joinType = requireNonNull(joinType, "joinType is null"); this.equiCriteria = requireNonNull(equiCriteria, "equiCriteria is null"); this.filter = requireNonNull(filter, "filter can not be null"); + this.distributionType = requireNonNull(distributionType, "distribtuionType cannot be null"); } @Override @@ -81,6 +83,10 @@ public MatchResult detailMatches(PlanNode node, PlanNodeStatsEstimate stats, Ses } } + if (distributionType.isPresent() && !distributionType.equals(joinNode.getDistributionType())) { + return NO_MATCH; + } + /* * Have to use order-independent comparison; there are no guarantees what order * the equi criteria will have after planning and optimizing. diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java index 9e39e36ec950..ce7c03dabffe 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java @@ -266,12 +266,18 @@ public static PlanMatchPattern join(JoinNode.Type joinType, List> expectedEquiCriteria, Optional expectedFilter, PlanMatchPattern left, PlanMatchPattern right) + { + return join(joinType, expectedEquiCriteria, expectedFilter, Optional.empty(), left, right); + } + + public static PlanMatchPattern join(JoinNode.Type joinType, List> expectedEquiCriteria, Optional expectedFilter, Optional distributionType, PlanMatchPattern left, PlanMatchPattern right) { return node(JoinNode.class, left, right).with( new JoinMatcher( joinType, expectedEquiCriteria, - expectedFilter.map(predicate -> rewriteIdentifiersToSymbolReferences(new SqlParser().createExpression(predicate))))); + expectedFilter.map(predicate -> rewriteIdentifiersToSymbolReferences(new SqlParser().createExpression(predicate))), + distributionType)); } public static PlanMatchPattern exchange(PlanMatchPattern... sources) From 5454fffc1d49cf911eb443d1a25ea111f81e7c26 Mon Sep 17 00:00:00 2001 From: Rebecca Schlussel Date: Thu, 29 Jun 2017 14:03:40 -0400 Subject: [PATCH 04/15] Support inserting stats for plan unit tests --- .../presto/testing/TestingLookup.java | 95 +++++++++++++++++++ .../iterative/rule/test/RuleAssert.java | 49 ++++++++-- 2 files changed, 138 insertions(+), 6 deletions(-) create mode 100644 presto-main/src/main/java/com/facebook/presto/testing/TestingLookup.java diff --git a/presto-main/src/main/java/com/facebook/presto/testing/TestingLookup.java b/presto-main/src/main/java/com/facebook/presto/testing/TestingLookup.java new file mode 100644 index 000000000000..7099545aecbe --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/testing/TestingLookup.java @@ -0,0 +1,95 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.testing; + +import com.facebook.presto.Session; +import com.facebook.presto.cost.CostCalculator; +import com.facebook.presto.cost.PlanNodeCostEstimate; +import com.facebook.presto.cost.PlanNodeStatsEstimate; +import com.facebook.presto.cost.StatsCalculator; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.GroupReference; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.google.common.collect.ImmutableMap; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.Function; + +import static java.util.Objects.requireNonNull; + +public class TestingLookup + implements Lookup +{ + private final StatsCalculator statsCalculator; + private final CostCalculator costCalculator; + private final Map stats = new HashMap<>(); + private final Map costs = new HashMap<>(); + private final Function resolver; + + public TestingLookup(StatsCalculator statsCalculator, CostCalculator costCalculator, Function resolver) + { + this(statsCalculator, costCalculator, ImmutableMap.of(), ImmutableMap.of(), resolver); + } + + private TestingLookup(StatsCalculator statsCalculator, CostCalculator costCalculator, Map stats, Map costs, Function resolver) + { + this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null"); + this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); + this.stats.putAll(stats); + this.costs.putAll(costs); + this.resolver = requireNonNull(resolver, "resolver is null"); + } + + public TestingLookup withStats(Map stats) + { + return new TestingLookup(statsCalculator, costCalculator, stats, ImmutableMap.of(), resolver); + } + + @Override + public PlanNode resolve(PlanNode node) + { + if (node instanceof GroupReference) { + return resolver.apply((GroupReference) node); + } + return node; + } + + @Override + public PlanNodeStatsEstimate getStats(PlanNode planNode, Session session, Map types) + { + PlanNode resolved = resolve(planNode); + PlanNodeStatsEstimate statsEstimate = stats.get(resolved); + if (statsEstimate == null) { + statsEstimate = statsCalculator.calculateStats(resolved, this, session, types); + stats.put(resolved, statsEstimate); + } + return statsEstimate; + } + + @Override + public PlanNodeCostEstimate getCumulativeCost(PlanNode planNode, Session session, Map types) + { + PlanNode resolved = resolve(planNode); + PlanNodeCostEstimate costEstimate = costs.get(resolved); + if (costEstimate == null) { + costEstimate = costCalculator.calculateCumulativeCost(resolved, this, session, types); + costs.put(resolved, costEstimate); + } + return costEstimate; + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java index 2b79a9d07992..30dba2e2dd40 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java @@ -15,6 +15,7 @@ import com.facebook.presto.Session; import com.facebook.presto.cost.CostCalculator; +import com.facebook.presto.cost.PlanNodeStatsEstimate; import com.facebook.presto.cost.StatsCalculator; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.security.AccessControl; @@ -28,7 +29,9 @@ import com.facebook.presto.sql.planner.iterative.Memo; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.sql.planner.planPrinter.PlanPrinter; +import com.facebook.presto.testing.TestingLookup; import com.facebook.presto.transaction.TransactionManager; import com.google.common.collect.ImmutableSet; @@ -37,27 +40,39 @@ import java.util.function.Function; import static com.facebook.presto.sql.planner.assertions.PlanAssert.assertPlan; +import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; import static com.facebook.presto.transaction.TransactionBuilder.transaction; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static java.util.Objects.requireNonNull; import static org.testng.Assert.fail; public class RuleAssert { private final Metadata metadata; + private final StatsCalculator statsCalculator; + private final CostCalculator costCalculator; private Session session; private final Rule rule; private final PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); private Map symbols; + private TestingLookup lookup; private PlanNode plan; private final TransactionManager transactionManager; private final AccessControl accessControl; - private final StatsCalculator statsCalculator; - private final CostCalculator costCalculator; - - public RuleAssert(Metadata metadata, Session session, Rule rule, TransactionManager transactionManager, AccessControl accessControl, StatsCalculator statsCalculator, CostCalculator costCalculator) + private Memo memo; + + public RuleAssert( + Metadata metadata, + Session session, + Rule rule, + TransactionManager transactionManager, + AccessControl accessControl, + StatsCalculator statsCalculator, + CostCalculator costCalculator) { this.metadata = metadata; this.session = session; @@ -88,9 +103,33 @@ public RuleAssert on(Function planProvider) PlanBuilder builder = new PlanBuilder(idAllocator, metadata); plan = planProvider.apply(builder); symbols = builder.getSymbols(); + memo = new Memo(idAllocator, plan); + lookup = new TestingLookup(statsCalculator, costCalculator, memo::resolve); + return this; + } + + public RuleAssert withStats(Map stats) + { + checkState(lookup != null, "lookup has not yet been initialized"); + Map planNodeMap = buildPlanNodeMap(); + lookup = lookup.withStats( + stats.entrySet() + .stream() + .collect(toImmutableMap( + entry -> { + checkState(planNodeMap.containsKey(entry.getKey()), "planNodeMap does not contain key"); + return planNodeMap.get(entry.getKey()); + }, + Map.Entry::getValue))); return this; } + private Map buildPlanNodeMap() + { + return searchFrom(plan, lookup).findAll().stream() + .collect(toImmutableMap(PlanNode::getId, planNode -> planNode)); + } + public void doesNotFire() { RuleApplication ruleApplication = applyRule(); @@ -143,8 +182,6 @@ public void matches(PlanMatchPattern pattern) private RuleApplication applyRule() { SymbolAllocator symbolAllocator = new SymbolAllocator(symbols); - Memo memo = new Memo(idAllocator, plan); - Lookup lookup = Lookup.from(memo::resolve, statsCalculator, costCalculator); if (!rule.getPattern().matches(plan)) { return new RuleApplication(lookup, symbolAllocator.getTypes(), Optional.empty()); From 0e609c8d805f018cfc6b4f8e20b1bccf98764249 Mon Sep 17 00:00:00 2001 From: Rebecca Schlussel Date: Thu, 29 Jun 2017 14:03:41 -0400 Subject: [PATCH 05/15] Support passing statsCalculator to RuleAssert --- .../sql/planner/iterative/rule/test/RuleAssert.java | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java index 30dba2e2dd40..1a7d97ea7c2d 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java @@ -51,7 +51,7 @@ public class RuleAssert { private final Metadata metadata; - private final StatsCalculator statsCalculator; + private StatsCalculator statsCalculator; private final CostCalculator costCalculator; private Session session; private final Rule rule; @@ -96,6 +96,13 @@ public RuleAssert withSession(Session session) return this; } + public RuleAssert withStatsCalculator(StatsCalculator statsCalculator) + { + checkState(lookup == null, "lookup has been set"); + this.statsCalculator = statsCalculator; + return this; + } + public RuleAssert on(Function planProvider) { checkArgument(plan == null, "plan has already been set"); From 6f496082b41f003f3b6d51f68deec404fee3bf88 Mon Sep 17 00:00:00 2001 From: Rebecca Schlussel Date: Thu, 29 Jun 2017 14:03:42 -0400 Subject: [PATCH 06/15] Support using a fake node count for unit tests Allowing the LocalQueryRunner to estimate costs using a fake node count allows unit tests to consider network costs and different cluster configurations. --- .../presto/testing/LocalQueryRunner.java | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java index d938c37b8e7f..268336a2e88b 100644 --- a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java +++ b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java @@ -251,15 +251,16 @@ public LocalQueryRunner(Session defaultSession) new FeaturesConfig() .setOptimizeMixedDistinctAggregations(true) .setIterativeOptimizerEnabled(true), - false); + false, + 1); } public LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig) { - this(defaultSession, featuresConfig, false); + this(defaultSession, featuresConfig, false, 1); } - private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, boolean withInitialTransaction) + private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, boolean withInitialTransaction, int nodeCountForStats) { requireNonNull(defaultSession, "defaultSession is null"); checkArgument(!defaultSession.getTransactionId().isPresent() || !withInitialTransaction, "Already in transaction"); @@ -379,14 +380,19 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, new CoefficientBasedStatsCalculator(metadata), ServerMainModule.createNewStatsCalculator(metadata, new FilterStatsCalculator(metadata), new ScalarStatsCalculator(metadata))); this.costCalculator = new CostCalculatorUsingExchanges(getNodeCount()); - this.estimatedExchangesCostCalculator = new CostCalculatorWithEstimatedExchanges(costCalculator, getNodeCount()); + this.estimatedExchangesCostCalculator = new CostCalculatorWithEstimatedExchanges(costCalculator, nodeCountForStats); this.lookup = new StatelessLookup(statsCalculator, costCalculator); } public static LocalQueryRunner queryRunnerWithInitialTransaction(Session defaultSession) { checkArgument(!defaultSession.getTransactionId().isPresent(), "Already in transaction!"); - return new LocalQueryRunner(defaultSession, new FeaturesConfig(), true); + return new LocalQueryRunner(defaultSession, new FeaturesConfig(), true, 1); + } + + public static LocalQueryRunner queryRunnerWithFakeNodeCountForStats(Session defaultSession, int nodeCount) + { + return new LocalQueryRunner(defaultSession, new FeaturesConfig(), false, nodeCount); } @Override From 8d126e1a24b03b3f8181480efc0dc65ae0a507c7 Mon Sep 17 00:00:00 2001 From: Rebecca Schlussel Date: Thu, 29 Jun 2017 14:03:44 -0400 Subject: [PATCH 07/15] Make binaryExpression() handle empty list Change ExpressionUtils.binaryExpression to return TRUE on an empty list --- .../com/facebook/presto/sql/ExpressionUtils.java | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/ExpressionUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/ExpressionUtils.java index 4ecc694ec6a0..012ea3a6b51d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/ExpressionUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/ExpressionUtils.java @@ -26,7 +26,6 @@ import com.facebook.presto.sql.tree.LogicalBinaryExpression; import com.facebook.presto.sql.tree.NotExpression; import com.facebook.presto.sql.tree.SymbolReference; -import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; @@ -103,7 +102,17 @@ public static Expression binaryExpression(LogicalBinaryExpression.Type type, Col { requireNonNull(type, "type is null"); requireNonNull(expressions, "expressions is null"); - Preconditions.checkArgument(!expressions.isEmpty(), "expressions is empty"); + + if (expressions.isEmpty()) { + switch (type) { + case AND: + return TRUE_LITERAL; + case OR: + return FALSE_LITERAL; + default: + throw new IllegalArgumentException("Unsupported LogicalBinaryExpression type"); + } + } // Build balanced tree for efficient recursive processing that // preserves the evaluation order of the input expressions. @@ -309,7 +318,8 @@ public static Expression normalize(Expression expression) public static Expression rewriteIdentifiersToSymbolReferences(Expression expression) { - return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter() { + return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter() + { @Override public Expression rewriteIdentifier(Identifier node, Void context, ExpressionTreeRewriter treeRewriter) { From 3ba7f241e3ca838d686bd4b6f771f296a5011af4 Mon Sep 17 00:00:00 2001 From: Rebecca Schlussel Date: Thu, 29 Jun 2017 14:03:45 -0400 Subject: [PATCH 08/15] Add methods to flip join and set distribution type --- .../presto/sql/planner/plan/JoinNode.java | 66 +++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/JoinNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/JoinNode.java index 3ecc8862ae35..0eb4615f145f 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/JoinNode.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/JoinNode.java @@ -28,10 +28,14 @@ import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.stream.Collectors; import java.util.stream.Stream; import static com.facebook.presto.sql.planner.SortExpressionExtractor.extractSortExpression; +import static com.facebook.presto.sql.planner.plan.JoinNode.Type.FULL; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; +import static com.facebook.presto.sql.planner.plan.JoinNode.Type.LEFT; +import static com.facebook.presto.sql.planner.plan.JoinNode.Type.RIGHT; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; @@ -100,6 +104,58 @@ public JoinNode(@JsonProperty("id") PlanNodeId id, checkArgument(!(criteria.isEmpty() && rightHashSymbol.isPresent()), "Right hash symbol is only valid in an equijoin"); } + public JoinNode flipChildren() + { + return new JoinNode( + getId(), + flipType(type), + right, + left, + flipJoinCriteria(criteria), + flipOutputSymbols(getOutputSymbols(), left, right), + filter, + rightHashSymbol, + leftHashSymbol, + distributionType); + } + + private static Type flipType(Type type) + { + switch (type) { + case INNER: + return INNER; + case FULL: + return FULL; + case LEFT: + return RIGHT; + case RIGHT: + return LEFT; + default: + throw new IllegalStateException("No inverse defined for join type: " + type); + } + } + + private static List flipJoinCriteria(List joinCriteria) + { + return joinCriteria.stream() + .map(EquiJoinClause::flip) + .collect(toImmutableList()); + } + + private static List flipOutputSymbols(List outputSymbols, PlanNode left, PlanNode right) + { + List leftSymbols = outputSymbols.stream() + .filter(symbol -> left.getOutputSymbols().contains(symbol)) + .collect(Collectors.toList()); + List rightSymbols = outputSymbols.stream() + .filter(symbol -> right.getOutputSymbols().contains(symbol)) + .collect(Collectors.toList()); + return ImmutableList.builder() + .addAll(rightSymbols) + .addAll(leftSymbols) + .build(); + } + public enum DistributionType { PARTITIONED, @@ -230,6 +286,11 @@ public PlanNode replaceChildren(List newChildren) return new JoinNode(getId(), type, newLeft, newRight, criteria, newOutputSymbols, filter, leftHashSymbol, rightHashSymbol, distributionType); } + public JoinNode withDistributionType(DistributionType distributionType) + { + return new JoinNode(getId(), type, left, right, criteria, outputSymbols, filter, leftHashSymbol, rightHashSymbol, Optional.of(distributionType)); + } + public boolean isCrossJoin() { return criteria.isEmpty() && !filter.isPresent() && type == INNER; @@ -264,6 +325,11 @@ public ComparisonExpression toExpression() return new ComparisonExpression(ComparisonExpressionType.EQUAL, left.toSymbolReference(), right.toSymbolReference()); } + public EquiJoinClause flip() + { + return new EquiJoinClause(right, left); + } + @Override public boolean equals(Object obj) { From e6d0332de41acaa5658209216e841ca6cd69b9b4 Mon Sep 17 00:00:00 2001 From: Rebecca Schlussel Date: Thu, 29 Jun 2017 14:03:46 -0400 Subject: [PATCH 09/15] Add ReorderJoins rule to pick the best join order Add a rule to enumerate join order possibilities for a join graph and choose the least cost option. This does a minimal form of cross join elimination, by only partitioning nodes into groups that have at least one edge between them, which eliminates some unnecessary cross joins from consideration. It also means that necessary cross joins will always be executed as late as possible in the plan (which may be worse). --- .../presto/sql/planner/PlanOptimizers.java | 22 +- .../planner/iterative/rule/MultiJoinNode.java | 116 +++++ .../planner/iterative/rule/ReorderJoins.java | 406 ++++++++++++++++++ .../DetermineJoinDistributionType.java | 3 + .../iterative/rule/TestJoinEnumerator.java | 101 +++++ .../rule/TestMultiJoinNodeBuilder.java | 248 +++++++++++ .../iterative/rule/TestReorderJoins.java | 361 ++++++++++++++++ .../iterative/rule/test/PlanBuilder.java | 13 +- 8 files changed, 1268 insertions(+), 2 deletions(-) create mode 100644 presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MultiJoinNode.java create mode 100644 presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java create mode 100644 presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.java create mode 100644 presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMultiJoinNodeBuilder.java create mode 100644 presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReorderJoins.java diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index a1f4c44c2834..f279f9699829 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -55,6 +55,7 @@ import com.facebook.presto.sql.planner.iterative.rule.RemoveEmptyDelete; import com.facebook.presto.sql.planner.iterative.rule.RemoveFullSample; import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantIdentityProjections; +import com.facebook.presto.sql.planner.iterative.rule.ReorderJoins; import com.facebook.presto.sql.planner.iterative.rule.SimplifyCountOverConstant; import com.facebook.presto.sql.planner.iterative.rule.SingleMarkDistinctToGroupBy; import com.facebook.presto.sql.planner.iterative.rule.SwapAdjacentWindowsBySpecifications; @@ -322,13 +323,32 @@ public PlanOptimizers( ImmutableList.of(new com.facebook.presto.sql.planner.optimizations.EliminateCrossJoins()), // This can pull up Filter and Project nodes from between Joins, so we need to push them down again ImmutableSet.of(new EliminateCrossJoins()) ), + new PredicatePushDown(metadata, sqlParser), new IterativeOptimizer( stats, statsCalculator, estimatedExchangesCostCalculator, ImmutableSet.of(new PushDownTableConstraints(metadata, sqlParser))), - projectionPushDown); + projectionPushDown, + new PruneUnreferencedOutputs(), + new IterativeOptimizer( + stats, + statsCalculator, + estimatedExchangesCostCalculator, + ImmutableSet.of(new RemoveRedundantIdentityProjections()) + ), + + // Because ReorderJoins runs only once, + // PredicatePushDown, PruneUnreferenedOutputpus and RemoveRedundantIdentityProjections + // need to run beforehand in order to produce an optimal join order + // It also needs to run after EliminateCrossJoins so that its chosen order doesn't get undone. + new IterativeOptimizer( + stats, + statsCalculator, + estimatedExchangesCostCalculator, + ImmutableSet.of(new ReorderJoins(costComparator)) + )); if (featuresConfig.isOptimizeSingleDistinct()) { builder.add( diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MultiJoinNode.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MultiJoinNode.java new file mode 100644 index 000000000000..ea394fd43b8e --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MultiJoinNode.java @@ -0,0 +1,116 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.tree.Expression; +import com.google.common.collect.ImmutableList; + +import java.util.ArrayList; +import java.util.List; + +import static com.facebook.presto.sql.ExpressionUtils.and; +import static com.facebook.presto.sql.planner.DeterminismEvaluator.isDeterministic; +import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; +import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +/** + * This class represents a set of inner joins that can be executed in any order. + */ +class MultiJoinNode +{ + private static final int JOIN_LIMIT = 10; + + private final List sources; + private final Expression filter; + private final List outputSymbols; + + public MultiJoinNode(List sources, Expression filter, List outputSymbols) + { + this.sources = ImmutableList.copyOf(requireNonNull(sources, "sources is null")); + this.filter = requireNonNull(filter, "filter is null"); + this.outputSymbols = ImmutableList.copyOf(requireNonNull(outputSymbols, "outputSymbols is null")); + + List inputSymbols = sources.stream().flatMap(source -> source.getOutputSymbols().stream()).collect(toImmutableList()); + checkArgument(inputSymbols.containsAll(outputSymbols), "inputs do not contain all output symbols"); + } + + public Expression getFilter() + { + return filter; + } + + public List getSources() + { + return sources; + } + + public List getOutputSymbols() + { + return outputSymbols; + } + + static MultiJoinNode toMultiJoinNode(JoinNode joinNode, Lookup lookup) + { + return new MultiJoinNodeBuilder(joinNode, lookup).toMultiJoinNode(); + } + + private static class MultiJoinNodeBuilder + { + private final List sources = new ArrayList<>(); + private final List filters = new ArrayList<>(); + private final List outputSymbols; + private final Lookup lookup; + + MultiJoinNodeBuilder(JoinNode node, Lookup lookup) + { + requireNonNull(node, "node is null"); + checkState(node.getType() == INNER, "join type must be INNER"); + this.outputSymbols = node.getOutputSymbols(); + this.lookup = requireNonNull(lookup, "lookup is null"); + flattenNode(node); + } + + private void flattenNode(PlanNode node) + { + PlanNode resolved = lookup.resolve(node); + if (resolved instanceof JoinNode && sources.size() < JOIN_LIMIT) { + JoinNode joinNode = (JoinNode) resolved; + if (joinNode.getType() == INNER && isDeterministic(joinNode.getFilter().orElse(TRUE_LITERAL))) { + flattenNode(joinNode.getLeft()); + flattenNode(joinNode.getRight()); + joinNode.getCriteria().stream() + .map(JoinNode.EquiJoinClause::toExpression) + .forEach(filters::add); + joinNode.getFilter().ifPresent(filters::add); + return; + } + } + sources.add(node); + } + + MultiJoinNode toMultiJoinNode() + { + return new MultiJoinNode(sources, and(filters), outputSymbols); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java new file mode 100644 index 000000000000..d2081822b3b5 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java @@ -0,0 +1,406 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.cost.CostComparator; +import com.facebook.presto.cost.PlanNodeCostEstimate; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.planner.EqualityInference; +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.SymbolsExtractor; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.FilterNode; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.ProjectNode; +import com.facebook.presto.sql.tree.ComparisonExpression; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.SymbolReference; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Ordering; +import com.google.common.collect.Sets; +import io.airlift.log.Logger; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.stream.IntStream; +import java.util.stream.Stream; + +import static com.facebook.presto.SystemSessionProperties.getJoinDistributionType; +import static com.facebook.presto.SystemSessionProperties.getJoinReorderingStrategy; +import static com.facebook.presto.cost.PlanNodeCostEstimate.INFINITE_COST; +import static com.facebook.presto.cost.PlanNodeCostEstimate.UNKNOWN_COST; +import static com.facebook.presto.sql.ExpressionUtils.and; +import static com.facebook.presto.sql.ExpressionUtils.combineConjuncts; +import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinReorderingStrategy.COST_BASED; +import static com.facebook.presto.sql.planner.EqualityInference.createEqualityInference; +import static com.facebook.presto.sql.planner.iterative.rule.MultiJoinNode.toMultiJoinNode; +import static com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerationResult.INFINITE_COST_RESULT; +import static com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerationResult.UNKNOWN_COST_RESULT; +import static com.facebook.presto.sql.planner.plan.Assignments.identity; +import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.PARTITIONED; +import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.REPLICATED; +import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; +import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; +import static com.facebook.presto.sql.tree.ComparisonExpressionType.EQUAL; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Predicates.in; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.util.Objects.requireNonNull; +import static java.util.stream.StreamSupport.stream; + +public class ReorderJoins + implements Rule +{ + private static final Logger log = Logger.get(ReorderJoins.class); + + private final CostComparator costComparator; + + public ReorderJoins(CostComparator costComparator) + { + this.costComparator = requireNonNull(costComparator, "costComparator is null"); + } + + @Override + public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + { + if (!(node instanceof JoinNode) || getJoinReorderingStrategy(session) != COST_BASED) { + return Optional.empty(); + } + + JoinNode joinNode = (JoinNode) node; + // We check that join distribution type is absent because we only want to do this transformation once (reordered joins will have distribution type already set). + // We check determinisitic filters because we can't reorder joins with non-deterministic filters + if (!(joinNode.getType() == INNER) || joinNode.getDistributionType().isPresent()) { + return Optional.empty(); + } + + MultiJoinNode multiJoinNode = toMultiJoinNode(joinNode, lookup); + if (multiJoinNode.getSources().size() < 2) { + return Optional.empty(); + } + + JoinEnumerationResult result = new JoinEnumerator(idAllocator, symbolAllocator, session, lookup, multiJoinNode.getFilter(), costComparator).chooseJoinOrder(multiJoinNode.getSources(), multiJoinNode.getOutputSymbols()); + return result.getCost().isUnknown() || result.getCost().equals(INFINITE_COST) ? Optional.empty() : result.getPlanNode(); + } + + @VisibleForTesting + static class JoinEnumerator + { + private final Map, JoinEnumerationResult> memo = new HashMap<>(); + private final PlanNodeIdAllocator idAllocator; + private final Session session; + private final Ordering resultOrdering; + private final EqualityInference allInference; + private final Expression allFilter; + private final SymbolAllocator symbolAllocator; + private final Lookup lookup; + + @VisibleForTesting + JoinEnumerator(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session, Lookup lookup, Expression filter, CostComparator costComparator) + { + requireNonNull(idAllocator, "idAllocator is null"); + requireNonNull(symbolAllocator, "symbolAllocator is null"); + requireNonNull(session, "session is null"); + requireNonNull(lookup, "lookup is null"); + requireNonNull(filter, "filter is null"); + requireNonNull(costComparator, "costComparator is null"); + this.idAllocator = idAllocator; + this.symbolAllocator = symbolAllocator; + this.session = session; + this.lookup = lookup; + this.resultOrdering = getResultOrdering(costComparator, session); + this.allInference = createEqualityInference(filter); + this.allFilter = filter; + } + + private static Ordering getResultOrdering(CostComparator costComparator, Session session) + { + return new Ordering() + { + @Override + public int compare(JoinEnumerationResult result1, JoinEnumerationResult result2) + { + return costComparator.compare(session, result1.cost, result2.cost); + } + }; + } + + private JoinEnumerationResult chooseJoinOrder(List sources, List outputSymbols) + { + Set multiJoinKey = ImmutableSet.copyOf(sources); + JoinEnumerationResult bestResult = memo.get(multiJoinKey); + if (bestResult == null) { + checkState(sources.size() > 1, "sources size is less than or equal to one"); + ImmutableList.Builder resultBuilder = ImmutableList.builder(); + Set> partitions = generatePartitions(sources.size()).collect(toImmutableSet()); + for (Set partition : partitions) { + JoinEnumerationResult result = createJoinAccordingToPartitioning(sources, outputSymbols, partition); + if (result.cost.isUnknown()) { + memo.put(multiJoinKey, result); + return result; + } + if (!result.cost.equals(INFINITE_COST)) { + resultBuilder.add(result); + } + } + + List results = resultBuilder.build(); + if (results.isEmpty()) { + memo.put(multiJoinKey, INFINITE_COST_RESULT); + return INFINITE_COST_RESULT; + } + + bestResult = resultOrdering.min(resultBuilder.build()); + memo.put(multiJoinKey, bestResult); + } + if (bestResult.planNode.isPresent()) { + log.debug("Least cost join was: " + bestResult.planNode.get().toString()); + } + return bestResult; + } + + /** + * This method generates all the ways of dividing totalNodes into two sets + * each containing at least one node. It will generate one set for each + * possible partitioning. The other partition is implied in the absent values. + * In order not to generate the inverse of any set, we always include the 0th + * node in our sets. + * + * @param totalNodes + * @return A set of sets each of which defines a partitioning of totalNodes + */ + @VisibleForTesting + static Stream> generatePartitions(int totalNodes) + { + checkArgument(totalNodes >= 2, "totalNodes must be greater than or equal to 2"); + Set numbers = IntStream.range(0, totalNodes) + .boxed() + .collect(toImmutableSet()); + return Sets.powerSet(numbers).stream() + .filter(subSet -> subSet.contains(0)) + .filter(subSet -> subSet.size() < numbers.size()); + } + + JoinEnumerationResult createJoinAccordingToPartitioning(List sources, List outputSymbols, Set partitioning) + { + Set leftSources = partitioning.stream() + .map(sources::get) + .collect(toImmutableSet()); + Set rightSources = Sets.difference(ImmutableSet.copyOf(sources), ImmutableSet.copyOf(leftSources)); + return createJoin(leftSources, rightSources, outputSymbols); + } + + private JoinEnumerationResult createJoin(Set leftSources, Set rightSources, List outputSymbols) + { + Set leftSymbols = leftSources.stream() + .flatMap(node -> node.getOutputSymbols().stream()) + .collect(toImmutableSet()); + Set rightSymbols = rightSources.stream() + .flatMap(node -> node.getOutputSymbols().stream()) + .collect(toImmutableSet()); + ImmutableList.Builder joinPredicatesBuilder = ImmutableList.builder(); + + // add join conjucts that were not used for inference + stream(EqualityInference.nonInferrableConjuncts(allFilter).spliterator(), false) + .map(conjuct -> allInference.rewriteExpression(conjuct, symbol -> leftSymbols.contains(symbol) || rightSymbols.contains(symbol))) + .filter(Objects::nonNull) + // filter expressions that contain only left or right symbols + .filter(conjuct -> allInference.rewriteExpression(conjuct, leftSymbols::contains) == null) + .filter(conjuct -> allInference.rewriteExpression(conjuct, rightSymbols::contains) == null) + .forEach(joinPredicatesBuilder::add); + + // create equality inference on available symbols + // TODO: make generateEqualitiesPartitionedBy take left and right scope + List joinEqualities = allInference.generateEqualitiesPartitionedBy(symbol -> leftSymbols.contains(symbol) || rightSymbols.contains(symbol)).getScopeEqualities(); + EqualityInference joinInference = createEqualityInference(joinEqualities.toArray(new Expression[joinEqualities.size()])); + joinPredicatesBuilder.addAll(joinInference.generateEqualitiesPartitionedBy(in(leftSymbols)).getScopeStraddlingEqualities()); + + List joinPredicates = joinPredicatesBuilder.build(); + List joinConditions = joinPredicates.stream() + .filter(JoinEnumerator::isJoinEqualityCondition) + .map(predicate -> toEquiJoinClause((ComparisonExpression) predicate, leftSymbols)) + .collect(toImmutableList()); + if (joinConditions.isEmpty()) { + return INFINITE_COST_RESULT; + } + List joinFilters = joinPredicates.stream() + .filter(predicate -> !isJoinEqualityCondition(predicate)) + .collect(toImmutableList()); + + Set requiredJoinSymbols = ImmutableSet.builder() + .addAll(outputSymbols) + .addAll(SymbolsExtractor.extractUnique(joinPredicates)) + .build(); + + JoinEnumerationResult leftResult = getJoinSource( + idAllocator, + ImmutableList.copyOf(leftSources), + requiredJoinSymbols.stream().filter(leftSymbols::contains).collect(toImmutableList())); + if (leftResult.cost.isUnknown()) { + return UNKNOWN_COST_RESULT; + } + if (leftResult.cost.equals(INFINITE_COST)) { + return INFINITE_COST_RESULT; + } + PlanNode left = leftResult.planNode.orElseThrow(() -> new IllegalStateException("no planNode present")); + JoinEnumerationResult rightResult = getJoinSource( + idAllocator, + ImmutableList.copyOf(rightSources), + requiredJoinSymbols.stream() + .filter(rightSymbols::contains) + .collect(toImmutableList())); + if (rightResult.cost.isUnknown()) { + return UNKNOWN_COST_RESULT; + } + if (rightResult.cost.equals(INFINITE_COST)) { + return INFINITE_COST_RESULT; + } + PlanNode right = rightResult.planNode.orElseThrow(() -> new IllegalStateException("no planNode present")); + + // sort output symbols so that the left input symbols are first + List sortedOutputSymbols = Stream.concat(left.getOutputSymbols().stream(), right.getOutputSymbols().stream()) + .filter(outputSymbols::contains) + .collect(toImmutableList()); + + // Cross joins can't filter symbols as part of the join + // If we're doing a cross join, use all output symbols from the inputs and add a project node + // on top + List joinOutputSymbols = sortedOutputSymbols; + if (joinConditions.isEmpty() && joinFilters.isEmpty()) { + joinOutputSymbols = Stream.concat(left.getOutputSymbols().stream(), right.getOutputSymbols().stream()) + .collect(toImmutableList()); + } + + JoinEnumerationResult result = setJoinNodeProperties(new JoinNode( + idAllocator.getNextId(), + INNER, + left, + right, + joinConditions, + joinOutputSymbols, + joinFilters.isEmpty() ? Optional.empty() : Optional.of(and(joinFilters)), + Optional.empty(), + Optional.empty(), + Optional.empty())); + + if (!joinOutputSymbols.equals(sortedOutputSymbols)) { + PlanNode resultNode = new ProjectNode(idAllocator.getNextId(), result.planNode.get(), identity(sortedOutputSymbols)); + result = new JoinEnumerationResult(lookup.getCumulativeCost(resultNode, session, symbolAllocator.getTypes()), Optional.of(resultNode)); + } + + return result; + } + + private JoinEnumerationResult getJoinSource(PlanNodeIdAllocator idAllocator, List nodes, List outputSymbols) + { + PlanNode planNode; + if (nodes.size() == 1) { + planNode = getOnlyElement(nodes); + ImmutableList.Builder predicates = ImmutableList.builder(); + predicates.addAll(allInference.generateEqualitiesPartitionedBy(outputSymbols::contains).getScopeEqualities()); + stream(EqualityInference.nonInferrableConjuncts(allFilter).spliterator(), false) + .map(conjuct -> allInference.rewriteExpression(conjuct, outputSymbols::contains)) + .filter(Objects::nonNull) + .forEach(predicates::add); + Expression filter = combineConjuncts(predicates.build()); + if (!(TRUE_LITERAL).equals(filter)) { + planNode = new FilterNode(idAllocator.getNextId(), planNode, filter); + } + return new JoinEnumerationResult(lookup.getCumulativeCost(planNode, session, symbolAllocator.getTypes()), Optional.of(planNode)); + } + return chooseJoinOrder(nodes, outputSymbols); + } + + private static boolean isJoinEqualityCondition(Expression expression) + { + return expression instanceof ComparisonExpression + && ((ComparisonExpression) expression).getType() == EQUAL + && ((ComparisonExpression) expression).getLeft() instanceof SymbolReference + && ((ComparisonExpression) expression).getRight() instanceof SymbolReference; + } + + private static JoinNode.EquiJoinClause toEquiJoinClause(ComparisonExpression equality, Set leftSymbols) + { + Symbol leftSymbol = Symbol.from(equality.getLeft()); + Symbol rightSymbol = Symbol.from(equality.getRight()); + JoinNode.EquiJoinClause equiJoinClause = new JoinNode.EquiJoinClause(leftSymbol, rightSymbol); + return leftSymbols.contains(leftSymbol) ? equiJoinClause : equiJoinClause.flip(); + } + + private JoinEnumerationResult setJoinNodeProperties(JoinNode joinNode) + { + List possibleJoinNodes = new ArrayList<>(); + FeaturesConfig.JoinDistributionType joinDistributionType = getJoinDistributionType(session); + if (joinDistributionType.canRepartition() && !joinNode.isCrossJoin()) { + JoinNode node = joinNode.withDistributionType(PARTITIONED); + possibleJoinNodes.add(new JoinEnumerationResult(lookup.getCumulativeCost(node, session, symbolAllocator.getTypes()), Optional.of(node))); + node = node.flipChildren(); + possibleJoinNodes.add(new JoinEnumerationResult(lookup.getCumulativeCost(node, session, symbolAllocator.getTypes()), Optional.of(node))); + } + if (joinDistributionType.canReplicate()) { + JoinNode node = joinNode.withDistributionType(REPLICATED); + possibleJoinNodes.add(new JoinEnumerationResult(lookup.getCumulativeCost(node, session, symbolAllocator.getTypes()), Optional.of(node))); + node = node.flipChildren(); + possibleJoinNodes.add(new JoinEnumerationResult(lookup.getCumulativeCost(node, session, symbolAllocator.getTypes()), Optional.of(node))); + } + if (possibleJoinNodes.stream().anyMatch(result -> result.cost.isUnknown())) { + return UNKNOWN_COST_RESULT; + } + return resultOrdering.min(possibleJoinNodes); + } + } + + @VisibleForTesting + static class JoinEnumerationResult + { + static final JoinEnumerationResult UNKNOWN_COST_RESULT = new JoinEnumerationResult(UNKNOWN_COST, Optional.empty()); + static final JoinEnumerationResult INFINITE_COST_RESULT = new JoinEnumerationResult(INFINITE_COST, Optional.empty()); + + private final Optional planNode; + private final PlanNodeCostEstimate cost; + + private JoinEnumerationResult(PlanNodeCostEstimate cost, Optional planNode) + { + this.cost = requireNonNull(cost); + this.planNode = requireNonNull(planNode); + checkArgument(cost.isUnknown() || cost.equals(INFINITE_COST) || planNode.isPresent(), "planNode must be present if cost is known"); + } + + public Optional getPlanNode() + { + return planNode; + } + + public PlanNodeCostEstimate getCost() + { + return cost; + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/DetermineJoinDistributionType.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/DetermineJoinDistributionType.java index 6a9f6991f298..424a8f4f3135 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/DetermineJoinDistributionType.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/DetermineJoinDistributionType.java @@ -111,6 +111,9 @@ public PlanNode visitDelete(DeleteNode node, RewriteContext context) private JoinNode.DistributionType getTargetJoinDistributionType(JoinNode node) { + if (node.getDistributionType().isPresent()) { + return node.getDistributionType().get(); + } // The implementation of full outer join only works if the data is hash partitioned. See LookupJoinOperators#buildSideOuterJoinUnvisitedPositions JoinNode.Type type = node.getType(); if (type == RIGHT || type == FULL || (isRepartitionedJoinEnabled(session) && !mustBroadcastJoin(node))) { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.java new file mode 100644 index 000000000000..c2449777969c --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestJoinEnumerator.java @@ -0,0 +1,101 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.cost.CostComparator; +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.SymbolAllocator; +import com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerationResult; +import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; +import com.facebook.presto.testing.LocalQueryRunner; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.util.Set; + +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerator.generatePartitions; +import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.airlift.testing.Closeables.closeAllRuntimeException; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; + +public class TestJoinEnumerator +{ + private LocalQueryRunner queryRunner; + + @BeforeClass + public void setUp() + { + queryRunner = new LocalQueryRunner(testSessionBuilder().build()); + } + + @AfterClass(alwaysRun = true) + public void tearDown() + { + closeAllRuntimeException(queryRunner); + queryRunner = null; + } + + @Test + public void testGeneratePartitions() + { + Set> partitions = generatePartitions(4).collect(toImmutableSet()); + assertEquals(partitions, + ImmutableSet.of( + ImmutableSet.of(0), + ImmutableSet.of(0, 1), + ImmutableSet.of(0, 2), + ImmutableSet.of(0, 3), + ImmutableSet.of(0, 1, 2), + ImmutableSet.of(0, 1, 3), + ImmutableSet.of(0, 2, 3))); + + partitions = generatePartitions(3).collect(toImmutableSet()); + assertEquals(partitions, + ImmutableSet.of( + ImmutableSet.of(0), + ImmutableSet.of(0, 1), + ImmutableSet.of(0, 2))); + } + + @Test + public void testDoesNotCreateJoinWhenPartitionedOnCrossJoin() + { + PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + PlanBuilder planBuilder = new PlanBuilder(idAllocator, queryRunner.getMetadata()); + Symbol a1 = planBuilder.symbol("A1", BIGINT); + Symbol b1 = planBuilder.symbol("B1", BIGINT); + MultiJoinNode multiJoinNode = new MultiJoinNode( + ImmutableList.of(planBuilder.values(a1), planBuilder.values(b1)), + TRUE_LITERAL, + ImmutableList.of(a1, b1)); + ReorderJoins.JoinEnumerator joinEnumerator = new ReorderJoins.JoinEnumerator( + idAllocator, + new SymbolAllocator(), + queryRunner.getDefaultSession(), + queryRunner.getLookup(), + multiJoinNode.getFilter(), + new CostComparator(1, 1, 1)); + JoinEnumerationResult actual = joinEnumerator.createJoinAccordingToPartitioning(multiJoinNode.getSources(), multiJoinNode.getOutputSymbols(), ImmutableSet.of(0)); + assertFalse(actual.getPlanNode().isPresent()); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMultiJoinNodeBuilder.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMultiJoinNodeBuilder.java new file mode 100644 index 000000000000..1f9d8d402de6 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMultiJoinNodeBuilder.java @@ -0,0 +1,248 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.sql.planner.PlanNodeIdAllocator; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.ValuesNode; +import com.facebook.presto.sql.tree.ArithmeticBinaryExpression; +import com.facebook.presto.sql.tree.ComparisonExpression; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.LongLiteral; +import com.facebook.presto.testing.LocalQueryRunner; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.sql.ExpressionUtils.and; +import static com.facebook.presto.sql.ExpressionUtils.extractConjuncts; +import static com.facebook.presto.sql.planner.iterative.rule.MultiJoinNode.toMultiJoinNode; +import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; +import static com.facebook.presto.sql.planner.plan.JoinNode.Type.LEFT; +import static com.facebook.presto.sql.tree.ArithmeticBinaryExpression.Type.ADD; +import static com.facebook.presto.sql.tree.ComparisonExpressionType.EQUAL; +import static com.facebook.presto.sql.tree.ComparisonExpressionType.GREATER_THAN; +import static com.facebook.presto.sql.tree.ComparisonExpressionType.LESS_THAN; +import static com.facebook.presto.sql.tree.ComparisonExpressionType.NOT_EQUAL; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static org.testng.Assert.assertEquals; + +public class TestMultiJoinNodeBuilder +{ + private final LocalQueryRunner queryRunner = new LocalQueryRunner(testSessionBuilder().build()); + + @Test(expectedExceptions = IllegalStateException.class) + public void testDoesNotFireForOuterJoins() + { + PlanBuilder p = new PlanBuilder(new PlanNodeIdAllocator(), queryRunner.getMetadata()); + JoinNode outerJoin = p.join( + JoinNode.Type.FULL, + p.values(p.symbol("A1", BIGINT)), + p.values(p.symbol("B1", BIGINT)), + ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), + ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), + Optional.empty()); + toMultiJoinNode(outerJoin, queryRunner.getLookup()); + } + + @Test + public void testDoesNotConvertNestedOuterJoins() + { + PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + PlanBuilder planBuilder = new PlanBuilder(idAllocator, queryRunner.getMetadata()); + Symbol a1 = planBuilder.symbol("A1", BIGINT); + Symbol b1 = planBuilder.symbol("B1", BIGINT); + Symbol c1 = planBuilder.symbol("C1", BIGINT); + JoinNode leftJoin = planBuilder.join( + LEFT, + planBuilder.values(a1), + planBuilder.values(b1), + ImmutableList.of(new JoinNode.EquiJoinClause(a1, b1)), + ImmutableList.of(a1, b1), + Optional.empty()); + ValuesNode valuesC = planBuilder.values(c1); + JoinNode joinNode = planBuilder.join( + INNER, + leftJoin, + valuesC, + ImmutableList.of(new JoinNode.EquiJoinClause(a1, c1)), + ImmutableList.of(a1, b1, c1), + Optional.empty()); + + MultiJoinNode expected = new MultiJoinNode(ImmutableList.of(leftJoin, valuesC), new ComparisonExpression(EQUAL, a1.toSymbolReference(), c1.toSymbolReference()), ImmutableList.of(a1, b1, c1)); + assertMultijoinEquals(toMultiJoinNode(joinNode, queryRunner.getLookup()), expected); + } + + @Test + public void testRetainsOutputSymbols() + { + PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + PlanBuilder planBuilder = new PlanBuilder(idAllocator, queryRunner.getMetadata()); + Symbol a1 = planBuilder.symbol("A1", BIGINT); + Symbol b1 = planBuilder.symbol("B1", BIGINT); + Symbol b2 = planBuilder.symbol("B2", BIGINT); + Symbol c1 = planBuilder.symbol("C1", BIGINT); + Symbol c2 = planBuilder.symbol("C2", BIGINT); + ValuesNode valuesA = planBuilder.values(a1); + ValuesNode valuesB = planBuilder.values(b1, b2); + ValuesNode valuesC = planBuilder.values(c1, c2); + JoinNode joinNode = planBuilder.join( + INNER, + valuesA, + planBuilder.join( + INNER, + valuesB, + valuesC, + ImmutableList.of(new JoinNode.EquiJoinClause(b1, c1)), + ImmutableList.of( + b1, + b2, + c1, + c2), + Optional.empty()), + ImmutableList.of(new JoinNode.EquiJoinClause(a1, b1)), + ImmutableList.of(a1, b1), + Optional.empty()); + MultiJoinNode expected = new MultiJoinNode( + ImmutableList.of(valuesA, valuesB, valuesC), + and(new ComparisonExpression(EQUAL, b1.toSymbolReference(), c1.toSymbolReference()), new ComparisonExpression(EQUAL, a1.toSymbolReference(), b1.toSymbolReference())), + ImmutableList.of(a1, b1)); + assertMultijoinEquals(toMultiJoinNode(joinNode, queryRunner.getLookup()), expected); + } + + @Test + public void testCombinesCriteriaAndFilters() + { + PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + PlanBuilder planBuilder = new PlanBuilder(idAllocator, queryRunner.getMetadata()); + Symbol a1 = planBuilder.symbol("A1", BIGINT); + Symbol b1 = planBuilder.symbol("B1", BIGINT); + Symbol b2 = planBuilder.symbol("B2", BIGINT); + Symbol c1 = planBuilder.symbol("C1", BIGINT); + Symbol c2 = planBuilder.symbol("C2", BIGINT); + ValuesNode valuesA = planBuilder.values(a1); + ValuesNode valuesB = planBuilder.values(b1, b2); + ValuesNode valuesC = planBuilder.values(c1, c2); + Expression bcFilter = and( + new ComparisonExpression(GREATER_THAN, c2.toSymbolReference(), new LongLiteral("0")), + new ComparisonExpression(NOT_EQUAL, c2.toSymbolReference(), new LongLiteral("7")), + new ComparisonExpression(GREATER_THAN, b2.toSymbolReference(), c2.toSymbolReference())); + ComparisonExpression abcFilter = new ComparisonExpression( + LESS_THAN, + new ArithmeticBinaryExpression(ADD, a1.toSymbolReference(), c1.toSymbolReference()), + b1.toSymbolReference()); + JoinNode joinNode = planBuilder.join( + INNER, + valuesA, + planBuilder.join( + INNER, + valuesB, + valuesC, + ImmutableList.of(new JoinNode.EquiJoinClause(b1, c1)), + ImmutableList.of( + b1, + b2, + c1, + c2), + Optional.of(bcFilter)), + ImmutableList.of(new JoinNode.EquiJoinClause(a1, b1)), + ImmutableList.of(a1, b1, b2, c1, c2), + Optional.of(abcFilter)); + MultiJoinNode expected = new MultiJoinNode( + ImmutableList.of(valuesA, valuesB, valuesC), + and(new ComparisonExpression(EQUAL, b1.toSymbolReference(), c1.toSymbolReference()), new ComparisonExpression(EQUAL, a1.toSymbolReference(), b1.toSymbolReference()), bcFilter, abcFilter), + ImmutableList.of(a1, b1, b2, c1, c2)); + assertMultijoinEquals(toMultiJoinNode(joinNode, queryRunner.getLookup()), expected); + } + + @Test + public void testConvertsBushyTrees() + { + PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + PlanBuilder planBuilder = new PlanBuilder(idAllocator, queryRunner.getMetadata()); + Symbol a1 = planBuilder.symbol("A1", BIGINT); + Symbol b1 = planBuilder.symbol("B1", BIGINT); + Symbol c1 = planBuilder.symbol("C1", BIGINT); + Symbol d1 = planBuilder.symbol("D1", BIGINT); + Symbol d2 = planBuilder.symbol("D2", BIGINT); + Symbol e1 = planBuilder.symbol("E1", BIGINT); + Symbol e2 = planBuilder.symbol("E2", BIGINT); + ValuesNode valuesA = planBuilder.values(a1); + ValuesNode valuesB = planBuilder.values(b1); + ValuesNode valuesC = planBuilder.values(c1); + ValuesNode valuesD = planBuilder.values(d1, d2); + ValuesNode valuesE = planBuilder.values(e1, e2); + JoinNode joinNode = planBuilder.join( + INNER, + planBuilder.join( + INNER, + planBuilder.join( + INNER, + valuesA, + valuesB, + ImmutableList.of(new JoinNode.EquiJoinClause(a1, b1)), + ImmutableList.of(a1, b1), + Optional.empty()), + valuesC, + ImmutableList.of(new JoinNode.EquiJoinClause(a1, c1)), + ImmutableList.of(a1, b1, c1), + Optional.empty()), + planBuilder.join( + INNER, + valuesD, + valuesE, + ImmutableList.of( + new JoinNode.EquiJoinClause(d1, e1), + new JoinNode.EquiJoinClause(d2, e2)), + ImmutableList.of( + d1, + d2, + e1, + e2), + Optional.empty()), + ImmutableList.of(new JoinNode.EquiJoinClause(b1, e1)), + ImmutableList.of( + a1, + b1, + c1, + d1, + d2, + e1, + e2), + Optional.empty()); + MultiJoinNode expected = new MultiJoinNode( + ImmutableList.of(valuesA, valuesB, valuesC, valuesD, valuesE), + and( + new ComparisonExpression(EQUAL, a1.toSymbolReference(), b1.toSymbolReference()), + new ComparisonExpression(EQUAL, a1.toSymbolReference(), c1.toSymbolReference()), + new ComparisonExpression(EQUAL, d1.toSymbolReference(), e1.toSymbolReference()), + new ComparisonExpression(EQUAL, d2.toSymbolReference(), e2.toSymbolReference()), + new ComparisonExpression(EQUAL, b1.toSymbolReference(), e1.toSymbolReference())), + ImmutableList.of(a1, b1, c1, d1, d2, e1, e2)); + assertMultijoinEquals(toMultiJoinNode(joinNode, queryRunner.getLookup()), expected); + } + + private static void assertMultijoinEquals(MultiJoinNode actual, MultiJoinNode expected) + { + assertEquals(ImmutableSet.copyOf(actual.getSources()), ImmutableSet.copyOf(expected.getSources())); + assertEquals(ImmutableSet.copyOf(extractConjuncts(actual.getFilter())), ImmutableSet.copyOf(extractConjuncts(expected.getFilter()))); + assertEquals(actual.getOutputSymbols(), expected.getOutputSymbols()); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReorderJoins.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReorderJoins.java new file mode 100644 index 000000000000..86081190b334 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReorderJoins.java @@ -0,0 +1,361 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.cost.CostComparator; +import com.facebook.presto.cost.PlanNodeStatsEstimate; +import com.facebook.presto.cost.StatsCalculator; +import com.facebook.presto.cost.SymbolStatsEstimate; +import com.facebook.presto.spi.TestingColumnHandle; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.PlanNodeId; +import com.facebook.presto.sql.tree.ComparisonExpression; +import com.facebook.presto.sql.tree.ComparisonExpressionType; +import com.facebook.presto.sql.tree.FunctionCall; +import com.facebook.presto.sql.tree.QualifiedName; +import com.facebook.presto.testing.LocalQueryRunner; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.util.Map; +import java.util.Optional; + +import static com.facebook.presto.cost.PlanNodeStatsEstimate.UNKNOWN_STATS; +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.equiJoinClause; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.PARTITIONED; +import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.REPLICATED; +import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; +import static com.facebook.presto.sql.tree.ComparisonExpressionType.EQUAL; +import static com.facebook.presto.testing.LocalQueryRunner.queryRunnerWithFakeNodeCountForStats; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static io.airlift.testing.Closeables.closeAllRuntimeException; + +public class TestReorderJoins +{ + private RuleTester tester; + + @BeforeClass + public void setUp() + { + Session session = testSessionBuilder() + .setCatalog("local") + .setSchema("tiny") + .setSystemProperty("join_distribution_type", "automatic") + .setSystemProperty("join_reordering_strategy", "COST_BASED") + .build(); + LocalQueryRunner queryRunner = queryRunnerWithFakeNodeCountForStats(session, 4); + tester = new RuleTester(queryRunner); + } + + @AfterClass(alwaysRun = true) + public void tearDown() + { + closeAllRuntimeException(tester); + tester = null; + } + + @Test + public void testKeepsOutputSymbols() + { + tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1))) + .on(p -> + p.join( + INNER, + p.values(new PlanNodeId("valuesA"), p.symbol("A1", BIGINT), p.symbol("A2", BIGINT)), + p.values(new PlanNodeId("valuesB"), p.symbol("B1", BIGINT)), + ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), + ImmutableList.of(p.symbol("A2", BIGINT)), + Optional.empty())) + .withStats(ImmutableMap.of( + new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(5000) + .setSymbolStatistics(ImmutableMap.of( + new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 100, 100), + new Symbol("A2"), new SymbolStatsEstimate(0, 100, 0, 100, 100))) + .build(), + new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(10000) + .setSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 100, 100))) + .build())) + .matches(join( + INNER, + ImmutableList.of(equiJoinClause("A1", "B1")), + Optional.empty(), + Optional.of(PARTITIONED), + values(ImmutableMap.of("A1", 0, "A2", 1)), + values(ImmutableMap.of("B1", 0)) + )); + } + + @Test + public void testReplicatesAndFlipsWhenOneTableMuchSmaller() + { + tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1))) + .on(p -> + p.join( + INNER, + p.values(new PlanNodeId("valuesA"), p.symbol("A1", BIGINT)), + p.values(new PlanNodeId("valuesB"), p.symbol("B1", BIGINT)), + ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), + ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), + Optional.empty())) + .withStats(ImmutableMap.of( + new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(100) + .setSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 6400, 100))) + .build(), + new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(10000) + .setSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .build())) + .matches(join( + INNER, + ImmutableList.of(equiJoinClause("B1", "A1")), + Optional.empty(), + Optional.of(REPLICATED), + values(ImmutableMap.of("B1", 0)), + values(ImmutableMap.of("A1", 0)) + )); + } + + @Test + public void testRepartitionsWhenRequiredBySession() + { + tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1))) + .on(p -> + p.join( + INNER, + p.values(new PlanNodeId("valuesA"), p.symbol("A1", BIGINT)), + p.values(new PlanNodeId("valuesB"), p.symbol("B1", BIGINT)), + ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), + ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), + Optional.empty())) + .setSystemProperty("join_distribution_type", "REPARTITIONED") + .withStats(ImmutableMap.of( + new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(100) + .setSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 6400, 100))) + .build(), + new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(10000) + .setSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .build())) + .matches(join( + INNER, + ImmutableList.of(equiJoinClause("B1", "A1")), + Optional.empty(), + Optional.of(PARTITIONED), + values(ImmutableMap.of("B1", 0)), + values(ImmutableMap.of("A1", 0)) + )); + } + + @Test + public void testRepartitionsWhenBothTablesEqual() + { + tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1))) + .on(p -> + p.join( + INNER, + p.values(new PlanNodeId("valuesA"), p.symbol("A1", BIGINT)), + p.values(new PlanNodeId("valuesB"), p.symbol("B1", BIGINT)), + ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), + ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), + Optional.empty())) + .withStats(ImmutableMap.of( + new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(10000) + .setSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .build(), + new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(10000) + .setSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .build())) + .matches(join( + INNER, + ImmutableList.of(equiJoinClause("A1", "B1")), + Optional.empty(), + Optional.of(PARTITIONED), + values(ImmutableMap.of("A1", 0)), + values(ImmutableMap.of("B1", 0)) + )); + } + + @Test + public void testReplicatesWhenRequiredBySession() + { + tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1))) + .on(p -> + p.join( + INNER, + p.values(new PlanNodeId("valuesA"), p.symbol("A1", BIGINT)), + p.values(new PlanNodeId("valuesB"), p.symbol("B1", BIGINT)), + ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), + ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), + Optional.empty())) + .setSystemProperty("join_distribution_type", "REPLICATED") + .withStats(ImmutableMap.of( + new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(10000) + .setSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .build(), + new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(10000) + .setSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .build())) + .matches(join( + INNER, + ImmutableList.of(equiJoinClause("A1", "B1")), + Optional.empty(), + Optional.of(REPLICATED), + values(ImmutableMap.of("A1", 0)), + values(ImmutableMap.of("B1", 0)) + )); + } + + @Test + public void testDoesNotFireForCrossJoin() + { + tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1))) + .on(p -> + p.join( + INNER, + p.values(new PlanNodeId("valuesA"), p.symbol("A1", BIGINT)), + p.values(new PlanNodeId("valuesB"), p.symbol("B1", BIGINT)), + ImmutableList.of(), + ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), + Optional.empty())) + .withStats(ImmutableMap.of( + new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(10000) + .setSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .build(), + new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(10000) + .setSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .build())) + .doesNotFire(); + } + + @Test + public void testDoesNotFireWithNoStats() + { + tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1))) + .withStatsCalculator(new UnknownStatsCalculator()) + .on(p -> + p.join( + INNER, + p.tableScan(ImmutableList.of(p.symbol("A1", BIGINT)), ImmutableMap.of(p.symbol("A1", BIGINT), new TestingColumnHandle("A1"))), + p.values(new PlanNodeId("valuesB"), p.symbol("B1", BIGINT)), + ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), + ImmutableList.of(p.symbol("A1", BIGINT)), + Optional.empty())) + .doesNotFire(); + } + + @Test + public void testDoesNotFireForNonDeterministicFilter() + { + tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1))) + .on(p -> + p.join( + INNER, + p.values(new PlanNodeId("valuesA"), p.symbol("A1", BIGINT)), + p.values(new PlanNodeId("valuesB"), p.symbol("B1", BIGINT)), + ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), + ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), + Optional.of(new ComparisonExpression(ComparisonExpressionType.LESS_THAN, p.symbol("A1", BIGINT).toSymbolReference(), new FunctionCall(QualifiedName.of("random"), ImmutableList.of()))))) + .doesNotFire(); + } + + @Test + public void testPredicatesPushedDown() + { + tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1))) + .on(p -> + p.join( + INNER, + p.join( + INNER, + p.values(new PlanNodeId("valuesA"), p.symbol("A1", BIGINT)), + p.values(new PlanNodeId("valuesB"), p.symbol("B1", BIGINT), p.symbol("B2", BIGINT)), + ImmutableList.of(), + ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT), p.symbol("B2", BIGINT)), + Optional.empty()), + p.values(new PlanNodeId("valuesC"), p.symbol("C1", BIGINT)), + ImmutableList.of( + new JoinNode.EquiJoinClause(p.symbol("B2", BIGINT), p.symbol("C1", BIGINT))), + ImmutableList.of(p.symbol("A1", BIGINT)), + Optional.of(new ComparisonExpression(EQUAL, p.symbol("A1", BIGINT).toSymbolReference(), p.symbol("B1", BIGINT).toSymbolReference())))) + .withStats(ImmutableMap.of( + new PlanNodeId("valuesA"), + PlanNodeStatsEstimate.builder() + .setOutputRowCount(10) + .setSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 100, 10))) + .build(), + new PlanNodeId("valuesB"), + PlanNodeStatsEstimate.builder() + .setOutputRowCount(5) + .setSymbolStatistics(ImmutableMap.of( + new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 100, 10), + new Symbol("B2"), new SymbolStatsEstimate(0, 100, 0, 100, 10))) + .build(), + new PlanNodeId("valuesC"), + PlanNodeStatsEstimate.builder() + .setOutputRowCount(1000) + .setSymbolStatistics(ImmutableMap.of(new Symbol("C1"), new SymbolStatsEstimate(0, 100, 0, 100, 100))) + .build())) + .matches( + join( + INNER, + ImmutableList.of(equiJoinClause("C1", "B2")), + values("C1"), + join( + INNER, + ImmutableList.of(equiJoinClause("A1", "B1")), + values("A1"), + values("B1", "B2")) + ) + ); + } + + private static class UnknownStatsCalculator + implements StatsCalculator + { + @Override + public PlanNodeStatsEstimate calculateStats( + PlanNode planNode, + Lookup lookup, + Session session, + Map types) + { + PlanNodeStatsEstimate.Builder statsBuilder = PlanNodeStatsEstimate.buildFrom(UNKNOWN_STATS); + planNode.getOutputSymbols() + .forEach(symbol -> statsBuilder.addSymbolStatistics(symbol, SymbolStatsEstimate.UNKNOWN_STATS)); + return statsBuilder.build(); + } + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java index 64ed14842aab..e4dafc7e55bf 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java @@ -42,6 +42,7 @@ import com.facebook.presto.sql.planner.plan.MarkDistinctNode; import com.facebook.presto.sql.planner.plan.OutputNode; import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.sql.planner.plan.ProjectNode; import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.SemiJoinNode; @@ -90,9 +91,14 @@ public PlanBuilder(PlanNodeIdAllocator idAllocator, Metadata metadata) } public ValuesNode values(Symbol... columns) + { + return values(idAllocator.getNextId(), columns); + } + + public ValuesNode values(PlanNodeId id, Symbol... columns) { return new ValuesNode( - idAllocator.getNextId(), + id, ImmutableList.copyOf(columns), ImmutableList.of()); } @@ -320,6 +326,11 @@ public ExchangeNode exchange(Consumer exchangeBuilderConsumer) return exchangeBuilder.build(); } + public JoinNode join(JoinNode.Type type, PlanNode left, PlanNode right, List criteria, List outputSymbols, Optional filter) + { + return new JoinNode(idAllocator.getNextId(), type, left, right, criteria, outputSymbols, filter, Optional.empty(), Optional.empty(), Optional.empty()); + } + public class ExchangeBuilder { private ExchangeNode.Type type = ExchangeNode.Type.GATHER; From 844fe2d2504058460aa5a40e65ff1b7a3bdbe37a Mon Sep 17 00:00:00 2001 From: Rebecca Schlussel Date: Thu, 29 Jun 2017 14:03:47 -0400 Subject: [PATCH 10/15] Add benchmark for ReorderJoins rule MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Results run on my development vm BenchmarkReorderJoinsConnectedGraph: BenchmarkReorderJoinsConnectedGraph.benchmarkReorderJoins ELIMINATE_CROSS_JOINS 2 avgt 30 54.610 ± 4.236 ms/op BenchmarkReorderJoinsConnectedGraph.benchmarkReorderJoins ELIMINATE_CROSS_JOINS 4 avgt 30 153.794 ± 9.075 ms/op BenchmarkReorderJoinsConnectedGraph.benchmarkReorderJoins ELIMINATE_CROSS_JOINS 6 avgt 30 326.410 ± 19.912 ms/op BenchmarkReorderJoinsConnectedGraph.benchmarkReorderJoins ELIMINATE_CROSS_JOINS 8 avgt 30 578.028 ± 33.308 ms/op BenchmarkReorderJoinsConnectedGraph.benchmarkReorderJoins ELIMINATE_CROSS_JOINS 10 avgt 30 955.494 ± 44.523 ms/op BenchmarkReorderJoinsConnectedGraph.benchmarkReorderJoins COST_BASED 2 avgt 30 54.844 ± 4.256 ms/op BenchmarkReorderJoinsConnectedGraph.benchmarkReorderJoins COST_BASED 4 avgt 30 161.164 ± 11.008 ms/op BenchmarkReorderJoinsConnectedGraph.benchmarkReorderJoins COST_BASED 6 avgt 30 440.007 ± 28.903 ms/op BenchmarkReorderJoinsConnectedGraph.benchmarkReorderJoins COST_BASED 8 avgt 30 2491.240 ± 72.341 ms/op BenchmarkReorderJoinsConnectedGraph.benchmarkReorderJoins COST_BASED 10 avgt 30 24026.603 ± 886.696 ms/opa BencharkReorderJoinsLinearGraph: BenchmarkReorderJoinsLinearQuery.benchmarkReorderJoins ELIMINATE_CROSS_JOINS avgt 30 944.179 ± 42.406 ms/op BenchmarkReorderJoinsLinearQuery.benchmarkReorderJoins COST_BASED avgt 30 1329.194 ± 71.704 ms/op --- .../BenchmarkReorderJoinsConnectedGraph.java | 118 ++++++++++++++++++ .../BenchmarkReorderJoinsLinearGraph.java | 108 ++++++++++++++++ 2 files changed, 226 insertions(+) create mode 100644 presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/BenchmarkReorderJoinsConnectedGraph.java create mode 100644 presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/BenchmarkReorderJoinsLinearGraph.java diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/BenchmarkReorderJoinsConnectedGraph.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/BenchmarkReorderJoinsConnectedGraph.java new file mode 100644 index 000000000000..d6fb0ea483d0 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/BenchmarkReorderJoinsConnectedGraph.java @@ -0,0 +1,118 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.testing.LocalQueryRunner; +import com.facebook.presto.testing.MaterializedResult; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tpch.TpchConnectorFactory; +import com.google.common.collect.ImmutableMap; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; +import org.openjdk.jmh.runner.options.VerboseMode; + +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static com.google.common.base.Preconditions.checkState; +import static java.lang.String.format; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.openjdk.jmh.annotations.Mode.AverageTime; +import static org.openjdk.jmh.annotations.Scope.Thread; + +@State(Thread) +@OutputTimeUnit(MILLISECONDS) +@BenchmarkMode(AverageTime) +@Fork(3) +@Warmup(iterations = 10) +@Measurement(iterations = 10) +public class BenchmarkReorderJoinsConnectedGraph +{ + @Benchmark + public MaterializedResult benchmarkReorderJoins(BenchmarkInfo benchmarkInfo) + { + return benchmarkInfo.getQueryRunner().execute(benchmarkInfo.getQuery()); + } + + @State(Thread) + public static class BenchmarkInfo + { + @Param({"ELIMINATE_CROSS_JOINS", "COST_BASED"}) + private String joinReorderingStrategy; + + @Param({"2", "4", "6", "8", "10"}) + private int numberOfTables; + + private String query; + private LocalQueryRunner queryRunner; + + @Setup + public void setup() + { + checkState(numberOfTables >= 2, "numberOfTables must be >= 2"); + Session session = testSessionBuilder() + .setSystemProperty("join_reordering_strategy", joinReorderingStrategy) + .setSystemProperty("join_distribution_type", "AUTOMATIC") + .setCatalog("tpch") + .setSchema("tiny") + .build(); + queryRunner = new LocalQueryRunner(session); + queryRunner.createCatalog("tpch", new TpchConnectorFactory(1), ImmutableMap.of()); + StringBuilder stringBuilder = new StringBuilder(); + stringBuilder.append("EXPLAIN SELECT * FROM NATION n1"); + for (int i = 2; i <= numberOfTables; i++) { + stringBuilder.append(format(" JOIN nation n%s on n%s.nationkey = n%s.nationkey", i, i - 1, i)); + } + query = stringBuilder.toString(); + } + + public String getQuery() + { + return query; + } + + public QueryRunner getQueryRunner() + { + return queryRunner; + } + + @TearDown + public void tearDown() + { + queryRunner.close(); + } + } + + public static void main(String[] args) + throws RunnerException + { + Options options = new OptionsBuilder() + .verbosity(VerboseMode.NORMAL) + .include(".*" + BenchmarkReorderJoinsConnectedGraph.class.getSimpleName() + ".*") + .build(); + + new Runner(options).run(); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/BenchmarkReorderJoinsLinearGraph.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/BenchmarkReorderJoinsLinearGraph.java new file mode 100644 index 000000000000..13c1644c1785 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/BenchmarkReorderJoinsLinearGraph.java @@ -0,0 +1,108 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.testing.LocalQueryRunner; +import com.facebook.presto.testing.MaterializedResult; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tpch.TpchConnectorFactory; +import com.google.common.collect.ImmutableMap; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; +import org.openjdk.jmh.runner.options.VerboseMode; + +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.openjdk.jmh.annotations.Mode.AverageTime; +import static org.openjdk.jmh.annotations.Scope.Thread; + +@State(Thread) +@OutputTimeUnit(MILLISECONDS) +@BenchmarkMode(AverageTime) +@Fork(3) +@Warmup(iterations = 10) +@Measurement(iterations = 10) +public class BenchmarkReorderJoinsLinearGraph +{ + @Benchmark + public MaterializedResult benchmarkReorderJoins(BenchmarkInfo benchmarkInfo) + { + return benchmarkInfo.getQueryRunner().execute( + "EXPLAIN SELECT * FROM " + + "nation n1 JOIN nation n2 ON n1.nationkey = n2.nationkey " + + "JOIN nation n3 on n2.comment = n3.comment " + + "JOIN nation n4 on n3.name = n4.name " + + "JOIN region r1 on n4.regionkey = r1.regionkey " + + "JOIN region r2 on r2.name = r2.name " + + "JOIN region r3 on r3.comment = r2.comment " + + "join region r4 on r4.regionkey = r3.regionkey"); + } + + @State(Thread) + public static class BenchmarkInfo + { + @Param({"ELIMINATE_CROSS_JOINS", "COST_BASED"}) + private String joinReorderingStrategy; + + private LocalQueryRunner queryRunner; + + @Setup + public void setup() + { + Session session = testSessionBuilder() + .setSystemProperty("join_reordering_strategy", joinReorderingStrategy) + .setSystemProperty("join_distribution_type", "AUTOMATIC") + .setCatalog("tpch") + .setSchema("tiny") + .build(); + queryRunner = new LocalQueryRunner(session); + queryRunner.createCatalog("tpch", new TpchConnectorFactory(1), ImmutableMap.of()); + } + + public QueryRunner getQueryRunner() + { + return queryRunner; + } + + @TearDown + public void tearDown() + { + queryRunner.close(); + } + } + + public static void main(String[] args) + throws RunnerException + { + Options options = new OptionsBuilder() + .verbosity(VerboseMode.NORMAL) + .include(".*" + BenchmarkReorderJoinsLinearGraph.class.getSimpleName() + ".*") + .build(); + + new Runner(options).run(); + } +} From 79d2efad7f373867b9f87f1dfdfb0ff071e5327e Mon Sep 17 00:00:00 2001 From: Karol Sobczak Date: Tue, 4 Jul 2017 14:15:53 +0200 Subject: [PATCH 11/15] Add pcollections library dependency --- pom.xml | 6 ++++++ presto-main/pom.xml | 5 +++++ 2 files changed, 11 insertions(+) diff --git a/pom.xml b/pom.xml index e733d6a8c3d6..5d09e062b9a5 100644 --- a/pom.xml +++ b/pom.xml @@ -588,6 +588,12 @@ 42.0.0 + + org.pcollections + pcollections + 2.1.2 + + org.antlr antlr4-runtime diff --git a/presto-main/pom.xml b/presto-main/pom.xml index dd535dce447c..0aedfbf788d7 100644 --- a/presto-main/pom.xml +++ b/presto-main/pom.xml @@ -262,6 +262,11 @@ jgrapht-core + + org.pcollections + pcollections + + org.apache.bval From 6d4246e3ce8b1417045ed32dd86bde9e5e37b238 Mon Sep 17 00:00:00 2001 From: Karol Sobczak Date: Tue, 4 Jul 2017 14:24:01 +0200 Subject: [PATCH 12/15] Use HashTreePMap in PlanNodeStatsEstimate to reduce map copying Previously there was a lot of map copying (through ImmutableMap.copyOf(...) and new HashMap(...)) which was significantly impacting stats code performance. HashTreePMap is much better for cases where individual entries of base map are modified which is common case in stats code. --- .../presto/cost/EnsureStatsMatchOutput.java | 17 +++--- .../presto/cost/PlanNodeStatsEstimate.java | 60 ++++++++++--------- .../presto/cost/TableScanStatsRule.java | 2 +- .../iterative/rule/TestReorderJoins.java | 30 +++++----- 4 files changed, 57 insertions(+), 52 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/cost/EnsureStatsMatchOutput.java b/presto-main/src/main/java/com/facebook/presto/cost/EnsureStatsMatchOutput.java index 94bc8e96a571..c8d4b20d88ff 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/EnsureStatsMatchOutput.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/EnsureStatsMatchOutput.java @@ -18,9 +18,10 @@ import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.plan.PlanNode; -import java.util.HashMap; import java.util.Map; +import static com.facebook.presto.cost.PlanNodeStatsEstimate.buildFrom; +import static com.facebook.presto.cost.SymbolStatsEstimate.UNKNOWN_STATS; import static com.google.common.base.Predicates.not; public class EnsureStatsMatchOutput @@ -29,16 +30,16 @@ public class EnsureStatsMatchOutput @Override public PlanNodeStatsEstimate normalize(PlanNode node, PlanNodeStatsEstimate estimate, Map types) { - Map symbolSymbolStats = new HashMap<>(); - estimate.getSymbolsWithKnownStatistics().stream() - .filter(node.getOutputSymbols()::contains) - .forEach(symbol -> symbolSymbolStats.put(symbol, estimate.getSymbolStatistics(symbol))); + PlanNodeStatsEstimate.Builder builder = buildFrom(estimate); node.getOutputSymbols().stream() .filter(not(estimate.getSymbolsWithKnownStatistics()::contains)) - .filter(not(symbolSymbolStats::containsKey)) - .forEach(symbol -> symbolSymbolStats.put(symbol, SymbolStatsEstimate.UNKNOWN_STATS)); + .forEach(symbol -> builder.addSymbolStatistics(symbol, UNKNOWN_STATS)); + + estimate.getSymbolsWithKnownStatistics().stream() + .filter(not(node.getOutputSymbols()::contains)) + .forEach(builder::removeSymbolStatistics); - return PlanNodeStatsEstimate.buildFrom(estimate).setSymbolStatistics(symbolSymbolStats).build(); + return builder.build(); } } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimate.java b/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimate.java index 3d59f90dab49..cf081a6926c5 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimate.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/PlanNodeStatsEstimate.java @@ -14,14 +14,13 @@ package com.facebook.presto.cost; import com.facebook.presto.sql.planner.Symbol; -import com.google.common.collect.ImmutableMap; +import org.pcollections.HashTreePMap; +import org.pcollections.PMap; -import java.util.HashMap; import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.function.Function; -import java.util.stream.Collectors; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; @@ -34,13 +33,13 @@ public class PlanNodeStatsEstimate public static final double DEFAULT_DATA_SIZE_PER_COLUMN = 10; private final double outputRowCount; - private final Map symbolStatistics; + private final PMap symbolStatistics; - private PlanNodeStatsEstimate(double outputRowCount, Map symbolStatistics) + private PlanNodeStatsEstimate(double outputRowCount, PMap symbolStatistics) { checkArgument(isNaN(outputRowCount) || outputRowCount >= 0, "outputRowCount cannot be negative"); this.outputRowCount = outputRowCount; - this.symbolStatistics = ImmutableMap.copyOf(symbolStatistics); + this.symbolStatistics = symbolStatistics; } /** @@ -83,26 +82,15 @@ public PlanNodeStatsEstimate mapOutputRowCount(Function mappingF public PlanNodeStatsEstimate mapSymbolColumnStatistics(Symbol symbol, Function mappingFunction) { return buildFrom(this) - .setSymbolStatistics(symbolStatistics.entrySet().stream() - .collect(Collectors.toMap( - Map.Entry::getKey, - e -> { - if (e.getKey().equals(symbol)) { - return mappingFunction.apply(e.getValue()); - } - return e.getValue(); - }))) + .addSymbolStatistics(symbol, mappingFunction.apply(symbolStatistics.get(symbol))) .build(); } public PlanNodeStatsEstimate add(PlanNodeStatsEstimate other) { // TODO this is broken (it does not operate on symbol stats at all). Remove or fix - ImmutableMap.Builder symbolsStatsBuilder = ImmutableMap.builder(); - symbolsStatsBuilder.putAll(symbolStatistics).putAll(other.symbolStatistics); // This may not count all information - - PlanNodeStatsEstimate.Builder statsBuilder = PlanNodeStatsEstimate.builder(); - return statsBuilder.setSymbolStatistics(symbolsStatsBuilder.build()) + return buildFrom(this) + .addSymbolStatistics(other.symbolStatistics) .setOutputRowCount(getOutputRowCount() + other.getOutputRowCount()) .build(); } @@ -153,14 +141,24 @@ public static Builder builder() public static Builder buildFrom(PlanNodeStatsEstimate other) { - return builder().setOutputRowCount(other.getOutputRowCount()) - .setSymbolStatistics(other.symbolStatistics); + return new Builder(other.getOutputRowCount(), other.symbolStatistics); } public static final class Builder { - private double outputRowCount = NaN; - private Map symbolStatistics = new HashMap<>(); + private double outputRowCount; + private PMap symbolStatistics; + + public Builder() + { + this(NaN, HashTreePMap.empty()); + } + + private Builder(double outputRowCount, PMap symbolStatistics) + { + this.outputRowCount = outputRowCount; + this.symbolStatistics = symbolStatistics; + } public Builder setOutputRowCount(double outputRowCount) { @@ -168,15 +166,21 @@ public Builder setOutputRowCount(double outputRowCount) return this; } - public Builder setSymbolStatistics(Map symbolStatistics) + public Builder addSymbolStatistics(Symbol symbol, SymbolStatsEstimate statistics) { - this.symbolStatistics = new HashMap<>(symbolStatistics); + symbolStatistics = symbolStatistics.plus(symbol, statistics); return this; } - public Builder addSymbolStatistics(Symbol symbol, SymbolStatsEstimate statistics) + public Builder addSymbolStatistics(Map symbolStatistics) + { + this.symbolStatistics = this.symbolStatistics.plusAll(symbolStatistics); + return this; + } + + public Builder removeSymbolStatistics(Symbol symbol) { - this.symbolStatistics.put(symbol, statistics); + symbolStatistics = symbolStatistics.minus(symbol); return this; } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/TableScanStatsRule.java b/presto-main/src/main/java/com/facebook/presto/cost/TableScanStatsRule.java index 234bc6729bb2..713e9c757a30 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/TableScanStatsRule.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/TableScanStatsRule.java @@ -77,7 +77,7 @@ public Optional calculate(PlanNode node, Lookup lookup, S return Optional.of(PlanNodeStatsEstimate.builder() .setOutputRowCount(tableStatistics.getRowCount().getValue()) - .setSymbolStatistics(outputSymbolStats) + .addSymbolStatistics(outputSymbolStats) .build()); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReorderJoins.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReorderJoins.java index 86081190b334..9118bbc32da4 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReorderJoins.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReorderJoins.java @@ -92,13 +92,13 @@ public void testKeepsOutputSymbols() .withStats(ImmutableMap.of( new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder() .setOutputRowCount(5000) - .setSymbolStatistics(ImmutableMap.of( + .addSymbolStatistics(ImmutableMap.of( new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 100, 100), new Symbol("A2"), new SymbolStatsEstimate(0, 100, 0, 100, 100))) .build(), new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder() .setOutputRowCount(10000) - .setSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 100, 100))) + .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 100, 100))) .build())) .matches(join( INNER, @@ -125,11 +125,11 @@ public void testReplicatesAndFlipsWhenOneTableMuchSmaller() .withStats(ImmutableMap.of( new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder() .setOutputRowCount(100) - .setSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 6400, 100))) + .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 6400, 100))) .build(), new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder() .setOutputRowCount(10000) - .setSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) .build())) .matches(join( INNER, @@ -157,11 +157,11 @@ public void testRepartitionsWhenRequiredBySession() .withStats(ImmutableMap.of( new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder() .setOutputRowCount(100) - .setSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 6400, 100))) + .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 6400, 100))) .build(), new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder() .setOutputRowCount(10000) - .setSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) .build())) .matches(join( INNER, @@ -188,11 +188,11 @@ public void testRepartitionsWhenBothTablesEqual() .withStats(ImmutableMap.of( new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder() .setOutputRowCount(10000) - .setSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) .build(), new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder() .setOutputRowCount(10000) - .setSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) .build())) .matches(join( INNER, @@ -220,11 +220,11 @@ public void testReplicatesWhenRequiredBySession() .withStats(ImmutableMap.of( new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder() .setOutputRowCount(10000) - .setSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) .build(), new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder() .setOutputRowCount(10000) - .setSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) .build())) .matches(join( INNER, @@ -251,11 +251,11 @@ public void testDoesNotFireForCrossJoin() .withStats(ImmutableMap.of( new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder() .setOutputRowCount(10000) - .setSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) .build(), new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder() .setOutputRowCount(10000) - .setSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) .build())) .doesNotFire(); } @@ -314,19 +314,19 @@ public void testPredicatesPushedDown() new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder() .setOutputRowCount(10) - .setSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 100, 10))) + .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 100, 10))) .build(), new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder() .setOutputRowCount(5) - .setSymbolStatistics(ImmutableMap.of( + .addSymbolStatistics(ImmutableMap.of( new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 100, 10), new Symbol("B2"), new SymbolStatsEstimate(0, 100, 0, 100, 10))) .build(), new PlanNodeId("valuesC"), PlanNodeStatsEstimate.builder() .setOutputRowCount(1000) - .setSymbolStatistics(ImmutableMap.of(new Symbol("C1"), new SymbolStatsEstimate(0, 100, 0, 100, 100))) + .addSymbolStatistics(ImmutableMap.of(new Symbol("C1"), new SymbolStatsEstimate(0, 100, 0, 100, 100))) .build())) .matches( join( From 09e09241b74fb284de0c5b61f76d23c65b2bec1d Mon Sep 17 00:00:00 2001 From: Karol Sobczak Date: Tue, 4 Jul 2017 14:57:57 +0200 Subject: [PATCH 13/15] Remove @ThreadSafe annotation from CostCalculator interface Not all CostCalculators are thread safe. --- .../src/main/java/com/facebook/presto/cost/CostCalculator.java | 3 --- 1 file changed, 3 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CostCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/CostCalculator.java index abccad58b89f..784996cf30bf 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/CostCalculator.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/CostCalculator.java @@ -21,8 +21,6 @@ import com.facebook.presto.sql.planner.plan.PlanNode; import com.google.inject.BindingAnnotation; -import javax.annotation.concurrent.ThreadSafe; - import java.lang.annotation.Retention; import java.lang.annotation.Target; import java.util.Map; @@ -36,7 +34,6 @@ * Computes estimated cost of executing given PlanNode. * Implementation may use lookup to compute needed traits for self/source nodes. */ -@ThreadSafe public interface CostCalculator { PlanNodeCostEstimate calculateCost( From bba57cac0e4a3f76931a38b359f47a9193253c20 Mon Sep 17 00:00:00 2001 From: Karol Sobczak Date: Tue, 4 Jul 2017 14:55:48 +0200 Subject: [PATCH 14/15] Introduce caching cost and stats calculator --- .../presto/cost/CachingCostCalculator.java | 51 +++++++++++ .../presto/cost/CachingStatsCalculator.java | 50 +++++++++++ .../planner/iterative/IterativeOptimizer.java | 4 +- .../planner/iterative/MemoBasedLookup.java | 85 ------------------- .../iterative/rule/TestMemoBasedLookup.java | 5 +- 5 files changed, 106 insertions(+), 89 deletions(-) create mode 100644 presto-main/src/main/java/com/facebook/presto/cost/CachingCostCalculator.java create mode 100644 presto-main/src/main/java/com/facebook/presto/cost/CachingStatsCalculator.java delete mode 100644 presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/MemoBasedLookup.java diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CachingCostCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/CachingCostCalculator.java new file mode 100644 index 000000000000..4d57c0cc7eea --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/CachingCostCalculator.java @@ -0,0 +1,51 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.PlanNode; + +import java.util.HashMap; +import java.util.Map; + +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +public class CachingCostCalculator + implements CostCalculator +{ + private final CostCalculator costCalculator; + private final Map costs = new HashMap<>(); + + public CachingCostCalculator(CostCalculator costCalculator) + { + this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); + } + + @Override + public PlanNodeCostEstimate calculateCost(PlanNode planNode, Lookup lookup, Session session, Map types) + { + if (!costs.containsKey(planNode)) { + // cannot use Map.computeIfAbsent due to costs map modification in the mappingFunction callback + PlanNodeCostEstimate cost = costCalculator.calculateCumulativeCost(planNode, lookup, session, types); + requireNonNull(costs, "computed cost can not be null"); + checkState(costs.put(planNode, cost) == null, "cost for " + planNode + " already computed"); + } + return costs.get(planNode); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/CachingStatsCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/CachingStatsCalculator.java new file mode 100644 index 000000000000..80d5947c06fa --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/CachingStatsCalculator.java @@ -0,0 +1,50 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.PlanNode; + +import java.util.HashMap; +import java.util.Map; + +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +public class CachingStatsCalculator + implements StatsCalculator +{ + private final StatsCalculator statsCalculator; + private final Map stats = new HashMap<>(); + + public CachingStatsCalculator(StatsCalculator statsCalculator) + { + this.statsCalculator = statsCalculator; + } + + @Override + public PlanNodeStatsEstimate calculateStats(PlanNode planNode, Lookup lookup, Session session, Map types) + { + if (!stats.containsKey(planNode)) { + // cannot use Map.computeIfAbsent due to stats map modification in the mappingFunction callback + PlanNodeStatsEstimate statsEstimate = statsCalculator.calculateStats(planNode, lookup, session, types); + requireNonNull(stats, "computed stats can not be null"); + checkState(stats.put(planNode, statsEstimate) == null, "statistics for " + planNode + " already computed"); + } + return stats.get(planNode); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/IterativeOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/IterativeOptimizer.java index 2252155b6697..d76bea4dd65b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/IterativeOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/IterativeOptimizer.java @@ -15,6 +15,8 @@ import com.facebook.presto.Session; import com.facebook.presto.SystemSessionProperties; +import com.facebook.presto.cost.CachingCostCalculator; +import com.facebook.presto.cost.CachingStatsCalculator; import com.facebook.presto.cost.CostCalculator; import com.facebook.presto.cost.StatsCalculator; import com.facebook.presto.matching.MatchingEngine; @@ -81,7 +83,7 @@ public PlanNode optimize(PlanNode plan, Session session, Map types } Memo memo = new Memo(idAllocator, plan); - Lookup lookup = new MemoBasedLookup(memo, statsCalculator, costCalculator); + Lookup lookup = Lookup.from(memo::resolve, new CachingStatsCalculator(statsCalculator), new CachingCostCalculator(costCalculator)); Duration timeout = SystemSessionProperties.getOptimizerTimeout(session); exploreGroup(memo.getRootGroup(), new Context(memo, lookup, idAllocator, symbolAllocator, System.nanoTime(), timeout.toMillis(), session)); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/MemoBasedLookup.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/MemoBasedLookup.java deleted file mode 100644 index b718fdfc5ded..000000000000 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/MemoBasedLookup.java +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.facebook.presto.sql.planner.iterative; - -import com.facebook.presto.Session; -import com.facebook.presto.cost.CostCalculator; -import com.facebook.presto.cost.PlanNodeCostEstimate; -import com.facebook.presto.cost.PlanNodeStatsEstimate; -import com.facebook.presto.cost.StatsCalculator; -import com.facebook.presto.spi.type.Type; -import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.plan.PlanNode; - -import java.util.HashMap; -import java.util.Map; - -import static com.google.common.base.Preconditions.checkState; -import static java.util.Objects.requireNonNull; - -public class MemoBasedLookup - implements Lookup -{ - private final Memo memo; - private final Map stats = new HashMap<>(); - private final Map costs = new HashMap<>(); - private final StatsCalculator statsCalculator; - private final CostCalculator costCalculator; - - public MemoBasedLookup(Memo memo, StatsCalculator statsCalculator, CostCalculator costCalculator) - { - this.memo = requireNonNull(memo, "memo can not be null"); - this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null"); - this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); - } - - @Override - public PlanNode resolve(PlanNode node) - { - if (node instanceof GroupReference) { - return memo.getNode(((GroupReference) node).getGroupId()); - } - return node; - } - - // todo[LO] maybe lookup passed to stats/cost calculator should be constrained so only - // methods for obtaining traits and only for self and sources would be allowed? - - @Override - public PlanNodeStatsEstimate getStats(PlanNode planNode, Session session, Map types) - { - PlanNode key = resolve(planNode); - if (!stats.containsKey(key)) { - // cannot use Map.computeIfAbsent due to stats map modification in the mappingFunction callback - PlanNodeStatsEstimate statsEstimate = statsCalculator.calculateStats(key, this, session, types); - requireNonNull(stats, "computed stats can not be null"); - checkState(stats.put(key, statsEstimate) == null, "statistics for " + key + " already computed"); - } - return stats.get(key); - } - - @Override - public PlanNodeCostEstimate getCumulativeCost(PlanNode planNode, Session session, Map types) - { - PlanNode key = resolve(planNode); - if (!costs.containsKey(key)) { - // cannot use Map.computeIfAbsent due to costs map modification in the mappingFunction callback - PlanNodeCostEstimate cost = costCalculator.calculateCumulativeCost(key, this, session, types); - requireNonNull(costs, "computed cost can not be null"); - checkState(costs.put(key, cost) == null, "cost for " + key + " already computed"); - } - return costs.get(key); - } -} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMemoBasedLookup.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMemoBasedLookup.java index 7bf5a19317a4..0aed522491b1 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMemoBasedLookup.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestMemoBasedLookup.java @@ -24,7 +24,6 @@ import com.facebook.presto.sql.planner.iterative.GroupReference; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.iterative.Memo; -import com.facebook.presto.sql.planner.iterative.MemoBasedLookup; import com.facebook.presto.sql.planner.plan.PlanNode; import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.testing.LocalQueryRunner; @@ -60,7 +59,7 @@ public void testResolvesGroupReferenceNode() PlanNode plan = node(source); Memo memo = new Memo(idAllocator, plan); - MemoBasedLookup lookup = new MemoBasedLookup(memo, new NodeCountingStatsCalculator(), new CostCalculatorUsingExchanges(1)); + Lookup lookup = Lookup.from(memo::resolve, new NodeCountingStatsCalculator(), new CostCalculatorUsingExchanges(1)); PlanNode memoSource = Iterables.getOnlyElement(memo.getNode(memo.getRootGroup()).getSources()); checkState(memoSource instanceof GroupReference, "expected GroupReference"); assertEquals(lookup.resolve(memoSource), source); @@ -71,7 +70,7 @@ public void testComputesStatsAndResolvesNodes() { PlanNode plan = node(node(node())); Memo memo = new Memo(idAllocator, plan); - MemoBasedLookup lookup = new MemoBasedLookup(memo, new NodeCountingStatsCalculator(), new CostCalculatorUsingExchanges(1)); + Lookup lookup = Lookup.from(memo::resolve, new NodeCountingStatsCalculator(), new CostCalculatorUsingExchanges(1)); PlanNodeStatsEstimate actualStats = lookup.getStats(memo.getNode(memo.getRootGroup()), queryRunner.getDefaultSession(), ImmutableMap.of()); PlanNodeStatsEstimate expectedStats = PlanNodeStatsEstimate.builder().setOutputRowCount(3).build(); From 0eb6a5f541a178e6b78027323d447dc46d7e082a Mon Sep 17 00:00:00 2001 From: Karol Sobczak Date: Wed, 5 Jul 2017 14:12:05 +0200 Subject: [PATCH 15/15] Use stats calculator that is join aware in ReorderJoins Previously JoinEnumerator#setJoinNodeProperties was recomputing join stats for each alternative of join even though actual stats didn't change. Now join stats are memoized by node id. This reduced enumeration time by a factor of 2. --- .../cost/JoinNodeCachingStatsCalculator.java | 57 ++++++ .../presto/sql/planner/PlanOptimizers.java | 2 +- .../presto/sql/planner/iterative/Lookup.java | 2 +- .../planner/iterative/rule/ReorderJoins.java | 14 +- .../presto/testing/TestingLookup.java | 95 --------- .../testing/TestingStatsCalculator.java | 51 +++++ .../iterative/rule/TestReorderJoins.java | 189 ++++++++++-------- .../iterative/rule/test/RuleAssert.java | 31 +-- 8 files changed, 229 insertions(+), 212 deletions(-) create mode 100644 presto-main/src/main/java/com/facebook/presto/cost/JoinNodeCachingStatsCalculator.java delete mode 100644 presto-main/src/main/java/com/facebook/presto/testing/TestingLookup.java create mode 100644 presto-main/src/main/java/com/facebook/presto/testing/TestingStatsCalculator.java diff --git a/presto-main/src/main/java/com/facebook/presto/cost/JoinNodeCachingStatsCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/JoinNodeCachingStatsCalculator.java new file mode 100644 index 000000000000..8a95329713d9 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/JoinNodeCachingStatsCalculator.java @@ -0,0 +1,57 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.cost; + +import com.facebook.presto.Session; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.PlanNodeId; + +import java.util.HashMap; +import java.util.Map; + +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +public class JoinNodeCachingStatsCalculator + implements StatsCalculator +{ + private final StatsCalculator statsCalculator; + private final Map stats = new HashMap<>(); + + public JoinNodeCachingStatsCalculator(StatsCalculator statsCalculator) + { + this.statsCalculator = statsCalculator; + } + + @Override + public PlanNodeStatsEstimate calculateStats(PlanNode planNode, Lookup lookup, Session session, Map types) + { + if (!(planNode instanceof JoinNode)) { + return statsCalculator.calculateStats(planNode, lookup, session, types); + } + + PlanNodeId key = planNode.getId(); + if (!stats.containsKey(key)) { + // cannot use Map.computeIfAbsent due to stats map modification in the mappingFunction callback + PlanNodeStatsEstimate statsEstimate = statsCalculator.calculateStats(planNode, lookup, session, types); + requireNonNull(stats, "computed stats can not be null"); + checkState(stats.put(key, statsEstimate) == null, "statistics for " + planNode + " already computed"); + } + return stats.get(key); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index f279f9699829..c5d07fec5a08 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -347,7 +347,7 @@ public PlanOptimizers( stats, statsCalculator, estimatedExchangesCostCalculator, - ImmutableSet.of(new ReorderJoins(costComparator)) + ImmutableSet.of(new ReorderJoins(costComparator, statsCalculator, costCalculator)) )); if (featuresConfig.isOptimizeSingleDistinct()) { diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Lookup.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Lookup.java index 65b2304952fc..d8e5146690f1 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Lookup.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/Lookup.java @@ -102,7 +102,7 @@ public PlanNodeStatsEstimate getStats(PlanNode node, Session session, Map types) { - return costCalculator.calculateCumulativeCost(node, this, session, types); + return costCalculator.calculateCumulativeCost(resolve(node), this, session, types); } }; } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java index d2081822b3b5..f6dee2122696 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ReorderJoins.java @@ -15,8 +15,13 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.cost.CachingCostCalculator; +import com.facebook.presto.cost.CachingStatsCalculator; +import com.facebook.presto.cost.CostCalculator; import com.facebook.presto.cost.CostComparator; +import com.facebook.presto.cost.JoinNodeCachingStatsCalculator; import com.facebook.presto.cost.PlanNodeCostEstimate; +import com.facebook.presto.cost.StatsCalculator; import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.planner.EqualityInference; import com.facebook.presto.sql.planner.PlanNodeIdAllocator; @@ -81,10 +86,14 @@ public class ReorderJoins private static final Logger log = Logger.get(ReorderJoins.class); private final CostComparator costComparator; + private final StatsCalculator statsCalculator; + private final CostCalculator costCalculator; - public ReorderJoins(CostComparator costComparator) + public ReorderJoins(CostComparator costComparator, StatsCalculator statsCalculator, CostCalculator costCalculator) { this.costComparator = requireNonNull(costComparator, "costComparator is null"); + this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null"); + this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); } @Override @@ -106,7 +115,8 @@ public Optional apply(PlanNode node, Lookup lookup, PlanNodeIdAllocato return Optional.empty(); } - JoinEnumerationResult result = new JoinEnumerator(idAllocator, symbolAllocator, session, lookup, multiJoinNode.getFilter(), costComparator).chooseJoinOrder(multiJoinNode.getSources(), multiJoinNode.getOutputSymbols()); + Lookup joinCachingStatsLookup = Lookup.from(lookup::resolve, new JoinNodeCachingStatsCalculator(new CachingStatsCalculator(statsCalculator)), new CachingCostCalculator(costCalculator)); + JoinEnumerationResult result = new JoinEnumerator(idAllocator, symbolAllocator, session, joinCachingStatsLookup, multiJoinNode.getFilter(), costComparator).chooseJoinOrder(multiJoinNode.getSources(), multiJoinNode.getOutputSymbols()); return result.getCost().isUnknown() || result.getCost().equals(INFINITE_COST) ? Optional.empty() : result.getPlanNode(); } diff --git a/presto-main/src/main/java/com/facebook/presto/testing/TestingLookup.java b/presto-main/src/main/java/com/facebook/presto/testing/TestingLookup.java deleted file mode 100644 index 7099545aecbe..000000000000 --- a/presto-main/src/main/java/com/facebook/presto/testing/TestingLookup.java +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.facebook.presto.testing; - -import com.facebook.presto.Session; -import com.facebook.presto.cost.CostCalculator; -import com.facebook.presto.cost.PlanNodeCostEstimate; -import com.facebook.presto.cost.PlanNodeStatsEstimate; -import com.facebook.presto.cost.StatsCalculator; -import com.facebook.presto.spi.type.Type; -import com.facebook.presto.sql.planner.Symbol; -import com.facebook.presto.sql.planner.iterative.GroupReference; -import com.facebook.presto.sql.planner.iterative.Lookup; -import com.facebook.presto.sql.planner.plan.PlanNode; -import com.google.common.collect.ImmutableMap; - -import java.util.HashMap; -import java.util.Map; -import java.util.function.Function; - -import static java.util.Objects.requireNonNull; - -public class TestingLookup - implements Lookup -{ - private final StatsCalculator statsCalculator; - private final CostCalculator costCalculator; - private final Map stats = new HashMap<>(); - private final Map costs = new HashMap<>(); - private final Function resolver; - - public TestingLookup(StatsCalculator statsCalculator, CostCalculator costCalculator, Function resolver) - { - this(statsCalculator, costCalculator, ImmutableMap.of(), ImmutableMap.of(), resolver); - } - - private TestingLookup(StatsCalculator statsCalculator, CostCalculator costCalculator, Map stats, Map costs, Function resolver) - { - this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null"); - this.costCalculator = requireNonNull(costCalculator, "costCalculator is null"); - this.stats.putAll(stats); - this.costs.putAll(costs); - this.resolver = requireNonNull(resolver, "resolver is null"); - } - - public TestingLookup withStats(Map stats) - { - return new TestingLookup(statsCalculator, costCalculator, stats, ImmutableMap.of(), resolver); - } - - @Override - public PlanNode resolve(PlanNode node) - { - if (node instanceof GroupReference) { - return resolver.apply((GroupReference) node); - } - return node; - } - - @Override - public PlanNodeStatsEstimate getStats(PlanNode planNode, Session session, Map types) - { - PlanNode resolved = resolve(planNode); - PlanNodeStatsEstimate statsEstimate = stats.get(resolved); - if (statsEstimate == null) { - statsEstimate = statsCalculator.calculateStats(resolved, this, session, types); - stats.put(resolved, statsEstimate); - } - return statsEstimate; - } - - @Override - public PlanNodeCostEstimate getCumulativeCost(PlanNode planNode, Session session, Map types) - { - PlanNode resolved = resolve(planNode); - PlanNodeCostEstimate costEstimate = costs.get(resolved); - if (costEstimate == null) { - costEstimate = costCalculator.calculateCumulativeCost(resolved, this, session, types); - costs.put(resolved, costEstimate); - } - return costEstimate; - } -} diff --git a/presto-main/src/main/java/com/facebook/presto/testing/TestingStatsCalculator.java b/presto-main/src/main/java/com/facebook/presto/testing/TestingStatsCalculator.java new file mode 100644 index 000000000000..0bf9e06ac13d --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/testing/TestingStatsCalculator.java @@ -0,0 +1,51 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.testing; + +import com.facebook.presto.Session; +import com.facebook.presto.cost.PlanNodeStatsEstimate; +import com.facebook.presto.cost.StatsCalculator; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.planner.Symbol; +import com.facebook.presto.sql.planner.iterative.Lookup; +import com.facebook.presto.sql.planner.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.PlanNodeId; +import com.google.common.collect.ImmutableMap; + +import java.util.Map; + +import static java.util.Objects.requireNonNull; + +public class TestingStatsCalculator + implements StatsCalculator +{ + private final StatsCalculator statsCalculator; + private final Map stats; + + public TestingStatsCalculator(StatsCalculator statsCalculator, Map stats) + { + this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null"); + this.stats = ImmutableMap.copyOf(requireNonNull(stats, "stats is null")); + } + + @Override + public PlanNodeStatsEstimate calculateStats(PlanNode planNode, Lookup lookup, Session session, Map types) + { + if (stats.containsKey(planNode.getId())) { + return stats.get(planNode.getId()); + } + + return statsCalculator.calculateStats(planNode, lookup, session, types); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReorderJoins.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReorderJoins.java index 9118bbc32da4..c64c901c3184 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReorderJoins.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestReorderJoins.java @@ -14,6 +14,7 @@ package com.facebook.presto.sql.planner.iterative.rule; import com.facebook.presto.Session; +import com.facebook.presto.cost.CostCalculator; import com.facebook.presto.cost.CostComparator; import com.facebook.presto.cost.PlanNodeStatsEstimate; import com.facebook.presto.cost.StatsCalculator; @@ -31,6 +32,7 @@ import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.QualifiedName; import com.facebook.presto.testing.LocalQueryRunner; +import com.facebook.presto.testing.TestingStatsCalculator; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.AfterClass; @@ -56,6 +58,8 @@ public class TestReorderJoins { private RuleTester tester; + private StatsCalculator statsCalculator; + private CostCalculator costCalculator; @BeforeClass public void setUp() @@ -67,6 +71,8 @@ public void setUp() .setSystemProperty("join_reordering_strategy", "COST_BASED") .build(); LocalQueryRunner queryRunner = queryRunnerWithFakeNodeCountForStats(session, 4); + statsCalculator = queryRunner.getStatsCalculator(); + costCalculator = queryRunner.getEstimatedExchangesCostCalculator(); tester = new RuleTester(queryRunner); } @@ -80,7 +86,20 @@ public void tearDown() @Test public void testKeepsOutputSymbols() { - tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1))) + StatsCalculator testingStatsCalculator = new TestingStatsCalculator(statsCalculator, ImmutableMap.of( + new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(5000) + .addSymbolStatistics(ImmutableMap.of( + new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 100, 100), + new Symbol("A2"), new SymbolStatsEstimate(0, 100, 0, 100, 100))) + .build(), + new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(10000) + .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 100, 100))) + .build())); + + tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1), testingStatsCalculator, costCalculator)) + .withStatsCalculator(testingStatsCalculator) .on(p -> p.join( INNER, @@ -89,17 +108,6 @@ public void testKeepsOutputSymbols() ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), ImmutableList.of(p.symbol("A2", BIGINT)), Optional.empty())) - .withStats(ImmutableMap.of( - new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder() - .setOutputRowCount(5000) - .addSymbolStatistics(ImmutableMap.of( - new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 100, 100), - new Symbol("A2"), new SymbolStatsEstimate(0, 100, 0, 100, 100))) - .build(), - new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder() - .setOutputRowCount(10000) - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 100, 100))) - .build())) .matches(join( INNER, ImmutableList.of(equiJoinClause("A1", "B1")), @@ -113,7 +121,18 @@ public void testKeepsOutputSymbols() @Test public void testReplicatesAndFlipsWhenOneTableMuchSmaller() { - tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1))) + StatsCalculator testingStatsCalculator = new TestingStatsCalculator(statsCalculator, ImmutableMap.of( + new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(100) + .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 6400, 100))) + .build(), + new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(10000) + .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .build())); + + tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1), testingStatsCalculator, costCalculator)) + .withStatsCalculator(testingStatsCalculator) .on(p -> p.join( INNER, @@ -122,15 +141,6 @@ public void testReplicatesAndFlipsWhenOneTableMuchSmaller() ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), Optional.empty())) - .withStats(ImmutableMap.of( - new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder() - .setOutputRowCount(100) - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 6400, 100))) - .build(), - new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder() - .setOutputRowCount(10000) - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) - .build())) .matches(join( INNER, ImmutableList.of(equiJoinClause("B1", "A1")), @@ -144,7 +154,18 @@ public void testReplicatesAndFlipsWhenOneTableMuchSmaller() @Test public void testRepartitionsWhenRequiredBySession() { - tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1))) + StatsCalculator testingStatsCalculator = new TestingStatsCalculator(statsCalculator, ImmutableMap.of( + new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(100) + .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 6400, 100))) + .build(), + new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(10000) + .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .build())); + + tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1), testingStatsCalculator, costCalculator)) + .withStatsCalculator(testingStatsCalculator) .on(p -> p.join( INNER, @@ -154,15 +175,6 @@ public void testRepartitionsWhenRequiredBySession() ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), Optional.empty())) .setSystemProperty("join_distribution_type", "REPARTITIONED") - .withStats(ImmutableMap.of( - new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder() - .setOutputRowCount(100) - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 6400, 100))) - .build(), - new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder() - .setOutputRowCount(10000) - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) - .build())) .matches(join( INNER, ImmutableList.of(equiJoinClause("B1", "A1")), @@ -176,7 +188,18 @@ public void testRepartitionsWhenRequiredBySession() @Test public void testRepartitionsWhenBothTablesEqual() { - tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1))) + StatsCalculator testingStatsCalculator = new TestingStatsCalculator(statsCalculator, ImmutableMap.of( + new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(10000) + .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .build(), + new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(10000) + .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .build())); + + tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1), testingStatsCalculator, costCalculator)) + .withStatsCalculator(testingStatsCalculator) .on(p -> p.join( INNER, @@ -185,15 +208,6 @@ public void testRepartitionsWhenBothTablesEqual() ImmutableList.of(new JoinNode.EquiJoinClause(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT))), ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), Optional.empty())) - .withStats(ImmutableMap.of( - new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder() - .setOutputRowCount(10000) - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) - .build(), - new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder() - .setOutputRowCount(10000) - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) - .build())) .matches(join( INNER, ImmutableList.of(equiJoinClause("A1", "B1")), @@ -207,7 +221,18 @@ public void testRepartitionsWhenBothTablesEqual() @Test public void testReplicatesWhenRequiredBySession() { - tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1))) + StatsCalculator testingStatsCalculator = new TestingStatsCalculator(statsCalculator, ImmutableMap.of( + new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(10000) + .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .build(), + new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(10000) + .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .build())); + + tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1), testingStatsCalculator, costCalculator)) + .withStatsCalculator(testingStatsCalculator) .on(p -> p.join( INNER, @@ -217,15 +242,6 @@ public void testReplicatesWhenRequiredBySession() ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), Optional.empty())) .setSystemProperty("join_distribution_type", "REPLICATED") - .withStats(ImmutableMap.of( - new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder() - .setOutputRowCount(10000) - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) - .build(), - new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder() - .setOutputRowCount(10000) - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) - .build())) .matches(join( INNER, ImmutableList.of(equiJoinClause("A1", "B1")), @@ -239,7 +255,18 @@ public void testReplicatesWhenRequiredBySession() @Test public void testDoesNotFireForCrossJoin() { - tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1))) + StatsCalculator testingStatsCalculator = new TestingStatsCalculator(statsCalculator, ImmutableMap.of( + new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(10000) + .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .build(), + new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder() + .setOutputRowCount(10000) + .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) + .build())); + + tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1), testingStatsCalculator, costCalculator)) + .withStatsCalculator(testingStatsCalculator) .on(p -> p.join( INNER, @@ -248,23 +275,15 @@ public void testDoesNotFireForCrossJoin() ImmutableList.of(), ImmutableList.of(p.symbol("A1", BIGINT), p.symbol("B1", BIGINT)), Optional.empty())) - .withStats(ImmutableMap.of( - new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder() - .setOutputRowCount(10000) - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) - .build(), - new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder() - .setOutputRowCount(10000) - .addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100))) - .build())) .doesNotFire(); } @Test public void testDoesNotFireWithNoStats() { - tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1))) - .withStatsCalculator(new UnknownStatsCalculator()) + StatsCalculator testingStatsCalculator = new UnknownStatsCalculator(); + tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1), testingStatsCalculator, costCalculator)) + .withStatsCalculator(testingStatsCalculator) .on(p -> p.join( INNER, @@ -279,7 +298,7 @@ public void testDoesNotFireWithNoStats() @Test public void testDoesNotFireForNonDeterministicFilter() { - tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1))) + tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1), statsCalculator, costCalculator)) .on(p -> p.join( INNER, @@ -294,7 +313,27 @@ public void testDoesNotFireForNonDeterministicFilter() @Test public void testPredicatesPushedDown() { - tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1))) + StatsCalculator testingStatsCalculator = new TestingStatsCalculator(statsCalculator, ImmutableMap.of( + new PlanNodeId("valuesA"), + PlanNodeStatsEstimate.builder() + .setOutputRowCount(10) + .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 100, 10))) + .build(), + new PlanNodeId("valuesB"), + PlanNodeStatsEstimate.builder() + .setOutputRowCount(5) + .addSymbolStatistics(ImmutableMap.of( + new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 100, 10), + new Symbol("B2"), new SymbolStatsEstimate(0, 100, 0, 100, 10))) + .build(), + new PlanNodeId("valuesC"), + PlanNodeStatsEstimate.builder() + .setOutputRowCount(1000) + .addSymbolStatistics(ImmutableMap.of(new Symbol("C1"), new SymbolStatsEstimate(0, 100, 0, 100, 100))) + .build())); + + tester.assertThat(new ReorderJoins(new CostComparator(1, 1, 1), testingStatsCalculator, costCalculator)) + .withStatsCalculator(testingStatsCalculator) .on(p -> p.join( INNER, @@ -310,24 +349,6 @@ public void testPredicatesPushedDown() new JoinNode.EquiJoinClause(p.symbol("B2", BIGINT), p.symbol("C1", BIGINT))), ImmutableList.of(p.symbol("A1", BIGINT)), Optional.of(new ComparisonExpression(EQUAL, p.symbol("A1", BIGINT).toSymbolReference(), p.symbol("B1", BIGINT).toSymbolReference())))) - .withStats(ImmutableMap.of( - new PlanNodeId("valuesA"), - PlanNodeStatsEstimate.builder() - .setOutputRowCount(10) - .addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 100, 10))) - .build(), - new PlanNodeId("valuesB"), - PlanNodeStatsEstimate.builder() - .setOutputRowCount(5) - .addSymbolStatistics(ImmutableMap.of( - new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 100, 10), - new Symbol("B2"), new SymbolStatsEstimate(0, 100, 0, 100, 10))) - .build(), - new PlanNodeId("valuesC"), - PlanNodeStatsEstimate.builder() - .setOutputRowCount(1000) - .addSymbolStatistics(ImmutableMap.of(new Symbol("C1"), new SymbolStatsEstimate(0, 100, 0, 100, 100))) - .build())) .matches( join( INNER, diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java index 1a7d97ea7c2d..f125cc50c149 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java @@ -15,7 +15,6 @@ import com.facebook.presto.Session; import com.facebook.presto.cost.CostCalculator; -import com.facebook.presto.cost.PlanNodeStatsEstimate; import com.facebook.presto.cost.StatsCalculator; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.security.AccessControl; @@ -29,9 +28,7 @@ import com.facebook.presto.sql.planner.iterative.Memo; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.plan.PlanNode; -import com.facebook.presto.sql.planner.plan.PlanNodeId; import com.facebook.presto.sql.planner.planPrinter.PlanPrinter; -import com.facebook.presto.testing.TestingLookup; import com.facebook.presto.transaction.TransactionManager; import com.google.common.collect.ImmutableSet; @@ -40,11 +37,9 @@ import java.util.function.Function; import static com.facebook.presto.sql.planner.assertions.PlanAssert.assertPlan; -import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; import static com.facebook.presto.transaction.TransactionBuilder.transaction; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; -import static com.google.common.collect.ImmutableMap.toImmutableMap; import static java.util.Objects.requireNonNull; import static org.testng.Assert.fail; @@ -59,7 +54,7 @@ public class RuleAssert private final PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); private Map symbols; - private TestingLookup lookup; + private Lookup lookup; private PlanNode plan; private final TransactionManager transactionManager; private final AccessControl accessControl; @@ -111,32 +106,10 @@ public RuleAssert on(Function planProvider) plan = planProvider.apply(builder); symbols = builder.getSymbols(); memo = new Memo(idAllocator, plan); - lookup = new TestingLookup(statsCalculator, costCalculator, memo::resolve); + lookup = Lookup.from(memo::resolve, statsCalculator, costCalculator); return this; } - public RuleAssert withStats(Map stats) - { - checkState(lookup != null, "lookup has not yet been initialized"); - Map planNodeMap = buildPlanNodeMap(); - lookup = lookup.withStats( - stats.entrySet() - .stream() - .collect(toImmutableMap( - entry -> { - checkState(planNodeMap.containsKey(entry.getKey()), "planNodeMap does not contain key"); - return planNodeMap.get(entry.getKey()); - }, - Map.Entry::getValue))); - return this; - } - - private Map buildPlanNodeMap() - { - return searchFrom(plan, lookup).findAll().stream() - .collect(toImmutableMap(PlanNode::getId, planNode -> planNode)); - } - public void doesNotFire() { RuleApplication ruleApplication = applyRule();