Skip to content

Commit

Permalink
First working version of SP1 Distributed Prover
Browse files Browse the repository at this point in the history
Optimized prototype

Remove unecessary complexity on async send/receive

Fix the commitment of shards public values

Setting the shard_batch_size to 1 and processing multiple checkpoints in the workers

Send only the first checkpoint and reexecute the runtime for the next ones

Make the worker computation stateless

Share the shard_batch_size and shard_size with workers

Make worker able to receive multiple requests

Redistribute a request when a worker fails

Reducing the size of the shard public values

Better request data structure

Keep a single instance of program and machine

Remove debugs about time duration
  • Loading branch information
Champii committed Jul 10, 2024
1 parent e7cb6a9 commit 4db0ce9
Show file tree
Hide file tree
Showing 23 changed files with 1,311 additions and 137 deletions.
184 changes: 119 additions & 65 deletions Cargo.lock

Large diffs are not rendered by default.

25 changes: 22 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,19 @@ risc0-build = { version = "0.21.0" }
risc0-binfmt = { version = "0.21.0" }

# SP1
sp1-sdk = { git = "https://github.com/succinctlabs/sp1.git", branch = "main" }
sp1-zkvm = { git = "https://github.com/succinctlabs/sp1.git", branch = "main" }
sp1-helper = { git = "https://github.com/succinctlabs/sp1.git", branch = "main" }
sp1-sdk = { git = "https://github.com/succinctlabs/sp1.git", rev = "14eb569d41d24721ffbd407d6060e202482d659c" }
sp1-zkvm = { git = "https://github.com/succinctlabs/sp1.git", rev = "14eb569d41d24721ffbd407d6060e202482d659c" }
sp1-helper = { git = "https://github.com/succinctlabs/sp1.git", rev = "14eb569d41d24721ffbd407d6060e202482d659c" }
sp1-core = { git = "https://github.com/succinctlabs/sp1.git", rev = "14eb569d41d24721ffbd407d6060e202482d659c" }


# Plonky3
p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "88ea2b866e41329817e4761429b4a5a2a9751c07" }
p3-challenger = { git = "https://github.com/Plonky3/Plonky3.git", rev = "88ea2b866e41329817e4761429b4a5a2a9751c07" }
p3-poseidon2 = { git = "https://github.com/Plonky3/Plonky3.git", rev = "88ea2b866e41329817e4761429b4a5a2a9751c07" }
p3-baby-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "88ea2b866e41329817e4761429b4a5a2a9751c07" }
p3-symmetric = { git = "https://github.com/Plonky3/Plonky3.git", rev = "88ea2b866e41329817e4761429b4a5a2a9751c07" }


# alloy
alloy-rlp = { version = "0.3.4", default-features = false }
Expand Down Expand Up @@ -188,3 +198,12 @@ revm-primitives = { git = "https://github.com/taikoxyz/revm.git", branch = "v36-
revm-precompile = { git = "https://github.com/taikoxyz/revm.git", branch = "v36-taiko" }
secp256k1 = { git = "https://github.com/CeciliaZ030/rust-secp256k1", branch = "sp1-patch" }
blst = { git = "https://github.com/CeciliaZ030/blst.git", branch = "v0.3.12-serialize" }

# Patch Plonky3 for Serialize and Deserialize of DuplexChallenger
[patch."https://github.com/Plonky3/Plonky3.git"]
p3-field = { git = "https://github.com/Champii/Plonky3.git", branch = "serde_patch" }
p3-challenger = { git = "https://github.com/Champii/Plonky3.git", branch = "serde_patch" }
p3-poseidon2 = { git = "https://github.com/Champii/Plonky3.git", branch = "serde_patch" }
p3-baby-bear = { git = "https://github.com/Champii/Plonky3.git", branch = "serde_patch" }
p3-symmetric = { git = "https://github.com/Champii/Plonky3.git", branch = "serde_patch" }

14 changes: 14 additions & 0 deletions core/src/interfaces.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ pub enum ProofType {
///
/// Uses the SP1 prover to build the block.
Sp1,
/// # Sp1Distributed
///
/// Uses the SP1 prover to build the block in a distributed way.
Sp1Distributed,
/// # Sgx
///
/// Builds the block on a SGX supported CPU to create a proof.
Expand All @@ -119,6 +123,7 @@ impl std::fmt::Display for ProofType {
f.write_str(match self {
ProofType::Native => "native",
ProofType::Sp1 => "sp1",
ProofType::Sp1Distributed => "sp1_distributed",
ProofType::Sgx => "sgx",
ProofType::Risc0 => "risc0",
})
Expand All @@ -132,6 +137,7 @@ impl FromStr for ProofType {
match s.trim().to_lowercase().as_str() {
"native" => Ok(ProofType::Native),
"sp1" => Ok(ProofType::Sp1),
"sp1_distributed" => Ok(ProofType::Sp1Distributed),
"sgx" => Ok(ProofType::Sgx),
"risc0" => Ok(ProofType::Risc0),
_ => Err(RaikoError::InvalidProofType(s.to_string())),
Expand Down Expand Up @@ -159,6 +165,14 @@ impl ProofType {
#[cfg(not(feature = "sp1"))]
Err(RaikoError::FeatureNotSupportedError(*self))
}
ProofType::Sp1Distributed => {
#[cfg(feature = "sp1")]
return sp1_driver::Sp1DistributedProver::run(input, output, config)
.await
.map_err(|e| e.into());
#[cfg(not(feature = "sp1"))]
Err(RaikoError::FeatureNotSupportedError(*self))
}
ProofType::Risc0 => {
#[cfg(feature = "risc0")]
return risc0_driver::Risc0Prover::run(input.clone(), output, config)
Expand Down
2 changes: 1 addition & 1 deletion host/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ ethers-core = { workspace = true }

[features]
default = []
sp1 = ["raiko-core/sp1"]
sp1 = ["raiko-core/sp1", "sp1-driver"]
risc0 = ["raiko-core/risc0"]
sgx = ["raiko-core/sgx"]

Expand Down
19 changes: 19 additions & 0 deletions host/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ fn default_address() -> String {
"0.0.0.0:8080".to_string()
}

fn default_worker_address() -> String {
"0.0.0.0:8081".to_string()
}

fn default_concurrency_limit() -> usize {
16
}
Expand Down Expand Up @@ -69,6 +73,21 @@ pub struct Cli {
/// [default: 0.0.0.0:8080]
address: String,

#[arg(long, require_equals = true, default_value = "0.0.0.0:8081")]
#[serde(default = "default_worker_address")]
/// Distributed SP1 worker listening address
/// [default: 0.0.0.0:8081]
worker_address: String,

#[arg(long, default_value = None)]
/// Distributed SP1 worker orchestrator address
///
/// Setting this will enable the worker and restrict it to only accept requests from
/// this orchestrator
///
/// [default: None]
orchestrator_address: Option<String>,

#[arg(long, require_equals = true, default_value = "16")]
#[serde(default = "default_concurrency_limit")]
/// Limit the max number of in-flight requests
Expand Down
5 changes: 5 additions & 0 deletions host/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@ use tokio::net::TcpListener;
use tracing::info;

pub mod api;
#[cfg(feature = "sp1")]
pub mod worker;

/// Starts the proverd server.
pub async fn serve(state: ProverState) -> anyhow::Result<()> {
#[cfg(feature = "sp1")]
worker::serve(state.clone()).await;

let addr = SocketAddr::from_str(&state.opts.address)
.map_err(|_| HostError::InvalidAddress(state.opts.address.clone()))?;
let listener = TcpListener::bind(addr).await?;
Expand Down
136 changes: 136 additions & 0 deletions host/src/server/worker.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
use crate::ProverState;
use raiko_lib::prover::{ProverError, WorkerError};
use sp1_driver::{
sp1_specifics::{Challenger, CoreSC, Machine, Program, ProvingKey, RiscvAir},
RequestData, WorkerProtocol, WorkerRequest, WorkerResponse, WorkerSocket, ELF,
};
use tokio::net::TcpListener;
use tracing::{error, info, warn};

pub async fn serve(state: ProverState) {
if state.opts.orchestrator_address.is_some() {
tokio::spawn(listen_worker(state));
}
}

async fn listen_worker(state: ProverState) {
info!(
"Listening as a SP1 worker on: {}",
state.opts.worker_address
);

let listener = TcpListener::bind(state.opts.worker_address).await.unwrap();

loop {
let Ok((socket, addr)) = listener.accept().await else {
error!("Error while accepting connection from orchestrator: Closing socket");

return;
};

if let Some(orchestrator_address) = &state.opts.orchestrator_address {
if addr.ip().to_string() != *orchestrator_address {
warn!("Unauthorized orchestrator connection from: {}", addr);

continue;
}
}

// We purposely don't spawn the task here, as we want to block to limit the number
// of concurrent connections to one.
if let Err(e) = handle_worker_socket(WorkerSocket::from_stream(socket)).await {
error!("Error while handling worker socket: {:?}", e);
}
}
}

async fn handle_worker_socket(mut socket: WorkerSocket) -> Result<(), ProverError> {
let program = Program::from(ELF);
let config = CoreSC::default();

let machine = RiscvAir::machine(config.clone());
let (pk, _vk) = machine.setup(&program);

while let Ok(protocol) = socket.receive().await {
match protocol {
WorkerProtocol::Request(request) => match request {
WorkerRequest::Ping => handle_ping(&mut socket).await?,
WorkerRequest::Commit(request_data) => {
handle_commit(&mut socket, &program, &machine, request_data).await?
}
WorkerRequest::Prove {
request_data,
challenger,
} => {
handle_prove(
&mut socket,
&program,
&machine,
&pk,
request_data,
challenger,
)
.await?
}
},
_ => Err(WorkerError::InvalidRequest)?,
}
}

Ok(())
}

async fn handle_ping(socket: &mut WorkerSocket) -> Result<(), WorkerError> {
socket
.send(WorkerProtocol::Response(WorkerResponse::Pong))
.await
}

async fn handle_commit(
socket: &mut WorkerSocket,
program: &Program,
machine: &Machine,
request_data: RequestData,
) -> Result<(), WorkerError> {
let (commitments, shards_public_values) = sp1_driver::sp1_specifics::commit(
program,
machine,
request_data.checkpoint,
request_data.nb_checkpoints,
request_data.public_values,
request_data.shard_batch_size,
request_data.shard_size,
)?;

socket
.send(WorkerProtocol::Response(WorkerResponse::Commitment {
commitments,
shards_public_values,
}))
.await
}

async fn handle_prove(
socket: &mut WorkerSocket,
program: &Program,
machine: &Machine,
pk: &ProvingKey,
request_data: RequestData,
challenger: Challenger,
) -> Result<(), WorkerError> {
let proof = sp1_driver::sp1_specifics::prove(
program,
machine,
pk,
request_data.checkpoint,
request_data.nb_checkpoints,
request_data.public_values,
request_data.shard_batch_size,
request_data.shard_size,
challenger,
)?;

socket
.send(WorkerProtocol::Response(WorkerResponse::Proof(proof)))
.await
}
2 changes: 1 addition & 1 deletion lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,4 @@ std = [
sgx = []
sp1 = []
risc0 = []
sp1-cycle-tracker = []
sp1-cycle-tracker = []
20 changes: 20 additions & 0 deletions lib/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ pub enum ProverError {
FileIo(#[from] std::io::Error),
#[error("ProverError::Param `{0}`")]
Param(#[from] serde_json::Error),
#[error("ProverError::Worker `{0}`")]
Worker(#[from] WorkerError),
}

impl From<String> for ProverError {
Expand All @@ -37,3 +39,21 @@ pub fn to_proof(proof: ProverResult<impl Serialize>) -> ProverResult<Proof> {
serde_json::to_value(res).map_err(|err| ProverError::GuestError(err.to_string()))
})
}

#[derive(ThisError, Debug)]
pub enum WorkerError {
#[error("All workers failed")]
AllWorkersFailed,
#[error("Worker IO error: {0}")]
IO(#[from] std::io::Error),
#[error("Worker Serde error: {0}")]
Serde(#[from] bincode::Error),
#[error("Worker invalid version")]
InvalidVersion,
#[error("Worker invalid request")]
InvalidRequest,
#[error("Worker invalid response")]
InvalidResponse,
#[error("Worker payload too big")]
PayloadTooBig,
}
26 changes: 26 additions & 0 deletions provers/sp1/driver/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,46 @@ alloy-sol-types = { workspace = true }
serde = { workspace = true , optional = true}
serde_json = { workspace = true , optional = true }
sp1-sdk = { workspace = true, optional = true }
sp1-core = { workspace = true, optional = true }
anyhow = { workspace = true, optional = true }
once_cell = { workspace = true, optional = true }
sha3 = { workspace = true, optional = true, default-features = false}

log = { workspace = true, optional = true }
tokio = { workspace = true, optional = true }
tracing = { workspace = true, optional = true }
tempfile = { workspace = true, optional = true }
bincode = { workspace = true, optional = true }

p3-field = { workspace = true, optional = true }
p3-challenger = { workspace = true, optional = true }
p3-poseidon2 = { workspace = true, optional = true }
p3-baby-bear = { workspace = true, optional = true }
p3-symmetric = { workspace = true, optional = true }


[features]
enable = [
"serde",
"serde_json",
"raiko-lib",
"sp1-sdk",
"sp1-core",
"anyhow",
"alloy-primitives",
"once_cell",
"sha3",

"log",
"tokio",
"tracing",
"tempfile",
"bincode",

"p3-field",
"p3-challenger",
"p3-poseidon2",
"p3-baby-bear",
"p3-symmetric",
]
neon = ["sp1-sdk?/neon"]
9 changes: 9 additions & 0 deletions provers/sp1/driver/src/distributed/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
mod prover;
pub mod sp1_specifics;
mod worker;

pub use prover::Sp1DistributedProver;
pub use worker::{
RequestData, WorkerEnvelope, WorkerPool, WorkerProtocol, WorkerRequest, WorkerResponse,
WorkerSocket,
};
Loading

0 comments on commit 4db0ce9

Please sign in to comment.