Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Join reordering stats11 speedup #605

Open
wants to merge 15 commits into
base: join-reordering-stats11
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.plan.PlanNode;

import java.util.HashMap;
import java.util.Map;

import static com.facebook.presto.cost.PlanNodeStatsEstimate.buildFrom;
import static com.facebook.presto.cost.SymbolStatsEstimate.UNKNOWN_STATS;
import static com.google.common.base.Predicates.not;

public class EnsureStatsMatchOutput
Expand All @@ -29,16 +30,16 @@ public class EnsureStatsMatchOutput
@Override
public PlanNodeStatsEstimate normalize(PlanNode node, PlanNodeStatsEstimate estimate, Map<Symbol, Type> types)
{
Map<Symbol, SymbolStatsEstimate> symbolSymbolStats = new HashMap<>();
estimate.getSymbolsWithKnownStatistics().stream()
.filter(node.getOutputSymbols()::contains)
.forEach(symbol -> symbolSymbolStats.put(symbol, estimate.getSymbolStatistics(symbol)));
PlanNodeStatsEstimate.Builder builder = buildFrom(estimate);

node.getOutputSymbols().stream()
.filter(not(estimate.getSymbolsWithKnownStatistics()::contains))
.filter(not(symbolSymbolStats::containsKey))
.forEach(symbol -> symbolSymbolStats.put(symbol, SymbolStatsEstimate.UNKNOWN_STATS));
.forEach(symbol -> builder.addSymbolStatistics(symbol, UNKNOWN_STATS));

estimate.getSymbolsWithKnownStatistics().stream()
.filter(not(node.getOutputSymbols()::contains))
.forEach(builder::removeSymbolStatistics);

return PlanNodeStatsEstimate.buildFrom(estimate).setSymbolStatistics(symbolSymbolStats).build();
return builder.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@
package com.facebook.presto.cost;

import com.facebook.presto.sql.planner.Symbol;
import com.google.common.collect.ImmutableMap;
import org.pcollections.HashTreePMap;
import org.pcollections.PMap;

import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
Expand All @@ -34,13 +33,13 @@ public class PlanNodeStatsEstimate
public static final double DEFAULT_DATA_SIZE_PER_COLUMN = 10;

private final double outputRowCount;
private final Map<Symbol, SymbolStatsEstimate> symbolStatistics;
private final PMap<Symbol, SymbolStatsEstimate> symbolStatistics;

private PlanNodeStatsEstimate(double outputRowCount, Map<Symbol, SymbolStatsEstimate> symbolStatistics)
private PlanNodeStatsEstimate(double outputRowCount, PMap<Symbol, SymbolStatsEstimate> symbolStatistics)
{
checkArgument(isNaN(outputRowCount) || outputRowCount >= 0, "outputRowCount cannot be negative");
this.outputRowCount = outputRowCount;
this.symbolStatistics = ImmutableMap.copyOf(symbolStatistics);
this.symbolStatistics = symbolStatistics;
}

/**
Expand Down Expand Up @@ -83,26 +82,15 @@ public PlanNodeStatsEstimate mapOutputRowCount(Function<Double, Double> mappingF
public PlanNodeStatsEstimate mapSymbolColumnStatistics(Symbol symbol, Function<SymbolStatsEstimate, SymbolStatsEstimate> mappingFunction)
{
return buildFrom(this)
.setSymbolStatistics(symbolStatistics.entrySet().stream()
.collect(Collectors.toMap(
Map.Entry::getKey,
e -> {
if (e.getKey().equals(symbol)) {
return mappingFunction.apply(e.getValue());
}
return e.getValue();
})))
.addSymbolStatistics(symbol, mappingFunction.apply(symbolStatistics.get(symbol)))
.build();
}

public PlanNodeStatsEstimate add(PlanNodeStatsEstimate other)
{
// TODO this is broken (it does not operate on symbol stats at all). Remove or fix
ImmutableMap.Builder<Symbol, SymbolStatsEstimate> symbolsStatsBuilder = ImmutableMap.builder();
symbolsStatsBuilder.putAll(symbolStatistics).putAll(other.symbolStatistics); // This may not count all information

PlanNodeStatsEstimate.Builder statsBuilder = PlanNodeStatsEstimate.builder();
return statsBuilder.setSymbolStatistics(symbolsStatsBuilder.build())
return buildFrom(this)
.addSymbolStatistics(other.symbolStatistics)
.setOutputRowCount(getOutputRowCount() + other.getOutputRowCount())
.build();
}
Expand Down Expand Up @@ -153,30 +141,46 @@ public static Builder builder()

public static Builder buildFrom(PlanNodeStatsEstimate other)
{
return builder().setOutputRowCount(other.getOutputRowCount())
.setSymbolStatistics(other.symbolStatistics);
return new Builder(other.getOutputRowCount(), other.symbolStatistics);
}

public static final class Builder
{
private double outputRowCount = NaN;
private Map<Symbol, SymbolStatsEstimate> symbolStatistics = new HashMap<>();
private double outputRowCount;
private PMap<Symbol, SymbolStatsEstimate> symbolStatistics;

public Builder()
{
this(NaN, HashTreePMap.empty());
}

private Builder(double outputRowCount, PMap<Symbol, SymbolStatsEstimate> symbolStatistics)
{
this.outputRowCount = outputRowCount;
this.symbolStatistics = symbolStatistics;
}

public Builder setOutputRowCount(double outputRowCount)
{
this.outputRowCount = outputRowCount;
return this;
}

public Builder setSymbolStatistics(Map<Symbol, SymbolStatsEstimate> symbolStatistics)
public Builder addSymbolStatistics(Symbol symbol, SymbolStatsEstimate statistics)
{
this.symbolStatistics = new HashMap<>(symbolStatistics);
symbolStatistics = symbolStatistics.plus(symbol, statistics);
return this;
}

public Builder addSymbolStatistics(Symbol symbol, SymbolStatsEstimate statistics)
public Builder addSymbolStatistics(Map<Symbol, SymbolStatsEstimate> symbolStatistics)
{
this.symbolStatistics = this.symbolStatistics.plusAll(symbolStatistics);
return this;
}

public Builder removeSymbolStatistics(Symbol symbol)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

having this method make me think that this whole Builder is not needed any more as you can just simply operate on SymbolStatsEstimate

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, but you need to construct initial PlanNodeStatsEstimate for which the builder seems to be useful

{
this.symbolStatistics.put(symbol, statistics);
symbolStatistics = symbolStatistics.minus(symbol);
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public Optional<PlanNodeStatsEstimate> calculate(PlanNode node, Lookup lookup, S

return Optional.of(PlanNodeStatsEstimate.builder()
.setOutputRowCount(tableStatistics.getRowCount().getValue())
.setSymbolStatistics(outputSymbolStats)
.addSymbolStatistics(outputSymbolStats)
.build());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@ public void testKeepsOutputSymbols()
.withStats(ImmutableMap.of(
new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder()
.setOutputRowCount(5000)
.setSymbolStatistics(ImmutableMap.of(
.addSymbolStatistics(ImmutableMap.of(
new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 100, 100),
new Symbol("A2"), new SymbolStatsEstimate(0, 100, 0, 100, 100)))
.build(),
new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder()
.setOutputRowCount(10000)
.setSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 100, 100)))
.addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 100, 100)))
.build()))
.matches(join(
INNER,
Expand All @@ -125,11 +125,11 @@ public void testReplicatesAndFlipsWhenOneTableMuchSmaller()
.withStats(ImmutableMap.of(
new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder()
.setOutputRowCount(100)
.setSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 6400, 100)))
.addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 6400, 100)))
.build(),
new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder()
.setOutputRowCount(10000)
.setSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100)))
.addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100)))
.build()))
.matches(join(
INNER,
Expand Down Expand Up @@ -157,11 +157,11 @@ public void testRepartitionsWhenRequiredBySession()
.withStats(ImmutableMap.of(
new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder()
.setOutputRowCount(100)
.setSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 6400, 100)))
.addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 6400, 100)))
.build(),
new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder()
.setOutputRowCount(10000)
.setSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100)))
.addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100)))
.build()))
.matches(join(
INNER,
Expand All @@ -188,11 +188,11 @@ public void testRepartitionsWhenBothTablesEqual()
.withStats(ImmutableMap.of(
new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder()
.setOutputRowCount(10000)
.setSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100)))
.addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100)))
.build(),
new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder()
.setOutputRowCount(10000)
.setSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100)))
.addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100)))
.build()))
.matches(join(
INNER,
Expand Down Expand Up @@ -220,11 +220,11 @@ public void testReplicatesWhenRequiredBySession()
.withStats(ImmutableMap.of(
new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder()
.setOutputRowCount(10000)
.setSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100)))
.addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100)))
.build(),
new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder()
.setOutputRowCount(10000)
.setSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100)))
.addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100)))
.build()))
.matches(join(
INNER,
Expand All @@ -251,11 +251,11 @@ public void testDoesNotFireForCrossJoin()
.withStats(ImmutableMap.of(
new PlanNodeId("valuesA"), PlanNodeStatsEstimate.builder()
.setOutputRowCount(10000)
.setSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100)))
.addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100)))
.build(),
new PlanNodeId("valuesB"), PlanNodeStatsEstimate.builder()
.setOutputRowCount(10000)
.setSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100)))
.addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000, 100)))
.build()))
.doesNotFire();
}
Expand Down Expand Up @@ -314,19 +314,19 @@ public void testPredicatesPushedDown()
new PlanNodeId("valuesA"),
PlanNodeStatsEstimate.builder()
.setOutputRowCount(10)
.setSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 100, 10)))
.addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 100, 10)))
.build(),
new PlanNodeId("valuesB"),
PlanNodeStatsEstimate.builder()
.setOutputRowCount(5)
.setSymbolStatistics(ImmutableMap.of(
.addSymbolStatistics(ImmutableMap.of(
new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 100, 10),
new Symbol("B2"), new SymbolStatsEstimate(0, 100, 0, 100, 10)))
.build(),
new PlanNodeId("valuesC"),
PlanNodeStatsEstimate.builder()
.setOutputRowCount(1000)
.setSymbolStatistics(ImmutableMap.of(new Symbol("C1"), new SymbolStatsEstimate(0, 100, 0, 100, 100)))
.addSymbolStatistics(ImmutableMap.of(new Symbol("C1"), new SymbolStatsEstimate(0, 100, 0, 100, 100)))
.build()))
.matches(
join(
Expand Down