Skip to content

Commit

Permalink
[Enhancement] Reduce the memory useage of TableStatistic (backport #5…
Browse files Browse the repository at this point in the history
…0316) (#50371)

Co-authored-by: Seaven <[email protected]>
  • Loading branch information
mergify[bot] and Seaven committed Sep 7, 2024
1 parent 95f82da commit 286198b
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 134 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public class CachedStatisticStorage implements StatisticStorage {
private final Executor statsCacheRefresherExecutor = Executors.newFixedThreadPool(Config.statistic_cache_thread_pool_size,
new ThreadFactoryBuilder().setDaemon(true).setNameFormat("stats-cache-refresher-%d").build());

AsyncLoadingCache<TableStatsCacheKey, Optional<TableStatistic>> tableStatsCache = Caffeine.newBuilder()
AsyncLoadingCache<TableStatsCacheKey, Optional<Long>> tableStatsCache = Caffeine.newBuilder()
.expireAfterWrite(Config.statistic_update_interval_sec * 2, TimeUnit.SECONDS)
.refreshAfterWrite(Config.statistic_update_interval_sec, TimeUnit.SECONDS)
.maximumSize(Config.statistic_cache_columns)
Expand Down Expand Up @@ -88,30 +88,29 @@ public class CachedStatisticStorage implements StatisticStorage {
.buildAsync(new ConnectorHistogramColumnStatsCacheLoader());

@Override
public Map<Long, TableStatistic> getTableStatistics(Long tableId, Collection<Partition> partitions) {
public Map<Long, Optional<Long>> getTableStatistics(Long tableId, Collection<Partition> partitions) {
// get Statistics Table column info, just return default column statistics
if (StatisticUtils.statisticTableBlackListCheck(tableId)) {
return partitions.stream().collect(Collectors.toMap(Partition::getId, p -> TableStatistic.unknown()));
return partitions.stream().collect(Collectors.toMap(Partition::getId, p -> Optional.empty()));
}

List<TableStatsCacheKey> keys = partitions.stream().map(p -> new TableStatsCacheKey(tableId, p.getId()))
.collect(Collectors.toList());

try {
CompletableFuture<Map<TableStatsCacheKey, Optional<TableStatistic>>> result = tableStatsCache.getAll(keys);
CompletableFuture<Map<TableStatsCacheKey, Optional<Long>>> result = tableStatsCache.getAll(keys);
if (result.isDone()) {
Map<TableStatsCacheKey, Optional<TableStatistic>> data = result.get();
return keys.stream().collect(Collectors.toMap(
TableStatsCacheKey::getPartitionId,
k -> data.getOrDefault(k, Optional.empty()).orElse(TableStatistic.unknown())));
Map<TableStatsCacheKey, Optional<Long>> data = result.get();
return keys.stream().collect(Collectors.toMap(TableStatsCacheKey::getPartitionId,
k -> data.getOrDefault(k, Optional.empty())));
}
} catch (InterruptedException e) {
LOG.warn("Failed to execute tableStatsCache.getAll", e);
Thread.currentThread().interrupt();
} catch (Exception e) {
LOG.warn("Faied to execute tableStatsCache.getAll", e);
}
return partitions.stream().collect(Collectors.toMap(Partition::getId, p -> TableStatistic.unknown()));
return partitions.stream().collect(Collectors.toMap(Partition::getId, p -> Optional.empty()));
}

@Override
Expand All @@ -122,7 +121,7 @@ public void refreshTableStatistic(Table table) {
}

try {
CompletableFuture<Map<TableStatsCacheKey, Optional<TableStatistic>>> completableFuture
CompletableFuture<Map<TableStatsCacheKey, Optional<Long>>> completableFuture
= tableStatsCache.getAll(statsCacheKeyList);
if (completableFuture.isDone()) {
completableFuture.get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,13 @@
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

public interface StatisticStorage {
default TableStatistic getTableStatistic(Long tableId, Long partitionId) {
return TableStatistic.unknown();
}

// partitionId: TableStatistic
default Map<Long, TableStatistic> getTableStatistics(Long tableId, Collection<Partition> partitions) {
return partitions.stream().collect(Collectors.toMap(Partition::getId, p -> TableStatistic.unknown()));
// partitionId: RowCount
default Map<Long, Optional<Long>> getTableStatistics(Long tableId, Collection<Partition> partitions) {
return partitions.stream().collect(Collectors.toMap(Partition::getId, p -> Optional.empty()));
}

default void refreshTableStatistic(Table table) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

public class StatisticsCalcUtils {
Expand Down Expand Up @@ -121,20 +122,20 @@ public static long getTableRowCount(Table table, Operator node, OptimizerContext
// For example, a large amount of data LOAD may cause the number of rows to change greatly.
// This leads to very inaccurate row counts.
long deltaRows = deltaRows(table, basicStatsMeta.getUpdateRows());
Map<Long, TableStatistic> tableStatisticMap = GlobalStateMgr.getCurrentState().getStatisticStorage()
Map<Long, Optional<Long>> tableStatisticMap = GlobalStateMgr.getCurrentState().getStatisticStorage()
.getTableStatistics(table.getId(), selectedPartitions);
for (Partition partition : selectedPartitions) {
long partitionRowCount;
TableStatistic tableStatistic =
tableStatisticMap.getOrDefault(partition.getId(), TableStatistic.unknown());
Optional<Long> tableStatistic =
tableStatisticMap.getOrDefault(partition.getId(), Optional.empty());
LocalDateTime updateDatetime = StatisticUtils.getPartitionLastUpdateTime(partition);
if (tableStatistic.equals(TableStatistic.unknown())) {
if (tableStatistic.isEmpty()) {
partitionRowCount = partition.getRowCount();
if (updateDatetime.isAfter(lastWorkTimestamp)) {
partitionRowCount += deltaRows;
}
} else {
partitionRowCount = tableStatistic.getRowCount();
partitionRowCount = tableStatistic.get();
if (updateDatetime.isAfter(basicStatsMeta.getUpdateTime())) {
partitionRowCount += deltaRows;
}
Expand Down Expand Up @@ -184,17 +185,13 @@ private static void updateQueryDumpInfo(OptimizerContext optimizerContext, Table

private static long deltaRows(Table table, long totalRowCount) {
long tblRowCount = 0L;
Map<Long, TableStatistic> tableStatisticMap = GlobalStateMgr.getCurrentState().getStatisticStorage()
Map<Long, Optional<Long>> tableStatisticMap = GlobalStateMgr.getCurrentState().getStatisticStorage()
.getTableStatistics(table.getId(), table.getPartitions());

for (Partition partition : table.getPartitions()) {
long partitionRowCount;
TableStatistic statistic = tableStatisticMap.getOrDefault(partition.getId(), TableStatistic.unknown());
if (statistic.equals(TableStatistic.unknown())) {
partitionRowCount = partition.getRowCount();
} else {
partitionRowCount = statistic.getRowCount();
}
Optional<Long> statistic = tableStatisticMap.getOrDefault(partition.getId(), Optional.empty());
partitionRowCount = statistic.orElseGet(partition::getRowCount);
tblRowCount += partitionRowCount;
}
if (tblRowCount < totalRowCount) {
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
package com.starrocks.sql.optimizer.statistics;

import com.github.benmanes.caffeine.cache.AsyncCacheLoader;
import com.google.api.client.util.Lists;
import com.starrocks.common.Config;
import com.starrocks.qe.ConnectContext;
import com.starrocks.statistic.StatisticExecutor;
import com.starrocks.statistic.StatisticUtils;
Expand All @@ -29,22 +31,22 @@
import java.util.concurrent.CompletionException;
import java.util.concurrent.Executor;

public class TableStatsCacheLoader implements AsyncCacheLoader<TableStatsCacheKey, Optional<TableStatistic>> {
public class TableStatsCacheLoader implements AsyncCacheLoader<TableStatsCacheKey, Optional<Long>> {
private final StatisticExecutor statisticExecutor = new StatisticExecutor();

@Override
public @NonNull CompletableFuture<Optional<TableStatistic>> asyncLoad(@NonNull TableStatsCacheKey cacheKey, @
public @NonNull CompletableFuture<Optional<Long>> asyncLoad(@NonNull TableStatsCacheKey cacheKey, @
NonNull Executor executor) {
return CompletableFuture.supplyAsync(() -> {
try {
ConnectContext connectContext = StatisticUtils.buildConnectContext();
connectContext.setThreadLocalInfo();
List<TStatisticData> statisticData = queryStatisticsData(connectContext, cacheKey.tableId, cacheKey.partitionId);
if (statisticData.size() == 0) {
return Optional.of(new TableStatistic(cacheKey.getTableId(), cacheKey.getPartitionId(), 0L));
List<TStatisticData> statisticData =
statisticExecutor.queryTableStats(connectContext, cacheKey.tableId, cacheKey.partitionId);
if (statisticData.isEmpty()) {
return Optional.empty();
} else {
return Optional.of(new TableStatistic(cacheKey.getTableId(), cacheKey.getPartitionId(),
statisticData.get(0).rowCount));
return Optional.of(statisticData.get(0).rowCount);
}
} catch (RuntimeException e) {
throw e;
Expand All @@ -57,31 +59,40 @@ public class TableStatsCacheLoader implements AsyncCacheLoader<TableStatsCacheKe
}

@Override
public @NonNull CompletableFuture<Map<@NonNull TableStatsCacheKey, @NonNull Optional<TableStatistic>>> asyncLoadAll(
public @NonNull CompletableFuture<Map<@NonNull TableStatsCacheKey, @NonNull Optional<Long>>> asyncLoadAll(
@NonNull Iterable<? extends @NonNull TableStatsCacheKey> cacheKey,
@NonNull Executor executor) {
return CompletableFuture.supplyAsync(() -> {
try {
TableStatsCacheKey tableStatsCacheKey = cacheKey.iterator().next();
long tableId = tableStatsCacheKey.getTableId();

ConnectContext connectContext = StatisticUtils.buildConnectContext();
connectContext.setThreadLocalInfo();
List<TStatisticData> statisticData = queryStatisticsData(connectContext, tableStatsCacheKey.getTableId());

Map<TableStatsCacheKey, Optional<TableStatistic>> result = new HashMap<>();
for (TStatisticData tStatisticData : statisticData) {
result.put(new TableStatsCacheKey(tableId, tStatisticData.partitionId),
Optional.of(new TableStatistic(tableId, tStatisticData.partitionId, tStatisticData.rowCount)));
Map<TableStatsCacheKey, Optional<Long>> result = new HashMap<>();
List<Long> pids = Lists.newArrayList();
long tableId = -1;
for (TableStatsCacheKey statsCacheKey : cacheKey) {
pids.add(statsCacheKey.getPartitionId());
tableId = statsCacheKey.getTableId();
if (pids.size() > Config.expr_children_limit / 2) {
List<TStatisticData> statisticData =
statisticExecutor.queryTableStats(connectContext, statsCacheKey.getTableId(), pids);

statisticData.forEach(tStatisticData -> result.put(
new TableStatsCacheKey(statsCacheKey.getTableId(), tStatisticData.partitionId),
Optional.of(tStatisticData.rowCount)));
pids.clear();
}
}
List<TStatisticData> statisticData = statisticExecutor.queryTableStats(connectContext, tableId, pids);
for (TStatisticData data : statisticData) {
result.put(new TableStatsCacheKey(tableId, data.partitionId), Optional.of(data.rowCount));
}
for (TableStatsCacheKey key : cacheKey) {
if (!result.containsKey(key)) {
result.put(key, Optional.empty());
}
}

return result;

} catch (RuntimeException e) {
throw e;
} catch (Exception e) {
Expand All @@ -91,12 +102,4 @@ public class TableStatsCacheLoader implements AsyncCacheLoader<TableStatsCacheKe
}
}, executor);
}

private List<TStatisticData> queryStatisticsData(ConnectContext context, long tableId) {
return statisticExecutor.queryTableStats(context, tableId);
}

private List<TStatisticData> queryStatisticsData(ConnectContext context, long tableId, long partitionId) {
return statisticExecutor.queryTableStats(context, tableId, partitionId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import com.starrocks.common.io.Writable;
import com.starrocks.persist.gson.GsonUtils;
import com.starrocks.server.GlobalStateMgr;
import com.starrocks.sql.optimizer.statistics.TableStatistic;
import org.apache.commons.collections4.MapUtils;

import java.io.DataInput;
Expand All @@ -32,6 +31,7 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;

public class BasicStatsMeta implements Writable {
@SerializedName("dbId")
Expand Down Expand Up @@ -132,16 +132,13 @@ public double getHealthy() {
long updatePartitionRowCount = 0L;
long updatePartitionCount = 0L;

Map<Long, TableStatistic> tableStatistics = GlobalStateMgr.getCurrentState().getStatisticStorage()
Map<Long, Optional<Long>> tableStatistics = GlobalStateMgr.getCurrentState().getStatisticStorage()
.getTableStatistics(table.getId(), table.getPartitions());

for (Partition partition : table.getPartitions()) {
TableStatistic statistic = tableStatistics.getOrDefault(partition.getId(), TableStatistic.unknown());

tableRowCount += partition.getRowCount();
if (!statistic.equals(TableStatistic.unknown())) {
cachedTableRowCount += tableStatistics.get(partition.getId()).getRowCount();
}
Optional<Long> statistic = tableStatistics.getOrDefault(partition.getId(), Optional.empty());
cachedTableRowCount += statistic.orElse(0L);
LocalDateTime loadTime = StatisticUtils.getPartitionLastUpdateTime(partition);

if (partition.hasData() && !isUpdatedAfterLoad(loadTime)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,8 @@ public static Pair<List<TStatisticData>, Status> queryDictSync(Long dbId, Long t
}
}

public List<TStatisticData> queryTableStats(ConnectContext context, Long tableId) {
String sql = StatisticSQLBuilder.buildQueryTableStatisticsSQL(tableId);
public List<TStatisticData> queryTableStats(ConnectContext context, Long tableId, List<Long> partitions) {
String sql = StatisticSQLBuilder.buildQueryTableStatisticsSQL(tableId, partitions);
return executeStatisticDQL(context, sql);
}

Expand All @@ -234,7 +234,7 @@ public List<TStatisticData> queryTableStats(ConnectContext context, Long tableId
private static List<TStatisticData> deserializerStatisticData(List<TResultBatch> sqlResult) throws TException {
List<TStatisticData> statistics = Lists.newArrayList();

if (sqlResult.size() < 1) {
if (sqlResult.isEmpty()) {
return statistics;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,13 @@ public class StatisticSQLBuilder {
DEFAULT_VELOCITY_ENGINE.setProperty(VelocityEngine.RUNTIME_LOG_REFERENCE_LOG_INVALID, false);
}

public static String buildQueryTableStatisticsSQL(Long tableId) {
public static String buildQueryTableStatisticsSQL(Long tableId, List<Long> partitionIds) {
VelocityContext context = new VelocityContext();
context.put("predicate", "table_id = " + tableId);
if (!partitionIds.isEmpty()) {
context.put("predicate", "table_id = " + tableId + " and partition_id in (" +
partitionIds.stream().map(String::valueOf).collect(Collectors.joining(", ")) + ")");
}
return build(context, QUERY_TABLE_STATISTIC_TEMPLATE);
}

Expand Down
Loading

0 comments on commit 286198b

Please sign in to comment.