diff --git a/CHANGELOG.md b/CHANGELOG.md index 86472a5f87f..4db65e1a3a2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ As a minor extension, we have adopted a slightly different versioning convention - **UNSTABLE** Cardano transactions certification: - Optimize the performances of the computation of the proof with a Merkle map. + - Handle rollback events from the Cardano chain by removing stale data. - Crates versions: diff --git a/Cargo.lock b/Cargo.lock index 23c6af53084..bed58ee7bf1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3531,7 +3531,7 @@ dependencies = [ [[package]] name = "mithril-aggregator" -version = "0.5.19" +version = "0.5.20" dependencies = [ "anyhow", "async-trait", @@ -3687,7 +3687,7 @@ dependencies = [ [[package]] name = "mithril-common" -version = "0.4.15" +version = "0.4.16" dependencies = [ "anyhow", "async-trait", @@ -3785,7 +3785,7 @@ dependencies = [ [[package]] name = "mithril-persistence" -version = "0.2.5" +version = "0.2.6" dependencies = [ "anyhow", "async-trait", @@ -3832,7 +3832,7 @@ dependencies = [ [[package]] name = "mithril-signer" -version = "0.2.142" +version = "0.2.143" dependencies = [ "anyhow", "async-trait", diff --git a/internal/mithril-persistence/Cargo.toml b/internal/mithril-persistence/Cargo.toml index b1a83aab618..9b0fb09210d 100644 --- a/internal/mithril-persistence/Cargo.toml +++ b/internal/mithril-persistence/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "mithril-persistence" -version = "0.2.5" +version = "0.2.6" description = "Common types, interfaces, and utilities to persist data for Mithril nodes." authors = { workspace = true } edition = { workspace = true } diff --git a/internal/mithril-persistence/src/database/query/block_range_root/delete_block_range_root.rs b/internal/mithril-persistence/src/database/query/block_range_root/delete_block_range_root.rs new file mode 100644 index 00000000000..b0e0a2aeed8 --- /dev/null +++ b/internal/mithril-persistence/src/database/query/block_range_root/delete_block_range_root.rs @@ -0,0 +1,139 @@ +use anyhow::Context; +use sqlite::Value; + +use mithril_common::entities::{BlockNumber, BlockRange}; +use mithril_common::StdResult; + +use crate::database::record::BlockRangeRootRecord; +use crate::sqlite::{Query, SourceAlias, SqLiteEntity, WhereCondition}; + +/// Query to delete old [BlockRangeRootRecord] from the sqlite database +pub struct DeleteBlockRangeRootQuery { + condition: WhereCondition, +} + +impl Query for DeleteBlockRangeRootQuery { + type Entity = BlockRangeRootRecord; + + fn filters(&self) -> WhereCondition { + self.condition.clone() + } + + fn get_definition(&self, condition: &str) -> String { + // it is important to alias the fields with the same name as the table + // since the table cannot be aliased in a RETURNING statement in SQLite. + let aliases = SourceAlias::new(&[("{:block_range_root:}", "block_range_root")]); + let projection = Self::Entity::get_projection().expand(aliases); + + format!("delete from block_range_root where {condition} returning {projection}") + } +} + +impl DeleteBlockRangeRootQuery { + pub fn contains_or_above_block_number_threshold( + block_number_threshold: BlockNumber, + ) -> StdResult { + let block_range = BlockRange::from_block_number(block_number_threshold); + let threshold = Value::Integer(block_range.start.try_into().with_context(|| { + format!("Failed to convert threshold `{block_number_threshold}` to i64") + })?); + + Ok(Self { + condition: WhereCondition::new("start >= ?*", vec![threshold]), + }) + } +} + +#[cfg(test)] +mod tests { + use mithril_common::crypto_helper::MKTreeNode; + use mithril_common::entities::BlockRange; + + use crate::database::query::{GetBlockRangeRootQuery, InsertBlockRangeRootQuery}; + use crate::database::test_helper::cardano_tx_db_connection; + use crate::sqlite::{ConnectionExtensions, SqliteConnection}; + + use super::*; + + fn insert_block_range_roots(connection: &SqliteConnection, records: Vec) { + connection + .fetch_first(InsertBlockRangeRootQuery::insert_many(records).unwrap()) + .unwrap(); + } + + fn block_range_root_dataset() -> Vec { + [ + ( + BlockRange::from_block_number(BlockRange::LENGTH), + MKTreeNode::from_hex("AAAA").unwrap(), + ), + ( + BlockRange::from_block_number(BlockRange::LENGTH * 2), + MKTreeNode::from_hex("BBBB").unwrap(), + ), + ( + BlockRange::from_block_number(BlockRange::LENGTH * 3), + MKTreeNode::from_hex("CCCC").unwrap(), + ), + ] + .into_iter() + .map(BlockRangeRootRecord::from) + .collect() + } + + #[test] + fn test_prune_work_even_without_block_range_root_in_db() { + let connection = cardano_tx_db_connection().unwrap(); + + let cursor = connection + .fetch( + DeleteBlockRangeRootQuery::contains_or_above_block_number_threshold(100).unwrap(), + ) + .expect("pruning shouldn't crash without block range root stored"); + assert_eq!(0, cursor.count()); + } + + #[test] + fn test_prune_all_data_if_given_block_number_is_lower_than_stored_number_of_block() { + parameterized_test_prune_block_range(0, block_range_root_dataset().len()); + } + + #[test] + fn test_prune_keep_all_block_range_root_if_given_number_of_block_is_greater_than_the_highest_one( + ) { + parameterized_test_prune_block_range(100_000, 0); + } + + #[test] + fn test_prune_block_range_when_block_number_is_block_range_start() { + parameterized_test_prune_block_range(BlockRange::LENGTH * 2, 2); + } + + #[test] + fn test_prune_block_range_when_block_number_is_in_block_range() { + parameterized_test_prune_block_range(BlockRange::LENGTH * 2 + 1, 2); + } + + #[test] + fn test_keep_block_range_when_block_number_is_just_before_range_start() { + parameterized_test_prune_block_range(BlockRange::LENGTH * 2 - 1, 3); + } + + fn parameterized_test_prune_block_range( + block_threshold: BlockNumber, + delete_record_number: usize, + ) { + let connection = cardano_tx_db_connection().unwrap(); + let dataset = block_range_root_dataset(); + insert_block_range_roots(&connection, dataset.clone()); + + let query = + DeleteBlockRangeRootQuery::contains_or_above_block_number_threshold(block_threshold) + .unwrap(); + let cursor = connection.fetch(query).unwrap(); + assert_eq!(delete_record_number, cursor.count()); + + let cursor = connection.fetch(GetBlockRangeRootQuery::all()).unwrap(); + assert_eq!(dataset.len() - delete_record_number, cursor.count()); + } +} diff --git a/internal/mithril-persistence/src/database/query/block_range_root/mod.rs b/internal/mithril-persistence/src/database/query/block_range_root/mod.rs index 7254c773a88..d10fb75852a 100644 --- a/internal/mithril-persistence/src/database/query/block_range_root/mod.rs +++ b/internal/mithril-persistence/src/database/query/block_range_root/mod.rs @@ -1,7 +1,9 @@ +mod delete_block_range_root; mod get_block_range_root; mod get_interval_without_block_range; mod insert_block_range; +pub use delete_block_range_root::*; pub use get_block_range_root::*; pub use get_interval_without_block_range::*; pub use insert_block_range::*; diff --git a/internal/mithril-persistence/src/database/query/cardano_transaction/delete_cardano_transaction.rs b/internal/mithril-persistence/src/database/query/cardano_transaction/delete_cardano_transaction.rs index 65f56fb427d..a5f9342ecfe 100644 --- a/internal/mithril-persistence/src/database/query/cardano_transaction/delete_cardano_transaction.rs +++ b/internal/mithril-persistence/src/database/query/cardano_transaction/delete_cardano_transaction.rs @@ -39,6 +39,16 @@ impl DeleteCardanoTransactionQuery { condition: WhereCondition::new("block_number < ?*", vec![threshold]), }) } + + pub fn above_block_number_threshold(block_number_threshold: BlockNumber) -> StdResult { + let threshold = Value::Integer(block_number_threshold.try_into().with_context(|| { + format!("Failed to convert threshold `{block_number_threshold}` to i64") + })?); + + Ok(Self { + condition: WhereCondition::new("block_number > ?*", vec![threshold]), + }) + } } #[cfg(test)] @@ -66,52 +76,112 @@ mod tests { ] } - #[test] - fn test_prune_work_even_without_transactions_in_db() { - let connection = cardano_tx_db_connection().unwrap(); - - let cursor = connection - .fetch(DeleteCardanoTransactionQuery::below_block_number_threshold(100).unwrap()) - .expect("pruning shouldn't crash without transactions stored"); - assert_eq!(0, cursor.count()); + mod prune_below_threshold_tests { + use super::*; + + #[test] + fn test_prune_work_even_without_transactions_in_db() { + let connection = cardano_tx_db_connection().unwrap(); + + let cursor = connection + .fetch(DeleteCardanoTransactionQuery::below_block_number_threshold(100).unwrap()) + .expect("pruning shouldn't crash without transactions stored"); + assert_eq!(0, cursor.count()); + } + + #[test] + fn test_prune_all_data_if_given_block_number_is_larger_than_stored_number_of_block() { + let connection = cardano_tx_db_connection().unwrap(); + insert_transactions(&connection, test_transaction_set()); + + let query = + DeleteCardanoTransactionQuery::below_block_number_threshold(100_000).unwrap(); + let cursor = connection.fetch(query).unwrap(); + assert_eq!(test_transaction_set().len(), cursor.count()); + + let cursor = connection.fetch(GetCardanoTransactionQuery::all()).unwrap(); + assert_eq!(0, cursor.count()); + } + + #[test] + fn test_prune_keep_all_tx_of_last_block_if_given_number_of_block_is_zero() { + let connection = cardano_tx_db_connection().unwrap(); + insert_transactions(&connection, test_transaction_set()); + + let query = DeleteCardanoTransactionQuery::below_block_number_threshold(0).unwrap(); + let cursor = connection.fetch(query).unwrap(); + assert_eq!(0, cursor.count()); + + let cursor = connection.fetch(GetCardanoTransactionQuery::all()).unwrap(); + assert_eq!(test_transaction_set().len(), cursor.count()); + } + + #[test] + fn test_prune_data_of_below_given_blocks() { + let connection = cardano_tx_db_connection().unwrap(); + insert_transactions(&connection, test_transaction_set()); + + let query = DeleteCardanoTransactionQuery::below_block_number_threshold(12).unwrap(); + let cursor = connection.fetch(query).unwrap(); + assert_eq!(4, cursor.count()); + + let cursor = connection.fetch(GetCardanoTransactionQuery::all()).unwrap(); + assert_eq!(2, cursor.count()); + } } - #[test] - fn test_prune_all_data_if_given_block_number_is_larger_than_stored_number_of_block() { - let connection = cardano_tx_db_connection().unwrap(); - insert_transactions(&connection, test_transaction_set()); - - let query = DeleteCardanoTransactionQuery::below_block_number_threshold(100_000).unwrap(); - let cursor = connection.fetch(query).unwrap(); - assert_eq!(test_transaction_set().len(), cursor.count()); - - let cursor = connection.fetch(GetCardanoTransactionQuery::all()).unwrap(); - assert_eq!(0, cursor.count()); - } - - #[test] - fn test_prune_keep_all_tx_of_last_block_if_given_number_of_block_is_zero() { - let connection = cardano_tx_db_connection().unwrap(); - insert_transactions(&connection, test_transaction_set()); - - let query = DeleteCardanoTransactionQuery::below_block_number_threshold(0).unwrap(); - let cursor = connection.fetch(query).unwrap(); - assert_eq!(0, cursor.count()); - - let cursor = connection.fetch(GetCardanoTransactionQuery::all()).unwrap(); - assert_eq!(test_transaction_set().len(), cursor.count()); - } - - #[test] - fn test_prune_data_of_below_given_blocks() { - let connection = cardano_tx_db_connection().unwrap(); - insert_transactions(&connection, test_transaction_set()); - - let query = DeleteCardanoTransactionQuery::below_block_number_threshold(12).unwrap(); - let cursor = connection.fetch(query).unwrap(); - assert_eq!(4, cursor.count()); - - let cursor = connection.fetch(GetCardanoTransactionQuery::all()).unwrap(); - assert_eq!(2, cursor.count()); + mod prune_above_threshold_tests { + use super::*; + + #[test] + fn test_prune_work_even_without_transactions_in_db() { + let connection = cardano_tx_db_connection().unwrap(); + + let cursor = connection + .fetch(DeleteCardanoTransactionQuery::above_block_number_threshold(100).unwrap()) + .expect("pruning shouldn't crash without transactions stored"); + assert_eq!(0, cursor.count()); + } + + #[test] + fn test_prune_all_data_if_given_block_number_is_lower_than_stored_number_of_block() { + let connection = cardano_tx_db_connection().unwrap(); + insert_transactions(&connection, test_transaction_set()); + + let query = DeleteCardanoTransactionQuery::above_block_number_threshold(0).unwrap(); + let cursor = connection.fetch(query).unwrap(); + assert_eq!(test_transaction_set().len(), cursor.count()); + + let cursor = connection.fetch(GetCardanoTransactionQuery::all()).unwrap(); + assert_eq!(0, cursor.count()); + } + + #[test] + fn test_prune_keep_all_tx_of_last_block_if_given_number_of_block_is_greater_than_the_highest_one( + ) { + let connection = cardano_tx_db_connection().unwrap(); + insert_transactions(&connection, test_transaction_set()); + + let query = + DeleteCardanoTransactionQuery::above_block_number_threshold(100_000).unwrap(); + let cursor = connection.fetch(query).unwrap(); + assert_eq!(0, cursor.count()); + + let cursor = connection.fetch(GetCardanoTransactionQuery::all()).unwrap(); + assert_eq!(test_transaction_set().len(), cursor.count()); + } + + #[test] + fn test_prune_data_of_above_given_blocks() { + let connection = cardano_tx_db_connection().unwrap(); + insert_transactions(&connection, test_transaction_set()); + + let query = DeleteCardanoTransactionQuery::above_block_number_threshold(10).unwrap(); + let cursor = connection.fetch(query).unwrap(); + assert_eq!(4, cursor.count()); + + let cursor = connection.fetch(GetCardanoTransactionQuery::all()).unwrap(); + assert_eq!(2, cursor.count()); + } } } diff --git a/internal/mithril-persistence/src/database/query/cardano_transaction/get_cardano_transaction.rs b/internal/mithril-persistence/src/database/query/cardano_transaction/get_cardano_transaction.rs index 28adde68c32..2049debf84e 100644 --- a/internal/mithril-persistence/src/database/query/cardano_transaction/get_cardano_transaction.rs +++ b/internal/mithril-persistence/src/database/query/cardano_transaction/get_cardano_transaction.rs @@ -31,11 +31,14 @@ impl GetCardanoTransactionQuery { pub fn by_transaction_hashes( transactions_hashes: Vec, - up_to: BlockNumber, + up_to_or_equal: BlockNumber, ) -> Self { let hashes_values = transactions_hashes.into_iter().map(Value::String).collect(); let condition = WhereCondition::where_in("transaction_hash", hashes_values).and_where( - WhereCondition::new("block_number <= ?*", vec![Value::Integer(up_to as i64)]), + WhereCondition::new( + "block_number <= ?*", + vec![Value::Integer(up_to_or_equal as i64)], + ), ); Self { condition } diff --git a/internal/mithril-persistence/src/database/repository/cardano_transaction_repository.rs b/internal/mithril-persistence/src/database/repository/cardano_transaction_repository.rs index 034a509a3e5..cdf224d5810 100644 --- a/internal/mithril-persistence/src/database/repository/cardano_transaction_repository.rs +++ b/internal/mithril-persistence/src/database/repository/cardano_transaction_repository.rs @@ -14,8 +14,8 @@ use mithril_common::signable_builder::BlockRangeRootRetriever; use mithril_common::StdResult; use crate::database::query::{ - DeleteCardanoTransactionQuery, GetBlockRangeRootQuery, GetCardanoTransactionQuery, - GetIntervalWithoutBlockRangeRootQuery, InsertBlockRangeRootQuery, + DeleteBlockRangeRootQuery, DeleteCardanoTransactionQuery, GetBlockRangeRootQuery, + GetCardanoTransactionQuery, GetIntervalWithoutBlockRangeRootQuery, InsertBlockRangeRootQuery, InsertCardanoTransactionQuery, }; use crate::database::record::{BlockRangeRootRecord, CardanoTransactionRecord}; @@ -256,6 +256,26 @@ impl CardanoTransactionRepository { Ok(()) } + + /// Remove transactions and block range roots that are in a rolled-back fork + /// + /// * Remove transactions with block number strictly greater than the given block number + /// * Remove block range roots that have lower bound range strictly above the given block number + pub async fn remove_rolled_back_transactions_and_block_range( + &self, + block_number: BlockNumber, + ) -> StdResult<()> { + let transaction = self.connection.begin_transaction()?; + let query = DeleteCardanoTransactionQuery::above_block_number_threshold(block_number)?; + self.connection.fetch_first(query)?; + + let query = + DeleteBlockRangeRootQuery::contains_or_above_block_number_threshold(block_number)?; + self.connection.fetch_first(query)?; + transaction.commit()?; + + Ok(()) + } } #[async_trait] @@ -991,4 +1011,56 @@ mod tests { let highest_beacon = repository.find_lower_bound().await.unwrap(); assert_eq!(Some(100), highest_beacon); } + + #[tokio::test] + async fn remove_transactions_and_block_range_greater_than_given_block_number() { + let connection = Arc::new(cardano_tx_db_connection().unwrap()); + let repository = CardanoTransactionRepository::new(connection); + + let cardano_transactions = vec![ + CardanoTransaction::new("tx-hash-123", BlockRange::LENGTH, 50, "block-hash-123", 50), + CardanoTransaction::new( + "tx-hash-123", + BlockRange::LENGTH * 3 - 1, + 50, + "block-hash-123", + 50, + ), + CardanoTransaction::new( + "tx-hash-456", + BlockRange::LENGTH * 3, + 51, + "block-hash-456", + 100, + ), + ]; + repository + .create_transactions(cardano_transactions) + .await + .unwrap(); + repository + .create_block_range_roots(vec![ + ( + BlockRange::from_block_number(BlockRange::LENGTH), + MKTreeNode::from_hex("AAAA").unwrap(), + ), + ( + BlockRange::from_block_number(BlockRange::LENGTH * 2), + MKTreeNode::from_hex("AAAA").unwrap(), + ), + ( + BlockRange::from_block_number(BlockRange::LENGTH * 3), + MKTreeNode::from_hex("AAAA").unwrap(), + ), + ]) + .await + .unwrap(); + + repository + .remove_rolled_back_transactions_and_block_range(BlockRange::LENGTH * 3) + .await + .unwrap(); + assert_eq!(2, repository.get_all_transactions().await.unwrap().len()); + assert_eq!(2, repository.get_all_block_range_root().unwrap().len()); + } } diff --git a/mithril-aggregator/Cargo.toml b/mithril-aggregator/Cargo.toml index e52ce204025..e68908adf46 100644 --- a/mithril-aggregator/Cargo.toml +++ b/mithril-aggregator/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "mithril-aggregator" -version = "0.5.19" +version = "0.5.20" description = "A Mithril Aggregator server" authors = { workspace = true } edition = { workspace = true } diff --git a/mithril-aggregator/src/database/repository/cardano_transaction_repository.rs b/mithril-aggregator/src/database/repository/cardano_transaction_repository.rs index 20f01f28e62..fc7987cf209 100644 --- a/mithril-aggregator/src/database/repository/cardano_transaction_repository.rs +++ b/mithril-aggregator/src/database/repository/cardano_transaction_repository.rs @@ -47,6 +47,14 @@ impl TransactionStore for CardanoTransactionRepository { } Ok(()) } + + async fn remove_rolled_back_transactions_and_block_range( + &self, + block_number: BlockNumber, + ) -> StdResult<()> { + self.remove_rolled_back_transactions_and_block_range(block_number) + .await + } } #[async_trait] diff --git a/mithril-aggregator/src/services/cardano_transactions_importer.rs b/mithril-aggregator/src/services/cardano_transactions_importer.rs index 77e2c7f04a7..4335971f9d7 100644 --- a/mithril-aggregator/src/services/cardano_transactions_importer.rs +++ b/mithril-aggregator/src/services/cardano_transactions_importer.rs @@ -3,7 +3,6 @@ use std::ops::Range; use std::path::{Path, PathBuf}; use std::sync::Arc; -use anyhow::anyhow; use async_trait::async_trait; use slog::{debug, Logger}; @@ -39,6 +38,15 @@ pub trait TransactionStore: Send + Sync { &self, block_ranges: Vec<(BlockRange, MKTreeNode)>, ) -> StdResult<()>; + + /// Remove transactions and block range roots that are in a rolled-back fork + /// + /// * Remove transactions with block number strictly greater than the given block number + /// * Remove block range roots that have lower bound range strictly above the given block number + async fn remove_rolled_back_transactions_and_block_range( + &self, + block_number: BlockNumber, + ) -> StdResult<()>; } /// Import and store [CardanoTransaction]. @@ -107,8 +115,10 @@ impl CardanoTransactionsImporter { .store_transactions(parsed_transactions) .await?; } - ChainScannedBlocks::RollBackward(_) => { - return Err(anyhow!("RollBackward not supported")); + ChainScannedBlocks::RollBackward(chain_point) => { + self.transaction_store + .remove_rolled_back_transactions_and_block_range(chain_point.block_number) + .await?; } } } @@ -264,7 +274,9 @@ mod tests { scanner_mock .expect_scan() .withf(move |_, from, until| from.is_none() && until == &up_to_block_number) - .return_once(move |_, _, _| Ok(Box::new(DumbBlockStreamer::new(vec![blocks])))); + .return_once(move |_, _, _| { + Ok(Box::new(DumbBlockStreamer::new().forwards(vec![blocks]))) + }); CardanoTransactionsImporter::new_for_test(Arc::new(scanner_mock), repository.clone()) }; @@ -376,10 +388,10 @@ mod tests { let up_to_block_number = 12; let connection = cardano_tx_db_connection().unwrap(); let repository = Arc::new(CardanoTransactionRepository::new(Arc::new(connection))); - let scanner = DumbBlockScanner::new(vec![ + let scanner = DumbBlockScanner::new().forwards(vec![vec![ ScannedBlock::new("block_hash-1", 10, 15, 10, vec!["tx_hash-1", "tx_hash-2"]), ScannedBlock::new("block_hash-2", 20, 25, 11, vec!["tx_hash-3", "tx_hash-4"]), - ]); + ]]); let last_tx = CardanoTransaction::new("tx-20", 30, 35, "block_hash-3", up_to_block_number); repository @@ -436,7 +448,9 @@ mod tests { && *until == up_to_block_number }) .return_once(move |_, _, _| { - Ok(Box::new(DumbBlockStreamer::new(vec![scanned_blocks]))) + Ok(Box::new( + DumbBlockStreamer::new().forwards(vec![scanned_blocks]), + )) }) .once(); CardanoTransactionsImporter::new_for_test(Arc::new(scanner_mock), repository.clone()) @@ -619,7 +633,7 @@ mod tests { let connection = Arc::new(cardano_tx_db_connection().unwrap()); let repository = Arc::new(CardanoTransactionRepository::new(connection.clone())); let importer = CardanoTransactionsImporter::new_for_test( - Arc::new(DumbBlockScanner::new(blocks.clone())), + Arc::new(DumbBlockScanner::new().forwards(vec![blocks.clone()])), Arc::new(CardanoTransactionRepository::new(connection.clone())), ); (importer, repository) @@ -640,4 +654,103 @@ mod tests { assert_eq!(transactions, cold_imported_transactions); assert_eq!(cold_imported_transactions, warm_imported_transactions); } + + #[tokio::test] + async fn when_rollbackward_should_remove_transactions() { + let connection = cardano_tx_db_connection().unwrap(); + let repository = Arc::new(CardanoTransactionRepository::new(Arc::new(connection))); + + let expected_remaining_transactions = + ScannedBlock::new("block_hash-130", 130, 5, 1, vec!["tx_hash-6", "tx_hash-7"]) + .into_transactions(); + repository + .store_transactions(expected_remaining_transactions.clone()) + .await + .unwrap(); + repository + .store_transactions( + ScannedBlock::new( + "block_hash-131", + 131, + 10, + 2, + vec!["tx_hash-8", "tx_hash-9", "tx_hash-10"], + ) + .into_transactions(), + ) + .await + .unwrap(); + + let chain_point = ChainPoint::new(1, 130, "block_hash-131"); + let scanner = DumbBlockScanner::new().backward(chain_point); + + let importer = + CardanoTransactionsImporter::new_for_test(Arc::new(scanner), repository.clone()); + + importer + .import(3000) + .await + .expect("Transactions Importer should succeed"); + + let stored_transactions = repository.get_all().await.unwrap(); + assert_eq!(expected_remaining_transactions, stored_transactions); + } + + #[tokio::test] + async fn when_rollbackward_should_remove_block_ranges() { + let connection = cardano_tx_db_connection().unwrap(); + let repository = Arc::new(CardanoTransactionRepository::new(Arc::new(connection))); + + let expected_remaining_block_ranges = vec![ + BlockRange::from_block_number(0), + BlockRange::from_block_number(BlockRange::LENGTH), + BlockRange::from_block_number(BlockRange::LENGTH * 2), + ]; + + repository + .store_block_range_roots( + expected_remaining_block_ranges + .iter() + .map(|b| (b.clone(), MKTreeNode::from_hex("AAAA").unwrap())) + .collect(), + ) + .await + .unwrap(); + repository + .store_block_range_roots( + [ + BlockRange::from_block_number(BlockRange::LENGTH * 3), + BlockRange::from_block_number(BlockRange::LENGTH * 4), + BlockRange::from_block_number(BlockRange::LENGTH * 5), + ] + .iter() + .map(|b| (b.clone(), MKTreeNode::from_hex("AAAA").unwrap())) + .collect(), + ) + .await + .unwrap(); + + let block_range_roots = repository.get_all_block_range_root().unwrap(); + assert_eq!(6, block_range_roots.len()); + + let chain_point = ChainPoint::new(1, BlockRange::LENGTH * 3, "block_hash-131"); + let scanner = DumbBlockScanner::new().backward(chain_point); + + let importer = + CardanoTransactionsImporter::new_for_test(Arc::new(scanner), repository.clone()); + + importer + .import(3000) + .await + .expect("Transactions Importer should succeed"); + + let block_range_roots = repository.get_all_block_range_root().unwrap(); + assert_eq!( + expected_remaining_block_ranges, + block_range_roots + .into_iter() + .map(|r| r.range) + .collect::>() + ); + } } diff --git a/mithril-aggregator/tests/certificate_chain.rs b/mithril-aggregator/tests/certificate_chain.rs index 7a769dcc119..37ec1a57e33 100644 --- a/mithril-aggregator/tests/certificate_chain.rs +++ b/mithril-aggregator/tests/certificate_chain.rs @@ -185,7 +185,7 @@ async fn certificate_chain() { let next_epoch_verification_keys = tester .dependencies .verification_key_store - .get_verification_keys(new_epoch + 1) + .get_verification_keys(new_epoch.offset_to_recording_epoch()) .await .expect("get_verification_keys should not fail"); assert_eq!( diff --git a/mithril-aggregator/tests/create_certificate.rs b/mithril-aggregator/tests/create_certificate.rs index 1729e0fdccb..c7e9b93d8b1 100644 --- a/mithril-aggregator/tests/create_certificate.rs +++ b/mithril-aggregator/tests/create_certificate.rs @@ -3,8 +3,8 @@ mod test_extensions; use mithril_aggregator::Configuration; use mithril_common::{ entities::{ - CardanoDbBeacon, ChainPoint, Epoch, ProtocolParameters, SignedEntityType, - SignedEntityTypeDiscriminants, StakeDistributionParty, TimePoint, + CardanoDbBeacon, CardanoTransactionsSigningConfig, ChainPoint, Epoch, ProtocolParameters, + SignedEntityType, SignedEntityTypeDiscriminants, StakeDistributionParty, TimePoint, }, test_utils::MithrilFixtureBuilder, }; @@ -19,11 +19,24 @@ async fn create_certificate() { }; let configuration = Configuration { protocol_parameters: protocol_parameters.clone(), + signed_entity_types: Some(SignedEntityTypeDiscriminants::CardanoTransactions.to_string()), data_stores_directory: get_test_dir("create_certificate"), + cardano_transactions_signing_config: CardanoTransactionsSigningConfig { + security_parameter: 0, + step: 30, + }, ..Configuration::new_sample() }; let mut tester = RuntimeTester::build( - TimePoint::new(1, 1, ChainPoint::new(10, 1, "block_hash-1")), + TimePoint { + epoch: Epoch(1), + immutable_file_number: 1, + chain_point: ChainPoint { + slot_number: 10, + block_number: 100, + block_hash: "block_hash-100".to_string(), + }, + }, configuration, ) .await; @@ -127,6 +140,76 @@ async fn create_certificate() { ) ); + comment!( + "Increase cardano chain block number to 185, + the state machine should be signing CardanoTransactions for block 180" + ); + tester.increase_block_number(85, 185).await.unwrap(); + cycle!(tester, "signing"); + let signers_for_transaction = &fixture.signers_fixture()[2..=6]; + tester + .send_single_signatures( + SignedEntityTypeDiscriminants::CardanoTransactions, + signers_for_transaction, + ) + .await + .unwrap(); + + comment!("The state machine should issue a certificate for the CardanoTransactions"); + cycle!(tester, "ready"); + assert_last_certificate_eq!( + tester, + ExpectedCertificate::new( + CardanoDbBeacon::new("devnet".to_string(), 1, 3), + &signers_for_transaction + .iter() + .map(|s| s.signer_with_stake.clone().into()) + .collect::>(), + fixture.compute_and_encode_avk(), + SignedEntityType::CardanoTransactions(Epoch(1), 180), + ExpectedCertificate::genesis_identifier(&CardanoDbBeacon::new( + "devnet".to_string(), + 1, + 1 + )), + ) + ); + + comment!( + "Got rollback to block number 149 from cardano chain, + the state machine should be signing CardanoTransactions for block 120" + ); + tester.cardano_chain_send_rollback(149).await.unwrap(); + cycle!(tester, "signing"); + let signers_for_transaction = &fixture.signers_fixture()[2..=6]; + tester + .send_single_signatures( + SignedEntityTypeDiscriminants::CardanoTransactions, + signers_for_transaction, + ) + .await + .unwrap(); + + comment!("The state machine should issue a certificate for the CardanoTransactions"); + cycle!(tester, "ready"); + assert_last_certificate_eq!( + tester, + ExpectedCertificate::new( + CardanoDbBeacon::new("devnet".to_string(), 1, 3), + &signers_for_transaction + .iter() + .map(|s| s.signer_with_stake.clone().into()) + .collect::>(), + fixture.compute_and_encode_avk(), + SignedEntityType::CardanoTransactions(Epoch(1), 120), + ExpectedCertificate::genesis_identifier(&CardanoDbBeacon::new( + "devnet".to_string(), + 1, + 1 + )), + ) + ); + comment!("Change the epoch while signing"); tester.increase_immutable_number().await.unwrap(); cycle!(tester, "signing"); diff --git a/mithril-aggregator/tests/test_extensions/runtime_tester.rs b/mithril-aggregator/tests/test_extensions/runtime_tester.rs index d58e10b7a6e..43d6beee62e 100644 --- a/mithril-aggregator/tests/test_extensions/runtime_tester.rs +++ b/mithril-aggregator/tests/test_extensions/runtime_tester.rs @@ -8,11 +8,12 @@ use mithril_aggregator::{ SignerRegistrationError, }; use mithril_common::{ - chain_observer::FakeObserver, + cardano_block_scanner::{DumbBlockScanner, ScannedBlock}, + chain_observer::{ChainObserver, FakeObserver}, crypto_helper::ProtocolGenesisSigner, - digesters::{DumbImmutableDigester, DumbImmutableFileObserver}, + digesters::{DumbImmutableDigester, DumbImmutableFileObserver, ImmutableFileObserver}, entities::{ - Certificate, CertificateSignature, Epoch, ImmutableFileNumber, + BlockNumber, Certificate, CertificateSignature, ChainPoint, Epoch, ImmutableFileNumber, SignedEntityTypeDiscriminants, Snapshot, StakeDistribution, TimePoint, }, era::{adapters::EraReaderDummyAdapter, EraMarker, EraReader, SupportedEra}, @@ -70,6 +71,7 @@ pub struct RuntimeTester { pub era_reader_adapter: Arc, pub observer: Arc, pub open_message_repository: Arc, + pub block_scanner: Arc, _logs_guard: slog_scope::GlobalLoggerGuard, } @@ -98,6 +100,7 @@ impl RuntimeTester { &SupportedEra::dummy().to_string(), Some(Epoch(0)), )])); + let block_scanner = Arc::new(DumbBlockScanner::new()); let mut deps_builder = DependenciesBuilder::new(configuration); deps_builder.snapshot_uploader = Some(snapshot_uploader.clone()); deps_builder.chain_observer = Some(chain_observer.clone()); @@ -105,6 +108,7 @@ impl RuntimeTester { deps_builder.immutable_digester = Some(digester.clone()); deps_builder.snapshotter = Some(snapshotter.clone()); deps_builder.era_reader = Some(Arc::new(EraReader::new(era_reader_adapter.clone()))); + deps_builder.block_scanner = Some(block_scanner.clone()); let dependencies = deps_builder.build_dependency_container().await.unwrap(); let runtime = deps_builder.create_aggregator_runner().await.unwrap(); @@ -126,6 +130,7 @@ impl RuntimeTester { era_reader_adapter, observer, open_message_repository, + block_scanner, _logs_guard: logger, } } @@ -243,6 +248,74 @@ impl RuntimeTester { Ok(new_epoch) } + /// increase the block number in the fake observer + pub async fn increase_block_number(&mut self, increment: u64, expected: u64) -> StdResult<()> { + let new_block_number = self + .chain_observer + .increase_block_number(increment) + .await + .ok_or_else(|| anyhow!("no block number returned".to_string()))?; + + anyhow::ensure!( + expected == new_block_number, + "expected to increase block number up to {expected}, got {new_block_number}", + ); + + // Make the block scanner return new blocks + let current_immutable = self + .immutable_file_observer + .get_last_immutable_number() + .await?; + let blocks_to_scan: Vec = ((expected - increment + 1)..=expected) + .map(|block_number| { + let block_hash = format!("block_hash-{block_number}"); + let slot_number = 10 * block_number; + ScannedBlock::new( + block_hash, + block_number, + slot_number, + current_immutable, + vec![format!("tx_hash-{block_number}-1")], + ) + }) + .collect(); + self.block_scanner.add_forwards(vec![blocks_to_scan]); + + Ok(()) + } + + pub async fn cardano_chain_send_rollback( + &mut self, + rollback_to_block_number: BlockNumber, + ) -> StdResult<()> { + let actual_block_number = self + .chain_observer + .get_current_chain_point() + .await? + .map(|c| c.block_number) + .ok_or_else(|| anyhow!("no block number returned".to_string()))?; + let decrement = actual_block_number - rollback_to_block_number; + let new_block_number = self + .chain_observer + .decrease_block_number(decrement) + .await + .ok_or_else(|| anyhow!("no block number returned".to_string()))?; + + anyhow::ensure!( + rollback_to_block_number == new_block_number, + "expected to increase block number up to {rollback_to_block_number}, got {new_block_number}", + ); + + let chain_point = ChainPoint { + slot_number: 1, + block_number: rollback_to_block_number, + block_hash: format!("block_hash-{rollback_to_block_number}"), + }; + self.block_scanner.add_backward(chain_point); + + Ok(()) + } + /// Register the given signers in the registerer pub async fn register_signers(&mut self, signers: &[SignerFixture]) -> StdResult<()> { let registration_epoch = self diff --git a/mithril-common/Cargo.toml b/mithril-common/Cargo.toml index 0692d419aff..936b384d8d5 100644 --- a/mithril-common/Cargo.toml +++ b/mithril-common/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "mithril-common" -version = "0.4.15" +version = "0.4.16" description = "Common types, interfaces, and utilities for Mithril nodes." authors = { workspace = true } edition = { workspace = true } diff --git a/mithril-common/src/cardano_block_scanner/dumb_block_scanner.rs b/mithril-common/src/cardano_block_scanner/dumb_block_scanner.rs index eb47dc40f40..39f8754a77a 100644 --- a/mithril-common/src/cardano_block_scanner/dumb_block_scanner.rs +++ b/mithril-common/src/cardano_block_scanner/dumb_block_scanner.rs @@ -1,8 +1,8 @@ use std::collections::VecDeque; use std::path::Path; +use std::sync::RwLock; use async_trait::async_trait; -use tokio::sync::RwLock; use crate::cardano_block_scanner::ChainScannedBlocks; use crate::cardano_block_scanner::{BlockScanner, BlockStreamer, ScannedBlock}; @@ -11,21 +11,49 @@ use crate::StdResult; /// Dumb block scanner pub struct DumbBlockScanner { - blocks: RwLock>, + streamer: RwLock, } impl DumbBlockScanner { /// Factory - pub fn new(blocks: Vec) -> Self { + pub fn new() -> Self { Self { - blocks: RwLock::new(blocks), + streamer: RwLock::new(DumbBlockStreamer::new()), } } - /// Update blocks returned used the streamer constructed by `scan` - pub async fn update_blocks(&self, new_blocks: Vec) { - let mut blocks = self.blocks.write().await; - *blocks = new_blocks; + /// Add to the inner streamer several [ChainScannedBlocks::RollForwards] responses at the end of the + /// its queue. + pub fn forwards(self, blocks: Vec>) -> Self { + self.add_forwards(blocks); + self + } + + /// Add to the inner streamer a [ChainScannedBlocks::RollBackward] response at the end of the + /// its queue. + pub fn backward(self, chain_point: ChainPoint) -> Self { + self.add_backward(chain_point); + self + } + + /// Add to the inner streamer several [ChainScannedBlocks::RollForwards] responses at the end of the + /// its queue. + pub fn add_forwards(&self, blocks: Vec>) { + let mut streamer = self.streamer.write().unwrap(); + *streamer = streamer.clone().forwards(blocks); + } + + /// Add to the inner streamer a [ChainScannedBlocks::RollBackward] response at the end of the + /// its queue. + pub fn add_backward(&self, chain_point: ChainPoint) { + let mut streamer = self.streamer.write().unwrap(); + *streamer = streamer.clone().rollback(chain_point); + } +} + +impl Default for DumbBlockScanner { + fn default() -> Self { + Self::new() } } @@ -37,32 +65,56 @@ impl BlockScanner for DumbBlockScanner { _from: Option, _until: BlockNumber, ) -> StdResult> { - let blocks = self.blocks.read().await.clone(); - Ok(Box::new(DumbBlockStreamer::new(vec![blocks]))) + let streamer = self.streamer.read().unwrap(); + Ok(Box::new(streamer.clone())) } } /// Dumb block streamer +#[derive(Clone)] pub struct DumbBlockStreamer { - blocks: VecDeque>, + streamer_responses: VecDeque, } impl DumbBlockStreamer { /// Factory - the resulting streamer can be polled one time for each list of blocks given - pub fn new(blocks: Vec>) -> Self { + pub fn new() -> Self { Self { - blocks: VecDeque::from(blocks), + streamer_responses: VecDeque::new(), } } + + /// Add to the streamer several [ChainScannedBlocks::RollForwards] responses at the end of the + /// its queue. + pub fn forwards(mut self, blocks: Vec>) -> Self { + let mut source: VecDeque<_> = blocks + .into_iter() + .map(ChainScannedBlocks::RollForwards) + .collect(); + self.streamer_responses.append(&mut source); + + self + } + + /// Add to the streamer a [ChainScannedBlocks::RollBackward] response at the end of the + /// its queue. + pub fn rollback(mut self, chain_point: ChainPoint) -> Self { + self.streamer_responses + .push_back(ChainScannedBlocks::RollBackward(chain_point)); + self + } +} + +impl Default for DumbBlockStreamer { + fn default() -> Self { + Self::new() + } } #[async_trait] impl BlockStreamer for DumbBlockStreamer { async fn poll_next(&mut self) -> StdResult> { - Ok(self - .blocks - .pop_front() - .map(ChainScannedBlocks::RollForwards)) + Ok(self.streamer_responses.pop_front()) } } @@ -74,7 +126,7 @@ mod tests { #[tokio::test] async fn polling_without_set_of_block_return_none() { - let mut streamer = DumbBlockStreamer::new(vec![]); + let mut streamer = DumbBlockStreamer::new().forwards(vec![]); let blocks = streamer.poll_next().await.unwrap(); assert_eq!(blocks, None); } @@ -82,7 +134,7 @@ mod tests { #[tokio::test] async fn polling_with_one_set_of_block_returns_some_once() { let expected_blocks = vec![ScannedBlock::new("hash-1", 1, 10, 20, Vec::<&str>::new())]; - let mut streamer = DumbBlockStreamer::new(vec![expected_blocks.clone()]); + let mut streamer = DumbBlockStreamer::new().forwards(vec![expected_blocks.clone()]); let blocks = streamer.poll_next().await.unwrap(); assert_eq!( @@ -104,7 +156,7 @@ mod tests { ], vec![ScannedBlock::new("hash-4", 4, 13, 23, Vec::<&str>::new())], ]; - let mut streamer = DumbBlockStreamer::new(expected_blocks.clone()); + let mut streamer = DumbBlockStreamer::new().forwards(expected_blocks.clone()); let blocks = streamer.poll_next().await.unwrap(); assert_eq!( @@ -132,10 +184,72 @@ mod tests { async fn dumb_scanned_construct_a_streamer_based_on_its_stored_blocks() { let expected_blocks = vec![ScannedBlock::new("hash-1", 1, 10, 20, Vec::<&str>::new())]; - let scanner = DumbBlockScanner::new(expected_blocks.clone()); + let scanner = DumbBlockScanner::new().forwards(vec![expected_blocks.clone()]); let mut streamer = scanner.scan(Path::new("dummy"), None, 5).await.unwrap(); let blocks = streamer.poll_all().await.unwrap(); assert_eq!(blocks, expected_blocks); } + + #[tokio::test] + async fn dumb_scanned_construct_a_streamer_based_on_its_stored_chain_scanned_blocks() { + let expected_blocks = vec![ScannedBlock::new("hash-1", 1, 10, 20, Vec::<&str>::new())]; + let expected_chain_point = ChainPoint::new(10, 2, "block-hash"); + + let scanner = DumbBlockScanner::new() + .forwards(vec![expected_blocks.clone()]) + .backward(expected_chain_point.clone()); + let mut streamer = scanner.scan(Path::new("dummy"), None, 5).await.unwrap(); + + let blocks = streamer.poll_next().await.unwrap(); + assert_eq!( + blocks, + Some(ChainScannedBlocks::RollForwards(expected_blocks.clone())) + ); + + let blocks = streamer.poll_next().await.unwrap(); + assert_eq!( + blocks, + Some(ChainScannedBlocks::RollBackward( + expected_chain_point.clone() + )) + ); + } + + #[tokio::test] + async fn polling_with_can_return_roll_backward() { + let expected_blocks = vec![ + vec![ScannedBlock::new("hash-1", 1, 10, 20, Vec::<&str>::new())], + vec![ScannedBlock::new("hash-4", 4, 13, 23, Vec::<&str>::new())], + ]; + + let expected_chain_point = ChainPoint::new(10, 2, "block-hash"); + + let mut streamer = DumbBlockStreamer::new() + .forwards(expected_blocks.clone()) + .rollback(expected_chain_point.clone()); + + let blocks = streamer.poll_next().await.unwrap(); + assert_eq!( + blocks, + Some(ChainScannedBlocks::RollForwards(expected_blocks[0].clone())) + ); + + let blocks = streamer.poll_next().await.unwrap(); + assert_eq!( + blocks, + Some(ChainScannedBlocks::RollForwards(expected_blocks[1].clone())) + ); + + let blocks = streamer.poll_next().await.unwrap(); + assert_eq!( + blocks, + Some(ChainScannedBlocks::RollBackward( + expected_chain_point.clone() + )) + ); + + let blocks = streamer.poll_next().await.unwrap(); + assert_eq!(blocks, None); + } } diff --git a/mithril-common/src/cardano_block_scanner/interface.rs b/mithril-common/src/cardano_block_scanner/interface.rs index 1059ea32e0b..3746f3b8e6a 100644 --- a/mithril-common/src/cardano_block_scanner/interface.rs +++ b/mithril-common/src/cardano_block_scanner/interface.rs @@ -56,7 +56,7 @@ pub trait BlockScanner: Sync + Send { } /// [ChainScannedBlocks] allows to scan new blocks and handle rollbacks -#[derive(Debug, PartialEq)] +#[derive(Debug, Clone, PartialEq)] pub enum ChainScannedBlocks { /// Roll forward on the chain to the next list of [ScannedBlock] RollForwards(Vec), diff --git a/mithril-common/src/chain_observer/fake_observer.rs b/mithril-common/src/chain_observer/fake_observer.rs index fc2487c56a9..38e47f5b9e1 100644 --- a/mithril-common/src/chain_observer/fake_observer.rs +++ b/mithril-common/src/chain_observer/fake_observer.rs @@ -18,9 +18,6 @@ pub struct FakeObserver { /// [get_current_epoch]: ChainObserver::get_current_epoch pub current_time_point: RwLock>, - /// The current chain point - pub current_chain_point: RwLock>, - /// A list of [TxDatum], used by [get_current_datums] /// /// [get_current_datums]: ChainObserver::get_current_datums @@ -33,7 +30,6 @@ impl FakeObserver { Self { signers: RwLock::new(vec![]), current_time_point: RwLock::new(current_time_point.clone()), - current_chain_point: RwLock::new(current_time_point.map(|t| t.chain_point)), datums: RwLock::new(vec![]), } } @@ -49,6 +45,39 @@ impl FakeObserver { current_time_point.as_ref().map(|b| b.epoch) } + /// Increase the block number of the [current_time_point][`FakeObserver::current_time_point`] by + /// the given increment. + pub async fn increase_block_number(&self, increment: BlockNumber) -> Option { + self.change_block_number(|actual_block_number| actual_block_number + increment) + .await + } + + /// Decrease the block number of the [current_time_point][`FakeObserver::current_time_point`] by + /// the given decrement. + pub async fn decrease_block_number(&self, decrement: BlockNumber) -> Option { + self.change_block_number(|actual_block_number| actual_block_number - decrement) + .await + } + + async fn change_block_number( + &self, + change_to_apply: impl Fn(u64) -> u64, + ) -> Option { + let mut current_time_point = self.current_time_point.write().await; + + *current_time_point = current_time_point.as_ref().map(|time_point| TimePoint { + chain_point: ChainPoint { + block_number: change_to_apply(time_point.chain_point.block_number), + ..time_point.chain_point.clone() + }, + ..time_point.clone() + }); + + current_time_point + .as_ref() + .map(|b| b.chain_point.block_number) + } + /// Set the signers that will use to compute the result of /// [get_current_stake_distribution][ChainObserver::get_current_stake_distribution]. pub async fn set_signers(&self, new_signers: Vec) { @@ -56,11 +85,10 @@ impl FakeObserver { *signers = new_signers; } - /// Set the chain point that will use to compute the result of - /// [get_current_chain_point][ChainObserver::get_current_chain_point]. - pub async fn set_current_chain_point(&self, new_current_chain_point: Option) { - let mut current_chain_point = self.current_chain_point.write().await; - *current_chain_point = new_current_chain_point; + /// Set the time point + pub async fn set_current_time_point(&self, new_current_time_point: Option) { + let mut current_time_point = self.current_time_point.write().await; + *current_time_point = new_current_time_point; } /// Set the datums that will use to compute the result of @@ -100,7 +128,12 @@ impl ChainObserver for FakeObserver { } async fn get_current_chain_point(&self) -> Result, ChainObserverError> { - Ok(self.current_chain_point.read().await.clone()) + Ok(self + .current_time_point + .read() + .await + .as_ref() + .map(|time_point| time_point.chain_point.clone())) } async fn get_current_stake_distribution( @@ -143,12 +176,12 @@ mod tests { async fn test_get_current_chain_point() { let fake_observer = FakeObserver::new(None); fake_observer - .set_current_chain_point(Some(fake_data::chain_point())) + .set_current_time_point(Some(TimePoint::dummy())) .await; let chain_point = fake_observer.get_current_chain_point().await.unwrap(); assert_eq!( - Some(fake_data::chain_point()), + Some(TimePoint::dummy().chain_point), chain_point, "get current chain point should not fail" ); @@ -185,4 +218,48 @@ mod tests { assert_eq!(fake_datums, datums); } + + #[tokio::test] + async fn test_increase_block_number() { + let fake_observer = FakeObserver::new(None); + fake_observer + .set_current_time_point(Some(TimePoint::dummy())) + .await; + fake_observer.increase_block_number(375).await; + + let chain_point = fake_observer.get_current_chain_point().await.unwrap(); + assert_eq!( + Some(ChainPoint { + block_number: TimePoint::dummy().chain_point.block_number + 375, + ..TimePoint::dummy().chain_point + }), + chain_point, + "get current chain point should not fail" + ); + } + + #[tokio::test] + async fn test_decrease_block_number() { + let fake_observer = FakeObserver::new(None); + fake_observer + .set_current_time_point(Some(TimePoint { + chain_point: ChainPoint { + block_number: 1000, + ..TimePoint::dummy().chain_point + }, + ..TimePoint::dummy() + })) + .await; + fake_observer.decrease_block_number(800).await; + + let chain_point = fake_observer.get_current_chain_point().await.unwrap(); + assert_eq!( + Some(ChainPoint { + block_number: 200, + ..TimePoint::dummy().chain_point + }), + chain_point, + "get current chain point should not fail" + ); + } } diff --git a/mithril-common/src/crypto_helper/merkle_tree.rs b/mithril-common/src/crypto_helper/merkle_tree.rs index 540a808c0c6..4d13e595c7a 100644 --- a/mithril-common/src/crypto_helper/merkle_tree.rs +++ b/mithril-common/src/crypto_helper/merkle_tree.rs @@ -1,6 +1,4 @@ -use anyhow::anyhow; -#[cfg(any(test, feature = "test_tools"))] -use anyhow::Context; +use anyhow::{anyhow, Context}; use blake2::{Blake2s256, Digest}; use ckb_merkle_mountain_range::{ MMRStoreReadOps, MMRStoreWriteOps, Merge, MerkleProof, Result as MMRResult, MMR, @@ -293,7 +291,11 @@ impl MKTree { /// Generate root of the Merkle tree pub fn compute_root(&self) -> StdResult { - Ok((*self.inner_tree.get_root()?).clone()) + Ok((*self + .inner_tree + .get_root() + .with_context(|| "Could not compute Merkle Tree root")?) + .clone()) } /// Generate Merkle proof of memberships in the tree diff --git a/mithril-signer/Cargo.toml b/mithril-signer/Cargo.toml index e7730d01f11..aff543fc0f6 100644 --- a/mithril-signer/Cargo.toml +++ b/mithril-signer/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "mithril-signer" -version = "0.2.142" +version = "0.2.143" description = "A Mithril Signer" authors = { workspace = true } edition = { workspace = true } diff --git a/mithril-signer/src/cardano_transactions_importer.rs b/mithril-signer/src/cardano_transactions_importer.rs index b853a9146ea..6afb01dfacf 100644 --- a/mithril-signer/src/cardano_transactions_importer.rs +++ b/mithril-signer/src/cardano_transactions_importer.rs @@ -3,7 +3,6 @@ use std::ops::Range; use std::path::{Path, PathBuf}; use std::sync::Arc; -use anyhow::anyhow; use async_trait::async_trait; use slog::{debug, Logger}; @@ -39,6 +38,15 @@ pub trait TransactionStore: Send + Sync { &self, block_ranges: Vec<(BlockRange, MKTreeNode)>, ) -> StdResult<()>; + + /// Remove transactions and block range roots that are in a rolled-back fork + /// + /// * Remove transactions with block number strictly greater than the given block number + /// * Remove block range roots that have lower bound range strictly above the given block number + async fn remove_rolled_back_transactions_and_block_range( + &self, + block_number: BlockNumber, + ) -> StdResult<()>; } /// Import and store [CardanoTransaction]. @@ -107,8 +115,10 @@ impl CardanoTransactionsImporter { .store_transactions(parsed_transactions) .await?; } - ChainScannedBlocks::RollBackward(_) => { - return Err(anyhow!("RollBackward not supported")); + ChainScannedBlocks::RollBackward(chain_point) => { + self.transaction_store + .remove_rolled_back_transactions_and_block_range(chain_point.block_number) + .await?; } } } @@ -264,7 +274,9 @@ mod tests { scanner_mock .expect_scan() .withf(move |_, from, until| from.is_none() && until == &up_to_block_number) - .return_once(move |_, _, _| Ok(Box::new(DumbBlockStreamer::new(vec![blocks])))); + .return_once(move |_, _, _| { + Ok(Box::new(DumbBlockStreamer::new().forwards(vec![blocks]))) + }); CardanoTransactionsImporter::new_for_test(Arc::new(scanner_mock), repository.clone()) }; @@ -376,10 +388,10 @@ mod tests { let up_to_block_number = 12; let connection = cardano_tx_db_connection().unwrap(); let repository = Arc::new(CardanoTransactionRepository::new(Arc::new(connection))); - let scanner = DumbBlockScanner::new(vec![ + let scanner = DumbBlockScanner::new().forwards(vec![vec![ ScannedBlock::new("block_hash-1", 10, 15, 10, vec!["tx_hash-1", "tx_hash-2"]), ScannedBlock::new("block_hash-2", 20, 25, 11, vec!["tx_hash-3", "tx_hash-4"]), - ]); + ]]); let last_tx = CardanoTransaction::new("tx-20", 30, 35, "block_hash-3", up_to_block_number); repository @@ -436,7 +448,9 @@ mod tests { && *until == up_to_block_number }) .return_once(move |_, _, _| { - Ok(Box::new(DumbBlockStreamer::new(vec![scanned_blocks]))) + Ok(Box::new( + DumbBlockStreamer::new().forwards(vec![scanned_blocks]), + )) }) .once(); CardanoTransactionsImporter::new_for_test(Arc::new(scanner_mock), repository.clone()) @@ -619,7 +633,7 @@ mod tests { let connection = Arc::new(cardano_tx_db_connection().unwrap()); let repository = Arc::new(CardanoTransactionRepository::new(connection.clone())); let importer = CardanoTransactionsImporter::new_for_test( - Arc::new(DumbBlockScanner::new(blocks.clone())), + Arc::new(DumbBlockScanner::new().forwards(vec![blocks.clone()])), Arc::new(CardanoTransactionRepository::new(connection.clone())), ); (importer, repository) @@ -640,4 +654,103 @@ mod tests { assert_eq!(transactions, cold_imported_transactions); assert_eq!(cold_imported_transactions, warm_imported_transactions); } + + #[tokio::test] + async fn when_rollbackward_should_remove_transactions() { + let connection = cardano_tx_db_connection().unwrap(); + let repository = Arc::new(CardanoTransactionRepository::new(Arc::new(connection))); + + let expected_remaining_transactions = + ScannedBlock::new("block_hash-130", 130, 5, 1, vec!["tx_hash-6", "tx_hash-7"]) + .into_transactions(); + repository + .store_transactions(expected_remaining_transactions.clone()) + .await + .unwrap(); + repository + .store_transactions( + ScannedBlock::new( + "block_hash-131", + 131, + 10, + 2, + vec!["tx_hash-8", "tx_hash-9", "tx_hash-10"], + ) + .into_transactions(), + ) + .await + .unwrap(); + + let chain_point = ChainPoint::new(1, 130, "block_hash-131"); + let scanner = DumbBlockScanner::new().backward(chain_point); + + let importer = + CardanoTransactionsImporter::new_for_test(Arc::new(scanner), repository.clone()); + + importer + .import(3000) + .await + .expect("Transactions Importer should succeed"); + + let stored_transactions = repository.get_all().await.unwrap(); + assert_eq!(expected_remaining_transactions, stored_transactions); + } + + #[tokio::test] + async fn when_rollbackward_should_remove_block_ranges() { + let connection = cardano_tx_db_connection().unwrap(); + let repository = Arc::new(CardanoTransactionRepository::new(Arc::new(connection))); + + let expected_remaining_block_ranges = vec![ + BlockRange::from_block_number(0), + BlockRange::from_block_number(BlockRange::LENGTH), + BlockRange::from_block_number(BlockRange::LENGTH * 2), + ]; + + repository + .store_block_range_roots( + expected_remaining_block_ranges + .iter() + .map(|b| (b.clone(), MKTreeNode::from_hex("AAAA").unwrap())) + .collect(), + ) + .await + .unwrap(); + repository + .store_block_range_roots( + [ + BlockRange::from_block_number(BlockRange::LENGTH * 3), + BlockRange::from_block_number(BlockRange::LENGTH * 4), + BlockRange::from_block_number(BlockRange::LENGTH * 5), + ] + .iter() + .map(|b| (b.clone(), MKTreeNode::from_hex("AAAA").unwrap())) + .collect(), + ) + .await + .unwrap(); + + let block_range_roots = repository.get_all_block_range_root().unwrap(); + assert_eq!(6, block_range_roots.len()); + + let chain_point = ChainPoint::new(1, BlockRange::LENGTH * 3, "block_hash-131"); + let scanner = DumbBlockScanner::new().backward(chain_point); + + let importer = + CardanoTransactionsImporter::new_for_test(Arc::new(scanner), repository.clone()); + + importer + .import(3000) + .await + .expect("Transactions Importer should succeed"); + + let block_range_roots = repository.get_all_block_range_root().unwrap(); + assert_eq!( + expected_remaining_block_ranges, + block_range_roots + .into_iter() + .map(|r| r.range) + .collect::>() + ); + } } diff --git a/mithril-signer/src/database/repository/cardano_transaction_repository.rs b/mithril-signer/src/database/repository/cardano_transaction_repository.rs index 68030c1eeb1..521b8f77da0 100644 --- a/mithril-signer/src/database/repository/cardano_transaction_repository.rs +++ b/mithril-signer/src/database/repository/cardano_transaction_repository.rs @@ -45,6 +45,14 @@ impl TransactionStore for CardanoTransactionRepository { } Ok(()) } + + async fn remove_rolled_back_transactions_and_block_range( + &self, + block_number: BlockNumber, + ) -> StdResult<()> { + self.remove_rolled_back_transactions_and_block_range(block_number) + .await + } } #[async_trait] diff --git a/mithril-signer/src/runtime/runner.rs b/mithril-signer/src/runtime/runner.rs index c4ea4465dfc..3703ab25032 100644 --- a/mithril-signer/src/runtime/runner.rs +++ b/mithril-signer/src/runtime/runner.rs @@ -535,7 +535,7 @@ mod tests { )); let mithril_stake_distribution_signable_builder = Arc::new(MithrilStakeDistributionSignableBuilder::default()); - let transaction_parser = Arc::new(DumbBlockScanner::new(vec![])); + let transaction_parser = Arc::new(DumbBlockScanner::new()); let transaction_store = Arc::new(MockTransactionStore::new()); let transaction_importer = Arc::new(CardanoTransactionsImporter::new( transaction_parser.clone(), diff --git a/mithril-signer/tests/create_cardano_transaction_single_signature.rs b/mithril-signer/tests/create_cardano_transaction_single_signature.rs new file mode 100644 index 00000000000..289432b9172 --- /dev/null +++ b/mithril-signer/tests/create_cardano_transaction_single_signature.rs @@ -0,0 +1,78 @@ +mod test_extensions; + +use mithril_common::{ + crypto_helper::tests_setup, + entities::{ChainPoint, Epoch, SignedEntityTypeDiscriminants, TimePoint}, + test_utils::MithrilFixtureBuilder, +}; + +use test_extensions::StateMachineTester; + +#[rustfmt::skip] +#[tokio::test] +async fn test_create_cardano_transaction_single_signature() { + let protocol_parameters = tests_setup::setup_protocol_parameters(); + let fixture = MithrilFixtureBuilder::default() + .with_signers(10) + .with_protocol_parameters(protocol_parameters.into()) + .build(); + let signers_with_stake = fixture.signers_with_stake(); + let initial_time_point = TimePoint { + epoch: Epoch(1), + immutable_file_number: 1, + chain_point: ChainPoint { + slot_number: 1, + // Note: the starting block number must be greater than the cardano_transactions_signing_config.step + // so first block range root computation is not on block 0. + block_number: 100, + block_hash: "block_hash-100".to_string(), + }, + }; + let mut tester = StateMachineTester::init(&signers_with_stake, initial_time_point) + .await + .expect("state machine tester init should not fail"); + let total_signer_registrations_expected = 3; + let total_signature_registrations_expected = 2; + + tester + .comment("state machine starts in Init and transit to Unregistered state.") + .is_init().await.unwrap() + .aggregator_send_signed_entity(SignedEntityTypeDiscriminants::CardanoTransactions).await + .cycle_unregistered().await.unwrap() + + .comment("getting an epoch settings changes the state → Registered") + .aggregator_send_epoch_settings().await + .cycle_registered().await.unwrap() + .register_signers(&signers_with_stake[..2]).await.unwrap() + .check_protocol_initializer(Epoch(2)).await.unwrap() + .check_stake_store(Epoch(2)).await.unwrap() + + .comment("waiting 2 epoch for the registration to be effective") + .increase_epoch(2).await.unwrap() + .cycle_unregistered().await.unwrap() + .cycle_registered().await.unwrap() + + .increase_epoch(3).await.unwrap() + .cycle_unregistered().await.unwrap() + + .comment("creating a new certificate pending with a cardano transaction signed entity → Registered") + .increase_block_number(70, 170).await.unwrap() + .cycle_registered().await.unwrap() + + .comment("signer can now create a single signature → Signed") + .cycle_signed().await.unwrap() + + .comment("more cycles do not change the state = Signed") + .cycle_signed().await.unwrap() + .cycle_signed().await.unwrap() + + .comment("new blocks means a new signature with the same stake distribution → Signed") + .increase_block_number(125, 295).await.unwrap() + .cardano_chain_send_rollback(230).await.unwrap() + .cycle_registered().await.unwrap() + .cycle_signed().await.unwrap() + + .comment("metrics should be correctly computed") + .check_metrics(total_signer_registrations_expected,total_signature_registrations_expected).await.unwrap() + ; +} diff --git a/mithril-signer/tests/state_machine.rs b/mithril-signer/tests/create_immutable_files_full_single_signature.rs similarity index 89% rename from mithril-signer/tests/state_machine.rs rename to mithril-signer/tests/create_immutable_files_full_single_signature.rs index 809513f1124..c7666a3450b 100644 --- a/mithril-signer/tests/state_machine.rs +++ b/mithril-signer/tests/create_immutable_files_full_single_signature.rs @@ -1,19 +1,30 @@ mod test_extensions; use mithril_common::{ - crypto_helper::tests_setup, entities::Epoch, test_utils::MithrilFixtureBuilder, + crypto_helper::tests_setup, + entities::{ChainPoint, Epoch, TimePoint}, + test_utils::MithrilFixtureBuilder, }; use test_extensions::StateMachineTester; #[rustfmt::skip] #[tokio::test] -async fn test_create_single_signature() { +async fn test_create_immutable_files_full_single_signature() { let protocol_parameters = tests_setup::setup_protocol_parameters(); let fixture = MithrilFixtureBuilder::default().with_signers(10).with_protocol_parameters(protocol_parameters.into()).build(); let signers_with_stake = fixture.signers_with_stake(); - let mut tester = StateMachineTester::init(&signers_with_stake).await.expect("state machine tester init should not fail"); + let initial_time_point = TimePoint { + epoch: Epoch(1), + immutable_file_number: 1, + chain_point: ChainPoint { + slot_number: 1, + block_number: 100, + block_hash: "block_hash-100".to_string(), + }, + }; + let mut tester = StateMachineTester::init(&signers_with_stake, initial_time_point).await.expect("state machine tester init should not fail"); let total_signer_registrations_expected = 4; let total_signature_registrations_expected = 3; diff --git a/mithril-signer/tests/era_switch.rs b/mithril-signer/tests/era_switch.rs index ea9c8941481..d351112e376 100644 --- a/mithril-signer/tests/era_switch.rs +++ b/mithril-signer/tests/era_switch.rs @@ -2,7 +2,7 @@ mod test_extensions; use mithril_common::{ crypto_helper::tests_setup, - entities::Epoch, + entities::{ChainPoint, Epoch, TimePoint}, era::{EraMarker, SupportedEra}, test_utils::MithrilFixtureBuilder, }; @@ -15,7 +15,16 @@ async fn era_fail_at_startup() { let protocol_parameters = tests_setup::setup_protocol_parameters(); let fixture = MithrilFixtureBuilder::default().with_signers(10).with_protocol_parameters(protocol_parameters.into()).build(); let signers_with_stake = fixture.signers_with_stake(); - let mut tester = StateMachineTester::init(&signers_with_stake) + let initial_time_point = TimePoint { + epoch: Epoch(1), + immutable_file_number: 1, + chain_point: ChainPoint { + slot_number: 1, + block_number: 100, + block_hash: "block_hash-100".to_string(), + }, + }; + let mut tester = StateMachineTester::init(&signers_with_stake, initial_time_point) .await.expect("state machine tester init should not fail"); tester.set_era_markers(vec![EraMarker::new("whatever", Some(Epoch(0)))]); diff --git a/mithril-signer/tests/test_extensions/certificate_handler.rs b/mithril-signer/tests/test_extensions/certificate_handler.rs index a150e4bd4ff..5a52452db1b 100644 --- a/mithril-signer/tests/test_extensions/certificate_handler.rs +++ b/mithril-signer/tests/test_extensions/certificate_handler.rs @@ -4,29 +4,36 @@ use anyhow::anyhow; use async_trait::async_trait; use mithril_common::{ entities::{ - CardanoDbBeacon, CertificatePending, Epoch, EpochSettings, SignedEntityType, Signer, - SingleSignatures, TimePoint, + CertificatePending, Epoch, EpochSettings, SignedEntityConfig, SignedEntityType, + SignedEntityTypeDiscriminants, Signer, SingleSignatures, TimePoint, }, test_utils::fake_data, - CardanoNetwork, MithrilTickerService, TickerService, + MithrilTickerService, TickerService, }; use mithril_signer::{AggregatorClient, AggregatorClientError}; use tokio::sync::RwLock; pub struct FakeAggregator { - network: CardanoNetwork, + signed_entity_config: SignedEntityConfig, registered_signers: RwLock>>, ticker_service: Arc, + current_certificate_pending_signed_entity: RwLock, withhold_epoch_settings: RwLock, } impl FakeAggregator { - pub fn new(network: CardanoNetwork, ticker_service: Arc) -> Self { + pub fn new( + signed_entity_config: SignedEntityConfig, + ticker_service: Arc, + ) -> Self { Self { - network, - withhold_epoch_settings: RwLock::new(true), + signed_entity_config, registered_signers: RwLock::new(HashMap::new()), ticker_service, + current_certificate_pending_signed_entity: RwLock::new( + SignedEntityTypeDiscriminants::CardanoImmutableFilesFull, + ), + withhold_epoch_settings: RwLock::new(true), } } @@ -41,6 +48,14 @@ impl FakeAggregator { *settings = false; } + pub async fn change_certificate_pending_signed_entity( + &self, + discriminant: SignedEntityTypeDiscriminants, + ) { + let mut signed_entity = self.current_certificate_pending_signed_entity.write().await; + *signed_entity = discriminant; + } + async fn get_time_point(&self) -> Result { let time_point = self .ticker_service @@ -76,15 +91,14 @@ impl AggregatorClient for FakeAggregator { if store.is_empty() { return Ok(None); } + + let current_signed_entity = *self.current_certificate_pending_signed_entity.read().await; let time_point = self.get_time_point().await?; - let beacon = CardanoDbBeacon::new( - self.network.to_string(), - *time_point.epoch, - time_point.immutable_file_number, - ); let mut certificate_pending = CertificatePending { epoch: time_point.epoch, - signed_entity_type: SignedEntityType::CardanoImmutableFilesFull(beacon), + signed_entity_type: self + .signed_entity_config + .time_point_to_signed_entity(current_signed_entity, &time_point), ..fake_data::certificate_pending() }; @@ -131,7 +145,6 @@ mod tests { use mithril_common::digesters::DumbImmutableFileObserver; use mithril_common::entities::ChainPoint; use mithril_common::test_utils::fake_data; - use mithril_common::CardanoNetwork; use super::*; @@ -150,7 +163,7 @@ mod tests { ( chain_observer, - FakeAggregator::new(CardanoNetwork::DevNet(42), ticker_service), + FakeAggregator::new(SignedEntityConfig::dummy(), ticker_service), ) } diff --git a/mithril-signer/tests/test_extensions/state_machine_tester.rs b/mithril-signer/tests/test_extensions/state_machine_tester.rs index ca7f4b512cb..bf4a879c57e 100644 --- a/mithril-signer/tests/test_extensions/state_machine_tester.rs +++ b/mithril-signer/tests/test_extensions/state_machine_tester.rs @@ -1,4 +1,5 @@ #![allow(dead_code)] +use anyhow::anyhow; use prometheus_parse::Value; use slog::Drain; use slog_scope::debug; @@ -7,10 +8,13 @@ use thiserror::Error; use mithril_common::{ api_version::APIVersionProvider, - cardano_block_scanner::DumbBlockScanner, + cardano_block_scanner::{DumbBlockScanner, ScannedBlock}, chain_observer::{ChainObserver, FakeObserver}, digesters::{DumbImmutableDigester, DumbImmutableFileObserver, ImmutableFileObserver}, - entities::{ChainPoint, Epoch, SignerWithStake, TimePoint}, + entities::{ + BlockNumber, CardanoTransactionsSigningConfig, ChainPoint, Epoch, SignedEntityConfig, + SignedEntityTypeDiscriminants, SignerWithStake, TimePoint, + }, era::{adapters::EraReaderDummyAdapter, EraChecker, EraMarker, EraReader, SupportedEra}, signable_builder::{ CardanoImmutableFilesFullSignableBuilder, CardanoTransactionsSignableBuilder, @@ -57,10 +61,11 @@ pub struct StateMachineTester { stake_store: Arc, era_checker: Arc, era_reader_adapter: Arc, - comment_no: u32, - _logs_guard: slog_scope::GlobalLoggerGuard, + block_scanner: Arc, metrics_service: Arc, expected_metrics_service: Arc, + comment_no: u32, + _logs_guard: slog_scope::GlobalLoggerGuard, } impl Debug for StateMachineTester { @@ -71,7 +76,10 @@ impl Debug for StateMachineTester { } impl StateMachineTester { - pub async fn init(signers_with_stake: &[SignerWithStake]) -> Result { + pub async fn init( + signers_with_stake: &[SignerWithStake], + initial_time_point: TimePoint, + ) -> Result { let selected_signer_with_stake = signers_with_stake.first().ok_or_else(|| { TestError::AssertFailed("there should be at least one signer with stakes".to_string()) })?; @@ -95,21 +103,21 @@ impl StateMachineTester { let immutable_observer = Arc::new(DumbImmutableFileObserver::new()); immutable_observer.shall_return(Some(1)).await; - let chain_observer = Arc::new(FakeObserver::new(Some(TimePoint { - epoch: Epoch(1), - immutable_file_number: 1, - chain_point: ChainPoint { - slot_number: 1, - block_number: 1, - block_hash: "block_hash-1".to_string(), - }, - }))); + + let chain_observer = Arc::new(FakeObserver::new(Some(initial_time_point))); let ticker_service = Arc::new(MithrilTickerService::new( chain_observer.clone(), immutable_observer.clone(), )); let certificate_handler = Arc::new(FakeAggregator::new( - config.get_network().unwrap(), + SignedEntityConfig { + allowed_discriminants: SignedEntityTypeDiscriminants::all(), + network: config.get_network().unwrap(), + cardano_transactions_signing_config: CardanoTransactionsSigningConfig { + security_parameter: 0, + step: 30, + }, + }, ticker_service.clone(), )); let digester = Arc::new(DumbImmutableDigester::new("DIGEST", true)); @@ -150,12 +158,12 @@ impl StateMachineTester { )); let mithril_stake_distribution_signable_builder = Arc::new(MithrilStakeDistributionSignableBuilder::default()); - let transaction_parser = Arc::new(DumbBlockScanner::new(vec![])); + let block_scanner = Arc::new(DumbBlockScanner::new()); let transaction_store = Arc::new(CardanoTransactionRepository::new( transaction_sqlite_connection, )); let transaction_importer = Arc::new(CardanoTransactionsImporter::new( - transaction_parser.clone(), + block_scanner.clone(), transaction_store.clone(), Path::new(""), slog_scope::logger(), @@ -211,10 +219,11 @@ impl StateMachineTester { stake_store, era_checker, era_reader_adapter, - comment_no: 0, - _logs_guard: logs_guard, + block_scanner, metrics_service, expected_metrics_service, + comment_no: 0, + _logs_guard: logs_guard, }) } @@ -291,6 +300,17 @@ impl StateMachineTester { self } + /// make the aggregator send the certificate pending with the given signed entity from now on + pub async fn aggregator_send_signed_entity( + &mut self, + discriminant: SignedEntityTypeDiscriminants, + ) -> &mut Self { + self.certificate_handler + .change_certificate_pending_signed_entity(discriminant) + .await; + self + } + /// check there is a protocol initializer for the given Epoch pub async fn check_protocol_initializer(&mut self, epoch: Epoch) -> Result<&mut Self> { let maybe_protocol_initializer = self @@ -323,6 +343,20 @@ impl StateMachineTester { ) } + /// increase the epoch in the chain observer + pub async fn increase_epoch(&mut self, expected: u64) -> Result<&mut Self> { + let new_epoch = self + .chain_observer + .next_epoch() + .await + .ok_or_else(|| TestError::ValueError("no epoch returned".to_string()))?; + + self.assert( + expected == new_epoch, + format!("Epoch increased by 1 to {new_epoch} ({expected} expected)"), + ) + } + /// increase the immutable file number in the dumb beacon provider pub async fn increase_immutable(&mut self, increment: u64, expected: u64) -> Result<&mut Self> { let immutable_number = self @@ -342,18 +376,74 @@ impl StateMachineTester { Ok(self) } - /// increase the epoch in the chain observer - pub async fn increase_epoch(&mut self, expected: u64) -> Result<&mut Self> { - let new_epoch = self + /// increase the block number in the fake observer + pub async fn increase_block_number( + &mut self, + increment: u64, + expected: u64, + ) -> Result<&mut Self> { + let new_block_number = self .chain_observer - .next_epoch() + .increase_block_number(increment) .await - .ok_or_else(|| TestError::ValueError("no epoch returned".to_string()))?; + .ok_or_else(|| TestError::ValueError("no block number returned".to_string()))?; self.assert( - expected == new_epoch, - format!("Epoch increased by 1 to {new_epoch} ({expected} expected)"), - ) + expected == new_block_number, + format!("expected to increase block number up to {expected}, got {new_block_number}"), + )?; + + // Make the block scanner return new blocks + let current_immutable = self.immutable_observer.get_last_immutable_number().await?; + let blocks_to_scan: Vec = ((expected - increment + 1)..=expected) + .map(|block_number| { + let block_hash = format!("block_hash-{block_number}"); + let slot_number = 10 * block_number; + ScannedBlock::new( + block_hash, + block_number, + slot_number, + current_immutable, + vec![format!("tx_hash-{block_number}-1")], + ) + }) + .collect(); + self.block_scanner.add_forwards(vec![blocks_to_scan]); + + Ok(self) + } + + pub async fn cardano_chain_send_rollback( + &mut self, + rollback_to_block_number: BlockNumber, + ) -> Result<&mut Self> { + let actual_block_number = self + .chain_observer + .get_current_chain_point() + .await + .map_err(|err| TestError::SubsystemError(anyhow!(err)))? + .map(|c| c.block_number) + .ok_or_else(|| TestError::ValueError("no block number returned".to_string()))?; + let decrement = actual_block_number - rollback_to_block_number; + let new_block_number = self + .chain_observer + .decrease_block_number(decrement) + .await + .ok_or_else(|| TestError::ValueError("no block number returned".to_string()))?; + + self.assert( + rollback_to_block_number == new_block_number, + format!("expected to increase block number up to {rollback_to_block_number}, got {new_block_number}"), + )?; + + let chain_point = ChainPoint { + slot_number: 1, + block_number: rollback_to_block_number, + block_hash: format!("block_hash-{rollback_to_block_number}"), + }; + self.block_scanner.add_backward(chain_point); + + Ok(self) } async fn current_epoch(&self) -> Result {