From 5d63e7e4d8908ac94555132b1de1a78abb2c9bc2 Mon Sep 17 00:00:00 2001 From: Mamy Ratsimbazafy Date: Tue, 2 Jul 2024 20:01:04 +0200 Subject: [PATCH] feat(task db): implement a task DB (#208) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * task-manager: dump fight vs sqlite * task-manager: SQL tables+views+triggers success - but arguments passed in execute are 'NULL' * task-manager: passing enqueue_task tests * task-manager: cleanup - ease copy pasting to SQL script, Registered status, persistent views, remove debug print * task-manager: add DB size query * task-manager: id_proof is unneeded + prettify queries * task-manager: change DB schema - allow multiple provers and status for same task in case of failures-retry * task-manager: allow task updates * task-manager: retrieve cached proofs from DB * task-manager: add status check * task-manager: add progress reports * chore(task_manager): Run cargo fmt * feat: address small lints * feat(task-manager): use result type with thiserror * chore(task-db): fix typos * refactor(task-manager): clean up tests * fix(docker): unignore task manager * [WIP](task_manager): write initial task handler stubs * chore(task_manager): run cargo fmt * [WIP](task_manager): write status and proof get handlers * refactor(host): use merge instead of nest * chore(format): format workflow files * chore(deps): use consistent dependency style * chore(host): rename tx to task_channel * [WIP](task_manager): add initial submit logic * chore(clippy): remove unused parameter * chore(clippy): remove unused imports * refactor(core): add copy trait to proof types * feat(task_manager): simplify db and adapt tests * fix(clippy): fix dereference issue * [WIP]: handle proof request by worker and update task status * [WIP]: add block fetching and initial blockhash getting for submit * [WIP]: handle task creation, status and proof retrieval * fix(host): fix route sub-path * feat(raiko): abstract task manager and impl a mem db for easy integration (#296) * impl a mem db for easy integration Signed-off-by: smtmfft * fix clippy and unit test Signed-off-by: smtmfft * fix fmt Signed-off-by: smtmfft --------- Signed-off-by: smtmfft * fix: throw error instead of panicing on runtime checks * fix(core,task_manager): add custom ensure and require fns * feat(task_db): sqlite and in memory abstraction (#301) * enable sqlite db by feature Signed-off-by: smtmfft * debug lifetime Signed-off-by: smtmfft * resolve lifetime issue and make all tests pass Signed-off-by: smtmfft * refactor(task_db): simplify structure for sqlite and use cached statements * feat(task_db): abstract task db implementation into wrapper * fix(task_db): add await to test call * fix(task_db): fix import declaration * fix(task_db): add async and mutable variables * fix(host): fix task manager usage * fix(task_db): fix test for async * Update Cargo.toml use in-mem as default. --------- Signed-off-by: smtmfft Co-authored-by: smtmfft Co-authored-by: smtmfft <99081233+smtmfft@users.noreply.github.com> * feat(task_manager): return empty list on key not found * feat(host,task_manager): add tracing and handle workers * feat(host): fix response structure * chore(clippy): remove unused imports * fix(ci): remove git merge added lines * fix(task_manager): add blob proof type field --------- Signed-off-by: smtmfft Co-authored-by: Petar Vujović Co-authored-by: smtmfft <99081233+smtmfft@users.noreply.github.com> Co-authored-by: smtmfft --- .dockerignore | 1 + .github/workflows/ci-native.yml | 13 +- .github/workflows/ci-risc0.yml | 12 +- .github/workflows/ci-sgx-hardware.yml | 10 +- .github/workflows/ci-sp1.yml | 12 +- .github/workflows/openapi-deploy.yml | 2 +- .gitignore | 7 + Cargo.lock | 122 +++- Cargo.toml | 7 + core/src/interfaces.rs | 21 +- core/src/lib.rs | 18 +- core/src/preflight.rs | 9 +- core/src/provider/mod.rs | 29 +- host/Cargo.toml | 1 + host/src/interfaces.rs | 47 ++ host/src/lib.rs | 225 ++++++- host/src/server/api/mod.rs | 9 +- host/src/server/api/v1/mod.rs | 27 +- host/src/server/api/v1/proof.rs | 9 +- host/src/server/api/v2/mod.rs | 73 +++ host/src/server/api/v2/proof.rs | 137 ++++ lib/src/protocol_instance.rs | 24 +- task_manager/Cargo.toml | 36 ++ task_manager/src/adv_sqlite.rs | 873 ++++++++++++++++++++++++++ task_manager/src/lib.rs | 383 +++++++++++ task_manager/src/mem_db.rs | 310 +++++++++ task_manager/tests/main.rs | 474 ++++++++++++++ 27 files changed, 2812 insertions(+), 79 deletions(-) create mode 100644 host/src/server/api/v2/mod.rs create mode 100644 host/src/server/api/v2/proof.rs create mode 100644 task_manager/Cargo.toml create mode 100644 task_manager/src/adv_sqlite.rs create mode 100644 task_manager/src/lib.rs create mode 100644 task_manager/src/mem_db.rs create mode 100644 task_manager/tests/main.rs diff --git a/.dockerignore b/.dockerignore index 0d744e0f2..438e4ad7a 100644 --- a/.dockerignore +++ b/.dockerignore @@ -19,3 +19,4 @@ !/provers/sgx/setup !/kzg_settings_raw.bin !/core +!/task_manager diff --git a/.github/workflows/ci-native.yml b/.github/workflows/ci-native.yml index 4621e7226..0c1ab7874 100644 --- a/.github/workflows/ci-native.yml +++ b/.github/workflows/ci-native.yml @@ -2,11 +2,10 @@ name: CI - Native on: workflow_call - jobs: - build-test-native: - name: Build and test native - uses: ./.github/workflows/ci-build-test-reusable.yml - with: - version_name: "native" - version_toolchain: "nightly-2024-04-17" + build-test-native: + name: Build and test native + uses: ./.github/workflows/ci-build-test-reusable.yml + with: + version_name: "native" + version_toolchain: "nightly-2024-04-17" diff --git a/.github/workflows/ci-risc0.yml b/.github/workflows/ci-risc0.yml index f0590c027..5ac10b864 100644 --- a/.github/workflows/ci-risc0.yml +++ b/.github/workflows/ci-risc0.yml @@ -12,9 +12,9 @@ on: merge_group: jobs: - build-test-risc0: - name: Build and test risc0 - uses: ./.github/workflows/ci-build-test-reusable.yml - with: - version_name: "risc0" - version_toolchain: "stable" + build-test-risc0: + name: Build and test risc0 + uses: ./.github/workflows/ci-build-test-reusable.yml + with: + version_name: "risc0" + version_toolchain: "stable" diff --git a/.github/workflows/ci-sgx-hardware.yml b/.github/workflows/ci-sgx-hardware.yml index 53c648d8b..6efa67ae8 100644 --- a/.github/workflows/ci-sgx-hardware.yml +++ b/.github/workflows/ci-sgx-hardware.yml @@ -11,7 +11,7 @@ jobs: TARGET: sgx CI: 1 EDMM: 0 - + steps: - uses: actions/checkout@v4 with: @@ -21,15 +21,15 @@ jobs: with: toolchain: stable profile: minimal - + - name: Install cargo-binstall uses: cargo-bins/cargo-binstall@v1.6.4 - + - name: Install sgx run: make install - + - name: Build sgx prover run: make build - + - name: Test sgx prover run: make test diff --git a/.github/workflows/ci-sp1.yml b/.github/workflows/ci-sp1.yml index 4da98dd70..8d8ee3a60 100644 --- a/.github/workflows/ci-sp1.yml +++ b/.github/workflows/ci-sp1.yml @@ -12,9 +12,9 @@ on: merge_group: jobs: - build-test-sgx: - name: Build and test sp1 - uses: ./.github/workflows/ci-build-test-reusable.yml - with: - version_name: "sp1" - version_toolchain: "nightly-2024-04-18" + build-test-sgx: + name: Build and test sp1 + uses: ./.github/workflows/ci-build-test-reusable.yml + with: + version_name: "sp1" + version_toolchain: "nightly-2024-04-18" diff --git a/.github/workflows/openapi-deploy.yml b/.github/workflows/openapi-deploy.yml index ee83ae59e..01c47cb86 100644 --- a/.github/workflows/openapi-deploy.yml +++ b/.github/workflows/openapi-deploy.yml @@ -44,7 +44,7 @@ jobs: if: github.ref == 'refs/heads/main' uses: actions/upload-pages-artifact@v2 with: - path: './openapi' + path: "./openapi" - name: Deploy to GitHub Pages if: github.ref == 'refs/heads/main' diff --git a/.gitignore b/.gitignore index 8be9c66e9..1053afbdb 100644 --- a/.gitignore +++ b/.gitignore @@ -38,6 +38,13 @@ target/ # MSVC Windows builds of rustc generate these, which store debugging information *.pdb +# SQLite +# ----------------------------------------------------------------------------------------- +*.sqlite +*.sqlite-shm +*.sqlite-wal +*.sqlite-journal + # Temp files, swap, debug, log, perf, cache # ----------------------------------------------------------------------------------------- *.swp diff --git a/Cargo.lock b/Cargo.lock index 2e39fe927..16e8689cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -116,7 +116,7 @@ dependencies = [ "getrandom 0.2.15", "once_cell", "version_check", - "zerocopy", + "zerocopy 0.7.34", ] [[package]] @@ -2316,7 +2316,7 @@ dependencies = [ "enr 0.12.1", "fnv", "futures", - "hashlink", + "hashlink 0.8.4", "hex", "hkdf", "lazy_static", @@ -2977,6 +2977,12 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + [[package]] name = "fastrand" version = "2.1.0" @@ -3528,6 +3534,15 @@ dependencies = [ "hashbrown 0.14.5", ] +[[package]] +name = "hashlink" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af" +dependencies = [ + "hashbrown 0.14.5", +] + [[package]] name = "hdrhistogram" version = "7.5.4" @@ -4251,6 +4266,17 @@ dependencies = [ "libc", ] +[[package]] +name = "libsqlite3-sys" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c10584274047cb335c23d3e61bcef8e323adae7c5c8c760540f73610177fc3f" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "libz-sys" version = "1.1.18" @@ -5778,6 +5804,7 @@ dependencies = [ "proptest", "raiko-core", "raiko-lib", + "raiko-task-manager", "reqwest 0.11.27", "reqwest 0.12.5", "reth-evm", @@ -5914,6 +5941,29 @@ dependencies = [ "url", ] +[[package]] +name = "raiko-task-manager" +version = "0.1.0" +dependencies = [ + "alloy-primitives", + "anyhow", + "async-trait", + "chrono", + "hex", + "num_enum 0.7.2", + "raiko-core", + "raiko-lib", + "rand 0.9.0-alpha.1", + "rand_chacha 0.9.0-alpha.1", + "rusqlite", + "serde", + "serde_json", + "tempfile", + "thiserror", + "tokio", + "tracing", +] + [[package]] name = "rand" version = "0.7.3" @@ -5938,6 +5988,17 @@ dependencies = [ "rand_core 0.6.4", ] +[[package]] +name = "rand" +version = "0.9.0-alpha.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d31e63ea85be51c423e52ba8f2e68a3efd53eed30203ee029dd09947333693e" +dependencies = [ + "rand_chacha 0.9.0-alpha.1", + "rand_core 0.9.0-alpha.1", + "zerocopy 0.8.0-alpha.6", +] + [[package]] name = "rand_chacha" version = "0.2.2" @@ -5958,6 +6019,16 @@ dependencies = [ "rand_core 0.6.4", ] +[[package]] +name = "rand_chacha" +version = "0.9.0-alpha.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78674ef918c19451dbd250f8201f8619b494f64c9aa6f3adb28fd8a0f1f6da46" +dependencies = [ + "ppv-lite86", + "rand_core 0.9.0-alpha.1", +] + [[package]] name = "rand_core" version = "0.5.1" @@ -5976,6 +6047,16 @@ dependencies = [ "getrandom 0.2.15", ] +[[package]] +name = "rand_core" +version = "0.9.0-alpha.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc89dffba8377c5ec847d12bb41492bda235dba31a25e8b695cd0fe6589eb8c9" +dependencies = [ + "getrandom 0.2.15", + "zerocopy 0.8.0-alpha.6", +] + [[package]] name = "rand_hc" version = "0.2.0" @@ -7397,6 +7478,21 @@ version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48fd7bd8a6377e15ad9d42a8ec25371b94ddc67abe7c8b9127bec79bebaaae18" +[[package]] +name = "rusqlite" +version = "0.31.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b838eba278d213a8beaf485bd313fd580ca4505a00d5871caeb1457c55322cae" +dependencies = [ + "bitflags 2.6.0", + "chrono", + "fallible-iterator", + "fallible-streaming-iterator", + "hashlink 0.9.1", + "libsqlite3-sys", + "smallvec", +] + [[package]] name = "rust-embed" version = "8.4.0" @@ -10018,7 +10114,16 @@ version = "0.7.34" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ae87e3fcd617500e5d106f0380cf7b77f3c6092aae37191433159dda23cfb087" dependencies = [ - "zerocopy-derive", + "zerocopy-derive 0.7.34", +] + +[[package]] +name = "zerocopy" +version = "0.8.0-alpha.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db678a6ee512bd06adf35c35be471cae2f9c82a5aed2b5d15e03628c98bddd57" +dependencies = [ + "zerocopy-derive 0.8.0-alpha.6", ] [[package]] @@ -10032,6 +10137,17 @@ dependencies = [ "syn 2.0.68", ] +[[package]] +name = "zerocopy-derive" +version = "0.8.0-alpha.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "201585ea96d37ee69f2ac769925ca57160cef31acb137c16f38b02b76f4c1e62" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.68", +] + [[package]] name = "zeroize" version = "1.8.1" diff --git a/Cargo.toml b/Cargo.toml index 8c9cec791..d9626f6b4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ members = [ "provers/sgx/setup", "pipeline", "core", + "task_manager", ] # Always optimize; building and running the guest takes much longer without optimization. @@ -35,6 +36,7 @@ opt-level = 3 # raiko raiko-lib = { path = "./lib", features = ["std"] } raiko-core = { path = "./core" } +raiko-task-manager = { path = "./task_manager" } # reth reth-primitives = { git = "https://github.com/taikoxyz/taiko-reth.git", branch = "v1.0.0-rc.2-taiko", default-features = false, features = ["alloy-compat", "taiko"] } @@ -109,6 +111,7 @@ base64-serde = "0.7.0" base64 = "0.21.7" libflate = { version = "2.0.0" } typetag = { version = "0.2.15" } +num_enum = "0.7.2" # tracing, logging tracing = "0.1" @@ -135,6 +138,7 @@ tokio = { version = "^1.23", features = ["full"] } hyper = { version = "0.14.27", features = ["server"] } reqwest = { version = "0.11.22", features = ["json"] } url = "2.5.0" +async-trait = "0.1.80" # crypto kzg = { package = "rust-kzg-zkcrypto", git = "https://github.com/brechtpd/rust-kzg.git", branch = "sp1-patch", default-features = false } @@ -156,6 +160,9 @@ anyhow = "1.0" thiserror = "1.0" thiserror-no-std = "2.0.2" +# SQLite +rusqlite = { version = "0.31.0", features = ["bundled"] } + # misc hashbrown = { version = "0.14", features = ["inline-more"] } tempfile = "3.8" diff --git a/core/src/interfaces.rs b/core/src/interfaces.rs index e68d5ce58..e07619732 100644 --- a/core/src/interfaces.rs +++ b/core/src/interfaces.rs @@ -79,10 +79,23 @@ impl From for RaikoError { pub type RaikoResult = Result; #[derive( - PartialEq, Eq, PartialOrd, Ord, Clone, Debug, Deserialize, Serialize, ToSchema, Hash, ValueEnum, + PartialEq, + Eq, + PartialOrd, + Ord, + Clone, + Debug, + Default, + Deserialize, + Serialize, + ToSchema, + Hash, + ValueEnum, + Copy, )] /// Available proof types. pub enum ProofType { + #[default] /// # Native /// /// This builds the block the same way the node does and then runs the result. @@ -144,7 +157,7 @@ impl ProofType { .await .map_err(|e| e.into()); #[cfg(not(feature = "sp1"))] - Err(RaikoError::FeatureNotSupportedError(self.clone())) + Err(RaikoError::FeatureNotSupportedError(*self)) } ProofType::Risc0 => { #[cfg(feature = "risc0")] @@ -152,7 +165,7 @@ impl ProofType { .await .map_err(|e| e.into()); #[cfg(not(feature = "risc0"))] - Err(RaikoError::FeatureNotSupportedError(self.clone())) + Err(RaikoError::FeatureNotSupportedError(*self)) } ProofType::Sgx => { #[cfg(feature = "sgx")] @@ -160,7 +173,7 @@ impl ProofType { .await .map_err(|e| e.into()); #[cfg(not(feature = "sgx"))] - Err(RaikoError::FeatureNotSupportedError(self.clone())) + Err(RaikoError::FeatureNotSupportedError(*self)) } }?; diff --git a/core/src/lib.rs b/core/src/lib.rs index 1e640e73b..71e99b8ae 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -158,18 +158,22 @@ fn check_eq(expected: &T, actual: &T, let _ = black_box(require_eq(expected, actual, message)); } +fn require(expression: bool, message: &str) -> RaikoResult<()> { + if !expression { + let msg = format!("Assertion failed: {message}"); + error!("{msg}"); + return Err(anyhow::Error::msg(msg).into()); + } + Ok(()) +} + fn require_eq( expected: &T, actual: &T, message: &str, ) -> RaikoResult<()> { - if expected != actual { - let msg = - format!("Assertion failed: {message} - Expected: {expected:?}, Found: {actual:?}",); - error!("{}", msg); - return Err(anyhow::Error::msg(msg).into()); - } - Ok(()) + let msg = format!("{message} - Expected: {expected:?}, Found: {actual:?}"); + require(expected == actual, &msg) } /// Merges two json's together, overwriting `a` with the values of `b` diff --git a/core/src/preflight.rs b/core/src/preflight.rs index 48687aa0a..b426de656 100644 --- a/core/src/preflight.rs +++ b/core/src/preflight.rs @@ -1,12 +1,13 @@ use crate::{ interfaces::{RaikoError, RaikoResult}, provider::{db::ProviderDb, rpc::RpcBlockDataProvider, BlockDataProvider}, + require, }; pub use alloy_primitives::*; use alloy_provider::{Provider, ReqwestProvider}; use alloy_rpc_types::{Filter, Transaction as AlloyRpcTransaction}; use alloy_sol_types::{SolCall, SolEvent}; -use anyhow::{anyhow, bail, Result}; +use anyhow::{anyhow, bail, ensure, Result}; use kzg_traits::{ eip_4844::{blob_to_kzg_commitment_rust, Blob}, G1, @@ -240,7 +241,7 @@ async fn prepare_taiko_chain_input( debug!("blob active"); // Get the blob hashes attached to the propose tx let blob_hashes = proposal_tx.blob_versioned_hashes.unwrap_or_default(); - assert!(!blob_hashes.is_empty()); + require(!blob_hashes.is_empty(), "blob hashes are empty")?; // Currently the protocol enforces the first blob hash to be used let blob_hash = blob_hashes[0]; // Get the blob data for this block @@ -363,7 +364,7 @@ async fn get_blob_data_beacon( let response = reqwest::get(url.clone()).await?; if response.status().is_success() { let blobs: GetBlobsResponse = response.json().await?; - assert!(!blobs.data.is_empty(), "blob data not available anymore"); + ensure!(!blobs.data.is_empty(), "blob data not available anymore"); // Get the blob data for the blob storing the tx list let tx_blob = blobs .data @@ -373,7 +374,7 @@ async fn get_blob_data_beacon( blob_hash == calc_blob_versioned_hash(&blob.blob) }) .cloned(); - assert!(tx_blob.is_some()); + ensure!(tx_blob.is_some()); Ok(blob_to_bytes(&tx_blob.unwrap().blob)) } else { warn!( diff --git a/core/src/provider/mod.rs b/core/src/provider/mod.rs index ee2849ad4..3d7f30ce6 100644 --- a/core/src/provider/mod.rs +++ b/core/src/provider/mod.rs @@ -1,9 +1,14 @@ -use alloy_primitives::{Address, U256}; +use alloy_primitives::{Address, B256, U256}; use alloy_rpc_types::Block; +use raiko_lib::consts::SupportedChainSpecs; use reth_primitives::revm_primitives::AccountInfo; use std::collections::HashMap; -use crate::{interfaces::RaikoResult, MerkleProof}; +use crate::{ + interfaces::{RaikoError, RaikoResult}, + provider::rpc::RpcBlockDataProvider, + MerkleProof, +}; pub mod db; pub mod rpc; @@ -24,3 +29,23 @@ pub trait BlockDataProvider { num_storage_proofs: usize, ) -> RaikoResult; } + +pub async fn get_task_data( + network: &str, + block_number: u64, + chain_specs: &SupportedChainSpecs, +) -> RaikoResult<(u64, B256)> { + let taiko_chain_spec = chain_specs + .get_chain_spec(network) + .ok_or_else(|| RaikoError::InvalidRequestConfig("Unsupported raiko network".to_string()))?; + let provider = RpcBlockDataProvider::new(&taiko_chain_spec.rpc.clone(), block_number - 1)?; + let blocks = provider.get_blocks(&[(block_number, true)]).await?; + let block = blocks + .first() + .ok_or_else(|| RaikoError::RPC("No block for requested block number".to_string()))?; + let blockhash = block + .header + .hash + .ok_or_else(|| RaikoError::RPC("No block hash for requested block".to_string()))?; + Ok((taiko_chain_spec.chain_id, blockhash)) +} diff --git a/host/Cargo.toml b/host/Cargo.toml index 3926be2f5..6e997bb34 100644 --- a/host/Cargo.toml +++ b/host/Cargo.toml @@ -14,6 +14,7 @@ sgx-prover = { path = "../provers/sgx/prover", optional = true } # raiko raiko-lib = { workspace = true } raiko-core = { workspace = true } +raiko-task-manager = { workspace = true } # alloy alloy-rlp = { workspace = true } diff --git a/host/src/interfaces.rs b/host/src/interfaces.rs index f9d2b9696..69d42cfa6 100644 --- a/host/src/interfaces.rs +++ b/host/src/interfaces.rs @@ -1,11 +1,21 @@ use axum::response::IntoResponse; use raiko_core::interfaces::ProofType; use raiko_lib::prover::ProverError; +use raiko_task_manager::{TaskManagerError, TaskStatus}; +use tokio::sync::mpsc::error::TrySendError; use utoipa::ToSchema; /// The standardized error returned by the Raiko host. #[derive(thiserror::Error, Debug, ToSchema)] pub enum HostError { + /// For unexpectedly dropping task handle. + #[error("Task handle unexpectedly dropped")] + HandleDropped, + + /// For full prover capacity. + #[error("Capacity full")] + CapacityFull, + /// For invalid address. #[error("Invalid address: {0}")] InvalidAddress(String), @@ -56,6 +66,10 @@ pub enum HostError { #[error("There was an unexpected error: {0}")] #[schema(value_type = Value)] Anyhow(#[from] anyhow::Error), + + /// For task manager errors. + #[error("There was an error with the task manager: {0}")] + TaskManager(#[from] TaskManagerError), } impl IntoResponse for HostError { @@ -74,11 +88,44 @@ impl IntoResponse for HostError { ("feature_not_supported_error".to_string(), t.to_string()) } HostError::Anyhow(e) => ("anyhow_error".to_string(), e.to_string()), + HostError::HandleDropped => ("handle_dropped".to_string(), "".to_string()), + HostError::CapacityFull => ("capacity_full".to_string(), "".to_string()), + HostError::TaskManager(e) => ("task_manager".to_string(), e.to_string()), }; axum::Json(serde_json::json!({ "status": "error", "error": error, "message": message })) .into_response() } } +impl From> for HostError { + fn from(value: TrySendError) -> Self { + match value { + TrySendError::Full(_) => HostError::CapacityFull, + TrySendError::Closed(_) => HostError::HandleDropped, + } + } +} + /// A type alias for the standardized result type returned by the Raiko host. pub type HostResult = axum::response::Result; + +impl From for TaskStatus { + fn from(value: HostError) -> Self { + match value { + HostError::HandleDropped + | HostError::CapacityFull + | HostError::JoinHandle(_) + | HostError::InvalidAddress(_) + | HostError::InvalidRequestConfig(_) => unreachable!(), + HostError::Conversion(_) + | HostError::Serde(_) + | HostError::Core(_) + | HostError::Anyhow(_) + | HostError::FeatureNotSupportedError(_) + | HostError::Io(_) => TaskStatus::UnspecifiedFailureReason, + HostError::RPC(_) => TaskStatus::NetworkFailure, + HostError::Guest(_) => TaskStatus::ProofFailure_Generic, + HostError::TaskManager(_) => TaskStatus::SqlDbCorruption, + } + } +} diff --git a/host/src/lib.rs b/host/src/lib.rs index 3cb6839db..80b59fe20 100644 --- a/host/src/lib.rs +++ b/host/src/lib.rs @@ -7,12 +7,30 @@ use std::{alloc, path::PathBuf}; use anyhow::Context; use cap::Cap; use clap::Parser; -use raiko_core::{interfaces::ProofRequestOpt, merge}; -use raiko_lib::consts::SupportedChainSpecs; +use raiko_core::{ + interfaces::{ProofRequest, ProofRequestOpt, RaikoError}, + merge, + provider::{get_task_data, rpc::RpcBlockDataProvider}, + Raiko, +}; +use raiko_lib::{consts::SupportedChainSpecs, Measurement}; +use raiko_task_manager::{get_task_manager, TaskManager, TaskManagerOpts, TaskStatus}; use serde::{Deserialize, Serialize}; use serde_json::Value; +use tokio::sync::mpsc; +use tracing::{error, info}; -use crate::interfaces::HostResult; +use crate::{ + interfaces::{HostError, HostResult}, + metrics::{ + inc_guest_error, inc_guest_req_count, inc_guest_success, inc_host_error, + inc_host_req_count, observe_guest_time, observe_prepare_input_time, observe_total_time, + }, + server::api::v1::{ + proof::{get_cached_input, set_cached_input, validate_cache_input}, + ProofResponse, + }, +}; #[global_allocator] static ALLOCATOR: Cap = Cap::new(alloc::System, usize::MAX); @@ -90,6 +108,13 @@ pub struct Cli { #[arg(long, require_equals = true)] /// Set jwt secret for auth jwt_secret: Option, + + #[arg(long, require_equals = true, default_value = "raiko.sqlite")] + /// Set the path to the sqlite db file + sqlite_file: PathBuf, + + #[arg(long, require_equals = true, default_value = "1048576")] + max_db_size: usize, } impl Cli { @@ -106,10 +131,31 @@ impl Cli { } } +type TaskChannelOpts = (ProofRequest, Cli, SupportedChainSpecs); + #[derive(Debug, Clone)] pub struct ProverState { pub opts: Cli, pub chain_specs: SupportedChainSpecs, + pub task_channel: mpsc::Sender, +} + +impl From for TaskManagerOpts { + fn from(val: Cli) -> Self { + Self { + sqlite_file: val.sqlite_file, + max_db_size: val.max_db_size, + } + } +} + +impl From<&Cli> for TaskManagerOpts { + fn from(val: &Cli) -> Self { + Self { + sqlite_file: val.sqlite_file.clone(), + max_db_size: val.max_db_size, + } + } } impl ProverState { @@ -132,10 +178,181 @@ impl ProverState { } } - Ok(Self { opts, chain_specs }) + let (task_channel, mut receiver) = mpsc::channel::(opts.concurrency_limit); + + let _spawn = tokio::spawn(async move { + while let Some((proof_request, opts, chain_specs)) = receiver.recv().await { + let Ok((chain_id, blockhash)) = get_task_data( + &proof_request.network, + proof_request.block_number, + &chain_specs, + ) + .await + else { + error!("Could not retrieve chain ID and blockhash"); + continue; + }; + let mut manager = get_task_manager(&opts.clone().into()); + if manager + .update_task_progress( + chain_id, + blockhash, + proof_request.proof_type, + Some(proof_request.prover.to_string()), + TaskStatus::WorkInProgress, + None, + ) + .await + .is_err() + { + error!("Could not update task to work in progress via task manager"); + } + match handle_proof(&proof_request, &opts, &chain_specs).await { + Ok(proof) => { + let proof = proof.proof.unwrap_or_default(); + let proof = proof.as_bytes(); + if manager + .update_task_progress( + chain_id, + blockhash, + proof_request.proof_type, + Some(proof_request.prover.to_string()), + TaskStatus::Success, + Some(proof), + ) + .await + .is_err() + { + error!("Could not update task progress to success via task manager"); + } + } + Err(error) => { + if manager + .update_task_progress( + chain_id, + blockhash, + proof_request.proof_type, + Some(proof_request.prover.to_string()), + error.into(), + None, + ) + .await + .is_err() + { + error!( + "Could not update task progress to error state via task manager" + ); + } + } + } + } + }); + + Ok(Self { + opts, + chain_specs, + task_channel, + }) } } +pub async fn handle_proof( + proof_request: &ProofRequest, + opts: &Cli, + chain_specs: &SupportedChainSpecs, +) -> HostResult { + inc_host_req_count(proof_request.block_number); + inc_guest_req_count(&proof_request.proof_type, proof_request.block_number); + + info!( + "# Generating proof for block {} on {}", + proof_request.block_number, proof_request.network + ); + + // Check for a cached input for the given request config. + let cached_input = get_cached_input( + &opts.cache_path, + proof_request.block_number, + &proof_request.network.to_string(), + ); + + let l1_chain_spec = chain_specs + .get_chain_spec(&proof_request.l1_network.to_string()) + .ok_or_else(|| HostError::InvalidRequestConfig("Unsupported l1 network".to_string()))?; + + let taiko_chain_spec = chain_specs + .get_chain_spec(&proof_request.network.to_string()) + .ok_or_else(|| HostError::InvalidRequestConfig("Unsupported raiko network".to_string()))?; + + // Execute the proof generation. + let total_time = Measurement::start("", false); + + let raiko = Raiko::new( + l1_chain_spec.clone(), + taiko_chain_spec.clone(), + proof_request.clone(), + ); + let provider = RpcBlockDataProvider::new( + &taiko_chain_spec.rpc.clone(), + proof_request.block_number - 1, + )?; + let input = match validate_cache_input(cached_input, &provider).await { + Ok(cache_input) => cache_input, + Err(_) => { + // no valid cache + memory::reset_stats(); + let measurement = Measurement::start("Generating input...", false); + let input = raiko.generate_input(provider).await?; + let input_time = measurement.stop_with("=> Input generated"); + observe_prepare_input_time(proof_request.block_number, input_time, true); + memory::print_stats("Input generation peak memory used: "); + input + } + }; + memory::reset_stats(); + let output = raiko.get_output(&input)?; + memory::print_stats("Guest program peak memory used: "); + + memory::reset_stats(); + let measurement = Measurement::start("Generating proof...", false); + let proof = raiko.prove(input.clone(), &output).await.map_err(|e| { + let total_time = total_time.stop_with("====> Proof generation failed"); + observe_total_time(proof_request.block_number, total_time, false); + match e { + RaikoError::Guest(e) => { + inc_guest_error(&proof_request.proof_type, proof_request.block_number); + HostError::Core(e.into()) + } + e => { + inc_host_error(proof_request.block_number); + e.into() + } + } + })?; + let guest_time = measurement.stop_with("=> Proof generated"); + observe_guest_time( + &proof_request.proof_type, + proof_request.block_number, + guest_time, + true, + ); + memory::print_stats("Prover peak memory used: "); + + inc_guest_success(&proof_request.proof_type, proof_request.block_number); + let total_time = total_time.stop_with("====> Complete proof generated"); + observe_total_time(proof_request.block_number, total_time, true); + + // Cache the input for future use. + set_cached_input( + &opts.cache_path, + proof_request.block_number, + &proof_request.network.to_string(), + &input, + )?; + + ProofResponse::try_from(proof) +} + mod memory { use tracing::debug; diff --git a/host/src/server/api/mod.rs b/host/src/server/api/mod.rs index 11e3e394e..806698a95 100644 --- a/host/src/server/api/mod.rs +++ b/host/src/server/api/mod.rs @@ -16,7 +16,8 @@ use tower_http::{ use crate::ProverState; -mod v1; +pub mod v1; +pub mod v2; pub fn create_router(concurrency_limit: usize, jwt_secret: Option<&str>) -> Router { let cors = CorsLayer::new() @@ -35,10 +36,12 @@ pub fn create_router(concurrency_limit: usize, jwt_secret: Option<&str>) -> Rout let trace = TraceLayer::new_for_http(); let v1_api = v1::create_router(concurrency_limit); + let v2_api = v2::create_router(); let router = Router::new() - .nest("/v1", v1_api.clone()) - .merge(v1_api) + .nest("/v1", v1_api) + .nest("/v2", v2_api.clone()) + .merge(v2_api) .layer(middleware) .layer(middleware::from_fn(check_max_body_size)) .layer(trace) diff --git a/host/src/server/api/v1/mod.rs b/host/src/server/api/v1/mod.rs index 3977e49c8..d84113816 100644 --- a/host/src/server/api/v1/mod.rs +++ b/host/src/server/api/v1/mod.rs @@ -9,9 +9,9 @@ use utoipa_swagger_ui::SwaggerUi; use crate::{interfaces::HostError, ProverState}; -mod health; -mod metrics; -mod proof; +pub mod health; +pub mod metrics; +pub mod proof; #[derive(OpenApi)] #[openapi( @@ -53,20 +53,25 @@ pub struct Docs; pub struct ProofResponse { #[schema(value_type = Option)] /// The output of the prover. - output: Option, + pub output: Option, /// The proof. - proof: Option, + pub proof: Option, /// The quote. - quote: Option, + pub quote: Option, } -impl IntoResponse for ProofResponse { - fn into_response(self) -> axum::response::Response { - axum::Json(serde_json::json!({ +impl ProofResponse { + pub fn to_response(&self) -> Value { + serde_json::json!({ "status": "ok", "data": self - })) - .into_response() + }) + } +} + +impl IntoResponse for ProofResponse { + fn into_response(self) -> axum::response::Response { + axum::Json(self.to_response()).into_response() } } diff --git a/host/src/server/api/v1/proof.rs b/host/src/server/api/v1/proof.rs index 4a306231a..7c22b6769 100644 --- a/host/src/server/api/v1/proof.rs +++ b/host/src/server/api/v1/proof.rs @@ -26,7 +26,7 @@ use crate::{ ProverState, }; -fn get_cached_input( +pub fn get_cached_input( cache_path: &Option, block_number: u64, network: &str, @@ -40,7 +40,7 @@ fn get_cached_input( bincode::deserialize_from(file).ok() } -fn set_cached_input( +pub fn set_cached_input( cache_path: &Option, block_number: u64, network: &str, @@ -57,7 +57,7 @@ fn set_cached_input( bincode::serialize_into(file, input).map_err(|e| HostError::Anyhow(e.into())) } -async fn validate_cache_input( +pub async fn validate_cache_input( cached_input: Option, provider: &RpcBlockDataProvider, ) -> HostResult { @@ -92,10 +92,11 @@ async fn validate_cache_input( } } -async fn handle_proof( +pub async fn handle_proof( ProverState { opts, chain_specs: support_chain_specs, + .. }: ProverState, req: Value, ) -> HostResult { diff --git a/host/src/server/api/v2/mod.rs b/host/src/server/api/v2/mod.rs new file mode 100644 index 000000000..47d9b49ff --- /dev/null +++ b/host/src/server/api/v2/mod.rs @@ -0,0 +1,73 @@ +use axum::Router; +use utoipa::OpenApi; +use utoipa_scalar::{Scalar, Servable}; +use utoipa_swagger_ui::SwaggerUi; + +use crate::{ + server::api::v1::{self, GuestOutputDoc, ProofResponse, Status}, + ProverState, +}; + +mod proof; + +#[derive(OpenApi)] +#[openapi( + info( + title = "Raiko Proverd Server API", + version = "2.0", + description = "Raiko Proverd Server API", + contact( + name = "API Support", + url = "https://community.taiko.xyz", + email = "info@taiko.xyz", + ), + license( + name = "MIT", + url = "https://github.com/taikoxyz/raiko/blob/taiko/unstable/LICENSE" + ), + ), + components( + schemas( + raiko_core::interfaces::ProofRequestOpt, + raiko_core::interfaces::ProverSpecificOpts, + crate::interfaces::HostError, + GuestOutputDoc, + ProofResponse, + Status, + ) + ), + tags( + (name = "Proving", description = "Routes that handle proving requests"), + (name = "Health", description = "Routes that report the server health status"), + (name = "Metrics", description = "Routes that give detailed insight into the server") + ) +)] +/// The root API struct which is generated from the `OpenApi` derive macro. +pub struct Docs; + +#[must_use] +pub fn create_docs() -> utoipa::openapi::OpenApi { + [ + v1::health::create_docs(), + v1::metrics::create_docs(), + proof::create_docs(), + ] + .into_iter() + .fold(Docs::openapi(), |mut doc, sub_doc| { + doc.merge(sub_doc); + doc + }) +} + +pub fn create_router() -> Router { + let docs = create_docs(); + + Router::new() + // Only add the concurrency limit to the proof route. We want to still be able to call + // healthchecks and metrics to have insight into the system. + .nest("/proof", proof::create_router()) + .nest("/health", v1::health::create_router()) + .nest("/metrics", v1::metrics::create_router()) + .merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", docs.clone())) + .merge(Scalar::with_url("/scalar", docs)) +} diff --git a/host/src/server/api/v2/proof.rs b/host/src/server/api/v2/proof.rs new file mode 100644 index 000000000..6460ffddc --- /dev/null +++ b/host/src/server/api/v2/proof.rs @@ -0,0 +1,137 @@ +use axum::{debug_handler, extract::State, routing::post, Json, Router}; +use raiko_core::interfaces::ProofRequest; +use raiko_core::provider::get_task_data; +use raiko_task_manager::{get_task_manager, EnqueueTaskParams, TaskManager, TaskStatus}; +use serde_json::Value; +use tracing::info; +use utoipa::OpenApi; + +use crate::{ + interfaces::HostResult, + metrics::{inc_current_req, inc_guest_req_count, inc_host_req_count}, + server::api::v1::ProofResponse, + ProverState, +}; + +#[utoipa::path(post, path = "/proof", + tag = "Proving", + request_body = ProofRequestOpt, + responses ( + (status = 200, description = "Successfully submitted proof task", body = Status) + ) +)] +#[debug_handler(state = ProverState)] +/// Submit a proof task with requested config, get task status or get proof value. +/// +/// Accepts a proof request and creates a proving task with the specified guest prover. +/// The guest provers currently available are: +/// - native - constructs a block and checks for equality +/// - sgx - uses the sgx environment to construct a block and produce proof of execution +/// - sp1 - uses the sp1 prover +/// - risc0 - uses the risc0 prover +async fn proof_handler( + State(prover_state): State, + Json(req): Json, +) -> HostResult> { + inc_current_req(); + // Override the existing proof request config from the config file and command line + // options with the request from the client. + let mut config = prover_state.opts.proof_request_opt.clone(); + config.merge(&req)?; + + // Construct the actual proof request from the available configs. + let proof_request = ProofRequest::try_from(config)?; + inc_host_req_count(proof_request.block_number); + inc_guest_req_count(&proof_request.proof_type, proof_request.block_number); + + let (chain_id, block_hash) = get_task_data( + &proof_request.network, + proof_request.block_number, + &prover_state.chain_specs, + ) + .await?; + + let mut manager = get_task_manager(&(&prover_state.opts).into()); + let status = manager + .get_task_proving_status( + chain_id, + block_hash, + proof_request.proof_type, + Some(proof_request.prover.to_string()), + ) + .await?; + + if status.is_empty() { + info!( + "# Generating proof for block {} on {}", + proof_request.block_number, proof_request.network + ); + + manager + .enqueue_task(&EnqueueTaskParams { + chain_id, + blockhash: block_hash, + proof_type: proof_request.proof_type, + prover: proof_request.prover.to_string(), + block_number: proof_request.block_number, + }) + .await?; + + prover_state.task_channel.try_send(( + proof_request.clone(), + prover_state.opts, + prover_state.chain_specs, + ))?; + + return Ok(Json(serde_json::json!( + { + "status": "ok", + "data": { + "status": TaskStatus::Registered, + } + } + ))); + } + + let status = status.last().unwrap().0; + + if matches!(status, TaskStatus::Success) { + let proof = manager + .get_task_proof( + chain_id, + block_hash, + proof_request.proof_type, + Some(proof_request.prover.to_string()), + ) + .await?; + + let response = ProofResponse { + proof: Some(String::from_utf8(proof).unwrap()), + output: None, + quote: None, + }; + + return Ok(Json(response.to_response())); + } + + Ok(Json(serde_json::json!( + { + "status": "ok", + "data": { + "status": status, + } + } + ))) +} + +#[derive(OpenApi)] +#[openapi(paths(proof_handler))] +struct Docs; + +pub fn create_docs() -> utoipa::openapi::OpenApi { + Docs::openapi() +} + +pub fn create_router() -> Router { + Router::new().route("/", post(proof_handler)) +} diff --git a/lib/src/protocol_instance.rs b/lib/src/protocol_instance.rs index f6e7ba0e1..2fc50453d 100644 --- a/lib/src/protocol_instance.rs +++ b/lib/src/protocol_instance.rs @@ -70,28 +70,28 @@ impl ProtocolInstance { if let Some(verified_chain_spec) = SupportedChainSpecs::default().get_chain_spec_with_chain_id(input.chain_spec.chain_id) { - assert_eq!( - input.chain_spec.max_spec_id, verified_chain_spec.max_spec_id, + ensure!( + input.chain_spec.max_spec_id == verified_chain_spec.max_spec_id, "unexpected max_spec_id" ); - assert_eq!( - input.chain_spec.hard_forks, verified_chain_spec.hard_forks, + ensure!( + input.chain_spec.hard_forks == verified_chain_spec.hard_forks, "unexpected hard_forks" ); - assert_eq!( - input.chain_spec.eip_1559_constants, verified_chain_spec.eip_1559_constants, + ensure!( + input.chain_spec.eip_1559_constants == verified_chain_spec.eip_1559_constants, "unexpected eip_1559_constants" ); - assert_eq!( - input.chain_spec.l1_contract, verified_chain_spec.l1_contract, + ensure!( + input.chain_spec.l1_contract == verified_chain_spec.l1_contract, "unexpected l1_contract" ); - assert_eq!( - input.chain_spec.l2_contract, verified_chain_spec.l2_contract, + ensure!( + input.chain_spec.l2_contract == verified_chain_spec.l2_contract, "unexpected l2_contract" ); - assert_eq!( - input.chain_spec.is_taiko, verified_chain_spec.is_taiko, + ensure!( + input.chain_spec.is_taiko == verified_chain_spec.is_taiko, "unexpected eip_1559_constants" ); } diff --git a/task_manager/Cargo.toml b/task_manager/Cargo.toml new file mode 100644 index 000000000..ec888c35e --- /dev/null +++ b/task_manager/Cargo.toml @@ -0,0 +1,36 @@ +[package] +name = "raiko-task-manager" +version = "0.1.0" +authors = ["Mamy Ratsimbazafy "] +edition = "2021" # { workspace = true } + +[dependencies] +raiko-lib = { workspace = true } +raiko-core = { workspace = true } +rusqlite = { workspace = true, features = ["chrono"] } +num_enum = { workspace = true } +chrono = { workspace = true, features = ["serde"] } +thiserror = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +hex = { workspace = true } +tracing = { workspace = true } +anyhow = { workspace = true } +tokio = { workspace = true } +async-trait = { workspace = true } + +[dev-dependencies] +rand = "0.9.0-alpha.1" # This is an alpha version, that has rng.gen_iter::() +rand_chacha = "0.9.0-alpha.1" +tempfile = "3.10.1" +alloy-primitives = { workspace = true, features = ["getrandom"] } +rusqlite = { workspace = true, features = ["trace"] } + +[features] +default = ["in-memory"] +sqlite = [] +in-memory = [] + +[[test]] +name = "task_manager_tests" +path = "tests/main.rs" diff --git a/task_manager/src/adv_sqlite.rs b/task_manager/src/adv_sqlite.rs new file mode 100644 index 000000000..25cd215a4 --- /dev/null +++ b/task_manager/src/adv_sqlite.rs @@ -0,0 +1,873 @@ +// Raiko +// Copyright (c) 2024 Taiko Labs +// Licensed and distributed under either of +// * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). +// * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). +// at your option. This file may not be copied, modified, or distributed except according to those terms. + +//! # Raiko Task Manager +//! +//! At the moment (Apr '24) proving requires a significant amount of time +//! and maintaining a connection with a potentially external party. +//! +//! By design Raiko is stateless, it prepares inputs and forward to the various proof systems. +//! However some proving backend like Risc0's Bonsai are also stateless, +//! and only accepts proofs and return result. +//! Hence to handle crashes, networking losses and restarts, we need to persist +//! the status of proof requests, task submitted, proof received, proof forwarded. +//! +//! In the diagram: +//! _____________ ______________ _______________ +//! Taiko L2 -> | Taiko-geth | ======> | Raiko-host | =========> | Raiko-guests | +//! | Taiko-reth | | | | Risc0 | +//! |____________| |_____________| | SGX | +//! | SP1 | +//! |______________| +//! _____________________________ +//! =========> | Prover Networks | +//! | Risc0's Bonsai | +//! | Succinct's Prover Network | +//! |____________________________| +//! _________________________ +//! =========> | Raiko-dist | +//! | Distributed Risc0 | +//! | Distributed SP1 | +//! |_______________________| +//! +//! We would position Raiko task manager either before Raiko-host or after Raiko-host. +//! +//! ## Implementation +//! +//! The task manager is a set of tables and KV-stores. +//! - Keys for table joins are prefixed with id +//! - KV-stores for (almost) immutable data +//! - KV-store for large inputs and indistinguishable from random proofs +//! - Tables for tasks and their metadata. +//! +//! __________________________ +//! | metadata | +//! |_________________________| A simple KV-store with the DB version for migration/upgrade detection. +//! | Key | Value | Future version may add new fields, without breaking older versions. +//! |_________________|_______| +//! | task_db_version | 0 | +//! |_________________|_______| +//! +//! ________________________ +//! | Proof systems | +//! |______________________| A map: ID -> proof systems +//! | id_proofsys | Desc | +//! |_____________|________| +//! | 0 | Risc0 | (0 for Risc0 and 1 for SP1 is intentional) +//! | 1 | SP1 | +//! | 2 | SGX | +//! |_____________|________| +//! +//! _________________________________________________ +//! | Task Status code | +//! |________________________________________________| +//! | id_status | Desc | +//! |_____________|__________________________________| +//! | 0 | Success | +//! | 1000 | Registered | +//! | 2000 | Work-in-progress | +//! | | | +//! | -1000 | Proof failure (prover - generic) | +//! | -1100 | Proof failure (OOM) | +//! | | | +//! | -2000 | Network failure | +//! | | | +//! | -3000 | Cancelled | +//! | -3100 | Cancelled (never started) | +//! | -3200 | Cancelled (aborted) | +//! | -3210 | Cancellation in progress | (Yes -3210 is intentional ;)) +//! | | | +//! | -4000 | Invalid or unsupported block | +//! | | | +//! | -9999 | Unspecified failure reason | +//! |_____________|__________________________________| +//! +//! Rationale: +//! - Convention, failures use negative status code. +//! - We leave space for new status codes +//! - -X000 status code are for generic failures segregated by failures: +//! on the networking side, the prover side or trying to prove an invalid block. +//! +//! A catchall -9999 error code is provided if a failure is not due to +//! either the network, the prover or the requester invalid block. +//! They should not exist in the DB and a proper analysis +//! and eventually status code should be assigned. +//! +//! ________________________________________________________________________________________________ +//! | Tasks metadata | +//! |________________________________________________________________________________________________| +//! | id_task | chain_id | block_number | blockhash | parent_hash | state_root | # of txs | gas_used | +//! |_________|__________|______________|___________|_____________|____________|__________|__________| +//! ____________________________________ +//! | Task queue | +//! |___________________________________| +//! | id_task | blockhash | id_proofsys | +//! |_________|___________|_____________| +//! ______________________________________ +//! | Task payloads | +//! |_____________________________________| +//! | id_task | inputs (serialized) | +//! |_________|___________________________| +//! _____________________________________ +//! | Task requests | +//! |____________________________________| +//! | id_task | id_submitter | timestamp | +//! |_________|______________|___________| +//! ___________________________________________________________________________________ +//! | Task progress trail | +//! |__________________________________________________________________________________| +//! | id_task | third_party | id_status | timestamp | +//! |_________|________________________|_________________________|_____________________| +//! | 101 | 'Based Proposer" | 1000 (Registered) | 2024-01-01 00:00:01 | +//! | 101 | 'A Prover Network' | 2000 (WIP) | 2024-01-01 00:00:01 | +//! | 101 | 'A Prover Network' | -2000 (Network failure) | 2024-01-01 00:02:00 | +//! | 101 | 'Proof in the Pudding' | 2000 (WIP) | 2024-01-01 00:02:30 | +//!·| 101 | 'Proof in the Pudding' | 0 (Success) | 2024-01-01 01:02:30 | +//! +//! Rationale: +//! - payloads are very large and warrant a dedicated table, with pruning +//! - metadata is useful to audit block building and prover efficiency +//! - Due to failures and retries, we may submit the same task to multiple fulfillers +//! or retry with the same fulfiller so we keep an audit trail of events. +//! +//! ____________________________ +//! | Proof cache | A map: ID -> proof +//! |___________________________| +//! | id_task | proof_value | +//! |__________|________________| A Groth16 proof is 2G₁+1G₂ elements +//! | 0 | 0xabcd...6789 | On BN254: 2*(2*32)+1*(2*2*32) = 256 bytes +//! | 1 | 0x1234...cdef | +//! | ... | ... | A SGX proof is ... +//! |__________|________________| A Stark proof (not wrapped in Groth16) would be several kilobytes +//! +//! Do we need pruning? +//! There are 60s * 60min * 24h * 30j = 2592000s in a month +//! dividing by 12, that's 216000 Ethereum slots. +//! Assuming 1kB of proofs per block (Stark-to-Groth16 Risc0 & SP1 + SGX, SGX size to be verified) +//! That's only 216MB per month. + +// Imports +// ---------------------------------------------------------------- +use std::{ + fs::File, + path::Path, + sync::{Arc, Once}, +}; + +use chrono::{DateTime, Utc}; +use raiko_core::interfaces::ProofType; +use raiko_lib::primitives::{ChainId, B256}; +use rusqlite::{ + named_params, {Connection, OpenFlags}, +}; +use tokio::sync::Mutex; + +use crate::{ + EnqueueTaskParams, TaskManager, TaskManagerError, TaskManagerOpts, TaskManagerResult, + TaskProvingStatus, TaskProvingStatusRecords, TaskStatus, +}; + +// Types +// ---------------------------------------------------------------- + +#[derive(Debug)] +pub struct TaskDb { + conn: Connection, +} + +pub struct SqliteTaskManager { + arc_task_db: Arc>, +} + +// Implementation +// ---------------------------------------------------------------- + +impl TaskDb { + fn open(path: &Path) -> TaskManagerResult { + let conn = Connection::open_with_flags(path, OpenFlags::SQLITE_OPEN_READ_WRITE)?; + conn.pragma_update(None, "foreign_keys", true)?; + conn.pragma_update(None, "locking_mode", "EXCLUSIVE")?; + conn.pragma_update(None, "journal_mode", "WAL")?; + conn.pragma_update(None, "synchronous", "NORMAL")?; + conn.pragma_update(None, "temp_store", "MEMORY")?; + Ok(conn) + } + + fn create(path: &Path) -> TaskManagerResult { + let _file = File::options() + .write(true) + .read(true) + .create_new(true) + .open(path)?; + + let conn = Self::open(path)?; + Self::create_tables(&conn)?; + Self::create_views(&conn)?; + + Ok(conn) + } + + /// Open an existing TaskDb database at "path" + /// If a database does not exist at the path, one is created. + pub fn open_or_create(path: &Path) -> TaskManagerResult { + let conn = if path.exists() { + Self::open(path) + } else { + Self::create(path) + }?; + Ok(Self { conn }) + } + + // SQL + // ---------------------------------------------------------------- + + fn create_tables(conn: &Connection) -> TaskManagerResult<()> { + // Change the task_db_version if backward compatibility is broken + // and introduce a migration on DB opening ... if conserving history is important. + conn.execute_batch( + r#" + -- Metadata and mappings + ----------------------------------------------- + CREATE TABLE metadata( + key BLOB UNIQUE NOT NULL PRIMARY KEY, + value BLOB + ); + + INSERT INTO + metadata(key, value) + VALUES + ('task_db_version', 0); + + CREATE TABLE proofsys( + id INTEGER UNIQUE NOT NULL PRIMARY KEY, + desc TEXT NOT NULL + ); + + INSERT INTO + proofsys(id, desc) + VALUES + (0, 'Native'), + (1, 'Risc0'), + (2, 'SP1'), + (3, 'SGX'); + + CREATE TABLE status_codes( + id INTEGER UNIQUE NOT NULL PRIMARY KEY, + desc TEXT NOT NULL + ); + + INSERT INTO + status_codes(id, desc) + VALUES + (0, 'Success'), + (1000, 'Registered'), + (2000, 'Work-in-progress'), + (-1000, 'Proof failure (generic)'), + (-1100, 'Proof failure (Out-Of-Memory)'), + (-2000, 'Network failure'), + (-3000, 'Cancelled'), + (-3100, 'Cancelled (never started)'), + (-3200, 'Cancelled (aborted)'), + (-3210, 'Cancellation in progress'), + (-4000, 'Invalid or unsupported block'), + (-9999, 'Unspecified failure reason'); + + -- Data + ----------------------------------------------- + -- Notes: + -- 1. a blockhash may appear as many times as there are prover backends. + -- 2. For query speed over (chain_id, blockhash) + -- there is no need to create an index as the UNIQUE constraint + -- has an implied index, see: + -- - https://sqlite.org/lang_createtable.html#uniqueconst + -- - https://www.sqlite.org/fileformat2.html#representation_of_sql_indices + CREATE TABLE tasks( + id INTEGER UNIQUE NOT NULL PRIMARY KEY, + chain_id INTEGER NOT NULL, + blockhash BLOB NOT NULL, + proofsys_id INTEGER NOT NULL, + prover TEXT NOT NULL, + FOREIGN KEY(proofsys_id) REFERENCES proofsys(id), + UNIQUE (chain_id, blockhash, proofsys_id) + ); + + -- Proofs might also be large, so we isolate them in a dedicated table + CREATE TABLE task_proofs( + task_id INTEGER UNIQUE NOT NULL PRIMARY KEY, + proof BLOB NOT NULL, + FOREIGN KEY(task_id) REFERENCES tasks(id) + ); + + CREATE TABLE task_status( + task_id INTEGER NOT NULL, + status_id INTEGER NOT NULL, + timestamp TIMESTAMP DEFAULT (STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) NOT NULL, + FOREIGN KEY(task_id) REFERENCES tasks(id), + FOREIGN KEY(status_id) REFERENCES status_codes(id), + UNIQUE (task_id, timestamp) + ); + "#, + )?; + + Ok(()) + } + + fn create_views(conn: &Connection) -> TaskManagerResult<()> { + // By convention, views will use an action verb as name. + conn.execute_batch( + r#" + CREATE VIEW enqueue_task AS + SELECT + t.id, + t.chain_id, + t.blockhash, + t.proofsys_id, + t.prover + FROM + tasks t + LEFT JOIN task_status ts on ts.task_id = t.id; + + CREATE VIEW update_task_progress AS + SELECT + t.id, + t.chain_id, + t.blockhash, + t.proofsys_id, + t.prover, + ts.status_id, + tpf.proof + FROM + tasks t + LEFT JOIN task_status ts on ts.task_id = t.id + LEFT JOIN task_proofs tpf on tpf.task_id = t.id; + "#, + )?; + + Ok(()) + } + + /// Set a tracer to debug SQL execution + /// for example: + /// db.set_tracer(Some(|stmt| println!("sqlite:\n-------\n{}\n=======", stmt))); + #[cfg(test)] + #[allow(dead_code)] + pub fn set_tracer(&mut self, trace_fn: Option) { + self.conn.trace(trace_fn); + } + + pub fn manage(&self) -> TaskManagerResult<()> { + // To update all the tables with the task_id assigned by Sqlite + // we require row IDs for the tasks table + // and we use last_insert_rowid() which is not reentrant and need a transaction lock + // and store them in a temporary table, configured to be in-memory. + // + // Alternative approaches considered: + // 1. Sqlite does not support variables (because it's embedded and significantly less overhead than other SQL "Client-Server" DBs). + // 2. using AUTOINCREMENT and/or the sqlite_sequence table + // - sqlite recommends not using AUTOINCREMENT for performance + // https://www.sqlite.org/autoinc.html + // 3. INSERT INTO ... RETURNING nested in a WITH clause (CTE / Common Table Expression) + // - Sqlite can only do RETURNING to the application, it cannot be nested in another query or diverted to another table + // https://sqlite.org/lang_returning.html#limitations_and_caveats + // 4. CREATE TEMPORARY TABLE AS with an INSERT INTO ... RETURNING nested + // - Same limitation AND CREATE TABLEAS seems to only support SELECT statements (but if we could nest RETURNING we can workaround that + // https://www.sqlite.org/lang_createtable.html#create_table_as_select_statements + // + // Hence we have to use row IDs and last_insert_rowid() + // + // Furthermore we use a view and an INSTEAD OF trigger to update the tables, + // the alternative being + // + // 5. Direct insert into tables + // This does not work as SQLite `execute` and `prepare` + // only process the first statement. + // + // And lastly, we need the view and trigger to be temporary because + // otherwise they can't access the temporary table: + // 6. https://sqlite.org/forum/info/4f998eeec510bceee69404541e5c9ca0a301868d59ec7c3486ecb8084309bba1 + // "Triggers in any schema other than temp may only access objects in their own schema. However, triggers in temp may access any object by name, even cross-schema." + self.conn.execute_batch( + r#" + -- PRAGMA temp_store = 'MEMORY'; + CREATE TEMPORARY TABLE IF NOT EXISTS temp.current_task(task_id INTEGER); + + CREATE TEMPORARY TRIGGER IF NOT EXISTS enqueue_task_insert_trigger INSTEAD OF + INSERT + ON enqueue_task + BEGIN + INSERT INTO + tasks(chain_id, blockhash, proofsys_id, prover) + VALUES + ( + new.chain_id, + new.blockhash, + new.proofsys_id, + new.prover + ); + + INSERT INTO + current_task + SELECT + id + FROM + tasks + WHERE + rowid = last_insert_rowid() + LIMIT + 1; + + -- Tasks are initialized at status 1000 - registered + -- timestamp is auto-filled with datetime('now'), see its field definition + INSERT INTO + task_status(task_id, status_id) + SELECT + tmp.task_id, + 1000 + FROM + current_task tmp; + + DELETE FROM + current_task; + END; + + CREATE TEMPORARY TRIGGER IF NOT EXISTS update_task_progress_trigger INSTEAD OF + INSERT + ON update_task_progress + BEGIN + INSERT INTO + current_task + SELECT + id + FROM + tasks + WHERE + chain_id = new.chain_id + AND blockhash = new.blockhash + AND proofsys_id = new.proofsys_id + LIMIT + 1; + + -- timestamp is auto-filled with datetime('now'), see its field definition + INSERT INTO + task_status(task_id, status_id) + SELECT + tmp.task_id, + new.status_id + FROM + current_task tmp + LIMIT + 1; + + INSERT + OR REPLACE INTO task_proofs + SELECT + task_id, + new.proof + FROM + current_task + WHERE + new.proof IS NOT NULL + LIMIT + 1; + + DELETE FROM + current_task; + END; + "#, + )?; + + Ok(()) + } + + pub fn enqueue_task( + &self, + EnqueueTaskParams { + chain_id, + blockhash, + proof_type, + prover, + .. + }: &EnqueueTaskParams, + ) -> TaskManagerResult> { + let mut statement = self.conn.prepare_cached( + r#" + INSERT INTO + enqueue_task( + chain_id, + blockhash, + proofsys_id, + prover + ) + VALUES + ( + :chain_id, + :blockhash, + :proofsys_id, + :prover + ); + "#, + )?; + statement.execute(named_params! { + ":chain_id": chain_id, + ":blockhash": blockhash.to_vec(), + ":proofsys_id": *proof_type as u8, + ":prover": prover, + })?; + + Ok(vec![TaskProvingStatus( + TaskStatus::Registered, + Some(prover.clone()), + Utc::now(), + )]) + } + + pub fn update_task_progress( + &self, + chain_id: ChainId, + blockhash: B256, + proof_type: ProofType, + prover: Option, + status: TaskStatus, + proof: Option<&[u8]>, + ) -> TaskManagerResult<()> { + let mut statement = self.conn.prepare_cached( + r#" + INSERT INTO + update_task_progress( + chain_id, + blockhash, + proofsys_id, + status_id, + prover, + proof + ) + VALUES + ( + :chain_id, + :blockhash, + :proofsys_id, + :status_id, + :prover, + :proof + ); + "#, + )?; + statement.execute(named_params! { + ":chain_id": chain_id, + ":blockhash": blockhash.to_vec(), + ":proofsys_id": proof_type as u8, + ":status_id": status as i32, + ":prover": prover.unwrap_or_default(), + ":proof": proof + })?; + + Ok(()) + } + + pub fn get_task_proving_status( + &self, + chain_id: ChainId, + blockhash: B256, + proof_type: ProofType, + prover: Option, + ) -> TaskManagerResult { + let mut statement = self.conn.prepare_cached( + r#" + SELECT + ts.status_id, + t.prover, + timestamp + FROM + task_status ts + LEFT JOIN tasks t ON ts.task_id = t.id + WHERE + t.chain_id = :chain_id + AND t.blockhash = :blockhash + AND t.proofsys_id = :proofsys_id + AND t.prover = :prover + ORDER BY + ts.timestamp; + "#, + )?; + let query = statement.query_map( + named_params! { + ":chain_id": chain_id, + ":blockhash": blockhash.to_vec(), + ":proofsys_id": proof_type as u8, + ":prover": prover.unwrap_or_default(), + }, + |row| { + Ok(TaskProvingStatus( + TaskStatus::from(row.get::<_, i32>(0)?), + Some(row.get::<_, String>(1)?), + row.get::<_, DateTime>(2)?, + )) + }, + )?; + + Ok(query.collect::, _>>()?) + } + + pub fn get_task_proving_status_by_id( + &self, + task_id: u64, + ) -> TaskManagerResult { + let mut statement = self.conn.prepare_cached( + r#" + SELECT + ts.status_id, + t.prover, + timestamp + FROM + task_status ts + LEFT JOIN tasks t ON ts.task_id = t.id + WHERE + t.id = :task_id + ORDER BY + ts.timestamp; + "#, + )?; + let query = statement.query_map( + named_params! { + ":task_id": task_id, + }, + |row| { + Ok(TaskProvingStatus( + TaskStatus::from(row.get::<_, i32>(0)?), + Some(row.get::<_, String>(1)?), + row.get::<_, DateTime>(2)?, + )) + }, + )?; + + Ok(query.collect::, _>>()?) + } + + pub fn get_task_proof( + &self, + chain_id: ChainId, + blockhash: B256, + proof_type: ProofType, + prover: Option, + ) -> TaskManagerResult> { + let mut statement = self.conn.prepare_cached( + r#" + SELECT + proof + FROM + task_proofs tp + LEFT JOIN tasks t ON tp.task_id = t.id + WHERE + t.chain_id = :chain_id + AND t.prover = :prover + AND t.blockhash = :blockhash + AND t.proofsys_id = :proofsys_id + LIMIT + 1; + "#, + )?; + let query = statement.query_row( + named_params! { + ":chain_id": chain_id, + ":blockhash": blockhash.to_vec(), + ":proofsys_id": proof_type as u8, + ":prover": prover.unwrap_or_default(), + }, + |row| row.get(0), + )?; + + Ok(query) + } + + pub fn get_task_proof_by_id(&self, task_id: u64) -> TaskManagerResult> { + let mut statement = self.conn.prepare_cached( + r#" + SELECT + proof + FROM + task_proofs tp + LEFT JOIN tasks t ON tp.task_id = t.id + WHERE + t.id = :task_id + LIMIT + 1; + "#, + )?; + let query = statement.query_row( + named_params! { + ":task_id": task_id, + }, + |row| row.get(0), + )?; + + Ok(query) + } + + pub fn get_db_size(&self) -> TaskManagerResult<(usize, Vec<(String, usize)>)> { + let mut statement = self.conn.prepare_cached( + r#" + SELECT + name as table_name, + SUM(pgsize) as table_size + FROM + dbstat + GROUP BY + table_name + ORDER BY + SUM(pgsize) DESC; + "#, + )?; + let query = statement.query_map([], |row| Ok((row.get(0)?, row.get(1)?)))?; + let details = query.collect::, _>>()?; + let total = details.iter().fold(0, |acc, (_, size)| acc + size); + + Ok((total, details)) + } + + pub fn prune_db(&self) -> TaskManagerResult<()> { + let mut statement = self.conn.prepare_cached( + r#" + DELETE FROM + tasks; + + DELETE FROM + task_proofs; + + DELETE FROM + task_status; + "#, + )?; + statement.execute([])?; + + Ok(()) + } +} + +#[async_trait::async_trait] +impl TaskManager for SqliteTaskManager { + fn new(opts: &TaskManagerOpts) -> Self { + static INIT: Once = Once::new(); + static mut CONN: Option>> = None; + INIT.call_once(|| { + unsafe { + CONN = Some(Arc::new(Mutex::new({ + let db = TaskDb::open_or_create(&opts.sqlite_file).unwrap(); + db.manage().unwrap(); + db + }))) + }; + }); + Self { + arc_task_db: unsafe { CONN.clone().unwrap() }, + } + } + + async fn enqueue_task( + &mut self, + params: &EnqueueTaskParams, + ) -> Result, TaskManagerError> { + let task_db = self.arc_task_db.lock().await; + task_db.enqueue_task(params) + } + + async fn update_task_progress( + &mut self, + chain_id: ChainId, + blockhash: B256, + proof_type: ProofType, + prover: Option, + status: TaskStatus, + proof: Option<&[u8]>, + ) -> TaskManagerResult<()> { + let task_db = self.arc_task_db.lock().await; + task_db.update_task_progress(chain_id, blockhash, proof_type, prover, status, proof) + } + + /// Returns the latest triplet (submitter or fulfiller, status, last update time) + async fn get_task_proving_status( + &mut self, + chain_id: ChainId, + blockhash: B256, + proof_type: ProofType, + prover: Option, + ) -> TaskManagerResult { + let task_db = self.arc_task_db.lock().await; + task_db.get_task_proving_status(chain_id, blockhash, proof_type, prover) + } + + /// Returns the latest triplet (submitter or fulfiller, status, last update time) + async fn get_task_proving_status_by_id( + &mut self, + task_id: u64, + ) -> TaskManagerResult { + let task_db = self.arc_task_db.lock().await; + task_db.get_task_proving_status_by_id(task_id) + } + + async fn get_task_proof( + &mut self, + chain_id: ChainId, + blockhash: B256, + proof_type: ProofType, + prover: Option, + ) -> TaskManagerResult> { + let task_db = self.arc_task_db.lock().await; + task_db.get_task_proof(chain_id, blockhash, proof_type, prover) + } + + async fn get_task_proof_by_id(&mut self, task_id: u64) -> TaskManagerResult> { + let task_db = self.arc_task_db.lock().await; + task_db.get_task_proof_by_id(task_id) + } + + /// Returns the total and detailed database size + async fn get_db_size(&mut self) -> TaskManagerResult<(usize, Vec<(String, usize)>)> { + let task_db = self.arc_task_db.lock().await; + task_db.get_db_size() + } + + async fn prune_db(&mut self) -> TaskManagerResult<()> { + let task_db = self.arc_task_db.lock().await; + task_db.prune_db() + } +} + +#[cfg(test)] +mod tests { + // We only test private functions here. + // Public API will be tested in a dedicated tests folder + + use super::*; + use tempfile::tempdir; + + #[test] + fn error_on_missing() { + let dir = tempdir().unwrap(); + let file = dir.path().join("db.sqlite"); + assert!(TaskDb::open(&file).is_err()); + } + + #[test] + fn ensure_exclusive() { + let dir = tempdir().unwrap(); + let file = dir.path().join("db.sqlite"); + + let _db = TaskDb::create(&file).unwrap(); + assert!(TaskDb::open(&file).is_err()); + std::fs::remove_file(&file).unwrap(); + } + + #[test] + fn ensure_unicity() { + let dir = tempdir().unwrap(); + let file = dir.path().join("db.sqlite"); + + let _db = TaskDb::create(&file).unwrap(); + assert!(TaskDb::create(&file).is_err()); + std::fs::remove_file(&file).unwrap(); + } +} diff --git a/task_manager/src/lib.rs b/task_manager/src/lib.rs new file mode 100644 index 000000000..42275c20f --- /dev/null +++ b/task_manager/src/lib.rs @@ -0,0 +1,383 @@ +use std::{ + io::{Error as IOError, ErrorKind as IOErrorKind}, + path::PathBuf, +}; + +use chrono::{DateTime, Utc}; +use num_enum::{FromPrimitive, IntoPrimitive}; +use raiko_core::interfaces::ProofType; +use raiko_lib::primitives::{ChainId, B256}; +use rusqlite::Error as SqlError; +use serde::Serialize; + +use crate::{adv_sqlite::SqliteTaskManager, mem_db::InMemoryTaskManager}; + +mod adv_sqlite; +mod mem_db; + +// Types +// ---------------------------------------------------------------- +#[derive(PartialEq, Debug, thiserror::Error)] +pub enum TaskManagerError { + #[error("IO Error {0}")] + IOError(IOErrorKind), + #[error("SQL Error {0}")] + SqlError(String), + #[error("Anyhow error: {0}")] + Anyhow(String), +} + +pub type TaskManagerResult = Result; + +impl From for TaskManagerError { + fn from(error: IOError) -> TaskManagerError { + TaskManagerError::IOError(error.kind()) + } +} + +impl From for TaskManagerError { + fn from(error: SqlError) -> TaskManagerError { + TaskManagerError::SqlError(error.to_string()) + } +} + +impl From for TaskManagerError { + fn from(error: serde_json::Error) -> TaskManagerError { + TaskManagerError::SqlError(error.to_string()) + } +} + +impl From for TaskManagerError { + fn from(value: anyhow::Error) -> Self { + TaskManagerError::Anyhow(value.to_string()) + } +} + +#[allow(non_camel_case_types)] +#[rustfmt::skip] +#[derive(PartialEq, Debug, Copy, Clone, IntoPrimitive, FromPrimitive, Serialize)] +#[repr(i32)] +pub enum TaskStatus { + Success = 0, + Registered = 1000, + WorkInProgress = 2000, + ProofFailure_Generic = -1000, + ProofFailure_OutOfMemory = -1100, + NetworkFailure = -2000, + Cancelled = -3000, + Cancelled_NeverStarted = -3100, + Cancelled_Aborted = -3200, + CancellationInProgress = -3210, + InvalidOrUnsupportedBlock = -4000, + UnspecifiedFailureReason = -9999, + #[num_enum(default)] + SqlDbCorruption = -99999, +} + +#[derive(Debug, Clone, Default)] +pub struct EnqueueTaskParams { + pub chain_id: ChainId, + pub blockhash: B256, + pub proof_type: ProofType, + pub prover: String, + pub block_number: u64, +} + +#[derive(Debug, Clone, Serialize)] +pub struct TaskDescriptor { + pub chain_id: ChainId, + pub blockhash: B256, + pub proof_system: ProofType, + pub prover: String, +} + +impl TaskDescriptor { + pub fn to_vec(self) -> Vec { + self.into() + } +} + +impl From for Vec { + fn from(val: TaskDescriptor) -> Self { + let mut v = Vec::new(); + v.extend_from_slice(&val.chain_id.to_be_bytes()); + v.extend_from_slice(val.blockhash.as_ref()); + v.extend_from_slice(&(val.proof_system as u8).to_be_bytes()); + v.extend_from_slice(val.prover.as_bytes()); + v + } +} + +// Taskkey from EnqueueTaskParams +impl From<&EnqueueTaskParams> for TaskDescriptor { + fn from(params: &EnqueueTaskParams) -> TaskDescriptor { + TaskDescriptor { + chain_id: params.chain_id, + blockhash: params.blockhash, + proof_system: params.proof_type, + prover: params.prover.clone(), + } + } +} + +impl From<(ChainId, B256, ProofType, Option)> for TaskDescriptor { + fn from( + (chain_id, blockhash, proof_system, prover): (ChainId, B256, ProofType, Option), + ) -> Self { + TaskDescriptor { + chain_id, + blockhash, + proof_system, + prover: prover.unwrap_or_default(), + } + } +} + +#[derive(Debug, Clone)] +pub struct TaskProvingStatus(pub TaskStatus, pub Option, pub DateTime); + +pub type TaskProvingStatusRecords = Vec; + +#[derive(Debug, Clone)] +pub struct TaskManagerOpts { + pub sqlite_file: PathBuf, + pub max_db_size: usize, +} + +#[async_trait::async_trait] +pub trait TaskManager { + /// new a task manager + fn new(opts: &TaskManagerOpts) -> Self; + + /// enqueue_task + async fn enqueue_task( + &mut self, + request: &EnqueueTaskParams, + ) -> TaskManagerResult; + + /// Update the task progress + async fn update_task_progress( + &mut self, + chain_id: ChainId, + blockhash: B256, + proof_system: ProofType, + prover: Option, + status: TaskStatus, + proof: Option<&[u8]>, + ) -> TaskManagerResult<()>; + + /// Returns the latest triplet (submitter or fulfiller, status, last update time) + async fn get_task_proving_status( + &mut self, + chain_id: ChainId, + blockhash: B256, + proof_system: ProofType, + prover: Option, + ) -> TaskManagerResult; + + /// Returns the latest triplet (submitter or fulfiller, status, last update time) + async fn get_task_proving_status_by_id( + &mut self, + task_id: u64, + ) -> TaskManagerResult; + + /// Returns the proof for the given task + async fn get_task_proof( + &mut self, + chain_id: ChainId, + blockhash: B256, + proof_system: ProofType, + prover: Option, + ) -> TaskManagerResult>; + + async fn get_task_proof_by_id(&mut self, task_id: u64) -> TaskManagerResult>; + + /// Returns the total and detailed database size + async fn get_db_size(&mut self) -> TaskManagerResult<(usize, Vec<(String, usize)>)>; + + /// Prune old tasks + async fn prune_db(&mut self) -> TaskManagerResult<()>; +} + +pub fn ensure(expression: bool, message: &str) -> TaskManagerResult<()> { + if !expression { + return Err(TaskManagerError::Anyhow(message.to_string())); + } + Ok(()) +} + +enum TaskManagerInstance { + InMemory(InMemoryTaskManager), + Sqlite(SqliteTaskManager), +} + +pub struct TaskManagerWrapper { + manager: TaskManagerInstance, +} + +#[async_trait::async_trait] +impl TaskManager for TaskManagerWrapper { + fn new(opts: &TaskManagerOpts) -> Self { + let manager = if cfg!(feature = "sqlite") { + TaskManagerInstance::Sqlite(SqliteTaskManager::new(opts)) + } else { + TaskManagerInstance::InMemory(InMemoryTaskManager::new(opts)) + }; + + Self { manager } + } + + async fn enqueue_task( + &mut self, + request: &EnqueueTaskParams, + ) -> TaskManagerResult { + match &mut self.manager { + TaskManagerInstance::InMemory(ref mut manager) => manager.enqueue_task(request).await, + TaskManagerInstance::Sqlite(ref mut manager) => manager.enqueue_task(request).await, + } + } + + async fn update_task_progress( + &mut self, + chain_id: ChainId, + blockhash: B256, + proof_system: ProofType, + prover: Option, + status: TaskStatus, + proof: Option<&[u8]>, + ) -> TaskManagerResult<()> { + match &mut self.manager { + TaskManagerInstance::InMemory(ref mut manager) => { + manager + .update_task_progress(chain_id, blockhash, proof_system, prover, status, proof) + .await + } + TaskManagerInstance::Sqlite(ref mut manager) => { + manager + .update_task_progress(chain_id, blockhash, proof_system, prover, status, proof) + .await + } + } + } + + async fn get_task_proving_status( + &mut self, + chain_id: ChainId, + blockhash: B256, + proof_system: ProofType, + prover: Option, + ) -> TaskManagerResult { + match &mut self.manager { + TaskManagerInstance::InMemory(ref mut manager) => { + manager + .get_task_proving_status(chain_id, blockhash, proof_system, prover) + .await + } + TaskManagerInstance::Sqlite(ref mut manager) => { + manager + .get_task_proving_status(chain_id, blockhash, proof_system, prover) + .await + } + } + } + + async fn get_task_proving_status_by_id( + &mut self, + task_id: u64, + ) -> TaskManagerResult { + match &mut self.manager { + TaskManagerInstance::InMemory(ref mut manager) => { + manager.get_task_proving_status_by_id(task_id).await + } + TaskManagerInstance::Sqlite(ref mut manager) => { + manager.get_task_proving_status_by_id(task_id).await + } + } + } + + async fn get_task_proof( + &mut self, + chain_id: ChainId, + blockhash: B256, + proof_system: ProofType, + prover: Option, + ) -> TaskManagerResult> { + match &mut self.manager { + TaskManagerInstance::InMemory(ref mut manager) => { + manager + .get_task_proof(chain_id, blockhash, proof_system, prover) + .await + } + TaskManagerInstance::Sqlite(ref mut manager) => { + manager + .get_task_proof(chain_id, blockhash, proof_system, prover) + .await + } + } + } + + async fn get_task_proof_by_id(&mut self, task_id: u64) -> TaskManagerResult> { + match &mut self.manager { + TaskManagerInstance::InMemory(ref mut manager) => { + manager.get_task_proof_by_id(task_id).await + } + TaskManagerInstance::Sqlite(ref mut manager) => { + manager.get_task_proof_by_id(task_id).await + } + } + } + + async fn get_db_size(&mut self) -> TaskManagerResult<(usize, Vec<(String, usize)>)> { + match &mut self.manager { + TaskManagerInstance::InMemory(ref mut manager) => manager.get_db_size().await, + TaskManagerInstance::Sqlite(ref mut manager) => manager.get_db_size().await, + } + } + + async fn prune_db(&mut self) -> TaskManagerResult<()> { + match &mut self.manager { + TaskManagerInstance::InMemory(ref mut manager) => manager.prune_db().await, + TaskManagerInstance::Sqlite(ref mut manager) => manager.prune_db().await, + } + } +} + +pub fn get_task_manager(opts: &TaskManagerOpts) -> TaskManagerWrapper { + TaskManagerWrapper::new(opts) +} + +#[cfg(test)] +mod test { + use super::*; + use std::path::Path; + + #[tokio::test] + async fn test_new_taskmanager() { + let sqlite_file: &Path = Path::new("test.db"); + // remove existed one + if sqlite_file.exists() { + std::fs::remove_file(sqlite_file).unwrap(); + } + + let opts = TaskManagerOpts { + sqlite_file: sqlite_file.to_path_buf(), + max_db_size: 1024 * 1024, + }; + let mut task_manager = get_task_manager(&opts); + + assert_eq!( + task_manager + .enqueue_task(&EnqueueTaskParams { + chain_id: 1, + blockhash: B256::default(), + proof_type: ProofType::Native, + prover: "test".to_string(), + block_number: 1 + }) + .await + .unwrap() + .len(), + 1 + ); + } +} diff --git a/task_manager/src/mem_db.rs b/task_manager/src/mem_db.rs new file mode 100644 index 000000000..e413b7ca0 --- /dev/null +++ b/task_manager/src/mem_db.rs @@ -0,0 +1,310 @@ +// Raiko +// Copyright (c) 2024 Taiko Labs +// Licensed and distributed under either of +// * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). +// * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). +// at your option. This file may not be copied, modified, or distributed except according to those terms. + +// Imports +// ---------------------------------------------------------------- +use std::{ + collections::HashMap, + sync::{Arc, Once}, +}; + +use crate::{ + ensure, EnqueueTaskParams, TaskDescriptor, TaskManager, TaskManagerError, TaskManagerOpts, + TaskManagerResult, TaskProvingStatus, TaskProvingStatusRecords, TaskStatus, +}; + +use chrono::Utc; +use raiko_core::interfaces::ProofType; +use raiko_lib::primitives::{keccak::keccak, ChainId, B256}; +use tokio::sync::Mutex; +use tracing::{debug, info}; + +#[derive(Debug)] +pub struct InMemoryTaskManager { + db: Arc>, +} + +#[derive(Debug)] +pub struct InMemoryTaskDb { + enqueue_task: HashMap, + task_id_desc: HashMap, + task_id: u64, +} + +impl InMemoryTaskDb { + fn new() -> InMemoryTaskDb { + InMemoryTaskDb { + enqueue_task: HashMap::new(), + task_id_desc: HashMap::new(), + task_id: 0, + } + } + + fn enqueue_task(&mut self, params: &EnqueueTaskParams) { + let key: B256 = keccak(TaskDescriptor::from(params).to_vec()).into(); + let task_status = TaskProvingStatus( + TaskStatus::Registered, + Some(params.prover.clone()), + Utc::now(), + ); + + match self.enqueue_task.get(&key) { + Some(task_proving_records) => { + debug!( + "Task already exists: {:?}", + task_proving_records.last().unwrap().0 + ); + } // do nothing + None => { + info!("Enqueue new task: {:?}", params); + self.enqueue_task.insert(key, vec![task_status]); + self.task_id_desc.insert(self.task_id, key); + self.task_id += 1; + } + } + } + + fn update_task_progress( + &mut self, + chain_id: ChainId, + blockhash: B256, + proof_system: ProofType, + prover: Option, + status: TaskStatus, + proof: Option<&[u8]>, + ) -> TaskManagerResult<()> { + let key: B256 = keccak( + TaskDescriptor::from((chain_id, blockhash, proof_system, prover.clone())).to_vec(), + ) + .into(); + ensure(self.enqueue_task.contains_key(&key), "no task found")?; + + let task_proving_records = self.enqueue_task.get(&key).unwrap(); + let task_status = task_proving_records.last().unwrap().0; + if status != task_status { + let new_records = task_proving_records + .iter() + .cloned() + .chain(std::iter::once(TaskProvingStatus( + status, + proof.map(hex::encode), + Utc::now(), + ))) + .collect(); + self.enqueue_task.insert(key, new_records); + } + Ok(()) + } + + fn get_task_proving_status( + &mut self, + chain_id: ChainId, + blockhash: B256, + proof_system: ProofType, + prover: Option, + ) -> TaskManagerResult { + let key: B256 = keccak( + TaskDescriptor::from((chain_id, blockhash, proof_system, prover.clone())).to_vec(), + ) + .into(); + + match self.enqueue_task.get(&key) { + Some(proving_status_records) => Ok(proving_status_records.clone()), + None => Ok(vec![]), + } + } + + fn get_task_proving_status_by_id( + &mut self, + task_id: u64, + ) -> TaskManagerResult { + ensure(self.task_id_desc.contains_key(&task_id), "no task found")?; + let key = self.task_id_desc.get(&task_id).unwrap(); + let task_status = self.enqueue_task.get(key).unwrap(); + Ok(task_status.clone()) + } + + fn get_task_proof( + &mut self, + chain_id: ChainId, + blockhash: B256, + proof_system: ProofType, + prover: Option, + ) -> TaskManagerResult> { + let key: B256 = keccak( + TaskDescriptor::from((chain_id, blockhash, proof_system, prover.clone())).to_vec(), + ) + .into(); + ensure(self.enqueue_task.contains_key(&key), "no task found")?; + + let proving_status_records = self.enqueue_task.get(&key).unwrap(); + let task_status = proving_status_records.last().unwrap(); + if task_status.0 == TaskStatus::Success { + let proof = task_status.1.clone().unwrap(); + Ok(hex::decode(proof).unwrap()) + } else { + Err(TaskManagerError::SqlError("working in process".to_owned())) + } + } + + fn get_task_proof_by_id(&mut self, task_id: u64) -> TaskManagerResult> { + ensure(self.task_id_desc.contains_key(&task_id), "no task found")?; + let key = self.task_id_desc.get(&task_id).unwrap(); + let task_records = self.enqueue_task.get(key).unwrap(); + let task_status = task_records.last().unwrap(); + if task_status.0 == TaskStatus::Success { + let proof = task_status.1.clone().unwrap(); + Ok(hex::decode(proof).unwrap()) + } else { + Err(TaskManagerError::SqlError("working in process".to_owned())) + } + } + + fn size(&mut self) -> TaskManagerResult<(usize, Vec<(String, usize)>)> { + Ok((self.enqueue_task.len() + self.task_id_desc.len(), vec![])) + } + + fn prune(&mut self) -> TaskManagerResult<()> { + Ok(()) + } +} + +#[async_trait::async_trait] +impl TaskManager for InMemoryTaskManager { + fn new(_opts: &TaskManagerOpts) -> Self { + static INIT: Once = Once::new(); + static mut SHARED_TASK_MANAGER: Option>> = None; + + INIT.call_once(|| { + let task_manager: Arc> = + Arc::new(Mutex::new(InMemoryTaskDb::new())); + unsafe { + SHARED_TASK_MANAGER = Some(Arc::clone(&task_manager)); + } + }); + + InMemoryTaskManager { + db: unsafe { SHARED_TASK_MANAGER.clone().unwrap() }, + } + } + + async fn enqueue_task( + &mut self, + params: &EnqueueTaskParams, + ) -> TaskManagerResult { + let mut db = self.db.lock().await; + let status = db.get_task_proving_status( + params.chain_id, + params.blockhash, + params.proof_type, + Some(params.prover.to_string()), + )?; + if status.is_empty() { + db.enqueue_task(params); + db.get_task_proving_status( + params.chain_id, + params.blockhash, + params.proof_type, + Some(params.prover.clone()), + ) + } else { + Ok(status) + } + } + + async fn update_task_progress( + &mut self, + chain_id: ChainId, + blockhash: B256, + proof_system: ProofType, + prover: Option, + status: TaskStatus, + proof: Option<&[u8]>, + ) -> TaskManagerResult<()> { + let mut db = self.db.lock().await; + db.update_task_progress(chain_id, blockhash, proof_system, prover, status, proof) + } + + /// Returns the latest triplet (submitter or fulfiller, status, last update time) + async fn get_task_proving_status( + &mut self, + chain_id: ChainId, + blockhash: B256, + proof_system: ProofType, + prover: Option, + ) -> TaskManagerResult { + let mut db = self.db.lock().await; + db.get_task_proving_status(chain_id, blockhash, proof_system, prover) + } + + /// Returns the latest triplet (submitter or fulfiller, status, last update time) + async fn get_task_proving_status_by_id( + &mut self, + task_id: u64, + ) -> TaskManagerResult { + let mut db = self.db.lock().await; + db.get_task_proving_status_by_id(task_id) + } + + async fn get_task_proof( + &mut self, + chain_id: ChainId, + blockhash: B256, + proof_system: ProofType, + prover: Option, + ) -> TaskManagerResult> { + let mut db = self.db.lock().await; + db.get_task_proof(chain_id, blockhash, proof_system, prover) + } + + async fn get_task_proof_by_id(&mut self, task_id: u64) -> TaskManagerResult> { + let mut db = self.db.lock().await; + db.get_task_proof_by_id(task_id) + } + + /// Returns the total and detailed database size + async fn get_db_size(&mut self) -> TaskManagerResult<(usize, Vec<(String, usize)>)> { + let mut db = self.db.lock().await; + db.size() + } + + async fn prune_db(&mut self) -> TaskManagerResult<()> { + let mut db = self.db.lock().await; + db.prune() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ProofType; + + #[test] + fn test_db_open() { + assert!(InMemoryTaskDb::new().size().is_ok()); + } + + #[test] + fn test_db_enqueue() { + let mut db = InMemoryTaskDb::new(); + let params = EnqueueTaskParams { + chain_id: 1, + blockhash: B256::default(), + proof_type: ProofType::Native, + prover: "0x1234".to_owned(), + ..Default::default() + }; + db.enqueue_task(¶ms); + let status = db.get_task_proving_status( + params.chain_id, + params.blockhash, + params.proof_type, + Some(params.prover.clone()), + ); + assert!(status.is_ok()); + } +} diff --git a/task_manager/tests/main.rs b/task_manager/tests/main.rs new file mode 100644 index 000000000..e5b85addd --- /dev/null +++ b/task_manager/tests/main.rs @@ -0,0 +1,474 @@ +// Raiko +// Copyright (c) 2024 Taiko Labs +// Licensed and distributed under either of +// * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). +// * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). +// at your option. This file may not be copied, modified, or distributed except according to those terms. + +#[cfg(test)] +mod tests { + use std::{collections::HashMap, time::Duration}; + + use alloy_primitives::Address; + use raiko_core::interfaces::{ProofRequest, ProofType}; + use rand::{Rng, SeedableRng}; + use rand_chacha::ChaCha8Rng; + + use raiko_lib::{input::BlobProofType, primitives::B256}; + use raiko_task_manager::{ + get_task_manager, EnqueueTaskParams, TaskManager, TaskManagerOpts, TaskStatus, + }; + + fn create_random_task(rng: &mut ChaCha8Rng) -> (u64, B256, ProofRequest) { + let chain_id = 100; + let proof_type = match rng.gen_range(0..4) { + 0 => ProofType::Native, + 1 => ProofType::Sgx, + 2 => ProofType::Sp1, + _ => ProofType::Risc0, + }; + let block_number = rng.gen_range(1..4_000_000); + let block_hash = B256::random(); + let graffiti = B256::random(); + let prover_args = HashMap::new(); + let prover = Address::random(); + + ( + chain_id, + block_hash, + ProofRequest { + block_number, + network: "network".to_string(), + l1_network: "l1_network".to_string(), + graffiti, + prover, + proof_type, + prover_args, + blob_proof_type: BlobProofType::ProofOfEquivalence, + }, + ) + } + + #[tokio::test] + async fn test_enqueue_task() { + // // Materialized local DB + // let dir = std::env::current_dir().unwrap().join("tests"); + // let file = dir.as_path().join("test_enqueue_task.sqlite"); + // if file.exists() { + // std::fs::remove_file(&file).unwrap() + // }; + + // temp dir DB + use tempfile::tempdir; + let dir = tempdir().unwrap(); + let file = dir.path().join("test_enqueue_task.sqlite"); + + let mut tama = get_task_manager(&TaskManagerOpts { + sqlite_file: file, + max_db_size: 1_000_000, + }); + + let (chain_id, block_hash, request) = + create_random_task(&mut ChaCha8Rng::seed_from_u64(123)); + tama.enqueue_task(&EnqueueTaskParams { + chain_id, + blockhash: block_hash, + proof_type: request.proof_type, + prover: request.prover.to_string(), + block_number: request.block_number, + }) + .await + .unwrap(); + } + + #[tokio::test] + async fn test_update_query_tasks_progress() { + // Materialized local DB + let dir = std::env::current_dir().unwrap().join("tests"); + let file = dir + .as_path() + .join("test_update_query_tasks_progress.sqlite"); + if file.exists() { + std::fs::remove_file(&file).unwrap() + }; + + // // temp dir DB + // use tempfile::tempdir; + // let dir = tempdir().unwrap(); + // let file = dir.path().join("test_update_task_progress.sqlite"); + + let mut tama = get_task_manager(&TaskManagerOpts { + sqlite_file: file, + max_db_size: 1_000_000, + }); + + let mut rng = ChaCha8Rng::seed_from_u64(123); + let mut tasks = vec![]; + + for _ in 0..5 { + let (chain_id, block_hash, request) = create_random_task(&mut rng); + + tama.enqueue_task(&EnqueueTaskParams { + chain_id, + blockhash: block_hash, + proof_type: request.proof_type, + prover: request.prover.to_string(), + block_number: request.block_number, + }) + .await + .unwrap(); + + let task_status = tama + .get_task_proving_status( + chain_id, + block_hash, + request.proof_type, + Some(request.prover.to_string()), + ) + .await + .unwrap(); + assert_eq!(task_status.len(), 1); + let status = task_status + .first() + .expect("Already confirmed there is exactly 1 element"); + assert_eq!(status.0, TaskStatus::Registered); + + tasks.push(( + chain_id, + block_hash, + request.block_number, + request.proof_type, + request.prover, + )); + } + + std::thread::sleep(Duration::from_millis(1)); + + { + let task_status = tama + .get_task_proving_status( + tasks[0].0, + tasks[0].1, + tasks[0].3, + Some(tasks[0].4.to_string()), + ) + .await + .unwrap(); + println!("{task_status:?}"); + tama.update_task_progress( + tasks[0].0, + tasks[0].1, + tasks[0].3, + Some(tasks[0].4.to_string()), + TaskStatus::Cancelled_NeverStarted, + None, + ) + .await + .unwrap(); + + let task_status = tama + .get_task_proving_status( + tasks[0].0, + tasks[0].1, + tasks[0].3, + Some(tasks[0].4.to_string()), + ) + .await + .unwrap(); + println!("{task_status:?}"); + assert_eq!(task_status.len(), 2); + assert_eq!(task_status[1].0, TaskStatus::Cancelled_NeverStarted); + assert_eq!(task_status[0].0, TaskStatus::Registered); + } + // ----------------------- + { + tama.update_task_progress( + tasks[1].0, + tasks[1].1, + tasks[1].3, + Some(tasks[1].4.to_string()), + TaskStatus::WorkInProgress, + None, + ) + .await + .unwrap(); + + { + let task_status = tama + .get_task_proving_status( + tasks[1].0, + tasks[1].1, + tasks[1].3, + Some(tasks[1].4.to_string()), + ) + .await + .unwrap(); + assert_eq!(task_status.len(), 2); + assert_eq!(task_status[1].0, TaskStatus::WorkInProgress); + assert_eq!(task_status[0].0, TaskStatus::Registered); + } + + std::thread::sleep(Duration::from_millis(1)); + + tama.update_task_progress( + tasks[1].0, + tasks[1].1, + tasks[1].3, + Some(tasks[1].4.to_string()), + TaskStatus::CancellationInProgress, + None, + ) + .await + .unwrap(); + + { + let task_status = tama + .get_task_proving_status( + tasks[1].0, + tasks[1].1, + tasks[1].3, + Some(tasks[1].4.to_string()), + ) + .await + .unwrap(); + assert_eq!(task_status.len(), 3); + assert_eq!(task_status[2].0, TaskStatus::CancellationInProgress); + assert_eq!(task_status[1].0, TaskStatus::WorkInProgress); + assert_eq!(task_status[0].0, TaskStatus::Registered); + } + + std::thread::sleep(Duration::from_millis(1)); + + tama.update_task_progress( + tasks[1].0, + tasks[1].1, + tasks[1].3, + Some(tasks[1].4.to_string()), + TaskStatus::Cancelled, + None, + ) + .await + .unwrap(); + + { + let task_status = tama + .get_task_proving_status( + tasks[1].0, + tasks[1].1, + tasks[1].3, + Some(tasks[1].4.to_string()), + ) + .await + .unwrap(); + assert_eq!(task_status.len(), 4); + assert_eq!(task_status[3].0, TaskStatus::Cancelled); + assert_eq!(task_status[2].0, TaskStatus::CancellationInProgress); + assert_eq!(task_status[1].0, TaskStatus::WorkInProgress); + assert_eq!(task_status[0].0, TaskStatus::Registered); + } + } + + // ----------------------- + { + tama.update_task_progress( + tasks[2].0, + tasks[2].1, + tasks[2].3, + Some(tasks[2].4.to_string()), + TaskStatus::WorkInProgress, + None, + ) + .await + .unwrap(); + + { + let task_status = tama + .get_task_proving_status( + tasks[2].0, + tasks[2].1, + tasks[2].3, + Some(tasks[2].4.to_string()), + ) + .await + .unwrap(); + assert_eq!(task_status.len(), 2); + assert_eq!(task_status[1].0, TaskStatus::WorkInProgress); + assert_eq!(task_status[0].0, TaskStatus::Registered); + } + + std::thread::sleep(Duration::from_millis(1)); + + let proof: Vec<_> = (&mut rng).gen_iter::().take(128).collect(); + tama.update_task_progress( + tasks[2].0, + tasks[2].1, + tasks[2].3, + Some(tasks[2].4.to_string()), + TaskStatus::Success, + Some(&proof), + ) + .await + .unwrap(); + + { + let task_status = tama + .get_task_proving_status( + tasks[2].0, + tasks[2].1, + tasks[2].3, + Some(tasks[2].4.to_string()), + ) + .await + .unwrap(); + assert_eq!(task_status.len(), 3); + assert_eq!(task_status[2].0, TaskStatus::Success); + assert_eq!(task_status[1].0, TaskStatus::WorkInProgress); + assert_eq!(task_status[0].0, TaskStatus::Registered); + } + + assert_eq!( + proof, + tama.get_task_proof( + tasks[2].0, + tasks[2].1, + tasks[2].3, + Some(tasks[2].4.to_string()) + ) + .await + .unwrap() + ); + } + + // ----------------------- + { + tama.update_task_progress( + tasks[3].0, + tasks[3].1, + tasks[3].3, + Some(tasks[3].4.to_string()), + TaskStatus::WorkInProgress, + None, + ) + .await + .unwrap(); + + { + let task_status = tama + .get_task_proving_status( + tasks[3].0, + tasks[3].1, + tasks[3].3, + Some(tasks[3].4.to_string()), + ) + .await + .unwrap(); + assert_eq!(task_status.len(), 2); + assert_eq!(task_status[1].0, TaskStatus::WorkInProgress); + assert_eq!(task_status[0].0, TaskStatus::Registered); + } + + std::thread::sleep(Duration::from_millis(1)); + + tama.update_task_progress( + tasks[3].0, + tasks[3].1, + tasks[3].3, + Some(tasks[3].4.to_string()), + TaskStatus::NetworkFailure, + None, + ) + .await + .unwrap(); + + { + let task_status = tama + .get_task_proving_status( + tasks[3].0, + tasks[3].1, + tasks[3].3, + Some(tasks[3].4.to_string()), + ) + .await + .unwrap(); + assert_eq!(task_status.len(), 3); + assert_eq!(task_status[2].0, TaskStatus::NetworkFailure); + assert_eq!(task_status[1].0, TaskStatus::WorkInProgress); + assert_eq!(task_status[0].0, TaskStatus::Registered); + } + + std::thread::sleep(Duration::from_millis(1)); + + tama.update_task_progress( + tasks[3].0, + tasks[3].1, + tasks[3].3, + Some(tasks[3].4.to_string()), + TaskStatus::WorkInProgress, + None, + ) + .await + .unwrap(); + + { + let task_status = tama + .get_task_proving_status( + tasks[3].0, + tasks[3].1, + tasks[3].3, + Some(tasks[3].4.to_string()), + ) + .await + .unwrap(); + assert_eq!(task_status.len(), 4); + assert_eq!(task_status[3].0, TaskStatus::WorkInProgress); + assert_eq!(task_status[2].0, TaskStatus::NetworkFailure); + assert_eq!(task_status[1].0, TaskStatus::WorkInProgress); + assert_eq!(task_status[0].0, TaskStatus::Registered); + } + + std::thread::sleep(Duration::from_millis(1)); + + let proof: Vec<_> = (&mut rng).gen_iter::().take(128).collect(); + tama.update_task_progress( + tasks[3].0, + tasks[3].1, + tasks[3].3, + Some(tasks[3].4.to_string()), + TaskStatus::Success, + Some(proof.as_slice()), + ) + .await + .unwrap(); + + { + let task_status = tama + .get_task_proving_status( + tasks[3].0, + tasks[3].1, + tasks[3].3, + Some(tasks[3].4.to_string()), + ) + .await + .unwrap(); + assert_eq!(task_status.len(), 5); + assert_eq!(task_status[4].0, TaskStatus::Success); + assert_eq!(task_status[3].0, TaskStatus::WorkInProgress); + assert_eq!(task_status[2].0, TaskStatus::NetworkFailure); + assert_eq!(task_status[1].0, TaskStatus::WorkInProgress); + assert_eq!(task_status[0].0, TaskStatus::Registered); + } + + assert_eq!( + proof, + tama.get_task_proof( + tasks[3].0, + tasks[3].1, + tasks[3].3, + Some(tasks[3].4.to_string()) + ) + .await + .unwrap() + ); + } + } +}