From 661f4bae384c61ce7a1555728433c77c47ff8545 Mon Sep 17 00:00:00 2001 From: Marius Grama Date: Wed, 2 Aug 2023 22:55:28 +0200 Subject: [PATCH] Add concurrent writes reconciliation for `INSERT` in Delta Lake --- .../plugin/deltalake/DeltaLakeMetadata.java | 101 ++++++++++++++---- .../BaseDeltaLakeConnectorSmokeTest.java | 58 +++++++++- 2 files changed, 135 insertions(+), 24 deletions(-) diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java index ead8ecbdb032..da08462690f7 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java @@ -14,6 +14,7 @@ package io.trino.plugin.deltalake; import com.fasterxml.jackson.core.JsonProcessingException; +import com.google.common.base.Throwables; import com.google.common.base.VerifyException; import com.google.common.collect.Comparators; import com.google.common.collect.ImmutableList; @@ -22,6 +23,8 @@ import com.google.common.collect.ImmutableTable; import com.google.common.collect.Iterables; import com.google.common.collect.Sets; +import dev.failsafe.Failsafe; +import dev.failsafe.RetryPolicy; import io.airlift.json.JsonCodec; import io.airlift.log.Logger; import io.airlift.slice.Slice; @@ -138,6 +141,7 @@ import java.io.IOException; import java.net.URI; import java.net.URISyntaxException; +import java.time.Duration; import java.time.Instant; import java.util.ArrayDeque; import java.util.Collection; @@ -157,6 +161,7 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.stream.Collectors; +import java.util.stream.LongStream; import java.util.stream.Stream; import static com.google.common.base.Preconditions.checkArgument; @@ -187,6 +192,7 @@ import static io.trino.plugin.deltalake.DeltaLakeColumnType.PARTITION_KEY; import static io.trino.plugin.deltalake.DeltaLakeColumnType.REGULAR; import static io.trino.plugin.deltalake.DeltaLakeColumnType.SYNTHESIZED; +import static io.trino.plugin.deltalake.DeltaLakeErrorCode.DELTA_LAKE_BAD_DATA; import static io.trino.plugin.deltalake.DeltaLakeErrorCode.DELTA_LAKE_BAD_WRITE; import static io.trino.plugin.deltalake.DeltaLakeErrorCode.DELTA_LAKE_FILESYSTEM_ERROR; import static io.trino.plugin.deltalake.DeltaLakeErrorCode.DELTA_LAKE_INVALID_SCHEMA; @@ -241,6 +247,7 @@ import static io.trino.plugin.deltalake.transactionlog.MetadataEntry.configurationForNewTable; import static io.trino.plugin.deltalake.transactionlog.TransactionLogParser.getMandatoryCurrentVersion; import static io.trino.plugin.deltalake.transactionlog.TransactionLogUtil.getTransactionLogDir; +import static io.trino.plugin.deltalake.transactionlog.checkpoint.TransactionLogTail.getEntriesFromJson; import static io.trino.plugin.hive.HiveMetadata.PRESTO_QUERY_ID_NAME; import static io.trino.plugin.hive.TableType.EXTERNAL_TABLE; import static io.trino.plugin.hive.TableType.MANAGED_TABLE; @@ -258,6 +265,7 @@ import static io.trino.spi.StandardErrorCode.INVALID_SCHEMA_PROPERTY; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.StandardErrorCode.QUERY_REJECTED; +import static io.trino.spi.StandardErrorCode.TRANSACTION_CONFLICT; import static io.trino.spi.connector.RetryMode.NO_RETRIES; import static io.trino.spi.connector.RowChangeParadigm.DELETE_ROW_AND_INSERT_ROW; import static io.trino.spi.connector.SchemaTableName.schemaTableName; @@ -324,6 +332,12 @@ public class DeltaLakeMetadata private static final int CDF_SUPPORTED_WRITER_VERSION = 4; private static final int COLUMN_MAPPING_MODE_SUPPORTED_READER_VERSION = 2; private static final int COLUMN_MAPPING_MODE_SUPPORTED_WRITER_VERSION = 5; + private static final RetryPolicy TRANSACTION_CONFLICT_RETRY_POLICY = RetryPolicy.builder() + .handleIf(throwable -> Throwables.getRootCause(throwable) instanceof TransactionConflictException) + .withDelay(Duration.ofMillis(200)) + .withJitter(Duration.ofMillis(100)) + .withMaxRetries(5) + .build(); // Matches the dummy column Databricks stores in the metastore private static final List DUMMY_DATA_COLUMNS = ImmutableList.of( @@ -1701,30 +1715,10 @@ public Optional finishInsert( boolean writeCommitted = false; try { - TransactionLogWriter transactionLogWriter = transactionLogWriterFactory.newWriter(session, handle.getLocation()); - - long createdTime = Instant.now().toEpochMilli(); - - TrinoFileSystem fileSystem = fileSystemFactory.create(session); - long commitVersion = getMandatoryCurrentVersion(fileSystem, handle.getLocation()) + 1; - if (commitVersion != handle.getReadVersion() + 1) { - throw new TransactionConflictException(format("Conflicting concurrent writes found. Expected transaction log version: %s, actual version: %s", - handle.getReadVersion(), - commitVersion - 1)); - } - Optional checkpointInterval = handle.getMetadataEntry().getCheckpointInterval(); - // it is not obvious why we need to persist this readVersion - transactionLogWriter.appendCommitInfoEntry(getCommitInfoEntry(session, commitVersion, createdTime, INSERT_OPERATION, handle.getReadVersion())); - - ColumnMappingMode columnMappingMode = getColumnMappingMode(handle.getMetadataEntry()); - List partitionColumns = getPartitionColumns( - handle.getMetadataEntry().getOriginalPartitionColumns(), - handle.getInputColumns(), - columnMappingMode); - appendAddFileEntries(transactionLogWriter, dataFileInfos, partitionColumns, true); - - transactionLogWriter.flush(); + long commitVersion = Failsafe.with(TRANSACTION_CONFLICT_RETRY_POLICY) + .get(() -> getInsertCommitVersion(session, handle, dataFileInfos)); writeCommitted = true; + Optional checkpointInterval = handle.getMetadataEntry().getCheckpointInterval(); writeCheckpointIfNeeded(session, handle.getTableName(), handle.getLocation(), checkpointInterval, commitVersion); if (isCollectExtendedStatisticsColumnStatisticsOnWrite(session) && !computedStatistics.isEmpty() && !dataFileInfos.isEmpty()) { @@ -1755,6 +1749,67 @@ public Optional finishInsert( return Optional.empty(); } + private long getInsertCommitVersion(ConnectorSession session, DeltaLakeInsertTableHandle handle, List dataFileInfos) + throws IOException + { + long createdTime = Instant.now().toEpochMilli(); + + TrinoFileSystem fileSystem = fileSystemFactory.create(session); + + String transactionLogDirectory = getTransactionLogDir(handle.getLocation()); + long currentVersion = getMandatoryCurrentVersion(fileSystem, handle.getLocation()); + if (currentVersion < handle.getReadVersion()) { + throw new TrinoException(TRANSACTION_CONFLICT, format("Conflicting concurrent writes found. Expected transaction log version: %s, actual version: %s", + handle.getReadVersion(), + currentVersion)); + } + else if (currentVersion > handle.getReadVersion()) { + // Ensure there are no structural changes on the table if concurrent writes finished in the meantime + List transactionLogEntries = LongStream.rangeClosed(handle.getReadVersion() + 1, currentVersion) + .boxed() + .flatMap(version -> { + try { + return getEntriesFromJson(version, transactionLogDirectory, fileSystem) + .orElseThrow(() -> new TrinoException(DELTA_LAKE_BAD_DATA, "Delta Lake log entries are missing for version " + version)) + .stream(); + } + catch (IOException e) { + throw new TrinoException(DELTA_LAKE_FILESYSTEM_ERROR, "Failed to access table metadata", e); + } + }) + .collect(toImmutableList()); + Optional currentMetadataEntry = transactionLogEntries.stream() + .map(DeltaLakeTransactionLogEntry::getMetaData) + .filter(Objects::nonNull) + .findFirst(); + if (currentMetadataEntry.isPresent()) { + throw new TrinoException(TRANSACTION_CONFLICT, format("Conflicting concurrent writes found. Metadata changed since the version: %s", handle.getReadVersion())); + } + Optional currentProtocolEntry = transactionLogEntries.stream() + .map(DeltaLakeTransactionLogEntry::getProtocol) + .filter(Objects::nonNull) + .findFirst(); + if (currentProtocolEntry.isPresent()) { + throw new TrinoException(TRANSACTION_CONFLICT, format("Conflicting concurrent writes found. Protocol changed since the version: %s", handle.getReadVersion())); + } + } + long commitVersion = currentVersion + 1; + // it is not obvious why we need to persist this readVersion + TransactionLogWriter transactionLogWriter = transactionLogWriterFactory.newWriter(session, handle.getLocation()); + transactionLogWriter.appendCommitInfoEntry(getCommitInfoEntry(session, commitVersion, createdTime, INSERT_OPERATION, currentVersion)); + + ColumnMappingMode columnMappingMode = getColumnMappingMode(handle.getMetadataEntry()); + List partitionColumns = getPartitionColumns( + handle.getMetadataEntry().getOriginalPartitionColumns(), + handle.getInputColumns(), + columnMappingMode); + appendAddFileEntries(transactionLogWriter, dataFileInfos, partitionColumns, true); + + transactionLogWriter.flush(); + + return commitVersion; + } + private static List getPartitionColumns(List originalPartitionColumns, List dataColumns, ColumnMappingMode columnMappingMode) { return switch (columnMappingMode) { diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeConnectorSmokeTest.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeConnectorSmokeTest.java index 18974d218837..e127e3895292 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeConnectorSmokeTest.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeConnectorSmokeTest.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import io.airlift.concurrent.MoreFutures; import io.airlift.units.DataSize; import io.airlift.units.Duration; import io.trino.Session; @@ -32,6 +33,7 @@ import io.trino.spi.QueryId; import io.trino.sql.planner.OptimizerConfig.JoinDistributionType; import io.trino.testing.BaseConnectorSmokeTest; +import io.trino.testing.DataProviders; import io.trino.testing.DistributedQueryRunner; import io.trino.testing.MaterializedResult; import io.trino.testing.MaterializedResultWithQueryId; @@ -49,6 +51,9 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.Callable; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.ExecutorService; import java.util.function.BiConsumer; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -80,6 +85,7 @@ import static io.trino.tpch.TpchTable.ORDERS; import static java.lang.String.format; import static java.util.Comparator.comparing; +import static java.util.concurrent.Executors.newFixedThreadPool; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.SECONDS; import static org.assertj.core.api.Assertions.assertThat; @@ -1324,7 +1330,8 @@ public Object[][] testCheckpointWriteStatsAsStructDataProvider() {"varchar", "'test'", "'ŻŻŻŻŻŻŻŻŻŻ'", "0.0", "null", "null"}, {"varbinary", "X'65683F'", "X'ffffffffffffffffffff'", "0.0", "null", "null"}, {"date", "date '2021-02-03'", "date '9999-12-31'", "0.0", "'2021-02-03'", "'9999-12-31'"}, - {"timestamp(3) with time zone", "timestamp '2001-08-22 03:04:05.321 -08:00'", "timestamp '9999-12-31 23:59:59.999 +12:00'", "0.0", "'2001-08-22 11:04:05.321 UTC'", "'9999-12-31 11:59:59.999 UTC'"}, + {"timestamp(3) with time zone", "timestamp '2001-08-22 03:04:05.321 -08:00'", "timestamp '9999-12-31 23:59:59.999 +12:00'", "0.0", "'2001-08-22 11:04:05.321 UTC'", + "'9999-12-31 11:59:59.999 UTC'"}, {"array(int)", "array[1]", "array[2147483647]", "null", "null", "null"}, {"map(varchar,int)", "map(array['foo', 'bar'], array[1, 2])", "map(array['foo', 'bar'], array[-2147483648, 2147483647])", "null", "null", "null"}, {"row(x bigint)", "cast(row(1) as row(x bigint))", "cast(row(9223372036854775807) as row(x bigint))", "null", "null", "null"}, @@ -2151,6 +2158,55 @@ public void testPartitionFilterIncluded() } } + @Test(dataProviderClass = DataProviders.class, dataProvider = "trueFalse") + public void testConcurrentModificationsReconciliation(boolean partitioned) + throws Exception + { + int threads = 3; + + CyclicBarrier barrier = new CyclicBarrier(threads); + ExecutorService executor = newFixedThreadPool(threads); + String tableName = "test_concurrent_inserts_table_" + randomNameSuffix(); + + assertUpdate("CREATE TABLE " + tableName + " (a INT, part INT) " + + (partitioned ? " WITH (partitioned_by = ARRAY['part'])" : "")); + + try { + // insert data concurrently + executor.invokeAll(ImmutableList.>builder() + .add(() -> { + barrier.await(20, SECONDS); + getQueryRunner().execute("INSERT INTO " + tableName + " VALUES (1, 10)"); + return null; + }) + .add(() -> { + barrier.await(20, SECONDS); + getQueryRunner().execute("INSERT INTO " + tableName + " VALUES (11, 20)"); + return null; + }) + .add(() -> { + barrier.await(20, SECONDS); + getQueryRunner().execute("INSERT INTO " + tableName + " VALUES (21, 30)"); + return null; + }) + .build()) + .forEach(MoreFutures::getDone); + + assertThat(query("SELECT SUM(a) FROM " + tableName)).matches("VALUES BIGINT '33'"); + assertQuery("SELECT version, operation, read_version FROM \"" + tableName + "$history\"", + """ + VALUES + (0, 'CREATE TABLE', 0), + (1, 'WRITE', 0), + (2, 'WRITE', 1), + (3, 'WRITE', 2) + """); + } + finally { + assertUpdate("DROP TABLE " + tableName); + } + } + private Set getActiveFiles(String tableName) { return getActiveFiles(tableName, getQueryRunner().getDefaultSession());