Skip to content

Commit

Permalink
Rollback hash index checkpoint (#4559)
Browse files Browse the repository at this point in the history
  • Loading branch information
royi-luo authored Dec 20, 2024
1 parent 3463783 commit 2d73b32
Show file tree
Hide file tree
Showing 19 changed files with 232 additions and 46 deletions.
15 changes: 15 additions & 0 deletions src/include/common/exception/checkpoint.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#pragma once

#include "common/api.h"
#include "exception.h"

namespace kuzu {
namespace common {

class KUZU_API CheckpointException : public Exception {
public:
explicit CheckpointException(const std::exception& e) : Exception(e.what()){};
};

} // namespace common
} // namespace kuzu
3 changes: 3 additions & 0 deletions src/include/main/client_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ class KUZU_API ClientContext {

void runFuncInTransaction(const std::function<void(void)>& fun);

std::unique_ptr<QueryResult> handleFailedExecution(
processor::ExecutionContext* executionContext, std::exception& e);

// Client side configurable settings.
ClientConfig clientConfig;
// Database configurable settings.
Expand Down
6 changes: 5 additions & 1 deletion src/include/storage/index/hash_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class OnDiskHashIndex {
virtual bool checkpoint() = 0;
virtual bool checkpointInMemory() = 0;
virtual bool rollbackInMemory() = 0;
virtual void rollbackCheckpoint() = 0;
virtual void bulkReserve(uint64_t numValuesToAppend) = 0;
};

Expand Down Expand Up @@ -149,6 +150,7 @@ class HashIndex final : public OnDiskHashIndex {
bool checkpoint() override;
bool checkpointInMemory() override;
bool rollbackInMemory() override;
void rollbackCheckpoint() override;
inline FileHandle* getFileHandle() const { return fileHandle; }

private:
Expand Down Expand Up @@ -378,10 +380,12 @@ class PrimaryKeyIndex {
void delete_(common::ValueVector* keyVector);

void checkpointInMemory();
void checkpoint();
void checkpoint(bool forceCheckpointAll = false);
FileHandle* getFileHandle() const { return fileHandle; }
OverflowFile* getOverflowFile() const { return overflowFile.get(); }

void rollbackCheckpoint();

common::PhysicalTypeID keyTypeID() const { return keyDataTypeID; }

void writeHeaders();
Expand Down
1 change: 1 addition & 0 deletions src/include/storage/storage_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class StorageManager {
main::ClientContext* context);

void checkpoint(main::ClientContext& clientContext);
void rollbackCheckpoint(main::ClientContext& clientContext);

PrimaryKeyIndex* getPKIndex(common::table_id_t tableID);

Expand Down
6 changes: 6 additions & 0 deletions src/include/storage/storage_structure/disk_array_collection.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ class DiskArrayCollection {
headerPagesOnDisk = headersForReadTrx.size();
}

void rollbackCheckpoint() {
for (size_t i = 0; i < headersForWriteTrx.size(); i++) {
*headersForWriteTrx[i] = *headersForReadTrx[i];
}
}

template<typename T>
std::unique_ptr<DiskArray<T>> getDiskArray(uint32_t idx) {
KU_ASSERT(idx < numHeaders);
Expand Down
2 changes: 1 addition & 1 deletion src/include/storage/storage_structure/overflow_file.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class OverflowFile {
OverflowFile(OverflowFile&& other) = delete;

void rollbackInMemory();
void checkpoint();
void checkpoint(bool forceUpdateHeader);
void checkpointInMemory();

OverflowFileHandle* addHandle() {
Expand Down
1 change: 1 addition & 0 deletions src/include/storage/store/node_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ class NodeTable final : public Table {

void commit(transaction::Transaction* transaction, LocalTable* localTable) override;
void checkpoint(common::Serializer& ser, catalog::TableCatalogEntry* tableEntry) override;
void rollbackCheckpoint() override;

void rollbackPKIndexInsert(const transaction::Transaction* transaction,
common::row_idx_t startRow, common::row_idx_t numRows_,
Expand Down
1 change: 1 addition & 0 deletions src/include/storage/store/rel_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ class RelTable final : public Table {

void commit(transaction::Transaction* transaction, LocalTable* localTable) override;
void checkpoint(common::Serializer& ser, catalog::TableCatalogEntry* tableEntry) override;
void rollbackCheckpoint() override {};

common::row_idx_t getNumTotalRows(const transaction::Transaction* transaction) override;

Expand Down
1 change: 1 addition & 0 deletions src/include/storage/store/table.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ class Table {

virtual void commit(transaction::Transaction* transaction, LocalTable* localTable) = 0;
virtual void checkpoint(common::Serializer& ser, catalog::TableCatalogEntry* tableEntry) = 0;
virtual void rollbackCheckpoint() = 0;

virtual common::row_idx_t getNumTotalRows(const transaction::Transaction* transaction) = 0;

Expand Down
2 changes: 1 addition & 1 deletion src/include/transaction/transaction_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "transaction.h"

namespace kuzu {

namespace main {
class ClientContext;
}
Expand Down Expand Up @@ -49,7 +50,6 @@ class KUZU_API TransactionContext {
bool hasActiveTransaction() const { return activeTransaction != nullptr; }
Transaction* getActiveTransaction() const { return activeTransaction.get(); }

private:
void clearTransaction();

private:
Expand Down
3 changes: 3 additions & 0 deletions src/include/transaction/transaction_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,15 @@ class TransactionManager {

void commit(main::ClientContext& clientContext);
void rollback(main::ClientContext& clientContext, const Transaction* transaction);

void checkpoint(main::ClientContext& clientContext);

private:
bool canAutoCheckpoint(const main::ClientContext& clientContext) const;
bool canCheckpointNoLock() const;
void checkpointNoLock(main::ClientContext& clientContext);
void rollbackCheckpoint(main::ClientContext& clientContext);

// This functions locks the mutex to start new transactions.
common::UniqLock stopNewTransactionsAndWaitUntilAllTransactionsLeave();

Expand Down
18 changes: 14 additions & 4 deletions src/main/client_context.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "main/client_context.h"

#include "binder/binder.h"
#include "common/exception/checkpoint.h"
#include "common/exception/connection.h"
#include "common/exception/runtime.h"
#include "common/random_engine.h"
Expand Down Expand Up @@ -509,13 +510,14 @@ std::unique_ptr<QueryResult> ClientContext::executeNoLock(PreparedStatement* pre
this->transactionContext->commit();
}
}
} catch (CheckpointException& e) {
transactionContext->clearTransaction();
return handleFailedExecution(executionContext.get(), e);
} catch (std::exception& e) {
transactionContext->rollback();
getMemoryManager()->getBufferManager()->getSpillerOrSkip(
[](auto& spiller) { spiller.clearFile(); });
progressBar->endProgress(executionContext->queryID);
return queryResultWithError(e.what());
return handleFailedExecution(executionContext.get(), e);
}

getMemoryManager()->getBufferManager()->getSpillerOrSkip(
[](auto& spiller) { spiller.clearFile(); });
executingTimer.stop();
Expand All @@ -527,6 +529,14 @@ std::unique_ptr<QueryResult> ClientContext::executeNoLock(PreparedStatement* pre
return queryResult;
}

std::unique_ptr<QueryResult> ClientContext::handleFailedExecution(
ExecutionContext* executionContext, std::exception& e) {
getMemoryManager()->getBufferManager()->getSpillerOrSkip(
[](auto& spiller) { spiller.clearFile(); });
progressBar->endProgress(executionContext->queryID);
return queryResultWithError(e.what());
}

// If there is an active transaction in the context, we execute the function in current active
// transaction. If there is no active transaction, we start an auto commit transaction.
void ClientContext::runFuncInTransaction(const std::function<void(void)>& fun) {
Expand Down
30 changes: 27 additions & 3 deletions src/storage/index/hash_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ bool HashIndex<T>::rollbackInMemory() {
return true;
}

template<typename T>
void HashIndex<T>::rollbackCheckpoint() {
pSlots->rollbackInMemoryIfNecessary();
oSlots->rollbackInMemoryIfNecessary();
}

template<typename T>
void HashIndex<T>::splitSlots(const Transaction* transaction, HashIndexHeader& header,
slot_id_t numSlotsToSplit) {
Expand Down Expand Up @@ -485,6 +491,12 @@ PrimaryKeyIndex::PrimaryKeyIndex(const DBFileIDAndName& dbFileIDAndName, bool re
}
},
[&](auto) { KU_UNREACHABLE; });

if (newIndex && !inMemMode) {
// checkpoint the creation of the index so that if we need to rollback it will be to a
// state we can retry from (an empty index with the disk arrays initialized)
checkpoint(true /* forceCheckpointAll */);
}
}

bool PrimaryKeyIndex::lookup(const Transaction* trx, ValueVector* keyVector, uint64_t vectorPos,
Expand Down Expand Up @@ -563,19 +575,31 @@ void PrimaryKeyIndex::writeHeaders() {
KU_ASSERT(headerIdx == NUM_HASH_INDEXES);
}

void PrimaryKeyIndex::checkpoint() {
void PrimaryKeyIndex::rollbackCheckpoint() {
for (idx_t i = 0; i < NUM_HASH_INDEXES; ++i) {
hashIndices[i]->rollbackCheckpoint();
}
hashIndexDiskArrays->rollbackCheckpoint();
hashIndexHeadersForWriteTrx.assign(hashIndexHeadersForReadTrx.begin(),
hashIndexHeadersForReadTrx.end());
if (overflowFile) {
overflowFile->rollbackInMemory();
}
}

void PrimaryKeyIndex::checkpoint(bool forceCheckpointAll) {
bool indexChanged = false;
for (auto i = 0u; i < NUM_HASH_INDEXES; i++) {
if (hashIndices[i]->checkpoint()) {
indexChanged = true;
}
}
if (indexChanged) {
if (indexChanged || forceCheckpointAll) {
writeHeaders();
hashIndexDiskArrays->checkpoint();
}
if (overflowFile) {
overflowFile->checkpoint();
overflowFile->checkpoint(forceCheckpointAll);
}
// Make sure that changes which bypassed the WAL are written.
// There is no other mechanism for enforcing that they are flushed
Expand Down
13 changes: 13 additions & 0 deletions src/storage/storage_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,19 @@ void StorageManager::checkpoint(main::ClientContext& clientContext) {
shadowFile->flushAll();
}

void StorageManager::rollbackCheckpoint(main::ClientContext& clientContext) {
if (main::DBConfig::isDBPathInMemory(databasePath)) {
return;
}
std::lock_guard lck{mtx};
const auto nodeTableEntries =
clientContext.getCatalog()->getNodeTableEntries(&DUMMY_CHECKPOINT_TRANSACTION);
for (const auto tableEntry : nodeTableEntries) {
KU_ASSERT(tables.contains(tableEntry->getTableID()));
tables.at(tableEntry->getTableID())->rollbackCheckpoint();
}
}

StorageManager::~StorageManager() = default;

} // namespace storage
Expand Down
4 changes: 2 additions & 2 deletions src/storage/storage_structure/overflow_file.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ void OverflowFile::writePageToDisk(page_idx_t pageIdx, uint8_t* data) const {
}
}

void OverflowFile::checkpoint() {
void OverflowFile::checkpoint(bool forceUpdateHeader) {
KU_ASSERT(fileHandle);
if (fileHandle->getNumPages() < pageCounter) {
fileHandle->addNewPages(pageCounter - fileHandle->getNumPages());
Expand All @@ -218,7 +218,7 @@ void OverflowFile::checkpoint() {
for (auto& handle : handles) {
handle->checkpoint();
}
if (headerChanged) {
if (headerChanged || forceUpdateHeader) {
uint8_t page[KUZU_PAGE_SIZE];
header.pages = pageCounter;
memcpy(page, &header, sizeof(header));
Expand Down
4 changes: 4 additions & 0 deletions src/storage/store/node_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,10 @@ void NodeTable::rollbackGroupCollectionInsert(common::row_idx_t numRows_) {
nodeGroups->rollbackInsert(numRows_);
}

void NodeTable::rollbackCheckpoint() {
pkIndex->rollbackCheckpoint();
}

TableStats NodeTable::getStats(const Transaction* transaction) const {
auto stats = nodeGroups->getStats();
const auto localTable = transaction->getLocalStorage()->getLocalTable(tableID,
Expand Down
59 changes: 36 additions & 23 deletions src/transaction/transaction_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <thread>

#include "common/exception/checkpoint.h"
#include "common/exception/transaction_manager.h"
#include "main/client_context.h"
#include "main/db_config.h"
Expand Down Expand Up @@ -94,8 +95,15 @@ void TransactionManager::rollback(main::ClientContext& clientContext,
}
}

void TransactionManager::rollbackCheckpoint(main::ClientContext& clientContext) {
if (main::DBConfig::isDBPathInMemory(clientContext.getDatabasePath())) {
return;
}
clientContext.getStorageManager()->rollbackCheckpoint(clientContext);
}

void TransactionManager::checkpoint(main::ClientContext& clientContext) {
std::unique_lock<std::mutex> lck{mtxForSerializingPublicFunctionCalls};
common::UniqLock lck{mtxForSerializingPublicFunctionCalls};
if (main::DBConfig::isDBPathInMemory(clientContext.getDatabasePath())) {
return;
}
Expand Down Expand Up @@ -152,28 +160,33 @@ void TransactionManager::checkpointNoLock(main::ClientContext& clientContext) {
// query stop working on the tasks of the query and these tasks are removed from the
// query.
auto lockForStartingTransaction = stopNewTransactionsAndWaitUntilAllTransactionsLeave();
// Checkpoint node/relTables, which writes the updated/newly-inserted pages and metadata to
// disk.
clientContext.getStorageManager()->checkpoint(clientContext);
// Checkpoint catalog, which serializes a snapshot of the catalog to disk.
clientContext.getCatalog()->checkpoint(clientContext.getDatabasePath(),
clientContext.getVFSUnsafe());
// Log the checkpoint to the WAL and flush WAL. This indicates that all shadow pages and files(
// snapshots of catalog and metadata) have been written to disk. The part is not done is replace
// them with the original pages or catalog and metadata files.
// If the system crashes before this point, the WAL can still be used to recover the system to a
// state where the checkpoint can be redo.
wal.logAndFlushCheckpoint();
// Replace the original pages and catalog and metadata files with the updated/newly-created
// ones.
StorageUtils::overwriteWALVersionFiles(clientContext.getDatabasePath(),
clientContext.getVFSUnsafe());
clientContext.getStorageManager()->getShadowFile().replayShadowPageRecords(clientContext);
// Clear the wal, and also shadowing files.
wal.clearWAL();
clientContext.getStorageManager()->getShadowFile().clearAll(clientContext);
StorageUtils::removeWALVersionFiles(clientContext.getDatabasePath(),
clientContext.getVFSUnsafe());
try {
// Checkpoint node/relTables, which writes the updated/newly-inserted pages and metadata to
// disk.
clientContext.getStorageManager()->checkpoint(clientContext);
// Checkpoint catalog, which serializes a snapshot of the catalog to disk.
clientContext.getCatalog()->checkpoint(clientContext.getDatabasePath(),
clientContext.getVFSUnsafe());
// Log the checkpoint to the WAL and flush WAL. This indicates that all shadow pages and
// files( snapshots of catalog and metadata) have been written to disk. The part is not done
// is replace them with the original pages or catalog and metadata files. If the system
// crashes before this point, the WAL can still be used to recover the system to a state
// where the checkpoint can be redo.
wal.logAndFlushCheckpoint();
// Replace the original pages and catalog and metadata files with the updated/newly-created
// ones.
StorageUtils::overwriteWALVersionFiles(clientContext.getDatabasePath(),
clientContext.getVFSUnsafe());
clientContext.getStorageManager()->getShadowFile().replayShadowPageRecords(clientContext);
// Clear the wal, and also shadowing files.
wal.clearWAL();
clientContext.getStorageManager()->getShadowFile().clearAll(clientContext);
StorageUtils::removeWALVersionFiles(clientContext.getDatabasePath(),
clientContext.getVFSUnsafe());
} catch (std::exception& e) {
rollbackCheckpoint(clientContext);
throw CheckpointException{e};
}
}

} // namespace transaction
Expand Down
Loading

0 comments on commit 2d73b32

Please sign in to comment.