Skip to content

Commit

Permalink
feat: follow from block interval (#582)
Browse files Browse the repository at this point in the history
* fix: unify block interval stream api

* wip: refactor of proving logic

* feat: update leader

* fix: test scripts

* fix: ci

* fix: redundand short arguments

* fix: build

* update: follow from

* fix: cleanup

* chore: update block polling time

* fix: error handling

* fix: improve error output

* fix: use cli block_time

* fix: reviews

* fix: comment

* fix: nit

* chore: passing runtime

* fix: tests

* fix: optimize

* fix: tests

* fix: clean up
  • Loading branch information
atanmarko committed Sep 4, 2024
1 parent 8b75549 commit 9126ece
Show file tree
Hide file tree
Showing 14 changed files with 378 additions and 307 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
*.iml
.idea/
.vscode
**/output.log

67 changes: 37 additions & 30 deletions zero_bin/common/src/block_interval.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use std::pin::Pin;
use std::sync::Arc;

use alloy::primitives::B256;
use alloy::rpc::types::eth::BlockId;
use alloy::{hex, providers::Provider, transports::Transport};
Expand All @@ -7,8 +10,11 @@ use futures::Stream;
use tracing::info;

use crate::parsing;
use crate::provider::CachedProvider;

const DEFAULT_BLOCK_TIME: u64 = 1000;
/// The async stream of block numbers.
/// The second bool flag indicates if the element is last in the interval.
pub type BlockIntervalStream = Pin<Box<dyn Stream<Item = Result<(u64, bool), anyhow::Error>>>>;

/// Range of blocks to be processed and proven.
#[derive(Debug, PartialEq, Clone)]
Expand All @@ -21,9 +27,6 @@ pub enum BlockInterval {
FollowFrom {
// Interval starting block number
start_block: u64,
// Block time specified in milliseconds.
// If not set, use the default block time to poll node.
block_time: Option<u64>,
},
}

Expand All @@ -44,7 +47,7 @@ impl BlockInterval {
/// assert_eq!(BlockInterval::new("0..10").unwrap(), BlockInterval::Range(0..10));
/// assert_eq!(BlockInterval::new("0..=10").unwrap(), BlockInterval::Range(0..11));
/// assert_eq!(BlockInterval::new("32141").unwrap(), BlockInterval::SingleBlockId(BlockId::Number(32141.into())));
/// assert_eq!(BlockInterval::new("100..").unwrap(), BlockInterval::FollowFrom{start_block: 100, block_time: None});
/// assert_eq!(BlockInterval::new("100..").unwrap(), BlockInterval::FollowFrom{start_block: 100});
/// ```
pub fn new(s: &str) -> anyhow::Result<BlockInterval> {
if (s.starts_with("0x") && s.len() == 66) || s.len() == 64 {
Expand Down Expand Up @@ -77,10 +80,7 @@ impl BlockInterval {
.map_err(|_| anyhow!("invalid block number '{num}'"))
})
.ok_or(anyhow!("invalid block interval range '{s}'"))??;
return Ok(BlockInterval::FollowFrom {
start_block: num,
block_time: None,
});
return Ok(BlockInterval::FollowFrom { start_block: num });
}
// Only single block number is left to try to parse
else {
Expand All @@ -92,16 +92,24 @@ impl BlockInterval {
}
}

/// Convert the block interval into an async stream of block numbers.
pub fn into_bounded_stream(self) -> anyhow::Result<impl Stream<Item = u64>> {
/// Convert the block interval into an async stream of block numbers. The
/// second bool flag indicates if the element is last in the interval.
pub fn into_bounded_stream(self) -> Result<BlockIntervalStream, anyhow::Error> {
match self {
BlockInterval::SingleBlockId(BlockId::Number(num)) => {
let num = num
.as_number()
.ok_or(anyhow!("invalid block number '{num}'"))?;
Ok(futures::stream::iter(num..num + 1))
let range = (num..num + 1).map(|it| Ok((it, true))).collect::<Vec<_>>();

Ok(Box::pin(futures::stream::iter(range)))
}
BlockInterval::Range(range) => {
let mut range = range.map(|it| Ok((it, false))).collect::<Vec<_>>();
// Set last element indicator to true
range.last_mut().map(|it| it.as_mut().map(|it| it.1 = true));
Ok(Box::pin(futures::stream::iter(range)))
}
BlockInterval::Range(range) => Ok(futures::stream::iter(range)),
_ => Err(anyhow!(
"could not create bounded stream from unbounded follow-from interval",
)),
Expand All @@ -126,36 +134,33 @@ impl BlockInterval {
/// numbers. Query the blockchain node for the latest block number.
pub async fn into_unbounded_stream<ProviderT, TransportT>(
self,
provider: ProviderT,
) -> Result<impl Stream<Item = Result<u64, anyhow::Error>>, anyhow::Error>
cached_provider: Arc<CachedProvider<ProviderT, TransportT>>,
block_time: u64,
) -> Result<BlockIntervalStream, anyhow::Error>
where
ProviderT: Provider<TransportT>,
ProviderT: Provider<TransportT> + 'static,
TransportT: Transport + Clone,
{
match self {
BlockInterval::FollowFrom {
start_block,
block_time,
} => Ok(try_stream! {
BlockInterval::FollowFrom { start_block } => Ok(Box::pin(try_stream! {
let mut current = start_block;
loop {
let last_block_number = provider.get_block_number().await.map_err(|e: alloy::transports::RpcError<_>| {
let last_block_number = cached_provider.get_provider().await?.get_block_number().await.map_err(|e: alloy::transports::RpcError<_>| {
anyhow!("could not retrieve latest block number from the provider: {e}")
})?;

if current < last_block_number {
current += 1;
yield current;
yield (current, false);
} else {
info!("Waiting for the new blocks to be mined, requested block number: {current}, \
latest block number: {last_block_number}");
let block_time = block_time.unwrap_or(DEFAULT_BLOCK_TIME);
// No need to poll the node too frequently, waiting
// a block time interval for a block to be mined should be enough
tokio::time::sleep(tokio::time::Duration::from_millis(block_time)).await;
}
}
}),
})),
_ => Err(anyhow!(
"could not create unbounded follow-from stream from fixed bounded interval",
)),
Expand Down Expand Up @@ -214,10 +219,7 @@ mod test {
fn can_create_follow_from_block_interval() {
assert_eq!(
BlockInterval::new("100..").unwrap(),
BlockInterval::FollowFrom {
start_block: 100,
block_time: None
}
BlockInterval::FollowFrom { start_block: 100 }
);
}

Expand Down Expand Up @@ -270,9 +272,14 @@ mod test {
.into_bounded_stream()
.unwrap();
while let Some(val) = stream.next().await {
result.push(val);
result.push(val.unwrap());
}
assert_eq!(result, Vec::from_iter(1u64..10u64));
let mut expected = Vec::from_iter(1u64..10u64)
.into_iter()
.map(|it| (it, false))
.collect::<Vec<_>>();
expected.last_mut().unwrap().1 = true;
assert_eq!(result, expected);
}

#[test]
Expand Down
6 changes: 6 additions & 0 deletions zero_bin/common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,9 @@ pub mod pre_checks;
pub mod prover_state;
pub mod provider;
pub mod version;

/// Size of the channel used to send block prover inputs to the per block
/// proving task. If the proving task is slow and can not consume inputs fast
/// enough retrieval of the block prover inputs will block until the proving
/// task consumes some of the inputs.
pub const BLOCK_CHANNEL_SIZE: usize = 16;
15 changes: 1 addition & 14 deletions zero_bin/leader/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,23 +51,10 @@ pub(crate) enum Command {
/// The previous proof output.
#[arg(long, short = 'f', value_hint = ValueHint::FilePath)]
previous_proof: Option<PathBuf>,
/// If provided, write the generated proofs to this directory instead of
/// stdout.
#[arg(long, short = 'o', value_hint = ValueHint::FilePath)]
proof_output_dir: Option<PathBuf>,
/// Network block time in milliseconds. This value is used
/// Blockchain network block time in milliseconds. This value is used
/// to determine the blockchain node polling interval.
#[arg(short, long, env = "ZERO_BIN_BLOCK_TIME", default_value_t = 2000)]
block_time: u64,
/// Keep intermediate proofs. Default action is to
/// delete them after the final proof is generated.
#[arg(
short,
long,
env = "ZERO_BIN_KEEP_INTERMEDIATE_PROOFS",
default_value_t = false
)]
keep_intermediate_proofs: bool,
/// Backoff in milliseconds for retry requests
#[arg(long, default_value_t = 0)]
backoff: u64,
Expand Down
137 changes: 62 additions & 75 deletions zero_bin/leader/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
use std::io::Write;
use std::path::PathBuf;
use std::sync::Arc;

use alloy::rpc::types::{BlockId, BlockNumberOrTag, BlockTransactionsKind};
use alloy::transports::http::reqwest::Url;
use anyhow::Result;
use anyhow::{anyhow, Result};
use paladin::runtime::Runtime;
use proof_gen::proof_types::GeneratedBlockProof;
use prover::ProverConfig;
use prover::{BlockProverInput, ProverConfig};
use rpc::{retry::build_http_retry_provider, RpcType};
use tracing::{error, info, warn};
use zero_bin_common::block_interval::BlockInterval;
use zero_bin_common::fs::generate_block_proof_file_name;
use tokio::sync::mpsc;
use tracing::info;
use zero_bin_common::block_interval::{BlockInterval, BlockIntervalStream};
use zero_bin_common::pre_checks::check_previous_proof_and_checkpoint;

#[derive(Debug)]
Expand All @@ -20,25 +18,24 @@ pub struct RpcParams {
pub rpc_type: RpcType,
pub backoff: u64,
pub max_retries: u32,
pub block_time: u64,
}

#[derive(Debug)]
pub struct ProofParams {
pub struct LeaderConfig {
pub checkpoint_block_number: u64,
pub previous_proof: Option<GeneratedBlockProof>,
pub proof_output_dir: Option<PathBuf>,
pub prover_config: ProverConfig,
pub keep_intermediate_proofs: bool,
}

/// The main function for the client.
pub(crate) async fn client_main(
runtime: Runtime,
runtime: Arc<Runtime>,
rpc_params: RpcParams,
block_interval: BlockInterval,
mut params: ProofParams,
mut leader_config: LeaderConfig,
) -> Result<()> {
use futures::{FutureExt, StreamExt};
use futures::StreamExt;

let cached_provider = Arc::new(zero_bin_common::provider::CachedProvider::new(
build_http_retry_provider(
Expand All @@ -48,94 +45,84 @@ pub(crate) async fn client_main(
)?,
));
check_previous_proof_and_checkpoint(
params.checkpoint_block_number,
&params.previous_proof,
leader_config.checkpoint_block_number,
&leader_config.previous_proof,
block_interval.get_start_block()?,
)?;
// Grab interval checkpoint block state trie.
let checkpoint_state_trie_root = cached_provider
.get_block(
params.checkpoint_block_number.into(),
leader_config.checkpoint_block_number.into(),
BlockTransactionsKind::Hashes,
)
.await?
.header
.state_root;

let mut block_prover_inputs = Vec::new();
let mut block_interval = block_interval.into_bounded_stream()?;
while let Some(block_num) = block_interval.next().await {
// Create a channel for block prover input and use it to send prover input to
// the proving task. The second element of the tuple is a flag indicating
// whether the block is the last one in the interval.
let (block_tx, block_rx) =
mpsc::channel::<(BlockProverInput, bool)>(zero_bin_common::BLOCK_CHANNEL_SIZE);
let test_only = leader_config.prover_config.test_only;

// Run proving task
let runtime_ = runtime.clone();
let proving_task = tokio::spawn(prover::prove(
block_rx,
runtime_,
leader_config.previous_proof.take(),
Arc::new(leader_config.prover_config),
));

// Create block interval stream. Could be bounded or unbounded.
let mut block_interval_stream: BlockIntervalStream = match block_interval {
block_interval @ BlockInterval::FollowFrom { .. } => {
block_interval
.into_unbounded_stream(cached_provider.clone(), rpc_params.block_time)
.await?
}
_ => block_interval.into_bounded_stream()?,
};

// Iterate over the block interval, retrieve prover input
// and send it to the proving task
while let Some(block_interval_elem) = block_interval_stream.next().await {
let (block_num, is_last_block) = block_interval_elem?;
let block_id = BlockId::Number(BlockNumberOrTag::Number(block_num));
// Get future of prover input for particular block.
// Get prover input for particular block.
let block_prover_input = rpc::block_prover_input(
cached_provider.clone(),
block_id,
checkpoint_state_trie_root,
rpc_params.rpc_type,
)
.boxed();
block_prover_inputs.push(block_prover_input);
.await?;
block_tx
.send((block_prover_input, is_last_block))
.await
.map_err(|e| anyhow!("failed to send block prover input through the channel: {e}"))?;
}

match proving_task.await {
Ok(Ok(_)) => {
info!("Proving task successfully finished");
}
Ok(Err(e)) => {
anyhow::bail!("Proving task finished with error: {e:?}");
}
Err(e) => {
anyhow::bail!("Unable to join proving task, error: {e:?}");
}
}

// If `keep_intermediate_proofs` is not set we only keep the last block
// proof from the interval. It contains all the necessary information to
// verify the whole sequence.
let proved_blocks = prover::prove(
block_prover_inputs,
&runtime,
params.previous_proof.take(),
params.prover_config,
params.proof_output_dir.clone(),
)
.await;
runtime.close().await?;
let proved_blocks = proved_blocks?;

if params.prover_config.test_only {
if test_only {
info!("All proof witnesses have been generated successfully.");
} else {
info!("All proofs have been generated successfully.");
}

if !params.prover_config.test_only {
if params.keep_intermediate_proofs {
if params.proof_output_dir.is_some() {
// All proof files (including intermediary) are written to disk and kept
warn!("Skipping cleanup, intermediate proof files are kept");
} else {
// Output all proofs to stdout
std::io::stdout().write_all(&serde_json::to_vec(
&proved_blocks
.into_iter()
.filter_map(|(_, block)| block)
.collect::<Vec<_>>(),
)?)?;
}
} else if let Some(proof_output_dir) = params.proof_output_dir.as_ref() {
// Remove intermediary proof files
proved_blocks
.into_iter()
.rev()
.skip(1)
.map(|(block_number, _)| {
generate_block_proof_file_name(&proof_output_dir.to_str(), block_number)
})
.for_each(|path| {
if let Err(e) = std::fs::remove_file(path) {
error!("Failed to remove intermediate proof file: {e}");
}
});
} else {
// Output only last proof to stdout
if let Some(last_block) = proved_blocks
.into_iter()
.filter_map(|(_, block)| block)
.last()
{
std::io::stdout().write_all(&serde_json::to_vec(&last_block)?)?;
}
}
}

Ok(())
}
Loading

0 comments on commit 9126ece

Please sign in to comment.