From fcd04589ca6457bff331a8187d294c4a9a4775f2 Mon Sep 17 00:00:00 2001 From: Zac Blanco Date: Mon, 8 Jul 2024 14:21:40 -0700 Subject: [PATCH] Add Sql-invoked function support for column statistics Adds support in the ANALYZE infrastructure for additional projections or constant arguments to aggregation functions. This allows connectors to support a wider range of functionality when generating metadata for column statistic collection --- .../presto/iceberg/TableStatisticsMaker.java | 4 +- .../presto/iceberg/util/StatisticsUtil.java | 5 + .../operator/StatisticsWriterOperator.java | 4 +- .../presto/operator/TableFinishOperator.java | 2 +- .../sql/planner/BasePlanFragmenter.java | 2 +- .../presto/sql/planner/LogicalPlanner.java | 14 ++- .../planner/StatisticsAggregationPlanner.java | 99 +++++++++++++++-- .../plan/StatisticAggregationsDescriptor.java | 103 +++++++++++------- .../sanity/ValidateDependenciesChecker.java | 3 +- .../sql/relational/SqlFunctionUtils.java | 30 +++++ .../TestStatisticAggregationsDescriptor.java | 39 ------- .../statistics/ColumnStatisticMetadata.java | 64 +++++++++-- .../spi/statistics/ColumnStatisticType.java | 10 +- 13 files changed, 270 insertions(+), 109 deletions(-) delete mode 100644 presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestStatisticAggregationsDescriptor.java diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/TableStatisticsMaker.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/TableStatisticsMaker.java index b998dc4dc1a72..9da3de696f620 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/TableStatisticsMaker.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/TableStatisticsMaker.java @@ -88,6 +88,7 @@ import static com.facebook.presto.iceberg.IcebergUtil.getIdentityPartitions; import static com.facebook.presto.iceberg.Partition.toMap; import static com.facebook.presto.iceberg.util.StatisticsUtil.calculateAndSetTableSize; +import static com.facebook.presto.iceberg.util.StatisticsUtil.formatIdentifier; import static com.facebook.presto.spi.statistics.ColumnStatisticType.NUMBER_OF_DISTINCT_VALUES; import static com.facebook.presto.spi.statistics.ColumnStatisticType.TOTAL_SIZE_IN_BYTES; import static com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel.HIGH; @@ -533,7 +534,8 @@ public static List getSupportedColumnStatistics(String if (isNumericType(type) || type.equals(DATE) || isVarcharType(type) || type.equals(TIMESTAMP) || type.equals(TIMESTAMP_WITH_TIME_ZONE)) { - supportedStatistics.add(NUMBER_OF_DISTINCT_VALUES.getColumnStatisticMetadataWithCustomFunction(columnName, "sketch_theta")); + supportedStatistics.add(NUMBER_OF_DISTINCT_VALUES.getColumnStatisticMetadataWithCustomFunction( + columnName, format("RETURN sketch_theta(%s)", formatIdentifier(columnName)), ImmutableList.of(columnName))); } if (!(type instanceof FixedWidthType)) { diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/util/StatisticsUtil.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/util/StatisticsUtil.java index ff9cd814615ba..4200610239c16 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/util/StatisticsUtil.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/util/StatisticsUtil.java @@ -172,4 +172,9 @@ public static List combineSelectedAndPredicateColumns(List< .iterator()) .build(); } + + public static String formatIdentifier(String s) + { + return '"' + s.replace("\"", "\"\"") + '"'; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/StatisticsWriterOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/StatisticsWriterOperator.java index 20a901f138428..509f4a7628251 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/StatisticsWriterOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/StatisticsWriterOperator.java @@ -175,7 +175,9 @@ private ComputedStatistics getComputedStatistics(Page page, int position) descriptor.getTableStatistics().forEach((type, channel) -> statistics.addTableStatistic(type, page.getBlock(channel).getSingleValueBlock(position))); - descriptor.getColumnStatistics().forEach((metadata, channel) -> statistics.addColumnStatistic(metadata, page.getBlock(channel).getSingleValueBlock(position))); + descriptor.getColumnStatistics().forEach((descriptor) -> + statistics.addColumnStatistic(descriptor.getMetadata(), + page.getBlock(descriptor.getItem()).getSingleValueBlock(position))); return statistics.build(); } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/TableFinishOperator.java b/presto-main/src/main/java/com/facebook/presto/operator/TableFinishOperator.java index 52b825d496167..bcdab3dde9fe3 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/TableFinishOperator.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/TableFinishOperator.java @@ -311,7 +311,7 @@ private ComputedStatistics getComputedStatistics(Page page, int position) descriptor.getTableStatistics().forEach((type, channel) -> statistics.addTableStatistic(type, page.getBlock(channel).getSingleValueBlock(position))); - descriptor.getColumnStatistics().forEach((metadata, channel) -> statistics.addColumnStatistic(metadata, page.getBlock(channel).getSingleValueBlock(position))); + descriptor.getColumnStatistics().forEach((descriptor) -> statistics.addColumnStatistic(descriptor.getMetadata(), page.getBlock(descriptor.getItem()).getSingleValueBlock(position))); return statistics.build(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java index de30f542a007b..aa134e91d03a5 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java @@ -123,7 +123,7 @@ public BasePlanFragmenter( this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); this.variableAllocator = requireNonNull(variableAllocator, "variableAllocator is null"); this.outputTableWriterNodeIds = ImmutableSet.copyOf(requireNonNull(outputTableWriterNodeIds, "outputTableWriterNodeIds is null")); - this.statisticsAggregationPlanner = new StatisticsAggregationPlanner(variableAllocator, metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver()); + this.statisticsAggregationPlanner = new StatisticsAggregationPlanner(variableAllocator, metadata.getFunctionAndTypeManager(), session); } public SubPlan buildRootFragment(PlanNode root, FragmentProperties properties) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java index 1763247be121e..d43b439de9ff9 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/LogicalPlanner.java @@ -110,6 +110,7 @@ import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Streams.zip; import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; public class LogicalPlanner { @@ -131,7 +132,7 @@ public LogicalPlanner( this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); this.metadata = requireNonNull(metadata, "metadata is null"); this.variableAllocator = requireNonNull(variableAllocator, "variableAllocator is null"); - this.statisticsAggregationPlanner = new StatisticsAggregationPlanner(this.variableAllocator, metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver()); + this.statisticsAggregationPlanner = new StatisticsAggregationPlanner(this.variableAllocator, metadata.getFunctionAndTypeManager(), session); this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); } @@ -224,7 +225,6 @@ private RelationPlan createAnalyzePlan(Analysis analysis, Analyze analyzeStateme columnNameToVariable.put(column.getName(), variable); } - List tableScanOutputs = tableScanOutputsBuilder.build(); TableStatisticsMetadata tableStatisticsMetadata = metadata.getStatisticsCollectionMetadata( session, targetTable.getConnectorId().getCatalogName(), @@ -232,14 +232,20 @@ private RelationPlan createAnalyzePlan(Analysis analysis, Analyze analyzeStateme TableStatisticAggregation tableStatisticAggregation = statisticsAggregationPlanner.createStatisticsAggregation(tableStatisticsMetadata, columnNameToVariable.build()); StatisticAggregations statisticAggregations = tableStatisticAggregation.getAggregations(); - + List tableScanOutputs = tableScanOutputsBuilder.build(); + Map assignments = ImmutableMap.builder() + .putAll(tableScanOutputs.stream().collect(toImmutableMap(identity(), identity()))) + .putAll(tableStatisticAggregation.getAdditionalVariables()) + .build(); + TableScanNode scanNode = new TableScanNode(getSourceLocation(analyzeStatement), idAllocator.getNextId(), targetTable, tableScanOutputs, variableToColumnHandle.build(), TupleDomain.all(), TupleDomain.all()); + PlanNode project = PlannerUtils.addProjections(scanNode, idAllocator, assignments); PlanNode planNode = new StatisticsWriterNode( getSourceLocation(analyzeStatement), idAllocator.getNextId(), new AggregationNode( getSourceLocation(analyzeStatement), idAllocator.getNextId(), - new TableScanNode(getSourceLocation(analyzeStatement), idAllocator.getNextId(), targetTable, tableScanOutputs, variableToColumnHandle.build(), TupleDomain.all(), TupleDomain.all()), + project, statisticAggregations.getAggregations(), singleGroupingSet(statisticAggregations.getGroupingVariables()), ImmutableList.of(), diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/StatisticsAggregationPlanner.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/StatisticsAggregationPlanner.java index 0f17eaf78d2d3..f9722664ecebb 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/StatisticsAggregationPlanner.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/StatisticsAggregationPlanner.java @@ -13,7 +13,9 @@ */ package com.facebook.presto.sql.planner; +import com.facebook.presto.Session; import com.facebook.presto.common.type.Type; +import com.facebook.presto.metadata.FunctionAndTypeManager; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.function.FunctionHandle; @@ -21,6 +23,7 @@ import com.facebook.presto.spi.function.StandardFunctionResolution; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.spi.statistics.ColumnStatisticMetadata; import com.facebook.presto.spi.statistics.ColumnStatisticType; @@ -33,15 +36,19 @@ import com.facebook.presto.sql.relational.FunctionResolution; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.stream.Collectors; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.UnknownType.UNKNOWN; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; +import static com.facebook.presto.spi.function.FunctionKind.AGGREGATE; import static com.facebook.presto.spi.statistics.TableStatisticType.ROW_COUNT; +import static com.facebook.presto.sql.relational.SqlFunctionUtils.sqlFunctionToRowExpression; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getOnlyElement; @@ -51,11 +58,15 @@ public class StatisticsAggregationPlanner { private final VariableAllocator variableAllocator; private final FunctionAndTypeResolver functionAndTypeResolver; + private final Session session; + private final FunctionAndTypeManager functionAndTypeManager; - public StatisticsAggregationPlanner(VariableAllocator variableAllocator, FunctionAndTypeResolver functionAndTypeResolver) + public StatisticsAggregationPlanner(VariableAllocator variableAllocator, FunctionAndTypeManager functionAndTypeManager, Session session) { this.variableAllocator = requireNonNull(variableAllocator, "variableAllocator is null"); - this.functionAndTypeResolver = requireNonNull(functionAndTypeResolver, "functionAndTypeResolver is null"); + this.session = requireNonNull(session, "session is null"); + this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); + this.functionAndTypeResolver = functionAndTypeManager.getFunctionAndTypeResolver(); } public TableStatisticAggregation createStatisticsAggregation(TableStatisticsMetadata statisticsMetadata, Map columnToVariableMap) @@ -70,6 +81,7 @@ public TableStatisticAggregation createStatisticsAggregation(TableStatisticsMeta for (int i = 0; i < groupingVariables.size(); i++) { descriptor.addGrouping(groupingColumns.get(i), groupingVariables.get(i)); } + ImmutableMap.Builder additionalVariables = ImmutableMap.builder(); ImmutableMap.Builder aggregations = ImmutableMap.builder(); StandardFunctionResolution functionResolution = new FunctionResolution(functionAndTypeResolver); @@ -98,18 +110,62 @@ public TableStatisticAggregation createStatisticsAggregation(TableStatisticsMeta VariableReferenceExpression inputVariable = columnToVariableMap.get(columnName); verify(inputVariable != null, "inputVariable is null"); ColumnStatisticsAggregation aggregation = createColumnAggregation(columnStatisticMetadata, inputVariable); + additionalVariables.putAll(aggregation.getInputProjections()); VariableReferenceExpression variable = variableAllocator.newVariable(statisticType + ":" + columnName, aggregation.getOutputType()); aggregations.put(variable, aggregation.getAggregation()); descriptor.addColumnStatistic(columnStatisticMetadata, variable); } StatisticAggregations aggregation = new StatisticAggregations(aggregations.build(), groupingVariables); - return new TableStatisticAggregation(aggregation, descriptor.build()); + return new TableStatisticAggregation(aggregation, descriptor.build(), additionalVariables.build()); } - private ColumnStatisticsAggregation createColumnAggregation(ColumnStatisticMetadata columnStatisticMetadata, VariableReferenceExpression input) + private ColumnStatisticsAggregation createColumnAggregationFromSqlFunction(String sqlFunction, VariableReferenceExpression input) { - FunctionHandle functionHandle = functionAndTypeResolver.lookupFunction(columnStatisticMetadata.getFunctionName(), TypeSignatureProvider.fromTypes(ImmutableList.of(input.getType()))); + RowExpression expression = sqlFunctionToRowExpression( + sqlFunction, + ImmutableSet.of(input), + functionAndTypeManager, + session); + verify(expression instanceof CallExpression, "column statistic SQL expressions must represent a function call"); + CallExpression call = (CallExpression) expression; + FunctionMetadata functionMeta = functionAndTypeResolver.getFunctionMetadata(call.getFunctionHandle()); + verify(functionMeta.getFunctionKind().equals(AGGREGATE), "column statistic function must be aggregates"); + // Aggregations input arguments are required to be variable reference expressions. + // For each one that isn't, allocate a new variable to reference + ImmutableMap.Builder inputProjections = ImmutableMap.builder(); + List callVariableArguments = call.getArguments() + .stream() + .map(argument -> { + if (argument instanceof VariableReferenceExpression) { + return argument; + } + VariableReferenceExpression newArgument = variableAllocator.newVariable(argument); + inputProjections.put(newArgument, argument); + return newArgument; + }) + .collect(Collectors.toList()); + CallExpression callWithVariables = new CallExpression( + call.getSourceLocation(), + call.getDisplayName(), + call.getFunctionHandle(), + call.getType(), + callVariableArguments); + return new ColumnStatisticsAggregation( + new AggregationNode.Aggregation(callWithVariables, + Optional.empty(), + Optional.empty(), + false, + Optional.empty()), + functionAndTypeResolver.getType(functionMeta.getReturnType()), + inputProjections.build()); + } + + private ColumnStatisticsAggregation createColumnAggregationFromFunctionName(ColumnStatisticMetadata columnStatisticMetadata, VariableReferenceExpression input) + { + FunctionHandle functionHandle = functionAndTypeResolver.lookupFunction(columnStatisticMetadata.getFunction(), TypeSignatureProvider.fromTypes(ImmutableList.builder() + .add(input.getType()) + .build())); FunctionMetadata functionMeta = functionAndTypeResolver.getFunctionMetadata(functionHandle); Type inputType = functionAndTypeResolver.getType(getOnlyElement(functionMeta.getArgumentTypes())); Type outputType = functionAndTypeResolver.getType(functionMeta.getReturnType()); @@ -118,7 +174,7 @@ private ColumnStatisticsAggregation createColumnAggregation(ColumnStatisticMetad new AggregationNode.Aggregation( new CallExpression( input.getSourceLocation(), - columnStatisticMetadata.getFunctionName(), + columnStatisticMetadata.getFunction(), functionHandle, outputType, ImmutableList.of(input)), @@ -126,20 +182,33 @@ private ColumnStatisticsAggregation createColumnAggregation(ColumnStatisticMetad Optional.empty(), false, Optional.empty()), - outputType); + outputType, + ImmutableMap.of()); + } + + private ColumnStatisticsAggregation createColumnAggregation(ColumnStatisticMetadata columnStatisticMetadata, VariableReferenceExpression input) + { + if (columnStatisticMetadata.isSqlExpression()) { + return createColumnAggregationFromSqlFunction(columnStatisticMetadata.getFunction(), input); + } + + return createColumnAggregationFromFunctionName(columnStatisticMetadata, input); } public static class TableStatisticAggregation { private final StatisticAggregations aggregations; private final StatisticAggregationsDescriptor descriptor; + private final Map additionalVariables; private TableStatisticAggregation( StatisticAggregations aggregations, - StatisticAggregationsDescriptor descriptor) + StatisticAggregationsDescriptor descriptor, + Map additionalVariables) { this.aggregations = requireNonNull(aggregations, "statisticAggregations is null"); this.descriptor = requireNonNull(descriptor, "descriptor is null"); + this.additionalVariables = requireNonNull(additionalVariables, "additionalVariables is null"); } public StatisticAggregations getAggregations() @@ -151,17 +220,24 @@ public StatisticAggregationsDescriptor getDescripto { return descriptor; } + + public Map getAdditionalVariables() + { + return additionalVariables; + } } public static class ColumnStatisticsAggregation { private final AggregationNode.Aggregation aggregation; private final Type outputType; + private final Map inputProjections; - private ColumnStatisticsAggregation(AggregationNode.Aggregation aggregation, Type outputType) + private ColumnStatisticsAggregation(AggregationNode.Aggregation aggregation, Type outputType, Map inputProjections) { this.aggregation = requireNonNull(aggregation, "aggregation is null"); this.outputType = requireNonNull(outputType, "outputType is null"); + this.inputProjections = requireNonNull(inputProjections, "additionalVariable is null"); } public AggregationNode.Aggregation getAggregation() @@ -173,5 +249,10 @@ public Type getOutputType() { return outputType; } + + public Map getInputProjections() + { + return inputProjections; + } } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/StatisticAggregationsDescriptor.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/StatisticAggregationsDescriptor.java index 95878fb4310f6..7f7db1692d0a1 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/StatisticAggregationsDescriptor.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/plan/StatisticAggregationsDescriptor.java @@ -14,28 +14,18 @@ package com.facebook.presto.sql.planner.plan; import com.facebook.presto.spi.statistics.ColumnStatisticMetadata; -import com.facebook.presto.spi.statistics.ColumnStatisticType; import com.facebook.presto.spi.statistics.TableStatisticType; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.core.JsonGenerator; -import com.fasterxml.jackson.databind.DeserializationContext; -import com.fasterxml.jackson.databind.JsonSerializer; -import com.fasterxml.jackson.databind.KeyDeserializer; -import com.fasterxml.jackson.databind.SerializerProvider; -import com.fasterxml.jackson.databind.annotation.JsonDeserialize; -import com.fasterxml.jackson.databind.annotation.JsonSerialize; -import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import java.io.IOException; +import java.util.List; import java.util.Map; import java.util.Objects; import java.util.function.Function; import static com.google.common.base.MoreObjects.toStringHelper; -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static java.util.Objects.requireNonNull; @@ -43,7 +33,7 @@ public class StatisticAggregationsDescriptor { private final Map grouping; private final Map tableStatistics; - private final Map columnStatistics; + private final List> columnStatistics; public static StatisticAggregationsDescriptor empty() { @@ -54,11 +44,22 @@ public static StatisticAggregationsDescriptor empty() public StatisticAggregationsDescriptor( @JsonProperty("grouping") Map grouping, @JsonProperty("tableStatistics") Map tableStatistics, - @JsonProperty("columnStatistics") Map columnStatistics) + @JsonProperty("columnStatistics") List> columnStatistics) { this.grouping = ImmutableMap.copyOf(requireNonNull(grouping, "grouping is null")); this.tableStatistics = ImmutableMap.copyOf(requireNonNull(tableStatistics, "tableStatistics is null")); - this.columnStatistics = ImmutableMap.copyOf(requireNonNull(columnStatistics, "columnStatistics is null")); + this.columnStatistics = requireNonNull(columnStatistics, "columnStatistics is null"); + } + + public StatisticAggregationsDescriptor( + Map grouping, + Map tableStatistics, + Map columnStatistics) + { + this(grouping, tableStatistics, ImmutableList.>builder() + .addAll(columnStatistics.entrySet().stream() + .map(e -> new ColumnStatisticsDescriptor<>(e.getKey(), e.getValue())).iterator()) + .build()); } @JsonProperty @@ -74,9 +75,7 @@ public Map getTableStatistics() } @JsonProperty - @JsonSerialize(keyUsing = ColumnStatisticMetadataKeySerializer.class) - @JsonDeserialize(keyUsing = ColumnStatisticMetadataKeyDeserializer.class) - public Map getColumnStatistics() + public List> getColumnStatistics() { return columnStatistics; } @@ -122,7 +121,11 @@ public StatisticAggregationsDescriptor map(Function mapper) return new StatisticAggregationsDescriptor<>( map(this.getGrouping(), mapper), map(this.getTableStatistics(), mapper), - map(this.getColumnStatistics(), mapper)); + map(this.getColumnStatistics().stream() + .collect(toImmutableMap( + ColumnStatisticsDescriptor::getMetadata, + ColumnStatisticsDescriptor::getItem)), + mapper)); } private static Map map(Map input, Function mapper) @@ -159,39 +162,59 @@ public StatisticAggregationsDescriptor build() } } - public static class ColumnStatisticMetadataKeySerializer - extends JsonSerializer + public static class ColumnStatisticsDescriptor { - @Override - public void serialize(ColumnStatisticMetadata value, JsonGenerator gen, SerializerProvider serializers) - throws IOException + private final ColumnStatisticMetadata metadata; + private final T item; + + @JsonCreator + public ColumnStatisticsDescriptor(@JsonProperty("metadata") ColumnStatisticMetadata metadata, @JsonProperty("item") T item) { - verify(value != null, "value is null"); - gen.writeFieldName(serialize(value)); + this.metadata = requireNonNull(metadata, "metadata is null"); + this.item = requireNonNull(item, "item is null"); } - @VisibleForTesting - static String serialize(ColumnStatisticMetadata value) + @JsonProperty + public T getItem() { - return value.getStatisticType().name() + ":" + value.getFunctionName() + ":" + value.getColumnName(); + return item; + } + + @JsonProperty + public ColumnStatisticMetadata getMetadata() + { + return metadata; } - } - public static class ColumnStatisticMetadataKeyDeserializer - extends KeyDeserializer - { @Override - public ColumnStatisticMetadata deserializeKey(String key, DeserializationContext ctx) + public int hashCode() { - return deserialize(requireNonNull(key, "key is null")); + return Objects.hash(metadata, item); } - @VisibleForTesting - static ColumnStatisticMetadata deserialize(String value) + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + + if (!(o instanceof ColumnStatisticsDescriptor)) { + return false; + } + + ColumnStatisticsDescriptor other = (ColumnStatisticsDescriptor) o; + return metadata.equals(other.metadata) && + item.equals(other.item); + } + + @Override + public String toString() { - String[] values = value.split(":", 3); - checkArgument(values.length == 3, "separator(s) not found: %s", value); - return new ColumnStatisticMetadata(values[2], ColumnStatisticType.valueOf(values[0]), values[1]); + return toStringHelper(this) + .add("metadata", metadata) + .add("item", item) + .toString(); } } } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java index 7ffc277dfde93..cda1acb7658f1 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java @@ -65,6 +65,7 @@ import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.SpatialJoinNode; import com.facebook.presto.sql.planner.plan.StatisticAggregationsDescriptor; +import com.facebook.presto.sql.planner.plan.StatisticAggregationsDescriptor.ColumnStatisticsDescriptor; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; import com.facebook.presto.sql.planner.plan.TableFinishNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; @@ -656,7 +657,7 @@ public Void visitStatisticsWriterNode(StatisticsWriterNode node, Set descriptor = node.getDescriptor(); Set dependencies = ImmutableSet.builder() .addAll(descriptor.getGrouping().values()) - .addAll(descriptor.getColumnStatistics().values()) + .addAll(descriptor.getColumnStatistics().stream().map(ColumnStatisticsDescriptor::getItem).iterator()) .addAll(descriptor.getTableStatistics().values()) .build(); List outputVariables = node.getSource().getOutputVariables(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlFunctionUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlFunctionUtils.java index d31a9c7c3425f..fdbcca5d44bbe 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlFunctionUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/relational/SqlFunctionUtils.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.sql.relational; +import com.facebook.presto.Session; import com.facebook.presto.common.function.SqlFunctionProperties; import com.facebook.presto.common.type.Type; import com.facebook.presto.expressions.RowExpressionRewriter; @@ -45,6 +46,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import static com.facebook.presto.spi.function.FunctionImplementationType.SQL; import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.analyzeSqlFunctionExpression; @@ -113,6 +115,34 @@ public static RowExpression getSqlFunctionRowExpression( argumentVariables); } + public static RowExpression sqlFunctionToRowExpression(String functionBody, + Set variables, + FunctionAndTypeManager functionAndTypeManager, + Session session) + { + Expression expression = parseSqlFunctionExpression( + new SqlInvokedScalarFunctionImplementation(functionBody), + session.getSqlFunctionProperties()); + return SqlToRowExpressionTranslator.translate( + expression, + analyzeSqlFunctionExpression( + functionAndTypeManager.getFunctionAndTypeResolver(), + session.getSqlFunctionProperties(), + expression, + variables.stream() + .collect(toImmutableMap( + VariableReferenceExpression::getName, + VariableReferenceExpression::getType))) + .getExpressionTypes(), + ImmutableMap.of(), + functionAndTypeManager, + Optional.empty(), + Optional.empty(), + session.getSqlFunctionProperties(), + session.getSessionFunctions(), + new SqlToRowExpressionTranslator.Context()); + } + private static Expression getSqlFunctionImplementationExpression( FunctionMetadata functionMetadata, SqlInvokedScalarFunctionImplementation implementation, diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestStatisticAggregationsDescriptor.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestStatisticAggregationsDescriptor.java deleted file mode 100644 index fe4081cb64252..0000000000000 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/plan/TestStatisticAggregationsDescriptor.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto.sql.planner.plan; - -import com.facebook.presto.spi.statistics.ColumnStatisticMetadata; -import com.facebook.presto.spi.statistics.ColumnStatisticType; -import com.google.common.collect.ImmutableList; -import org.testng.annotations.Test; - -import static com.facebook.presto.sql.planner.plan.StatisticAggregationsDescriptor.ColumnStatisticMetadataKeyDeserializer.deserialize; -import static com.facebook.presto.sql.planner.plan.StatisticAggregationsDescriptor.ColumnStatisticMetadataKeySerializer.serialize; -import static com.facebook.presto.testing.assertions.Assert.assertEquals; - -public class TestStatisticAggregationsDescriptor -{ - private static final ImmutableList COLUMNS = ImmutableList.of("", "col1", "$:###:;", "abc+dddd___"); - - @Test - public void testColumnStatisticMetadataKeySerializationRoundTrip() - { - for (String column : COLUMNS) { - for (ColumnStatisticType type : ColumnStatisticType.values()) { - ColumnStatisticMetadata expected = type.getColumnStatisticMetadata(column); - assertEquals(deserialize(serialize(expected)), expected); - } - } - } -} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatisticMetadata.java b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatisticMetadata.java index ff32b0167ed62..474e6d00de1b4 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatisticMetadata.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatisticMetadata.java @@ -16,26 +16,56 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; import java.util.Objects; import static java.util.Objects.requireNonNull; +/** + *

+ * Represents a column statistic that should be computed during an {@code ANALYZE} query. The + * field {@code isSqlExpression} denotes whether the value of the {@code function} field + * represents a simple function name, or a complex SQL expression. + *

+ *

+ * In the case of a SQL expression, it should parse to a valid function call that may use + * constant-value arguments or references to specific column names in the arguments. Column + * references must be properly quoted, as the string should be valid SQL. The SQL itself parses + * as a function, so it should always begin with a {@code RETURN} keyword. Any references + * columns used in the SQL expression as function arguments should be added to the + * {@code columnArguments} field. + *

+ *

+ * Example: Suppose you want to compute the column statistic is of a t-digest with a + * configurable weight and compression for a column named "x". The {@code function} should be: + *
+ * {@code "RETURN tdigest_agg("x", 1, 200)"} + *

+ * Additionally, the {@code columnArguments} field should be populated with {@code ["x"]} + *

+ *

+ */ public class ColumnStatisticMetadata { private final String columnName; private final ColumnStatisticType statisticType; - - private final String functionName; + private final String function; + private final List columnArguments; + private final boolean isSqlExpression; @JsonCreator public ColumnStatisticMetadata( @JsonProperty("columnName") String columnName, @JsonProperty("statisticType") ColumnStatisticType statisticType, - @JsonProperty("functionName") String functionName) + @JsonProperty("function") String function, + @JsonProperty("columnArguments") List columnArguments, + @JsonProperty("isSqlExpression") boolean isSqlExpression) { this.columnName = requireNonNull(columnName, "columnName is null"); this.statisticType = requireNonNull(statisticType, "statisticType is null"); - this.functionName = requireNonNull(functionName, "functionName is null"); + this.function = requireNonNull(function, "functionName is null"); + this.columnArguments = requireNonNull(columnArguments, "additionalArguments is null"); + this.isSqlExpression = isSqlExpression; } @JsonProperty @@ -51,9 +81,21 @@ public ColumnStatisticType getStatisticType() } @JsonProperty - public String getFunctionName() + public String getFunction() + { + return function; + } + + @JsonProperty + public List getColumnArguments() + { + return columnArguments; + } + + @JsonProperty + public boolean isSqlExpression() { - return functionName; + return isSqlExpression; } @Override @@ -68,13 +110,15 @@ public boolean equals(Object o) ColumnStatisticMetadata that = (ColumnStatisticMetadata) o; return Objects.equals(columnName, that.columnName) && statisticType == that.statisticType && - Objects.equals(functionName, that.functionName); + Objects.equals(function, that.function) && + Objects.equals(columnArguments, that.columnArguments) && + Objects.equals(isSqlExpression, that.isSqlExpression); } @Override public int hashCode() { - return Objects.hash(columnName, statisticType, functionName); + return Objects.hash(columnName, statisticType, function, columnArguments, isSqlExpression); } @Override @@ -83,7 +127,9 @@ public String toString() return "ColumnStatisticMetadata{" + "columnName='" + columnName + '\'' + ", statisticType=" + statisticType + - ", functionName=" + functionName + + ", function=" + function + + ", columnArguments=" + columnArguments + + ", isSqlExpression=" + isSqlExpression + '}'; } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatisticType.java b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatisticType.java index 20b9261e7bc25..c196076d15f32 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatisticType.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/statistics/ColumnStatisticType.java @@ -13,6 +13,10 @@ */ package com.facebook.presto.spi.statistics; +import java.util.List; + +import static java.util.Collections.emptyList; + public enum ColumnStatisticType { MAX_VALUE("max"), @@ -32,11 +36,11 @@ public enum ColumnStatisticType public ColumnStatisticMetadata getColumnStatisticMetadata(String columnName) { - return new ColumnStatisticMetadata(columnName, this, this.functionName); + return new ColumnStatisticMetadata(columnName, this, this.functionName, false, emptyList()); } - public ColumnStatisticMetadata getColumnStatisticMetadataWithCustomFunction(String columnName, String functionName) + public ColumnStatisticMetadata getColumnStatisticMetadataWithCustomFunction(String columnName, String functionSql, List columnArguments) { - return new ColumnStatisticMetadata(columnName, this, functionName); + return new ColumnStatisticMetadata(columnName, this, functionSql, true, columnArguments); } }