Skip to content

Commit

Permalink
feat: initial HashJoinRel support (substrait-io#187)
Browse files Browse the repository at this point in the history
* fix: pojo to proto JoinRel conversion skipped postJoinFilter
  • Loading branch information
vibhatha authored Oct 25, 2023
1 parent 31b91f3 commit d6328a6
Show file tree
Hide file tree
Showing 9 changed files with 303 additions and 0 deletions.
29 changes: 29 additions & 0 deletions core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.substrait.relation.Rel;
import io.substrait.relation.Set;
import io.substrait.relation.Sort;
import io.substrait.relation.physical.HashJoin;
import io.substrait.type.ImmutableType;
import io.substrait.type.NamedStruct;
import io.substrait.type.Type;
Expand Down Expand Up @@ -165,6 +166,34 @@ private Join join(
.build();
}

public HashJoin hashJoin(
List<Integer> leftKeys,
List<Integer> rightKeys,
HashJoin.JoinType joinType,
Rel left,
Rel right) {
return hashJoin(leftKeys, rightKeys, joinType, Optional.empty(), left, right);
}

public HashJoin hashJoin(
List<Integer> leftKeys,
List<Integer> rightKeys,
HashJoin.JoinType joinType,
Optional<Rel.Remap> remap,
Rel left,
Rel right) {
return HashJoin.builder()
.left(left)
.right(right)
.leftKeys(
this.fieldReferences(left, leftKeys.stream().mapToInt(Integer::intValue).toArray()))
.rightKeys(
this.fieldReferences(right, rightKeys.stream().mapToInt(Integer::intValue).toArray()))
.joinType(joinType)
.remap(remap)
.build();
}

public NamedScan namedScan(
Iterable<String> tableName, Iterable<String> columnNames, Iterable<Type> types) {
return namedScan(tableName, columnNames, types, Optional.empty());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.substrait.relation;

import io.substrait.relation.physical.HashJoin;

public abstract class AbstractRelVisitor<OUTPUT, EXCEPTION extends Exception>
implements RelVisitor<OUTPUT, EXCEPTION> {
public abstract OUTPUT visitFallback(Rel rel);
Expand Down Expand Up @@ -83,4 +85,9 @@ public OUTPUT visit(ExtensionMulti extensionMulti) throws EXCEPTION {
public OUTPUT visit(ExtensionTable extensionTable) throws EXCEPTION {
return visitFallback(extensionTable);
}

@Override
public OUTPUT visit(HashJoin hashJoin) throws EXCEPTION {
return visitFallback(hashJoin);
}
}
37 changes: 37 additions & 0 deletions core/src/main/java/io/substrait/relation/ProtoRelConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import io.substrait.proto.ExtensionSingleRel;
import io.substrait.proto.FetchRel;
import io.substrait.proto.FilterRel;
import io.substrait.proto.HashJoinRel;
import io.substrait.proto.JoinRel;
import io.substrait.proto.ProjectRel;
import io.substrait.proto.ReadRel;
Expand All @@ -27,6 +28,7 @@
import io.substrait.relation.files.FileOrFiles;
import io.substrait.relation.files.ImmutableFileFormat;
import io.substrait.relation.files.ImmutableFileOrFiles;
import io.substrait.relation.physical.HashJoin;
import io.substrait.type.ImmutableNamedStruct;
import io.substrait.type.NamedStruct;
import io.substrait.type.Type;
Expand Down Expand Up @@ -95,6 +97,9 @@ public Rel from(io.substrait.proto.Rel rel) {
case EXTENSION_MULTI -> {
return newExtensionMulti(rel.getExtensionMulti());
}
case HASH_JOIN -> {
return newHashJoin(rel.getHashJoin());
}
default -> {
throw new UnsupportedOperationException("Unsupported RelTypeCase of " + relType);
}
Expand Down Expand Up @@ -490,6 +495,38 @@ private Set newSet(SetRel rel) {
return builder.build();
}

private Rel newHashJoin(HashJoinRel rel) {
Rel left = from(rel.getLeft());
Rel right = from(rel.getRight());
var leftKeys = rel.getLeftKeysList();
var rightKeys = rel.getRightKeysList();

Type.Struct leftStruct = left.getRecordType();
Type.Struct rightStruct = right.getRecordType();
Type.Struct unionedStruct = Type.Struct.builder().from(leftStruct).from(rightStruct).build();
var leftConverter = new ProtoExpressionConverter(lookup, extensions, leftStruct, this);
var rightConverter = new ProtoExpressionConverter(lookup, extensions, rightStruct, this);
var unionConverter = new ProtoExpressionConverter(lookup, extensions, unionedStruct, this);
var builder =
HashJoin.builder()
.left(left)
.right(right)
.leftKeys(leftKeys.stream().map(leftConverter::from).collect(Collectors.toList()))
.rightKeys(rightKeys.stream().map(rightConverter::from).collect(Collectors.toList()))
.joinType(HashJoin.JoinType.fromProto(rel.getType()))
.postJoinFilter(
Optional.ofNullable(
rel.hasPostJoinFilter() ? unionConverter.from(rel.getPostJoinFilter()) : null));

builder
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
.remap(optionalRelmap(rel.getCommon()));
if (rel.hasAdvancedExtension()) {
builder.extension(advancedExtension(rel.getAdvancedExtension()));
}
return builder.build();
}

private static Optional<Rel.Remap> optionalRelmap(io.substrait.proto.RelCommon relCommon) {
return Optional.ofNullable(
relCommon.hasEmit() ? Rel.Remap.of(relCommon.getEmit().getOutputMappingList()) : null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.ImmutableFieldReference;
import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.physical.ImmutableHashJoin;
import io.substrait.type.Type;
import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -166,6 +168,29 @@ public Optional<Rel> visit(Cross cross) throws RuntimeException {
.build());
}

@Override
public Optional<Rel> visit(HashJoin hashJoin) throws RuntimeException {
var left = hashJoin.getLeft().accept(this);
var right = hashJoin.getRight().accept(this);
var leftKeys = hashJoin.getLeftKeys();
var rightKeys = hashJoin.getRightKeys();
var postFilter = hashJoin.getPostJoinFilter().flatMap(t -> visitExpression(t));
if (allEmpty(left, right, postFilter)) {
return Optional.empty();
}
return Optional.of(
ImmutableHashJoin.builder()
.from(hashJoin)
.left(left.orElse(hashJoin.getLeft()))
.right(right.orElse(hashJoin.getRight()))
.leftKeys(leftKeys)
.rightKeys(rightKeys)
.postJoinFilter(
Optional.ofNullable(
postFilter.orElseGet(() -> hashJoin.getPostJoinFilter().orElse(null))))
.build());
}

private Optional<Expression> visitExpression(Expression expression) {
ExpressionVisitor<Optional<Expression>, RuntimeException> visitor =
new AbstractExpressionVisitor<>() {
Expand Down
34 changes: 34 additions & 0 deletions core/src/main/java/io/substrait/relation/RelProtoConverter.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.substrait.relation;

import io.substrait.expression.Expression;
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.proto.ExpressionProtoConverter;
import io.substrait.extension.ExtensionCollector;
Expand All @@ -12,6 +13,7 @@
import io.substrait.proto.ExtensionSingleRel;
import io.substrait.proto.FetchRel;
import io.substrait.proto.FilterRel;
import io.substrait.proto.HashJoinRel;
import io.substrait.proto.JoinRel;
import io.substrait.proto.ProjectRel;
import io.substrait.proto.ReadRel;
Expand All @@ -21,6 +23,7 @@
import io.substrait.proto.SortField;
import io.substrait.proto.SortRel;
import io.substrait.relation.files.FileOrFiles;
import io.substrait.relation.physical.HashJoin;
import io.substrait.type.proto.TypeProtoConverter;
import java.util.Collection;
import java.util.List;
Expand Down Expand Up @@ -68,6 +71,10 @@ private List<SortField> toProtoS(Collection<Expression.SortField> sorts) {
.collect(java.util.stream.Collectors.toList());
}

private io.substrait.proto.Expression.FieldReference toProto(FieldReference fieldReference) {
return fieldReference.accept(exprProtoConverter).getSelection();
}

@Override
public Rel visit(Aggregate aggregate) throws RuntimeException {
var builder =
Expand Down Expand Up @@ -166,6 +173,8 @@ public Rel visit(Join join) throws RuntimeException {

join.getCondition().ifPresent(t -> builder.setExpression(toProto(t)));

join.getPostJoinFilter().ifPresent(t -> builder.setPostJoinFilter(toProto(t)));

join.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
return Rel.newBuilder().setJoin(builder).build();
}
Expand Down Expand Up @@ -228,6 +237,31 @@ public Rel visit(ExtensionTable extensionTable) throws RuntimeException {
return Rel.newBuilder().setRead(builder).build();
}

@Override
public Rel visit(HashJoin hashJoin) throws RuntimeException {
var builder =
HashJoinRel.newBuilder()
.setCommon(common(hashJoin))
.setLeft(toProto(hashJoin.getLeft()))
.setRight(toProto(hashJoin.getRight()))
.setType(hashJoin.getJoinType().toProto());

List<FieldReference> leftKeys = hashJoin.getLeftKeys();
List<FieldReference> rightKeys = hashJoin.getRightKeys();

if (leftKeys.size() != rightKeys.size()) {
throw new RuntimeException("Number of left and right keys must be equal.");
}

builder.addAllLeftKeys(leftKeys.stream().map(this::toProto).collect(Collectors.toList()));
builder.addAllRightKeys(rightKeys.stream().map(this::toProto).collect(Collectors.toList()));

hashJoin.getPostJoinFilter().ifPresent(t -> builder.setPostJoinFilter(toProto(t)));

hashJoin.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
return Rel.newBuilder().setHashJoin(builder).build();
}

@Override
public Rel visit(Project project) throws RuntimeException {
var builder =
Expand Down
4 changes: 4 additions & 0 deletions core/src/main/java/io/substrait/relation/RelVisitor.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.substrait.relation;

import io.substrait.relation.physical.HashJoin;

public interface RelVisitor<OUTPUT, EXCEPTION extends Exception> {
OUTPUT visit(Aggregate aggregate) throws EXCEPTION;

Expand Down Expand Up @@ -32,4 +34,6 @@ public interface RelVisitor<OUTPUT, EXCEPTION extends Exception> {
OUTPUT visit(ExtensionMulti extensionMulti) throws EXCEPTION;

OUTPUT visit(ExtensionTable extensionTable) throws EXCEPTION;

OUTPUT visit(HashJoin hashJoin) throws EXCEPTION;
}
85 changes: 85 additions & 0 deletions core/src/main/java/io/substrait/relation/physical/HashJoin.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package io.substrait.relation.physical;

import io.substrait.expression.Expression;
import io.substrait.expression.FieldReference;
import io.substrait.proto.HashJoinRel;
import io.substrait.relation.BiRel;
import io.substrait.relation.HasExtension;
import io.substrait.relation.RelVisitor;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.util.List;
import java.util.Optional;
import java.util.stream.Stream;
import org.immutables.value.Value;

@Value.Immutable
public abstract class HashJoin extends BiRel implements HasExtension {

public abstract List<FieldReference> getLeftKeys();

public abstract List<FieldReference> getRightKeys();

public abstract JoinType getJoinType();

public abstract Optional<Expression> getPostJoinFilter();

public static enum JoinType {
UNKNOWN(HashJoinRel.JoinType.JOIN_TYPE_UNSPECIFIED),
INNER(HashJoinRel.JoinType.JOIN_TYPE_INNER),
OUTER(HashJoinRel.JoinType.JOIN_TYPE_OUTER),
LEFT(HashJoinRel.JoinType.JOIN_TYPE_LEFT),
RIGHT(HashJoinRel.JoinType.JOIN_TYPE_RIGHT),
LEFT_SEMI(HashJoinRel.JoinType.JOIN_TYPE_LEFT_SEMI),
RIGHT_SEMI(HashJoinRel.JoinType.JOIN_TYPE_RIGHT_SEMI),
LEFT_ANTI(HashJoinRel.JoinType.JOIN_TYPE_LEFT_ANTI),
RIGHT_ANTI(HashJoinRel.JoinType.JOIN_TYPE_RIGHT_ANTI);

private HashJoinRel.JoinType proto;

JoinType(HashJoinRel.JoinType proto) {
this.proto = proto;
}

public static JoinType fromProto(HashJoinRel.JoinType proto) {
for (var v : values()) {
if (v.proto == proto) {
return v;
}
}
throw new IllegalArgumentException("Unknown type: " + proto);
}

public HashJoinRel.JoinType toProto() {
return proto;
}
}

@Override
protected Type.Struct deriveRecordType() {
Stream<Type> leftTypes =
switch (getJoinType()) {
case RIGHT, OUTER -> getLeft().getRecordType().fields().stream()
.map(TypeCreator::asNullable);
case RIGHT_ANTI, RIGHT_SEMI -> Stream.empty();
default -> getLeft().getRecordType().fields().stream();
};
Stream<Type> rightTypes =
switch (getJoinType()) {
case LEFT, OUTER -> getRight().getRecordType().fields().stream()
.map(TypeCreator::asNullable);
case LEFT_ANTI, LEFT_SEMI -> Stream.empty();
default -> getRight().getRecordType().fields().stream();
};
return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes));
}

@Override
public <O, E extends Exception> O accept(RelVisitor<O, E> visitor) throws E {
return visitor.visit(this);
}

public static ImmutableHashJoin.Builder builder() {
return ImmutableHashJoin.builder();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@
import io.substrait.relation.Set;
import io.substrait.relation.Sort;
import io.substrait.relation.VirtualTableScan;
import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.utils.StringHolder;
import io.substrait.relation.utils.StringHolderHandlingProtoRelConverter;
import io.substrait.type.NamedStruct;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.junit.jupiter.api.Nested;
Expand Down Expand Up @@ -174,6 +176,26 @@ void join() {
verifyRoundTrip(rel);
}

@Test
void hashJoin() {
// with empty keys
List<Integer> leftEmptyKeys = Collections.emptyList();
List<Integer> rightEmptyKeys = Collections.emptyList();
Rel relWithoutKeys =
HashJoin.builder()
.from(
b.hashJoin(
leftEmptyKeys,
rightEmptyKeys,
HashJoin.JoinType.INNER,
commonTable,
commonTable))
.commonExtension(commonExtension)
.extension(relExtension)
.build();
verifyRoundTrip(relWithoutKeys);
}

@Test
void project() {
Rel rel =
Expand Down
Loading

0 comments on commit d6328a6

Please sign in to comment.