Skip to content

Commit

Permalink
Fix the logic in LocalExchange for acquiring the bucket count
Browse files Browse the repository at this point in the history
Passing empty node list would make `getBucketNodeMap` stuck.
The correct way to acquire the bucket count is to use `getBucketCount`.
  • Loading branch information
kewang1024 authored and highker committed Nov 10, 2021
1 parent 48fadf1 commit 426fbdd
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import com.facebook.presto.operator.PipelineExecutionStrategy;
import com.facebook.presto.operator.PrecomputedHashGenerator;
import com.facebook.presto.spi.BucketFunction;
import com.facebook.presto.spi.connector.ConnectorBucketNodeMap;
import com.facebook.presto.spi.connector.ConnectorNodePartitioningProvider;
import com.facebook.presto.sql.planner.PartitioningHandle;
import com.facebook.presto.sql.planner.PartitioningProviderManager;
Expand Down Expand Up @@ -150,7 +149,7 @@ else if (partitioning.equals(FIXED_HASH_DISTRIBUTION) || partitioning.getConnect
}
}

private static PartitionFunction createPartitionFunction(
static PartitionFunction createPartitionFunction(
PartitioningProviderManager partitioningProviderManager,
Session session,
PartitioningHandle partitioning,
Expand All @@ -170,14 +169,10 @@ private static PartitionFunction createPartitionFunction(
}

ConnectorNodePartitioningProvider partitioningProvider = partitioningProviderManager.getPartitioningProvider(partitioning.getConnectorId().get());
ConnectorBucketNodeMap connectorBucketNodeMap = partitioningProvider.getBucketNodeMap(
int bucketCount = partitioningProvider.getBucketCount(
partitioning.getTransactionHandle().orElse(null),
session.toConnectorSession(partitioning.getConnectorId().get()),
partitioning.getConnectorHandle(),
ImmutableList.of());
checkArgument(connectorBucketNodeMap != null, "No partition map %s", partitioning);

int bucketCount = connectorBucketNodeMap.getBucketCount();
partitioning.getConnectorHandle());
int[] bucketToPartition = new int[bucketCount];
for (int bucket = 0; bucket < bucketCount; bucket++) {
bucketToPartition[bucket] = bucket % partitionCount;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,21 @@
import com.facebook.presto.execution.Lifespan;
import com.facebook.presto.operator.InterpretedHashGenerator;
import com.facebook.presto.operator.PageAssertions;
import com.facebook.presto.operator.PartitionFunction;
import com.facebook.presto.operator.PipelineExecutionStrategy;
import com.facebook.presto.operator.exchange.LocalExchange.LocalExchangeFactory;
import com.facebook.presto.operator.exchange.LocalExchange.LocalExchangeSinkFactory;
import com.facebook.presto.operator.exchange.LocalExchange.LocalExchangeSinkFactoryId;
import com.facebook.presto.spi.BucketFunction;
import com.facebook.presto.spi.ConnectorId;
import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.ConnectorSplit;
import com.facebook.presto.spi.Node;
import com.facebook.presto.spi.connector.ConnectorBucketNodeMap;
import com.facebook.presto.spi.connector.ConnectorNodePartitioningProvider;
import com.facebook.presto.spi.connector.ConnectorPartitioningHandle;
import com.facebook.presto.spi.connector.ConnectorTransactionHandle;
import com.facebook.presto.sql.planner.PartitioningHandle;
import com.facebook.presto.sql.planner.PartitioningProviderManager;
import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.ListenableFuture;
Expand All @@ -36,17 +47,23 @@
import java.util.List;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.function.ToIntFunction;
import java.util.stream.Stream;

import static com.facebook.airlift.testing.Assertions.assertContains;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.operator.PipelineExecutionStrategy.GROUPED_EXECUTION;
import static com.facebook.presto.operator.PipelineExecutionStrategy.UNGROUPED_EXECUTION;
import static com.facebook.presto.operator.exchange.LocalExchange.createPartitionFunction;
import static com.facebook.presto.spi.connector.ConnectorBucketNodeMap.createBucketNodeMap;
import static com.facebook.presto.spi.schedule.NodeSelectionStrategy.SOFT_AFFINITY;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_PASSTHROUGH_DISTRIBUTION;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION;
import static com.facebook.presto.testing.TestingSession.testSessionBuilder;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.airlift.units.DataSize.Unit.BYTE;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
Expand Down Expand Up @@ -434,6 +451,71 @@ public void testPartition(PipelineExecutionStrategy executionStrategy)
});
}

@Test
public void testCreatePartitionFunction()
{
int partitionCount = 10;
PartitioningProviderManager partitioningProviderManager = new PartitioningProviderManager();
partitioningProviderManager.addPartitioningProvider(
new ConnectorId("prism"),
new ConnectorNodePartitioningProvider() {
@Override
public ConnectorBucketNodeMap getBucketNodeMap(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioningHandle, List<Node> sortedNodes)
{
return createBucketNodeMap(Stream.generate(() -> sortedNodes).flatMap(List::stream).limit(10).collect(toImmutableList()), SOFT_AFFINITY);
}

@Override
public ToIntFunction<ConnectorSplit> getSplitBucketFunction(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioningHandle)
{
return null;
}

@Override
public BucketFunction getBucketFunction(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioningHandle, List<Type> partitionChannelTypes, int bucketCount)
{
return (Page page, int position) -> partitionCount;
}

@Override
public int getBucketCount(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorPartitioningHandle partitioningHandle)
{
return 10;
}
});
PartitioningHandle partitioningHandle = new PartitioningHandle(
Optional.of(new ConnectorId("prism")),
Optional.of(new ConnectorTransactionHandle() {
@Override
public int hashCode()
{
return super.hashCode();
}

@Override
public boolean equals(Object obj)
{
return super.equals(obj);
}
}),
new ConnectorPartitioningHandle() {
@Override
public boolean isSingleNode()
{
return false;
}

@Override
public boolean isCoordinatorOnly()
{
return false;
}
});
PartitionFunction partitionFunction = createPartitionFunction(partitioningProviderManager, session, partitioningHandle, 600, ImmutableList.of(), false);

assertEquals(partitionFunction.getPartitionCount(), partitionCount);
}

@Test(dataProvider = "executionStrategy")
public void writeUnblockWhenAllReadersFinish(PipelineExecutionStrategy executionStrategy)
{
Expand Down

0 comments on commit 426fbdd

Please sign in to comment.