diff --git a/presto-main/src/main/java/com/facebook/presto/operator/exchange/LocalExchange.java b/presto-main/src/main/java/com/facebook/presto/operator/exchange/LocalExchange.java index 0c855d158983..74cdcad53b1d 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/exchange/LocalExchange.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/exchange/LocalExchange.java @@ -23,7 +23,6 @@ import com.facebook.presto.operator.PipelineExecutionStrategy; import com.facebook.presto.operator.PrecomputedHashGenerator; import com.facebook.presto.spi.BucketFunction; -import com.facebook.presto.spi.connector.ConnectorBucketNodeMap; import com.facebook.presto.spi.connector.ConnectorNodePartitioningProvider; import com.facebook.presto.sql.planner.PartitioningHandle; import com.facebook.presto.sql.planner.PartitioningProviderManager; @@ -150,7 +149,7 @@ else if (partitioning.equals(FIXED_HASH_DISTRIBUTION) || partitioning.getConnect } } - private static PartitionFunction createPartitionFunction( + static PartitionFunction createPartitionFunction( PartitioningProviderManager partitioningProviderManager, Session session, PartitioningHandle partitioning, @@ -170,14 +169,10 @@ private static PartitionFunction createPartitionFunction( } ConnectorNodePartitioningProvider partitioningProvider = partitioningProviderManager.getPartitioningProvider(partitioning.getConnectorId().get()); - ConnectorBucketNodeMap connectorBucketNodeMap = partitioningProvider.getBucketNodeMap( + int bucketCount = partitioningProvider.getBucketCount( partitioning.getTransactionHandle().orElse(null), session.toConnectorSession(partitioning.getConnectorId().get()), - partitioning.getConnectorHandle(), - ImmutableList.of()); - checkArgument(connectorBucketNodeMap != null, "No partition map %s", partitioning); - - int bucketCount = connectorBucketNodeMap.getBucketCount(); + partitioning.getConnectorHandle()); int[] bucketToPartition = new int[bucketCount]; for (int bucket = 0; bucket < bucketCount; bucket++) { bucketToPartition[bucket] = bucket % partitionCount; diff --git a/presto-main/src/test/java/com/facebook/presto/operator/exchange/TestLocalExchange.java b/presto-main/src/test/java/com/facebook/presto/operator/exchange/TestLocalExchange.java index 64287a15ed87..63133d04008f 100644 --- a/presto-main/src/test/java/com/facebook/presto/operator/exchange/TestLocalExchange.java +++ b/presto-main/src/test/java/com/facebook/presto/operator/exchange/TestLocalExchange.java @@ -20,10 +20,21 @@ import com.facebook.presto.execution.Lifespan; import com.facebook.presto.operator.InterpretedHashGenerator; import com.facebook.presto.operator.PageAssertions; +import com.facebook.presto.operator.PartitionFunction; import com.facebook.presto.operator.PipelineExecutionStrategy; import com.facebook.presto.operator.exchange.LocalExchange.LocalExchangeFactory; import com.facebook.presto.operator.exchange.LocalExchange.LocalExchangeSinkFactory; import com.facebook.presto.operator.exchange.LocalExchange.LocalExchangeSinkFactoryId; +import com.facebook.presto.spi.BucketFunction; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.Node; +import com.facebook.presto.spi.connector.ConnectorBucketNodeMap; +import com.facebook.presto.spi.connector.ConnectorNodePartitioningProvider; +import com.facebook.presto.spi.connector.ConnectorPartitioningHandle; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.sql.planner.PartitioningHandle; import com.facebook.presto.sql.planner.PartitioningProviderManager; import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; @@ -36,17 +47,23 @@ import java.util.List; import java.util.Optional; import java.util.function.Consumer; +import java.util.function.ToIntFunction; +import java.util.stream.Stream; import static com.facebook.airlift.testing.Assertions.assertContains; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.operator.PipelineExecutionStrategy.GROUPED_EXECUTION; import static com.facebook.presto.operator.PipelineExecutionStrategy.UNGROUPED_EXECUTION; +import static com.facebook.presto.operator.exchange.LocalExchange.createPartitionFunction; +import static com.facebook.presto.spi.connector.ConnectorBucketNodeMap.createBucketNodeMap; +import static com.facebook.presto.spi.schedule.NodeSelectionStrategy.SOFT_AFFINITY; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_PASSTHROUGH_DISTRIBUTION; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static com.google.common.collect.ImmutableList.toImmutableList; import static io.airlift.units.DataSize.Unit.BYTE; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; @@ -434,6 +451,71 @@ public void testPartition(PipelineExecutionStrategy executionStrategy) }); } + @Test + public void testCreatePartitionFunction() + { + int partitionCount = 10; + PartitioningProviderManager partitioningProviderManager = new PartitioningProviderManager(); + partitioningProviderManager.addPartitioningProvider( + new ConnectorId("prism"), + new ConnectorNodePartitioningProvider() { + @Override + public ConnectorBucketNodeMap getBucketNodeMap(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioningHandle, List sortedNodes) + { + return createBucketNodeMap(Stream.generate(() -> sortedNodes).flatMap(List::stream).limit(10).collect(toImmutableList()), SOFT_AFFINITY); + } + + @Override + public ToIntFunction getSplitBucketFunction(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioningHandle) + { + return null; + } + + @Override + public BucketFunction getBucketFunction(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioningHandle, List partitionChannelTypes, int bucketCount) + { + return (Page page, int position) -> partitionCount; + } + + @Override + public int getBucketCount(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioningHandle) + { + return 10; + } + }); + PartitioningHandle partitioningHandle = new PartitioningHandle( + Optional.of(new ConnectorId("prism")), + Optional.of(new ConnectorTransactionHandle() { + @Override + public int hashCode() + { + return super.hashCode(); + } + + @Override + public boolean equals(Object obj) + { + return super.equals(obj); + } + }), + new ConnectorPartitioningHandle() { + @Override + public boolean isSingleNode() + { + return false; + } + + @Override + public boolean isCoordinatorOnly() + { + return false; + } + }); + PartitionFunction partitionFunction = createPartitionFunction(partitioningProviderManager, session, partitioningHandle, 600, ImmutableList.of(), false); + + assertEquals(partitionFunction.getPartitionCount(), partitionCount); + } + @Test(dataProvider = "executionStrategy") public void writeUnblockWhenAllReadersFinish(PipelineExecutionStrategy executionStrategy) {