From fcd04589ca6457bff331a8187d294c4a9a4775f2 Mon Sep 17 00:00:00 2001
From: Zac Blanco
Date: Mon, 8 Jul 2024 14:21:40 -0700
Subject: [PATCH] Add Sql-invoked function support for column statistics
Adds support in the ANALYZE infrastructure for additional
projections or constant arguments to aggregation functions.
This allows connectors to support a wider range of functionality
when generating metadata for column statistic collection
---
.../presto/iceberg/TableStatisticsMaker.java | 4 +-
.../presto/iceberg/util/StatisticsUtil.java | 5 +
.../operator/StatisticsWriterOperator.java | 4 +-
.../presto/operator/TableFinishOperator.java | 2 +-
.../sql/planner/BasePlanFragmenter.java | 2 +-
.../presto/sql/planner/LogicalPlanner.java | 14 ++-
.../planner/StatisticsAggregationPlanner.java | 99 +++++++++++++++--
.../plan/StatisticAggregationsDescriptor.java | 103 +++++++++++-------
.../sanity/ValidateDependenciesChecker.java | 3 +-
.../sql/relational/SqlFunctionUtils.java | 30 +++++
.../TestStatisticAggregationsDescriptor.java | 39 -------
.../statistics/ColumnStatisticMetadata.java | 64 +++++++++--
.../spi/statistics/ColumnStatisticType.java | 10 +-
13 files changed, 270 insertions(+), 109 deletions(-)
delete mode 100644 presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestStatisticAggregationsDescriptor.java
diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/TableStatisticsMaker.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/TableStatisticsMaker.java
index b998dc4dc1a72..9da3de696f620 100644
--- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/TableStatisticsMaker.java
+++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/TableStatisticsMaker.java
@@ -88,6 +88,7 @@
import static com.facebook.presto.iceberg.IcebergUtil.getIdentityPartitions;
import static com.facebook.presto.iceberg.Partition.toMap;
import static com.facebook.presto.iceberg.util.StatisticsUtil.calculateAndSetTableSize;
+import static com.facebook.presto.iceberg.util.StatisticsUtil.formatIdentifier;
import static com.facebook.presto.spi.statistics.ColumnStatisticType.NUMBER_OF_DISTINCT_VALUES;
import static com.facebook.presto.spi.statistics.ColumnStatisticType.TOTAL_SIZE_IN_BYTES;
import static com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel.HIGH;
@@ -533,7 +534,8 @@ public static List getSupportedColumnStatistics(String
if (isNumericType(type) || type.equals(DATE) || isVarcharType(type) ||
type.equals(TIMESTAMP) ||
type.equals(TIMESTAMP_WITH_TIME_ZONE)) {
- supportedStatistics.add(NUMBER_OF_DISTINCT_VALUES.getColumnStatisticMetadataWithCustomFunction(columnName, "sketch_theta"));
+ supportedStatistics.add(NUMBER_OF_DISTINCT_VALUES.getColumnStatisticMetadataWithCustomFunction(
+ columnName, format("RETURN sketch_theta(%s)", formatIdentifier(columnName)), ImmutableList.of(columnName)));
}
if (!(type instanceof FixedWidthType)) {
diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/util/StatisticsUtil.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/util/StatisticsUtil.java
index ff9cd814615ba..4200610239c16 100644
--- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/util/StatisticsUtil.java
+++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/util/StatisticsUtil.java
@@ -172,4 +172,9 @@ public static List combineSelectedAndPredicateColumns(List<
.iterator())
.build();
}
+
+ public static String formatIdentifier(String s)
+ {
+ return '"' + s.replace("\"", "\"\"") + '"';
+ }
}
diff --git a/presto-main/src/main/java/com/facebook/presto/operator/StatisticsWriterOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/StatisticsWriterOperator.java
index 20a901f138428..509f4a7628251 100644
--- a/presto-main/src/main/java/com/facebook/presto/operator/StatisticsWriterOperator.java
+++ b/presto-main/src/main/java/com/facebook/presto/operator/StatisticsWriterOperator.java
@@ -175,7 +175,9 @@ private ComputedStatistics getComputedStatistics(Page page, int position)
descriptor.getTableStatistics().forEach((type, channel) ->
statistics.addTableStatistic(type, page.getBlock(channel).getSingleValueBlock(position)));
- descriptor.getColumnStatistics().forEach((metadata, channel) -> statistics.addColumnStatistic(metadata, page.getBlock(channel).getSingleValueBlock(position)));
+ descriptor.getColumnStatistics().forEach((descriptor) ->
+ statistics.addColumnStatistic(descriptor.getMetadata(),
+ page.getBlock(descriptor.getItem()).getSingleValueBlock(position)));
return statistics.build();
}
diff --git a/presto-main/src/main/java/com/facebook/presto/operator/TableFinishOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/TableFinishOperator.java
index 52b825d496167..bcdab3dde9fe3 100644
--- a/presto-main/src/main/java/com/facebook/presto/operator/TableFinishOperator.java
+++ b/presto-main/src/main/java/com/facebook/presto/operator/TableFinishOperator.java
@@ -311,7 +311,7 @@ private ComputedStatistics getComputedStatistics(Page page, int position)
descriptor.getTableStatistics().forEach((type, channel) ->
statistics.addTableStatistic(type, page.getBlock(channel).getSingleValueBlock(position)));
- descriptor.getColumnStatistics().forEach((metadata, channel) -> statistics.addColumnStatistic(metadata, page.getBlock(channel).getSingleValueBlock(position)));
+ descriptor.getColumnStatistics().forEach((descriptor) -> statistics.addColumnStatistic(descriptor.getMetadata(), page.getBlock(descriptor.getItem()).getSingleValueBlock(position)));
return statistics.build();
}
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java
index de30f542a007b..aa134e91d03a5 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java
@@ -123,7 +123,7 @@ public BasePlanFragmenter(
this.idAllocator = requireNonNull(idAllocator, "idAllocator is null");
this.variableAllocator = requireNonNull(variableAllocator, "variableAllocator is null");
this.outputTableWriterNodeIds = ImmutableSet.copyOf(requireNonNull(outputTableWriterNodeIds, "outputTableWriterNodeIds is null"));
- this.statisticsAggregationPlanner = new StatisticsAggregationPlanner(variableAllocator, metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver());
+ this.statisticsAggregationPlanner = new StatisticsAggregationPlanner(variableAllocator, metadata.getFunctionAndTypeManager(), session);
}
public SubPlan buildRootFragment(PlanNode root, FragmentProperties properties)
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java
index 1763247be121e..d43b439de9ff9 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java
@@ -110,6 +110,7 @@
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Streams.zip;
import static java.util.Objects.requireNonNull;
+import static java.util.function.Function.identity;
public class LogicalPlanner
{
@@ -131,7 +132,7 @@ public LogicalPlanner(
this.idAllocator = requireNonNull(idAllocator, "idAllocator is null");
this.metadata = requireNonNull(metadata, "metadata is null");
this.variableAllocator = requireNonNull(variableAllocator, "variableAllocator is null");
- this.statisticsAggregationPlanner = new StatisticsAggregationPlanner(this.variableAllocator, metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver());
+ this.statisticsAggregationPlanner = new StatisticsAggregationPlanner(this.variableAllocator, metadata.getFunctionAndTypeManager(), session);
this.sqlParser = requireNonNull(sqlParser, "sqlParser is null");
}
@@ -224,7 +225,6 @@ private RelationPlan createAnalyzePlan(Analysis analysis, Analyze analyzeStateme
columnNameToVariable.put(column.getName(), variable);
}
- List tableScanOutputs = tableScanOutputsBuilder.build();
TableStatisticsMetadata tableStatisticsMetadata = metadata.getStatisticsCollectionMetadata(
session,
targetTable.getConnectorId().getCatalogName(),
@@ -232,14 +232,20 @@ private RelationPlan createAnalyzePlan(Analysis analysis, Analyze analyzeStateme
TableStatisticAggregation tableStatisticAggregation = statisticsAggregationPlanner.createStatisticsAggregation(tableStatisticsMetadata, columnNameToVariable.build());
StatisticAggregations statisticAggregations = tableStatisticAggregation.getAggregations();
-
+ List tableScanOutputs = tableScanOutputsBuilder.build();
+ Map assignments = ImmutableMap.builder()
+ .putAll(tableScanOutputs.stream().collect(toImmutableMap(identity(), identity())))
+ .putAll(tableStatisticAggregation.getAdditionalVariables())
+ .build();
+ TableScanNode scanNode = new TableScanNode(getSourceLocation(analyzeStatement), idAllocator.getNextId(), targetTable, tableScanOutputs, variableToColumnHandle.build(), TupleDomain.all(), TupleDomain.all());
+ PlanNode project = PlannerUtils.addProjections(scanNode, idAllocator, assignments);
PlanNode planNode = new StatisticsWriterNode(
getSourceLocation(analyzeStatement),
idAllocator.getNextId(),
new AggregationNode(
getSourceLocation(analyzeStatement),
idAllocator.getNextId(),
- new TableScanNode(getSourceLocation(analyzeStatement), idAllocator.getNextId(), targetTable, tableScanOutputs, variableToColumnHandle.build(), TupleDomain.all(), TupleDomain.all()),
+ project,
statisticAggregations.getAggregations(),
singleGroupingSet(statisticAggregations.getGroupingVariables()),
ImmutableList.of(),
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/StatisticsAggregationPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/StatisticsAggregationPlanner.java
index 0f17eaf78d2d3..f9722664ecebb 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/planner/StatisticsAggregationPlanner.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/StatisticsAggregationPlanner.java
@@ -13,7 +13,9 @@
*/
package com.facebook.presto.sql.planner;
+import com.facebook.presto.Session;
import com.facebook.presto.common.type.Type;
+import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.function.FunctionHandle;
@@ -21,6 +23,7 @@
import com.facebook.presto.spi.function.StandardFunctionResolution;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.relation.CallExpression;
+import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.spi.statistics.ColumnStatisticMetadata;
import com.facebook.presto.spi.statistics.ColumnStatisticType;
@@ -33,15 +36,19 @@
import com.facebook.presto.sql.relational.FunctionResolution;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.stream.Collectors;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.UnknownType.UNKNOWN;
import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED;
+import static com.facebook.presto.spi.function.FunctionKind.AGGREGATE;
import static com.facebook.presto.spi.statistics.TableStatisticType.ROW_COUNT;
+import static com.facebook.presto.sql.relational.SqlFunctionUtils.sqlFunctionToRowExpression;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;
@@ -51,11 +58,15 @@ public class StatisticsAggregationPlanner
{
private final VariableAllocator variableAllocator;
private final FunctionAndTypeResolver functionAndTypeResolver;
+ private final Session session;
+ private final FunctionAndTypeManager functionAndTypeManager;
- public StatisticsAggregationPlanner(VariableAllocator variableAllocator, FunctionAndTypeResolver functionAndTypeResolver)
+ public StatisticsAggregationPlanner(VariableAllocator variableAllocator, FunctionAndTypeManager functionAndTypeManager, Session session)
{
this.variableAllocator = requireNonNull(variableAllocator, "variableAllocator is null");
- this.functionAndTypeResolver = requireNonNull(functionAndTypeResolver, "functionAndTypeResolver is null");
+ this.session = requireNonNull(session, "session is null");
+ this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
+ this.functionAndTypeResolver = functionAndTypeManager.getFunctionAndTypeResolver();
}
public TableStatisticAggregation createStatisticsAggregation(TableStatisticsMetadata statisticsMetadata, Map columnToVariableMap)
@@ -70,6 +81,7 @@ public TableStatisticAggregation createStatisticsAggregation(TableStatisticsMeta
for (int i = 0; i < groupingVariables.size(); i++) {
descriptor.addGrouping(groupingColumns.get(i), groupingVariables.get(i));
}
+ ImmutableMap.Builder additionalVariables = ImmutableMap.builder();
ImmutableMap.Builder aggregations = ImmutableMap.builder();
StandardFunctionResolution functionResolution = new FunctionResolution(functionAndTypeResolver);
@@ -98,18 +110,62 @@ public TableStatisticAggregation createStatisticsAggregation(TableStatisticsMeta
VariableReferenceExpression inputVariable = columnToVariableMap.get(columnName);
verify(inputVariable != null, "inputVariable is null");
ColumnStatisticsAggregation aggregation = createColumnAggregation(columnStatisticMetadata, inputVariable);
+ additionalVariables.putAll(aggregation.getInputProjections());
VariableReferenceExpression variable = variableAllocator.newVariable(statisticType + ":" + columnName, aggregation.getOutputType());
aggregations.put(variable, aggregation.getAggregation());
descriptor.addColumnStatistic(columnStatisticMetadata, variable);
}
StatisticAggregations aggregation = new StatisticAggregations(aggregations.build(), groupingVariables);
- return new TableStatisticAggregation(aggregation, descriptor.build());
+ return new TableStatisticAggregation(aggregation, descriptor.build(), additionalVariables.build());
}
- private ColumnStatisticsAggregation createColumnAggregation(ColumnStatisticMetadata columnStatisticMetadata, VariableReferenceExpression input)
+ private ColumnStatisticsAggregation createColumnAggregationFromSqlFunction(String sqlFunction, VariableReferenceExpression input)
{
- FunctionHandle functionHandle = functionAndTypeResolver.lookupFunction(columnStatisticMetadata.getFunctionName(), TypeSignatureProvider.fromTypes(ImmutableList.of(input.getType())));
+ RowExpression expression = sqlFunctionToRowExpression(
+ sqlFunction,
+ ImmutableSet.of(input),
+ functionAndTypeManager,
+ session);
+ verify(expression instanceof CallExpression, "column statistic SQL expressions must represent a function call");
+ CallExpression call = (CallExpression) expression;
+ FunctionMetadata functionMeta = functionAndTypeResolver.getFunctionMetadata(call.getFunctionHandle());
+ verify(functionMeta.getFunctionKind().equals(AGGREGATE), "column statistic function must be aggregates");
+ // Aggregations input arguments are required to be variable reference expressions.
+ // For each one that isn't, allocate a new variable to reference
+ ImmutableMap.Builder inputProjections = ImmutableMap.builder();
+ List callVariableArguments = call.getArguments()
+ .stream()
+ .map(argument -> {
+ if (argument instanceof VariableReferenceExpression) {
+ return argument;
+ }
+ VariableReferenceExpression newArgument = variableAllocator.newVariable(argument);
+ inputProjections.put(newArgument, argument);
+ return newArgument;
+ })
+ .collect(Collectors.toList());
+ CallExpression callWithVariables = new CallExpression(
+ call.getSourceLocation(),
+ call.getDisplayName(),
+ call.getFunctionHandle(),
+ call.getType(),
+ callVariableArguments);
+ return new ColumnStatisticsAggregation(
+ new AggregationNode.Aggregation(callWithVariables,
+ Optional.empty(),
+ Optional.empty(),
+ false,
+ Optional.empty()),
+ functionAndTypeResolver.getType(functionMeta.getReturnType()),
+ inputProjections.build());
+ }
+
+ private ColumnStatisticsAggregation createColumnAggregationFromFunctionName(ColumnStatisticMetadata columnStatisticMetadata, VariableReferenceExpression input)
+ {
+ FunctionHandle functionHandle = functionAndTypeResolver.lookupFunction(columnStatisticMetadata.getFunction(), TypeSignatureProvider.fromTypes(ImmutableList.builder()
+ .add(input.getType())
+ .build()));
FunctionMetadata functionMeta = functionAndTypeResolver.getFunctionMetadata(functionHandle);
Type inputType = functionAndTypeResolver.getType(getOnlyElement(functionMeta.getArgumentTypes()));
Type outputType = functionAndTypeResolver.getType(functionMeta.getReturnType());
@@ -118,7 +174,7 @@ private ColumnStatisticsAggregation createColumnAggregation(ColumnStatisticMetad
new AggregationNode.Aggregation(
new CallExpression(
input.getSourceLocation(),
- columnStatisticMetadata.getFunctionName(),
+ columnStatisticMetadata.getFunction(),
functionHandle,
outputType,
ImmutableList.of(input)),
@@ -126,20 +182,33 @@ private ColumnStatisticsAggregation createColumnAggregation(ColumnStatisticMetad
Optional.empty(),
false,
Optional.empty()),
- outputType);
+ outputType,
+ ImmutableMap.of());
+ }
+
+ private ColumnStatisticsAggregation createColumnAggregation(ColumnStatisticMetadata columnStatisticMetadata, VariableReferenceExpression input)
+ {
+ if (columnStatisticMetadata.isSqlExpression()) {
+ return createColumnAggregationFromSqlFunction(columnStatisticMetadata.getFunction(), input);
+ }
+
+ return createColumnAggregationFromFunctionName(columnStatisticMetadata, input);
}
public static class TableStatisticAggregation
{
private final StatisticAggregations aggregations;
private final StatisticAggregationsDescriptor descriptor;
+ private final Map additionalVariables;
private TableStatisticAggregation(
StatisticAggregations aggregations,
- StatisticAggregationsDescriptor descriptor)
+ StatisticAggregationsDescriptor descriptor,
+ Map additionalVariables)
{
this.aggregations = requireNonNull(aggregations, "statisticAggregations is null");
this.descriptor = requireNonNull(descriptor, "descriptor is null");
+ this.additionalVariables = requireNonNull(additionalVariables, "additionalVariables is null");
}
public StatisticAggregations getAggregations()
@@ -151,17 +220,24 @@ public StatisticAggregationsDescriptor getDescripto
{
return descriptor;
}
+
+ public Map getAdditionalVariables()
+ {
+ return additionalVariables;
+ }
}
public static class ColumnStatisticsAggregation
{
private final AggregationNode.Aggregation aggregation;
private final Type outputType;
+ private final Map inputProjections;
- private ColumnStatisticsAggregation(AggregationNode.Aggregation aggregation, Type outputType)
+ private ColumnStatisticsAggregation(AggregationNode.Aggregation aggregation, Type outputType, Map inputProjections)
{
this.aggregation = requireNonNull(aggregation, "aggregation is null");
this.outputType = requireNonNull(outputType, "outputType is null");
+ this.inputProjections = requireNonNull(inputProjections, "additionalVariable is null");
}
public AggregationNode.Aggregation getAggregation()
@@ -173,5 +249,10 @@ public Type getOutputType()
{
return outputType;
}
+
+ public Map getInputProjections()
+ {
+ return inputProjections;
+ }
}
}
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/StatisticAggregationsDescriptor.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/StatisticAggregationsDescriptor.java
index 95878fb4310f6..7f7db1692d0a1 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/StatisticAggregationsDescriptor.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/StatisticAggregationsDescriptor.java
@@ -14,28 +14,18 @@
package com.facebook.presto.sql.planner.plan;
import com.facebook.presto.spi.statistics.ColumnStatisticMetadata;
-import com.facebook.presto.spi.statistics.ColumnStatisticType;
import com.facebook.presto.spi.statistics.TableStatisticType;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
-import com.fasterxml.jackson.core.JsonGenerator;
-import com.fasterxml.jackson.databind.DeserializationContext;
-import com.fasterxml.jackson.databind.JsonSerializer;
-import com.fasterxml.jackson.databind.KeyDeserializer;
-import com.fasterxml.jackson.databind.SerializerProvider;
-import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
-import com.fasterxml.jackson.databind.annotation.JsonSerialize;
-import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
-import java.io.IOException;
+import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import static com.google.common.base.MoreObjects.toStringHelper;
-import static com.google.common.base.Preconditions.checkArgument;
-import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static java.util.Objects.requireNonNull;
@@ -43,7 +33,7 @@ public class StatisticAggregationsDescriptor
{
private final Map grouping;
private final Map tableStatistics;
- private final Map columnStatistics;
+ private final List> columnStatistics;
public static StatisticAggregationsDescriptor empty()
{
@@ -54,11 +44,22 @@ public static StatisticAggregationsDescriptor empty()
public StatisticAggregationsDescriptor(
@JsonProperty("grouping") Map grouping,
@JsonProperty("tableStatistics") Map tableStatistics,
- @JsonProperty("columnStatistics") Map columnStatistics)
+ @JsonProperty("columnStatistics") List> columnStatistics)
{
this.grouping = ImmutableMap.copyOf(requireNonNull(grouping, "grouping is null"));
this.tableStatistics = ImmutableMap.copyOf(requireNonNull(tableStatistics, "tableStatistics is null"));
- this.columnStatistics = ImmutableMap.copyOf(requireNonNull(columnStatistics, "columnStatistics is null"));
+ this.columnStatistics = requireNonNull(columnStatistics, "columnStatistics is null");
+ }
+
+ public StatisticAggregationsDescriptor(
+ Map grouping,
+ Map tableStatistics,
+ Map columnStatistics)
+ {
+ this(grouping, tableStatistics, ImmutableList.>builder()
+ .addAll(columnStatistics.entrySet().stream()
+ .map(e -> new ColumnStatisticsDescriptor<>(e.getKey(), e.getValue())).iterator())
+ .build());
}
@JsonProperty
@@ -74,9 +75,7 @@ public Map getTableStatistics()
}
@JsonProperty
- @JsonSerialize(keyUsing = ColumnStatisticMetadataKeySerializer.class)
- @JsonDeserialize(keyUsing = ColumnStatisticMetadataKeyDeserializer.class)
- public Map getColumnStatistics()
+ public List> getColumnStatistics()
{
return columnStatistics;
}
@@ -122,7 +121,11 @@ public StatisticAggregationsDescriptor map(Function mapper)
return new StatisticAggregationsDescriptor<>(
map(this.getGrouping(), mapper),
map(this.getTableStatistics(), mapper),
- map(this.getColumnStatistics(), mapper));
+ map(this.getColumnStatistics().stream()
+ .collect(toImmutableMap(
+ ColumnStatisticsDescriptor::getMetadata,
+ ColumnStatisticsDescriptor::getItem)),
+ mapper));
}
private static Map map(Map input, Function mapper)
@@ -159,39 +162,59 @@ public StatisticAggregationsDescriptor build()
}
}
- public static class ColumnStatisticMetadataKeySerializer
- extends JsonSerializer
+ public static class ColumnStatisticsDescriptor
{
- @Override
- public void serialize(ColumnStatisticMetadata value, JsonGenerator gen, SerializerProvider serializers)
- throws IOException
+ private final ColumnStatisticMetadata metadata;
+ private final T item;
+
+ @JsonCreator
+ public ColumnStatisticsDescriptor(@JsonProperty("metadata") ColumnStatisticMetadata metadata, @JsonProperty("item") T item)
{
- verify(value != null, "value is null");
- gen.writeFieldName(serialize(value));
+ this.metadata = requireNonNull(metadata, "metadata is null");
+ this.item = requireNonNull(item, "item is null");
}
- @VisibleForTesting
- static String serialize(ColumnStatisticMetadata value)
+ @JsonProperty
+ public T getItem()
{
- return value.getStatisticType().name() + ":" + value.getFunctionName() + ":" + value.getColumnName();
+ return item;
+ }
+
+ @JsonProperty
+ public ColumnStatisticMetadata getMetadata()
+ {
+ return metadata;
}
- }
- public static class ColumnStatisticMetadataKeyDeserializer
- extends KeyDeserializer
- {
@Override
- public ColumnStatisticMetadata deserializeKey(String key, DeserializationContext ctx)
+ public int hashCode()
{
- return deserialize(requireNonNull(key, "key is null"));
+ return Objects.hash(metadata, item);
}
- @VisibleForTesting
- static ColumnStatisticMetadata deserialize(String value)
+ @Override
+ public boolean equals(Object o)
+ {
+ if (this == o) {
+ return true;
+ }
+
+ if (!(o instanceof ColumnStatisticsDescriptor)) {
+ return false;
+ }
+
+ ColumnStatisticsDescriptor> other = (ColumnStatisticsDescriptor>) o;
+ return metadata.equals(other.metadata) &&
+ item.equals(other.item);
+ }
+
+ @Override
+ public String toString()
{
- String[] values = value.split(":", 3);
- checkArgument(values.length == 3, "separator(s) not found: %s", value);
- return new ColumnStatisticMetadata(values[2], ColumnStatisticType.valueOf(values[0]), values[1]);
+ return toStringHelper(this)
+ .add("metadata", metadata)
+ .add("item", item)
+ .toString();
}
}
}
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java
index 7ffc277dfde93..cda1acb7658f1 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java
@@ -65,6 +65,7 @@
import com.facebook.presto.sql.planner.plan.SequenceNode;
import com.facebook.presto.sql.planner.plan.SpatialJoinNode;
import com.facebook.presto.sql.planner.plan.StatisticAggregationsDescriptor;
+import com.facebook.presto.sql.planner.plan.StatisticAggregationsDescriptor.ColumnStatisticsDescriptor;
import com.facebook.presto.sql.planner.plan.StatisticsWriterNode;
import com.facebook.presto.sql.planner.plan.TableFinishNode;
import com.facebook.presto.sql.planner.plan.TableWriterMergeNode;
@@ -656,7 +657,7 @@ public Void visitStatisticsWriterNode(StatisticsWriterNode node, Set descriptor = node.getDescriptor();
Set dependencies = ImmutableSet.builder()
.addAll(descriptor.getGrouping().values())
- .addAll(descriptor.getColumnStatistics().values())
+ .addAll(descriptor.getColumnStatistics().stream().map(ColumnStatisticsDescriptor::getItem).iterator())
.addAll(descriptor.getTableStatistics().values())
.build();
List outputVariables = node.getSource().getOutputVariables();
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlFunctionUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlFunctionUtils.java
index d31a9c7c3425f..fdbcca5d44bbe 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlFunctionUtils.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlFunctionUtils.java
@@ -13,6 +13,7 @@
*/
package com.facebook.presto.sql.relational;
+import com.facebook.presto.Session;
import com.facebook.presto.common.function.SqlFunctionProperties;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.expressions.RowExpressionRewriter;
@@ -45,6 +46,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.Set;
import static com.facebook.presto.spi.function.FunctionImplementationType.SQL;
import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.analyzeSqlFunctionExpression;
@@ -113,6 +115,34 @@ public static RowExpression getSqlFunctionRowExpression(
argumentVariables);
}
+ public static RowExpression sqlFunctionToRowExpression(String functionBody,
+ Set variables,
+ FunctionAndTypeManager functionAndTypeManager,
+ Session session)
+ {
+ Expression expression = parseSqlFunctionExpression(
+ new SqlInvokedScalarFunctionImplementation(functionBody),
+ session.getSqlFunctionProperties());
+ return SqlToRowExpressionTranslator.translate(
+ expression,
+ analyzeSqlFunctionExpression(
+ functionAndTypeManager.getFunctionAndTypeResolver(),
+ session.getSqlFunctionProperties(),
+ expression,
+ variables.stream()
+ .collect(toImmutableMap(
+ VariableReferenceExpression::getName,
+ VariableReferenceExpression::getType)))
+ .getExpressionTypes(),
+ ImmutableMap.of(),
+ functionAndTypeManager,
+ Optional.empty(),
+ Optional.empty(),
+ session.getSqlFunctionProperties(),
+ session.getSessionFunctions(),
+ new SqlToRowExpressionTranslator.Context());
+ }
+
private static Expression getSqlFunctionImplementationExpression(
FunctionMetadata functionMetadata,
SqlInvokedScalarFunctionImplementation implementation,
diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestStatisticAggregationsDescriptor.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestStatisticAggregationsDescriptor.java
deleted file mode 100644
index fe4081cb64252..0000000000000
--- a/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestStatisticAggregationsDescriptor.java
+++ /dev/null
@@ -1,39 +0,0 @@
-/*
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package com.facebook.presto.sql.planner.plan;
-
-import com.facebook.presto.spi.statistics.ColumnStatisticMetadata;
-import com.facebook.presto.spi.statistics.ColumnStatisticType;
-import com.google.common.collect.ImmutableList;
-import org.testng.annotations.Test;
-
-import static com.facebook.presto.sql.planner.plan.StatisticAggregationsDescriptor.ColumnStatisticMetadataKeyDeserializer.deserialize;
-import static com.facebook.presto.sql.planner.plan.StatisticAggregationsDescriptor.ColumnStatisticMetadataKeySerializer.serialize;
-import static com.facebook.presto.testing.assertions.Assert.assertEquals;
-
-public class TestStatisticAggregationsDescriptor
-{
- private static final ImmutableList COLUMNS = ImmutableList.of("", "col1", "$:###:;", "abc+dddd___");
-
- @Test
- public void testColumnStatisticMetadataKeySerializationRoundTrip()
- {
- for (String column : COLUMNS) {
- for (ColumnStatisticType type : ColumnStatisticType.values()) {
- ColumnStatisticMetadata expected = type.getColumnStatisticMetadata(column);
- assertEquals(deserialize(serialize(expected)), expected);
- }
- }
- }
-}
diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatisticMetadata.java b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatisticMetadata.java
index ff32b0167ed62..474e6d00de1b4 100644
--- a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatisticMetadata.java
+++ b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatisticMetadata.java
@@ -16,26 +16,56 @@
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
+import java.util.List;
import java.util.Objects;
import static java.util.Objects.requireNonNull;
+/**
+ *
+ * Represents a column statistic that should be computed during an {@code ANALYZE} query. The
+ * field {@code isSqlExpression} denotes whether the value of the {@code function} field
+ * represents a simple function name, or a complex SQL expression.
+ *
+ *
+ * In the case of a SQL expression, it should parse to a valid function call that may use
+ * constant-value arguments or references to specific column names in the arguments. Column
+ * references must be properly quoted, as the string should be valid SQL. The SQL itself parses
+ * as a function, so it should always begin with a {@code RETURN} keyword. Any references
+ * columns used in the SQL expression as function arguments should be added to the
+ * {@code columnArguments} field.
+ *
+ *
+ * Example: Suppose you want to compute the column statistic is of a t-digest with a
+ * configurable weight and compression for a column named "x". The {@code function} should be:
+ *
+ * {@code "RETURN tdigest_agg("x", 1, 200)"}
+ *
+ * Additionally, the {@code columnArguments} field should be populated with {@code ["x"]}
+ *
+ *
+ */
public class ColumnStatisticMetadata
{
private final String columnName;
private final ColumnStatisticType statisticType;
-
- private final String functionName;
+ private final String function;
+ private final List columnArguments;
+ private final boolean isSqlExpression;
@JsonCreator
public ColumnStatisticMetadata(
@JsonProperty("columnName") String columnName,
@JsonProperty("statisticType") ColumnStatisticType statisticType,
- @JsonProperty("functionName") String functionName)
+ @JsonProperty("function") String function,
+ @JsonProperty("columnArguments") List columnArguments,
+ @JsonProperty("isSqlExpression") boolean isSqlExpression)
{
this.columnName = requireNonNull(columnName, "columnName is null");
this.statisticType = requireNonNull(statisticType, "statisticType is null");
- this.functionName = requireNonNull(functionName, "functionName is null");
+ this.function = requireNonNull(function, "functionName is null");
+ this.columnArguments = requireNonNull(columnArguments, "additionalArguments is null");
+ this.isSqlExpression = isSqlExpression;
}
@JsonProperty
@@ -51,9 +81,21 @@ public ColumnStatisticType getStatisticType()
}
@JsonProperty
- public String getFunctionName()
+ public String getFunction()
+ {
+ return function;
+ }
+
+ @JsonProperty
+ public List getColumnArguments()
+ {
+ return columnArguments;
+ }
+
+ @JsonProperty
+ public boolean isSqlExpression()
{
- return functionName;
+ return isSqlExpression;
}
@Override
@@ -68,13 +110,15 @@ public boolean equals(Object o)
ColumnStatisticMetadata that = (ColumnStatisticMetadata) o;
return Objects.equals(columnName, that.columnName) &&
statisticType == that.statisticType &&
- Objects.equals(functionName, that.functionName);
+ Objects.equals(function, that.function) &&
+ Objects.equals(columnArguments, that.columnArguments) &&
+ Objects.equals(isSqlExpression, that.isSqlExpression);
}
@Override
public int hashCode()
{
- return Objects.hash(columnName, statisticType, functionName);
+ return Objects.hash(columnName, statisticType, function, columnArguments, isSqlExpression);
}
@Override
@@ -83,7 +127,9 @@ public String toString()
return "ColumnStatisticMetadata{" +
"columnName='" + columnName + '\'' +
", statisticType=" + statisticType +
- ", functionName=" + functionName +
+ ", function=" + function +
+ ", columnArguments=" + columnArguments +
+ ", isSqlExpression=" + isSqlExpression +
'}';
}
}
diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatisticType.java b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatisticType.java
index 20b9261e7bc25..c196076d15f32 100644
--- a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatisticType.java
+++ b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatisticType.java
@@ -13,6 +13,10 @@
*/
package com.facebook.presto.spi.statistics;
+import java.util.List;
+
+import static java.util.Collections.emptyList;
+
public enum ColumnStatisticType
{
MAX_VALUE("max"),
@@ -32,11 +36,11 @@ public enum ColumnStatisticType
public ColumnStatisticMetadata getColumnStatisticMetadata(String columnName)
{
- return new ColumnStatisticMetadata(columnName, this, this.functionName);
+ return new ColumnStatisticMetadata(columnName, this, this.functionName, false, emptyList());
}
- public ColumnStatisticMetadata getColumnStatisticMetadataWithCustomFunction(String columnName, String functionName)
+ public ColumnStatisticMetadata getColumnStatisticMetadataWithCustomFunction(String columnName, String functionSql, List columnArguments)
{
- return new ColumnStatisticMetadata(columnName, this, functionName);
+ return new ColumnStatisticMetadata(columnName, this, functionSql, true, columnArguments);
}
}