Skip to content

Commit

Permalink
Add Sql-invoked function support for column statistics
Browse files Browse the repository at this point in the history
Adds support in the ANALYZE infrastructure for additional
projections or constant arguments to aggregation functions.
This allows connectors to support a wider range of functionality
when generating metadata for column statistic collection
  • Loading branch information
ZacBlanco committed Jul 16, 2024
1 parent 3d25cd9 commit fcd0458
Show file tree
Hide file tree
Showing 13 changed files with 270 additions and 109 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
import static com.facebook.presto.iceberg.IcebergUtil.getIdentityPartitions;
import static com.facebook.presto.iceberg.Partition.toMap;
import static com.facebook.presto.iceberg.util.StatisticsUtil.calculateAndSetTableSize;
import static com.facebook.presto.iceberg.util.StatisticsUtil.formatIdentifier;
import static com.facebook.presto.spi.statistics.ColumnStatisticType.NUMBER_OF_DISTINCT_VALUES;
import static com.facebook.presto.spi.statistics.ColumnStatisticType.TOTAL_SIZE_IN_BYTES;
import static com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel.HIGH;
Expand Down Expand Up @@ -533,7 +534,8 @@ public static List<ColumnStatisticMetadata> getSupportedColumnStatistics(String
if (isNumericType(type) || type.equals(DATE) || isVarcharType(type) ||
type.equals(TIMESTAMP) ||
type.equals(TIMESTAMP_WITH_TIME_ZONE)) {
supportedStatistics.add(NUMBER_OF_DISTINCT_VALUES.getColumnStatisticMetadataWithCustomFunction(columnName, "sketch_theta"));
supportedStatistics.add(NUMBER_OF_DISTINCT_VALUES.getColumnStatisticMetadataWithCustomFunction(
columnName, format("RETURN sketch_theta(%s)", formatIdentifier(columnName)), ImmutableList.of(columnName)));
}

if (!(type instanceof FixedWidthType)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,9 @@ public static List<IcebergColumnHandle> combineSelectedAndPredicateColumns(List<
.iterator())
.build();
}

public static String formatIdentifier(String s)
{
return '"' + s.replace("\"", "\"\"") + '"';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,9 @@ private ComputedStatistics getComputedStatistics(Page page, int position)
descriptor.getTableStatistics().forEach((type, channel) ->
statistics.addTableStatistic(type, page.getBlock(channel).getSingleValueBlock(position)));

descriptor.getColumnStatistics().forEach((metadata, channel) -> statistics.addColumnStatistic(metadata, page.getBlock(channel).getSingleValueBlock(position)));
descriptor.getColumnStatistics().forEach((descriptor) ->
statistics.addColumnStatistic(descriptor.getMetadata(),
page.getBlock(descriptor.getItem()).getSingleValueBlock(position)));

return statistics.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ private ComputedStatistics getComputedStatistics(Page page, int position)
descriptor.getTableStatistics().forEach((type, channel) ->
statistics.addTableStatistic(type, page.getBlock(channel).getSingleValueBlock(position)));

descriptor.getColumnStatistics().forEach((metadata, channel) -> statistics.addColumnStatistic(metadata, page.getBlock(channel).getSingleValueBlock(position)));
descriptor.getColumnStatistics().forEach((descriptor) -> statistics.addColumnStatistic(descriptor.getMetadata(), page.getBlock(descriptor.getItem()).getSingleValueBlock(position)));

return statistics.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ public BasePlanFragmenter(
this.idAllocator = requireNonNull(idAllocator, "idAllocator is null");
this.variableAllocator = requireNonNull(variableAllocator, "variableAllocator is null");
this.outputTableWriterNodeIds = ImmutableSet.copyOf(requireNonNull(outputTableWriterNodeIds, "outputTableWriterNodeIds is null"));
this.statisticsAggregationPlanner = new StatisticsAggregationPlanner(variableAllocator, metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver());
this.statisticsAggregationPlanner = new StatisticsAggregationPlanner(variableAllocator, metadata.getFunctionAndTypeManager(), session);
}

public SubPlan buildRootFragment(PlanNode root, FragmentProperties properties)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Streams.zip;
import static java.util.Objects.requireNonNull;
import static java.util.function.Function.identity;

public class LogicalPlanner
{
Expand All @@ -131,7 +132,7 @@ public LogicalPlanner(
this.idAllocator = requireNonNull(idAllocator, "idAllocator is null");
this.metadata = requireNonNull(metadata, "metadata is null");
this.variableAllocator = requireNonNull(variableAllocator, "variableAllocator is null");
this.statisticsAggregationPlanner = new StatisticsAggregationPlanner(this.variableAllocator, metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver());
this.statisticsAggregationPlanner = new StatisticsAggregationPlanner(this.variableAllocator, metadata.getFunctionAndTypeManager(), session);
this.sqlParser = requireNonNull(sqlParser, "sqlParser is null");
}

Expand Down Expand Up @@ -224,22 +225,27 @@ private RelationPlan createAnalyzePlan(Analysis analysis, Analyze analyzeStateme
columnNameToVariable.put(column.getName(), variable);
}

List<VariableReferenceExpression> tableScanOutputs = tableScanOutputsBuilder.build();
TableStatisticsMetadata tableStatisticsMetadata = metadata.getStatisticsCollectionMetadata(
session,
targetTable.getConnectorId().getCatalogName(),
tableMetadata.getMetadata());

TableStatisticAggregation tableStatisticAggregation = statisticsAggregationPlanner.createStatisticsAggregation(tableStatisticsMetadata, columnNameToVariable.build());
StatisticAggregations statisticAggregations = tableStatisticAggregation.getAggregations();

List<VariableReferenceExpression> tableScanOutputs = tableScanOutputsBuilder.build();
Map<VariableReferenceExpression, RowExpression> assignments = ImmutableMap.<VariableReferenceExpression, RowExpression>builder()
.putAll(tableScanOutputs.stream().collect(toImmutableMap(identity(), identity())))
.putAll(tableStatisticAggregation.getAdditionalVariables())
.build();
TableScanNode scanNode = new TableScanNode(getSourceLocation(analyzeStatement), idAllocator.getNextId(), targetTable, tableScanOutputs, variableToColumnHandle.build(), TupleDomain.all(), TupleDomain.all());
PlanNode project = PlannerUtils.addProjections(scanNode, idAllocator, assignments);
PlanNode planNode = new StatisticsWriterNode(
getSourceLocation(analyzeStatement),
idAllocator.getNextId(),
new AggregationNode(
getSourceLocation(analyzeStatement),
idAllocator.getNextId(),
new TableScanNode(getSourceLocation(analyzeStatement), idAllocator.getNextId(), targetTable, tableScanOutputs, variableToColumnHandle.build(), TupleDomain.all(), TupleDomain.all()),
project,
statisticAggregations.getAggregations(),
singleGroupingSet(statisticAggregations.getGroupingVariables()),
ImmutableList.of(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@
*/
package com.facebook.presto.sql.planner;

import com.facebook.presto.Session;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.spi.function.FunctionMetadata;
import com.facebook.presto.spi.function.StandardFunctionResolution;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.spi.statistics.ColumnStatisticMetadata;
import com.facebook.presto.spi.statistics.ColumnStatisticType;
Expand All @@ -33,15 +36,19 @@
import com.facebook.presto.sql.relational.FunctionResolution;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.UnknownType.UNKNOWN;
import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED;
import static com.facebook.presto.spi.function.FunctionKind.AGGREGATE;
import static com.facebook.presto.spi.statistics.TableStatisticType.ROW_COUNT;
import static com.facebook.presto.sql.relational.SqlFunctionUtils.sqlFunctionToRowExpression;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;
Expand All @@ -51,11 +58,15 @@ public class StatisticsAggregationPlanner
{
private final VariableAllocator variableAllocator;
private final FunctionAndTypeResolver functionAndTypeResolver;
private final Session session;
private final FunctionAndTypeManager functionAndTypeManager;

public StatisticsAggregationPlanner(VariableAllocator variableAllocator, FunctionAndTypeResolver functionAndTypeResolver)
public StatisticsAggregationPlanner(VariableAllocator variableAllocator, FunctionAndTypeManager functionAndTypeManager, Session session)
{
this.variableAllocator = requireNonNull(variableAllocator, "variableAllocator is null");
this.functionAndTypeResolver = requireNonNull(functionAndTypeResolver, "functionAndTypeResolver is null");
this.session = requireNonNull(session, "session is null");
this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
this.functionAndTypeResolver = functionAndTypeManager.getFunctionAndTypeResolver();
}

public TableStatisticAggregation createStatisticsAggregation(TableStatisticsMetadata statisticsMetadata, Map<String, VariableReferenceExpression> columnToVariableMap)
Expand All @@ -70,6 +81,7 @@ public TableStatisticAggregation createStatisticsAggregation(TableStatisticsMeta
for (int i = 0; i < groupingVariables.size(); i++) {
descriptor.addGrouping(groupingColumns.get(i), groupingVariables.get(i));
}
ImmutableMap.Builder<VariableReferenceExpression, RowExpression> additionalVariables = ImmutableMap.builder();

ImmutableMap.Builder<VariableReferenceExpression, AggregationNode.Aggregation> aggregations = ImmutableMap.builder();
StandardFunctionResolution functionResolution = new FunctionResolution(functionAndTypeResolver);
Expand Down Expand Up @@ -98,18 +110,62 @@ public TableStatisticAggregation createStatisticsAggregation(TableStatisticsMeta
VariableReferenceExpression inputVariable = columnToVariableMap.get(columnName);
verify(inputVariable != null, "inputVariable is null");
ColumnStatisticsAggregation aggregation = createColumnAggregation(columnStatisticMetadata, inputVariable);
additionalVariables.putAll(aggregation.getInputProjections());
VariableReferenceExpression variable = variableAllocator.newVariable(statisticType + ":" + columnName, aggregation.getOutputType());
aggregations.put(variable, aggregation.getAggregation());
descriptor.addColumnStatistic(columnStatisticMetadata, variable);
}

StatisticAggregations aggregation = new StatisticAggregations(aggregations.build(), groupingVariables);
return new TableStatisticAggregation(aggregation, descriptor.build());
return new TableStatisticAggregation(aggregation, descriptor.build(), additionalVariables.build());
}

private ColumnStatisticsAggregation createColumnAggregation(ColumnStatisticMetadata columnStatisticMetadata, VariableReferenceExpression input)
private ColumnStatisticsAggregation createColumnAggregationFromSqlFunction(String sqlFunction, VariableReferenceExpression input)
{
FunctionHandle functionHandle = functionAndTypeResolver.lookupFunction(columnStatisticMetadata.getFunctionName(), TypeSignatureProvider.fromTypes(ImmutableList.of(input.getType())));
RowExpression expression = sqlFunctionToRowExpression(
sqlFunction,
ImmutableSet.of(input),
functionAndTypeManager,
session);
verify(expression instanceof CallExpression, "column statistic SQL expressions must represent a function call");
CallExpression call = (CallExpression) expression;
FunctionMetadata functionMeta = functionAndTypeResolver.getFunctionMetadata(call.getFunctionHandle());
verify(functionMeta.getFunctionKind().equals(AGGREGATE), "column statistic function must be aggregates");
// Aggregations input arguments are required to be variable reference expressions.
// For each one that isn't, allocate a new variable to reference
ImmutableMap.Builder<VariableReferenceExpression, RowExpression> inputProjections = ImmutableMap.builder();
List<RowExpression> callVariableArguments = call.getArguments()
.stream()
.map(argument -> {
if (argument instanceof VariableReferenceExpression) {
return argument;
}
VariableReferenceExpression newArgument = variableAllocator.newVariable(argument);
inputProjections.put(newArgument, argument);
return newArgument;
})
.collect(Collectors.toList());
CallExpression callWithVariables = new CallExpression(
call.getSourceLocation(),
call.getDisplayName(),
call.getFunctionHandle(),
call.getType(),
callVariableArguments);
return new ColumnStatisticsAggregation(
new AggregationNode.Aggregation(callWithVariables,
Optional.empty(),
Optional.empty(),
false,
Optional.empty()),
functionAndTypeResolver.getType(functionMeta.getReturnType()),
inputProjections.build());
}

private ColumnStatisticsAggregation createColumnAggregationFromFunctionName(ColumnStatisticMetadata columnStatisticMetadata, VariableReferenceExpression input)
{
FunctionHandle functionHandle = functionAndTypeResolver.lookupFunction(columnStatisticMetadata.getFunction(), TypeSignatureProvider.fromTypes(ImmutableList.<Type>builder()
.add(input.getType())
.build()));
FunctionMetadata functionMeta = functionAndTypeResolver.getFunctionMetadata(functionHandle);
Type inputType = functionAndTypeResolver.getType(getOnlyElement(functionMeta.getArgumentTypes()));
Type outputType = functionAndTypeResolver.getType(functionMeta.getReturnType());
Expand All @@ -118,28 +174,41 @@ private ColumnStatisticsAggregation createColumnAggregation(ColumnStatisticMetad
new AggregationNode.Aggregation(
new CallExpression(
input.getSourceLocation(),
columnStatisticMetadata.getFunctionName(),
columnStatisticMetadata.getFunction(),
functionHandle,
outputType,
ImmutableList.of(input)),
Optional.empty(),
Optional.empty(),
false,
Optional.empty()),
outputType);
outputType,
ImmutableMap.of());
}

private ColumnStatisticsAggregation createColumnAggregation(ColumnStatisticMetadata columnStatisticMetadata, VariableReferenceExpression input)
{
if (columnStatisticMetadata.isSqlExpression()) {
return createColumnAggregationFromSqlFunction(columnStatisticMetadata.getFunction(), input);
}

return createColumnAggregationFromFunctionName(columnStatisticMetadata, input);
}

public static class TableStatisticAggregation
{
private final StatisticAggregations aggregations;
private final StatisticAggregationsDescriptor<VariableReferenceExpression> descriptor;
private final Map<VariableReferenceExpression, RowExpression> additionalVariables;

private TableStatisticAggregation(
StatisticAggregations aggregations,
StatisticAggregationsDescriptor<VariableReferenceExpression> descriptor)
StatisticAggregationsDescriptor<VariableReferenceExpression> descriptor,
Map<VariableReferenceExpression, RowExpression> additionalVariables)
{
this.aggregations = requireNonNull(aggregations, "statisticAggregations is null");
this.descriptor = requireNonNull(descriptor, "descriptor is null");
this.additionalVariables = requireNonNull(additionalVariables, "additionalVariables is null");
}

public StatisticAggregations getAggregations()
Expand All @@ -151,17 +220,24 @@ public StatisticAggregationsDescriptor<VariableReferenceExpression> getDescripto
{
return descriptor;
}

public Map<VariableReferenceExpression, RowExpression> getAdditionalVariables()
{
return additionalVariables;
}
}

public static class ColumnStatisticsAggregation
{
private final AggregationNode.Aggregation aggregation;
private final Type outputType;
private final Map<VariableReferenceExpression, RowExpression> inputProjections;

private ColumnStatisticsAggregation(AggregationNode.Aggregation aggregation, Type outputType)
private ColumnStatisticsAggregation(AggregationNode.Aggregation aggregation, Type outputType, Map<VariableReferenceExpression, RowExpression> inputProjections)
{
this.aggregation = requireNonNull(aggregation, "aggregation is null");
this.outputType = requireNonNull(outputType, "outputType is null");
this.inputProjections = requireNonNull(inputProjections, "additionalVariable is null");
}

public AggregationNode.Aggregation getAggregation()
Expand All @@ -173,5 +249,10 @@ public Type getOutputType()
{
return outputType;
}

public Map<VariableReferenceExpression, RowExpression> getInputProjections()
{
return inputProjections;
}
}
}
Loading

0 comments on commit fcd0458

Please sign in to comment.