Skip to content

Commit

Permalink
fix local_exchange
Browse files Browse the repository at this point in the history
  • Loading branch information
Nitin-Kashyap committed Dec 13, 2024
1 parent 8e9a566 commit e6a978e
Show file tree
Hide file tree
Showing 14 changed files with 53 additions and 16 deletions.
3 changes: 2 additions & 1 deletion be/src/pipeline/dependency.h
Original file line number Diff line number Diff line change
Expand Up @@ -728,14 +728,15 @@ inline std::string get_exchange_type_name(ExchangeType idx) {
}

struct DataDistribution {
DataDistribution(ExchangeType type) : distribution_type(type) {}
DataDistribution(ExchangeType type) : distribution_type(type), hash_type(THashType::CRC32) {}
DataDistribution(ExchangeType type, const std::vector<TExpr>& partition_exprs_)
: distribution_type(type), partition_exprs(partition_exprs_) {}
DataDistribution(const DataDistribution& other) = default;
bool need_local_exchange() const { return distribution_type != ExchangeType::NOOP; }
DataDistribution& operator=(const DataDistribution& other) = default;
ExchangeType distribution_type;
std::vector<TExpr> partition_exprs;
THashType::type hash_type;
};

class ExchangerBase;
Expand Down
2 changes: 2 additions & 0 deletions be/src/pipeline/exec/partitioned_hash_join_sink_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,8 @@ PartitionedHashJoinSinkOperatorX::PartitionedHashJoinSinkOperatorX(ObjectPool* p
descs),
_join_distribution(tnode.hash_join_node.__isset.dist_type ? tnode.hash_join_node.dist_type
: TJoinDistributionType::NONE),
_hash_type(tnode.hash_join_node.__isset.__isset.hash_type ? tnode.hash_join_node.hash_type
: THashType::CRC32),
_distribution_partition_exprs(tnode.__isset.distribute_expr_lists
? tnode.distribute_expr_lists[1]
: std::vector<TExpr> {}),
Expand Down
4 changes: 3 additions & 1 deletion be/src/pipeline/exec/partitioned_hash_join_sink_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ class PartitionedHashJoinSinkOperatorX
return _join_distribution == TJoinDistributionType::BUCKET_SHUFFLE ||
_join_distribution == TJoinDistributionType::COLOCATE
? DataDistribution(ExchangeType::BUCKET_HASH_SHUFFLE,
_distribution_partition_exprs)
_distribution_partition_exprs,
hash_type)
: DataDistribution(ExchangeType::HASH_SHUFFLE,
_distribution_partition_exprs);
}
Expand All @@ -134,6 +135,7 @@ class PartitionedHashJoinSinkOperatorX
Status _setup_internal_operator(RuntimeState* state);

const TJoinDistributionType::type _join_distribution;
THashType::type _hash_type;

std::vector<TExpr> _build_exprs;

Expand Down
3 changes: 2 additions & 1 deletion be/src/vec/columns/column_string.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,8 @@ void ColumnStr<T>::update_crcs_with_value(uint32_t* __restrict hashes, doris::Pr
}
}

void ColumnString::update_murmurs_with_value(int32_t* __restrict hashes, doris::PrimitiveType type,
template <typename T>
void ColumnStr<T>::update_murmurs_with_value(int32_t* __restrict hashes, doris::PrimitiveType type,
int32_t rows, uint32_t offset,
const uint8_t* __restrict null_data) const {
auto s = rows;
Expand Down
4 changes: 2 additions & 2 deletions be/src/vec/runtime/partitioner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ Status Murmur32HashPartitioner<ChannelIds>::clone(RuntimeState* state,
}

template <typename ChannelIds>
int32_t Murmur32HashPartitioner<ChannelIds>::_get_default_seed() {
return static_cast<int32_t>(HashUtil::SPARK_MURMUR_32_SEED);
int32_t Murmur32HashPartitioner<ChannelIds>::_get_default_seed() const {
return reinterpret_cast<int32_t>(HashUtil::SPARK_MURMUR_32_SEED);
}

template class Crc32HashPartitioner<ShuffleChannelIds>;
Expand Down
6 changes: 3 additions & 3 deletions be/src/vec/runtime/partitioner.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ class Crc32HashPartitioner : public PartitionerBase {

void _do_hash(const ColumnPtr& column, uint32_t* __restrict result, int idx) const;

HashValueType _get_default_seed() {
return 0;
HashValueType _get_default_seed() const {
return reinterpret_cast<HashValueType>(0);
}

VExprContextSPtrs _partition_expr_ctxs;
Expand Down Expand Up @@ -126,7 +126,7 @@ class Murmur32HashPartitioner final : public Partitioner<int32_t, ChannelIds> {

Status clone(RuntimeState* state, std::unique_ptr<PartitionerBase>& partitioner) override;

int32_t _get_default_seed() override;
int32_t _get_default_seed() const;

private:
void _do_hash(const ColumnPtr& column, int32_t* __restrict result, int idx) const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,10 @@ private TScanRangeLocations splitToScanRange(
isSparkBucketedHiveTable = ((HMSExternalTable) targetTable).isSparkBucketedTable();
if (isSparkBucketedHiveTable) {
bucketNum = HiveBucketUtil.getBucketNumberFromPath(fileSplit.getPath().getName()).getAsInt();
if (!bucketSeq2locations.containsKey(bucketNum)) {
bucketSeq2locations.put(bucketNum, curLocations);
}
curLocations = bucketSeq2locations.get(bucketNum).get(0);
}
}

Expand Down Expand Up @@ -481,13 +485,10 @@ private TScanRangeLocations splitToScanRange(
curLocations.addToLocations(location);

if (LOG.isDebugEnabled()) {
LOG.debug("assign to backend {} with table split: {} ({}, {}), location: {}",
LOG.debug("assign to backend {} with table split: {} ({}, {}), location: {}, bucketNum: {}",
curLocations.getLocations().get(0).getBackendId(), fileSplit.getPath(),
fileSplit.getStart(), fileSplit.getLength(),
Joiner.on("|").join(fileSplit.getHosts()));
}
if (isSparkBucketedHiveTable) {
bucketSeq2locations.put(bucketNum, curLocations);
Joiner.on("|").join(fileSplit.getHosts()), bucketNum);
}

return curLocations;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import org.apache.doris.thrift.TTableType;

import com.google.common.collect.BiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
Expand All @@ -84,10 +85,8 @@
import org.apache.logging.log4j.Logger;

import java.io.IOException;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1427,6 +1427,8 @@ public PlanFragment visitPhysicalHashJoin(
hashJoinNode.setDistributionMode(DistributionMode.BROADCAST);
} else if (JoinUtils.shouldBucketShuffleJoin(physicalHashJoin)) {
hashJoinNode.setDistributionMode(DistributionMode.BUCKET_SHUFFLE);
hashJoinNode.setHashType(((DistributionSpecHash) physicalHashJoin.left()
.getPhysicalProperties().getDistributionSpec()).getShuffleFunction());
} else {
hashJoinNode.setDistributionMode(DistributionMode.PARTITIONED);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,9 @@ public enum StorageBucketHashType {
// SPARK_MURMUR32 is the hash function for Spark bucketed hive table's storage and computation
STORAGE_BUCKET_SPARK_MURMUR32;

/**
* convert to thrift
*/
public THashType toThrift() {
switch (this) {
case STORAGE_BUCKET_CRC32:
Expand All @@ -357,6 +360,21 @@ public THashType toThrift() {
return THashType.XXHASH64;
}
}

/**
* convert from thrift
*/
public static StorageBucketHashType fromThrift(THashType hashType) {
switch (hashType) {
case CRC32:
return STORAGE_BUCKET_CRC32;
case SPARK_MURMUR32:
return STORAGE_BUCKET_SPARK_MURMUR32;
case XXHASH64:
default:
return STORAGE_BUCKET_XXHASH64;
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.apache.doris.common.UserException;
import org.apache.doris.datasource.hive.HMSExternalTable;
import org.apache.doris.datasource.hive.source.HiveScanNode;
import org.apache.doris.nereids.properties.DistributionSpecHash;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.thrift.THashType;
import org.apache.doris.thrift.TPartitionType;
Expand Down Expand Up @@ -339,6 +340,7 @@ private PlanFragment createHashJoinFragment(
Ref<THashType> hashType = Ref.from(THashType.CRC32);
if (canBucketShuffleJoin(node, leftChildFragment, rhsPartitionExprs, hashType)) {
node.setDistributionMode(HashJoinNode.DistributionMode.BUCKET_SHUFFLE);
node.setHashType(DistributionSpecHash.StorageBucketHashType.fromThrift(hashType.value));
DataPartition rhsJoinPartition =
new DataPartition(TPartitionType.BUCKET_SHFFULE_HASH_PARTITIONED,
rhsPartitionExprs, hashType.value);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.apache.doris.common.CheckedMath;
import org.apache.doris.common.Pair;
import org.apache.doris.common.UserException;
import org.apache.doris.nereids.properties.DistributionSpecHash;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.statistics.StatisticalType;
import org.apache.doris.thrift.TEqJoinCondition;
Expand Down Expand Up @@ -79,6 +80,7 @@ public class HashJoinNode extends JoinNodeBase {
private List<Expr> markJoinConjuncts;

private DistributionMode distrMode;
private DistributionSpecHash.StorageBucketHashType hashType;
private boolean isColocate = false; //the flag for colocate join
private String colocateReason = ""; // if can not do colocate join, set reason here

Expand Down Expand Up @@ -249,6 +251,10 @@ public void setColocate(boolean colocate, String reason) {
colocateReason = reason;
}

public void setHashType(DistributionSpecHash.StorageBucketHashType hashType) {
this.hashType = hashType;
}

/**
* Calculate the slots output after going through the hash table in the hash join node.
* The most essential difference between 'hashOutputSlots' and 'outputSlots' is that
Expand Down Expand Up @@ -817,6 +823,7 @@ protected void toThrift(TPlanNode msg) {
}
}
msg.hash_join_node.setDistType(isColocate ? TJoinDistributionType.COLOCATE : distrMode.toThrift());
msg.hash_join_node.setHashType(hashType.toThrift());
msg.hash_join_node.setUseSpecificProjections(useSpecificProjections);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2718,6 +2718,7 @@ private void computeScanRangeAssignmentByBucketForHive(
fragmentIdToSeqToAddressMap.put(scanNode.getFragmentId(), new HashMap<>());
fragmentIdBucketSeqToScanRangeMap.put(scanNode.getFragmentId(), new BucketSeqToScanRange());
fragmentIdToBuckendIdBucketCountMap.put(scanNode.getFragmentId(), new HashMap<>());
scanNode.getFragment().setBucketNum(bucketNum);
}
Map<Integer, TNetworkAddress> bucketSeqToAddress
= fragmentIdToSeqToAddressMap.get(scanNode.getFragmentId());
Expand Down
1 change: 1 addition & 0 deletions gensrc/thrift/PlanNodes.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,7 @@ struct THashJoinNode {
13: optional list<Exprs.TExpr> mark_join_conjuncts
// use_specific_projections true, if output exprssions is denoted by srcExprList represents, o.w. PlanNode.projections
14: optional bool use_specific_projections
15: optional Partitions.THashType hash_type
}

struct TNestedLoopJoinNode {
Expand Down

0 comments on commit e6a978e

Please sign in to comment.