Skip to content

Commit

Permalink
Allow configuring k value for histograms
Browse files Browse the repository at this point in the history
Adds support in the ANALYZE infrastructure to configure additional
arguments for a particular aggregation. Adds additional configuration
to the iceberg connector to utilize the infrastructure.
  • Loading branch information
ZacBlanco committed Jul 9, 2024
1 parent a29d9c0 commit 0dceef8
Show file tree
Hide file tree
Showing 13 changed files with 140 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ public TableStatisticsMetadata getStatisticsCollectionMetadata(ConnectorSession
{
Set<ColumnStatisticMetadata> columnStatistics = tableMetadata.getColumns().stream()
.filter(column -> !column.isHidden())
.flatMap(meta -> getSupportedColumnStatistics(meta.getName(), meta.getType()).stream())
.flatMap(meta -> getSupportedColumnStatistics(session, meta.getName(), meta.getType()).stream())
.collect(toImmutableSet());

Set<TableStatisticType> tableStatistics = ImmutableSet.of(ROW_COUNT);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import javax.validation.constraints.DecimalMax;
import javax.validation.constraints.DecimalMin;
import javax.validation.constraints.Max;
import javax.validation.constraints.Min;
import javax.validation.constraints.NotNull;

Expand Down Expand Up @@ -53,6 +54,7 @@ public class IcebergConfig
private boolean pushdownFilterEnabled;
private boolean deleteAsJoinRewriteEnabled = true;
private int rowsForMetadataOptimizationThreshold = 1000;
private int histogramKParameter = 4096;

private EnumSet<ColumnStatisticType> hiveStatisticsMergeFlags = EnumSet.noneOf(ColumnStatisticType.class);
private String fileIOImpl = HadoopFileIO.class.getName();
Expand Down Expand Up @@ -349,4 +351,19 @@ public IcebergConfig setSplitManagerThreads(int splitManagerThreads)
this.splitManagerThreads = splitManagerThreads;
return this;
}

@Min(8)
@Max(65535)
public int getHistogramKParameter()
{
return histogramKParameter;
}

@Config("iceberg.histogram-k-parameter")
@ConfigDescription("K parameter for KLL sketch used in histogram statistics")
public IcebergConfig setHistogramKParameter(int histogramKParameter)
{
this.histogramKParameter = histogramKParameter;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ public final class IcebergSessionProperties
public static final String HIVE_METASTORE_STATISTICS_MERGE_STRATEGY = "hive_statistics_merge_strategy";
public static final String STATISTIC_SNAPSHOT_RECORD_DIFFERENCE_WEIGHT = "statistic_snapshot_record_difference_weight";
public static final String ROWS_FOR_METADATA_OPTIMIZATION_THRESHOLD = "rows_for_metadata_optimization_threshold";
public static final String HISTOGRAM_K_PARAMETER = "histogram_k_parameter";

private final List<PropertyMetadata<?>> sessionProperties;

Expand Down Expand Up @@ -184,6 +185,11 @@ public IcebergSessionProperties(
"of an Iceberg table exceeds this threshold, metadata optimization would be skipped for " +
"the table. A value of 0 means skip metadata optimization directly.",
icebergConfig.getRowsForMetadataOptimizationThreshold(),
false))
.add(integerProperty(
HISTOGRAM_K_PARAMETER,
"The parameter passed to the KLL sketch function when generating histogram statistics",
icebergConfig.getHistogramKParameter(),
false));

nessieConfig.ifPresent((config) -> propertiesBuilder
Expand Down Expand Up @@ -313,4 +319,9 @@ public static String getNessieReferenceHash(ConnectorSession session)
{
return session.getProperty(NESSIE_REFERENCE_HASH, String.class);
}

public static int getHistogramKParameter(ConnectorSession session)
{
return session.getProperty(HISTOGRAM_K_PARAMETER, Integer.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.facebook.airlift.log.Logger;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.RunLengthEncodedBlock;
import com.facebook.presto.common.predicate.Range;
import com.facebook.presto.common.predicate.TupleDomain;
import com.facebook.presto.common.type.FixedWidthType;
Expand Down Expand Up @@ -93,12 +94,14 @@
import static com.facebook.presto.iceberg.ExpressionConverter.toIcebergExpression;
import static com.facebook.presto.iceberg.IcebergErrorCode.ICEBERG_FILESYSTEM_ERROR;
import static com.facebook.presto.iceberg.IcebergErrorCode.ICEBERG_INVALID_METADATA;
import static com.facebook.presto.iceberg.IcebergSessionProperties.getHistogramKParameter;
import static com.facebook.presto.iceberg.IcebergSessionProperties.getStatisticSnapshotRecordDifferenceWeight;
import static com.facebook.presto.iceberg.IcebergUtil.getIdentityPartitions;
import static com.facebook.presto.iceberg.Partition.toMap;
import static com.facebook.presto.iceberg.TypeConverter.toPrestoType;
import static com.facebook.presto.iceberg.statistics.KllHistogram.isKllHistogramSupportedType;
import static com.facebook.presto.iceberg.util.StatisticsUtil.calculateAndSetTableSize;
import static com.facebook.presto.spi.relation.ConstantExpression.createConstantExpression;
import static com.facebook.presto.spi.statistics.ColumnStatisticType.HISTOGRAM;
import static com.facebook.presto.spi.statistics.ColumnStatisticType.NUMBER_OF_DISTINCT_VALUES;
import static com.facebook.presto.spi.statistics.ColumnStatisticType.TOTAL_SIZE_IN_BYTES;
Expand Down Expand Up @@ -577,18 +580,19 @@ private Map<Integer, ColumnStatistics.Builder> loadStatisticsFile(StatisticsFile
return ImmutableMap.copyOf(result);
}

public static List<ColumnStatisticMetadata> getSupportedColumnStatistics(String columnName, com.facebook.presto.common.type.Type type)
public static List<ColumnStatisticMetadata> getSupportedColumnStatistics(ConnectorSession session, String columnName, com.facebook.presto.common.type.Type type)
{
ImmutableList.Builder<ColumnStatisticMetadata> supportedStatistics = ImmutableList.builder();
// all types which support being passed to the sketch_theta function
if (isNumericType(type) || type.equals(DATE) || isVarcharType(type) ||
type.equals(TIMESTAMP) ||
type.equals(TIMESTAMP_WITH_TIME_ZONE)) {
supportedStatistics.add(NUMBER_OF_DISTINCT_VALUES.getColumnStatisticMetadataWithCustomFunction(columnName, "sketch_theta"));
supportedStatistics.add(NUMBER_OF_DISTINCT_VALUES.getColumnStatisticMetadataWithCustomFunction(columnName, "sketch_theta", ImmutableList.of()));
}

if (isKllHistogramSupportedType(type)) {
supportedStatistics.add(HISTOGRAM.getColumnStatisticMetadataWithCustomFunction(columnName, "sketch_kll"));
supportedStatistics.add(HISTOGRAM.getColumnStatisticMetadataWithCustomFunction(columnName, "sketch_kll_with_k",
ImmutableList.of(createConstantExpression(RunLengthEncodedBlock.create(BIGINT, (long) getHistogramKParameter(session), 1), BIGINT))));
}

if (!(type instanceof FixedWidthType)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ private ComputedStatistics getComputedStatistics(Page page, int position)
descriptor.getTableStatistics().forEach((type, channel) ->
statistics.addTableStatistic(type, page.getBlock(channel).getSingleValueBlock(position)));

descriptor.getColumnStatistics().forEach((metadata, channel) -> statistics.addColumnStatistic(metadata, page.getBlock(channel).getSingleValueBlock(position)));
descriptor.getColumnStatistics().forEach((descriptor) -> statistics.addColumnStatistic(descriptor.getMetadata(), page.getBlock(descriptor.getItem()).getSingleValueBlock(position)));

return statistics.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ private ComputedStatistics getComputedStatistics(Page page, int position)
descriptor.getTableStatistics().forEach((type, channel) ->
statistics.addTableStatistic(type, page.getBlock(channel).getSingleValueBlock(position)));

descriptor.getColumnStatistics().forEach((metadata, channel) -> statistics.addColumnStatistic(metadata, page.getBlock(channel).getSingleValueBlock(position)));
descriptor.getColumnStatistics().forEach((descriptor) -> statistics.addColumnStatistic(descriptor.getMetadata(), page.getBlock(descriptor.getItem()).getSingleValueBlock(position)));

return statistics.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,22 +225,25 @@ private RelationPlan createAnalyzePlan(Analysis analysis, Analyze analyzeStateme
columnNameToVariable.put(column.getName(), variable);
}

List<VariableReferenceExpression> tableScanOutputs = tableScanOutputsBuilder.build();
TableStatisticsMetadata tableStatisticsMetadata = metadata.getStatisticsCollectionMetadata(
session,
targetTable.getConnectorId().getCatalogName(),
tableMetadata.getMetadata());

TableStatisticAggregation tableStatisticAggregation = statisticsAggregationPlanner.createStatisticsAggregation(tableStatisticsMetadata, columnNameToVariable.build());
StatisticAggregations statisticAggregations = tableStatisticAggregation.getAggregations();

List<VariableReferenceExpression> tableScanOutputs = tableScanOutputsBuilder.build();
Assignments assignments = Assignments.builder()
.putAll(tableScanOutputs.stream().collect(toImmutableMap(k -> k, k -> k)))
.putAll(tableStatisticAggregation.getAdditionalVariables())
.build();
PlanNode planNode = new StatisticsWriterNode(
getSourceLocation(analyzeStatement),
idAllocator.getNextId(),
new AggregationNode(
getSourceLocation(analyzeStatement),
idAllocator.getNextId(),
new TableScanNode(getSourceLocation(analyzeStatement), idAllocator.getNextId(), targetTable, tableScanOutputs, variableToColumnHandle.build(), TupleDomain.all(), TupleDomain.all()),
new ProjectNode(idAllocator.getNextId(), new TableScanNode(getSourceLocation(analyzeStatement), idAllocator.getNextId(), targetTable, tableScanOutputs, variableToColumnHandle.build(), TupleDomain.all(), TupleDomain.all()), assignments),
statisticAggregations.getAggregations(),
singleGroupingSet(statisticAggregations.getGroupingVariables()),
ImmutableList.of(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import com.facebook.presto.spi.function.StandardFunctionResolution;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.spi.statistics.ColumnStatisticMetadata;
import com.facebook.presto.spi.statistics.ColumnStatisticType;
Expand All @@ -44,7 +46,6 @@
import static com.facebook.presto.spi.statistics.TableStatisticType.ROW_COUNT;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;
import static java.util.Objects.requireNonNull;

public class StatisticsAggregationPlanner
Expand Down Expand Up @@ -72,6 +73,7 @@ public TableStatisticAggregation createStatisticsAggregation(TableStatisticsMeta
for (int i = 0; i < groupingVariables.size(); i++) {
descriptor.addGrouping(groupingColumns.get(i), groupingVariables.get(i));
}
ImmutableMap.Builder<VariableReferenceExpression, RowExpression> additionalVariables = ImmutableMap.builder();

ImmutableMap.Builder<VariableReferenceExpression, AggregationNode.Aggregation> aggregations = ImmutableMap.builder();
StandardFunctionResolution functionResolution = new FunctionResolution(functionAndTypeResolver);
Expand Down Expand Up @@ -103,48 +105,66 @@ public TableStatisticAggregation createStatisticsAggregation(TableStatisticsMeta
VariableReferenceExpression inputVariable = columnToVariableMap.get(columnName);
verify(inputVariable != null, "inputVariable is null");
ColumnStatisticsAggregation aggregation = createColumnAggregation(columnStatisticMetadata, inputVariable);
additionalVariables.putAll(aggregation.getAdditionalVariables());
VariableReferenceExpression variable = variableAllocator.newVariable(statisticType + ":" + columnName, aggregation.getOutputType());
aggregations.put(variable, aggregation.getAggregation());
descriptor.addColumnStatistic(columnStatisticMetadata, variable);
}

StatisticAggregations aggregation = new StatisticAggregations(aggregations.build(), groupingVariables);
return new TableStatisticAggregation(aggregation, descriptor.build());
return new TableStatisticAggregation(aggregation, descriptor.build(), additionalVariables.build());
}

private ColumnStatisticsAggregation createColumnAggregation(ColumnStatisticMetadata columnStatisticMetadata, VariableReferenceExpression input)
{
FunctionHandle functionHandle = functionAndTypeResolver.lookupFunction(columnStatisticMetadata.getFunctionName(), TypeSignatureProvider.fromTypes(ImmutableList.of(input.getType())));
FunctionHandle functionHandle = functionAndTypeResolver.lookupFunction(columnStatisticMetadata.getFunctionName(), TypeSignatureProvider.fromTypes(ImmutableList.<Type>builder()
.add(input.getType())
.addAll(columnStatisticMetadata.getAdditionalArguments().stream().map(ConstantExpression::getType).iterator())
.build()));
FunctionMetadata functionMeta = functionAndTypeResolver.getFunctionMetadata(functionHandle);
Type inputType = functionAndTypeResolver.getType(getOnlyElement(functionMeta.getArgumentTypes()));
Type inputType = functionAndTypeResolver.getType(functionMeta.getArgumentTypes().get(0));
Type outputType = functionAndTypeResolver.getType(functionMeta.getReturnType());
verify(inputType.equals(input.getType()) || input.getType().equals(UNKNOWN), "resolved function input type does not match the input type: %s != %s", inputType, input.getType());
ImmutableMap.Builder<VariableReferenceExpression, ConstantExpression> additionalArgVariables = ImmutableMap.builder();
return new ColumnStatisticsAggregation(
new AggregationNode.Aggregation(
new CallExpression(
input.getSourceLocation(),
columnStatisticMetadata.getFunctionName(),
functionHandle,
outputType,
ImmutableList.of(input)),
ImmutableList.<RowExpression>builder()
.add(input)
.addAll(columnStatisticMetadata.getAdditionalArguments().stream()
.map(arg -> {
VariableReferenceExpression newExpr = variableAllocator.newVariable(arg);
additionalArgVariables.put(newExpr, arg);
return newExpr;
})
.iterator())
.build()),
Optional.empty(),
Optional.empty(),
false,
Optional.empty()),
outputType);
outputType,
additionalArgVariables.build());
}

public static class TableStatisticAggregation
{
private final StatisticAggregations aggregations;
private final StatisticAggregationsDescriptor<VariableReferenceExpression> descriptor;
private final Map<VariableReferenceExpression, RowExpression> additionalVariables;

private TableStatisticAggregation(
StatisticAggregations aggregations,
StatisticAggregationsDescriptor<VariableReferenceExpression> descriptor)
StatisticAggregationsDescriptor<VariableReferenceExpression> descriptor,
Map<VariableReferenceExpression, RowExpression> additionalVariables)
{
this.aggregations = requireNonNull(aggregations, "statisticAggregations is null");
this.descriptor = requireNonNull(descriptor, "descriptor is null");
this.additionalVariables = requireNonNull(additionalVariables, "additionalVariables is null");
}

public StatisticAggregations getAggregations()
Expand All @@ -156,17 +176,24 @@ public StatisticAggregationsDescriptor<VariableReferenceExpression> getDescripto
{
return descriptor;
}

public Map<VariableReferenceExpression, RowExpression> getAdditionalVariables()
{
return (Map<VariableReferenceExpression, RowExpression>) additionalVariables;
}
}

public static class ColumnStatisticsAggregation
{
private final AggregationNode.Aggregation aggregation;
private final Type outputType;
private final Map<VariableReferenceExpression, ConstantExpression> additionalVariables;

private ColumnStatisticsAggregation(AggregationNode.Aggregation aggregation, Type outputType)
private ColumnStatisticsAggregation(AggregationNode.Aggregation aggregation, Type outputType, Map<VariableReferenceExpression, ConstantExpression> additionalVariables)
{
this.aggregation = requireNonNull(aggregation, "aggregation is null");
this.outputType = requireNonNull(outputType, "outputType is null");
this.additionalVariables = requireNonNull(additionalVariables, "additionalVariable is null");
}

public AggregationNode.Aggregation getAggregation()
Expand All @@ -178,5 +205,10 @@ public Type getOutputType()
{
return outputType;
}

public Map<VariableReferenceExpression, ConstantExpression> getAdditionalVariables()
{
return additionalVariables;
}
}
}
Loading

0 comments on commit 0dceef8

Please sign in to comment.