Skip to content

Commit

Permalink
feat(raiko): bonsai auto scaling (#341)
Browse files Browse the repository at this point in the history
* use bonsai auto scaling api

Signed-off-by: smtmfft <[email protected]>

* update auto-scaling poc

* refine auto scaler set/get logic

* remove dup code

* add missing error return

* Update provers/risc0/driver/src/lib.rs

Co-authored-by: Petar Vujović <[email protected]>

* Update provers/risc0/driver/src/bonsai/auto_scaling.rs

Co-authored-by: Petar Vujović <[email protected]>

* refine auto scaling

* remove useless comments

Signed-off-by: smtmfft <[email protected]>

---------

Signed-off-by: smtmfft <[email protected]>
Co-authored-by: Petar Vujović <[email protected]>
  • Loading branch information
smtmfft and petarvujovic98 authored Aug 16, 2024
1 parent cce1371 commit dc89e60
Show file tree
Hide file tree
Showing 11 changed files with 244 additions and 18 deletions.
4 changes: 4 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ risc0-build = { version = "1.0.1" }
risc0-binfmt = { version = "1.0.1" }

# SP1
sp1-sdk = { version = "1.0.1" }
sp1-zkvm = { version = "1.0.1" }
sp1-helper = { version = "1.0.1" }
sp1-sdk = { version = "1.0.1" }
sp1-zkvm = { version = "1.0.1" }
sp1-helper = { version = "1.0.1" }

# alloy
alloy-rlp = { version = "0.3.4", default-features = false }
Expand Down Expand Up @@ -189,4 +189,4 @@ revm-primitives = { git = "https://github.com/taikoxyz/revm.git", branch = "v36-
revm-precompile = { git = "https://github.com/taikoxyz/revm.git", branch = "v36-taiko" }
secp256k1 = { git = "https://github.com/CeciliaZ030/rust-secp256k1", branch = "sp1-patch" }
blst = { git = "https://github.com/CeciliaZ030/blst.git", branch = "v0.3.12-serialize" }
alloy-serde = { git = "https://github.com/CeciliaZ030/alloy.git", branch = "v0.1.4-fix"}
alloy-serde = { git = "https://github.com/CeciliaZ030/alloy.git", branch = "v0.1.4-fix" }
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ ARG BUILD_FLAGS=""

WORKDIR /opt/raiko
COPY . .
RUN cargo build --release ${BUILD_FLAGS} --features "sgx" --features "docker_build"
RUN cargo build --release ${BUILD_FLAGS} --features "sgx,risc0" --features "docker_build"

FROM gramineproject/gramine:1.6-jammy as runtime
ENV DEBIAN_FRONTEND=noninteractive
Expand Down
6 changes: 6 additions & 0 deletions provers/risc0/driver/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ typetag = { workspace = true, optional = true }
serde_with = { workspace = true, optional = true }
serde_json = { workspace = true, optional = true }
hex = { workspace = true, optional = true }
reqwest = { workspace = true, optional = true }
lazy_static = { workspace = true, optional = true }
tokio = { workspace = true }
tokio-util = { workspace = true }

[features]
enable = [
Expand All @@ -57,6 +61,8 @@ enable = [
"serde_with",
"serde_json",
"hex",
"reqwest",
"lazy_static"
]
cuda = ["risc0-zkvm?/cuda"]
metal = ["risc0-zkvm?/metal"]
Expand Down
3 changes: 3 additions & 0 deletions provers/risc0/driver/src/bonsai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ use std::{

use crate::Risc0Param;

pub mod auto_scaling;

pub async fn verify_bonsai_receipt<O: Eq + Debug + DeserializeOwned>(
image_id: Digest,
expected_output: &O,
Expand Down Expand Up @@ -194,6 +196,7 @@ pub async fn cancel_proof(uuid: String) -> anyhow::Result<()> {
let client = bonsai_sdk::alpha_async::get_client_from_env(risc0_zkvm::VERSION).await?;
let session = bonsai_sdk::alpha::SessionId { uuid };
session.stop(&client)?;
auto_scaling::shutdown_bonsai().await?;
Ok(())
}

Expand Down
204 changes: 204 additions & 0 deletions provers/risc0/driver/src/bonsai/auto_scaling.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
use anyhow::{Error, Ok, Result};
use lazy_static::lazy_static;
use reqwest::{header::HeaderMap, header::HeaderValue, header::CONTENT_TYPE, Client};
use serde::Deserialize;
use std::env;
use tracing::{debug, error as trace_err};

#[derive(Debug, Deserialize, Default)]
struct ScalerResponse {
desired: u32,
current: u32,
pending: u32,
}
struct BonsaiAutoScaler {
url: String,
headers: HeaderMap,
client: Client,
on_setting_status: Option<ScalerResponse>,
}

impl BonsaiAutoScaler {
fn new(bonsai_api_url: String, api_key: String) -> Self {
let url = bonsai_api_url + "/workers";
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
headers.insert("x-api-key", HeaderValue::from_str(&api_key).unwrap());

Self {
url,
headers,
client: Client::new(),
on_setting_status: None,
}
}

async fn get_bonsai_gpu_num(&self) -> Result<ScalerResponse> {
debug!("Requesting scaler status from: {}", self.url);
let response = self
.client
.get(self.url.clone())
.headers(self.headers.clone())
.send()
.await?;

// Check if the request was successful
if response.status().is_success() {
// Parse the JSON response
let data: ScalerResponse = response.json().await.unwrap_or_default();
debug!("Scaler status: {data:?}");
Ok(data)
} else {
trace_err!("Request failed with status: {}", response.status());
Err(Error::msg("Failed to get bonsai gpu num".to_string()))
}
}

async fn set_bonsai_gpu_num(&mut self, gpu_num: u32) -> Result<()> {
if self.on_setting_status.is_some() {
// log an err if there is a race adjustment.
trace_err!("Last bonsai setting is not active, please check.");
}

debug!("Requesting scaler status from: {}", self.url);
let response = self
.client
.post(self.url.clone())
.headers(self.headers.clone())
.body(gpu_num.to_string())
.send()
.await?;

// Check if the request was successful
if response.status().is_success() {
self.on_setting_status = Some(ScalerResponse {
desired: gpu_num,
current: 0,
pending: 0,
});
Ok(())
} else {
trace_err!("Request failed with status: {}", response.status());
Err(Error::msg("Failed to get bonsai gpu num".to_string()))
}
}

async fn wait_for_bonsai_config_active(&mut self, time_out_sec: u64) -> Result<()> {
match &self.on_setting_status {
None => Ok(()),
Some(setting) => {
// loop until some timeout
let start_time = std::time::Instant::now();
let mut check_time = std::time::Instant::now();
while check_time.duration_since(start_time).as_secs() < time_out_sec {
tokio::time::sleep(tokio::time::Duration::from_secs(10)).await;
check_time = std::time::Instant::now();
let current_bonsai_gpu_num = self.get_bonsai_gpu_num().await?;
if current_bonsai_gpu_num.current == setting.desired {
self.on_setting_status = None;
return Ok(());
}
}
Err(Error::msg(
"checking bonsai config active timeout".to_string(),
))
}
}
}
}

lazy_static! {
static ref BONSAI_API_URL: String =
env::var("BONSAI_API_URL").expect("BONSAI_API_URL must be set");
static ref BONSAI_API_KEY: String =
env::var("BONSAI_API_KEY").expect("BONSAI_API_KEY must be set");
static ref MAX_BONSAI_GPU_NUM: u32 = env::var("MAX_BONSAI_GPU_NUM")
.unwrap_or_else(|_| "15".to_string())
.parse()
.unwrap();
}

pub(crate) async fn maxpower_bonsai() -> Result<()> {
let mut auto_scaler =
BonsaiAutoScaler::new(BONSAI_API_URL.to_string(), BONSAI_API_KEY.to_string());
let current_gpu_num = auto_scaler.get_bonsai_gpu_num().await?;
// either already maxed out or pending to be maxed out
if current_gpu_num.current == *MAX_BONSAI_GPU_NUM
&& current_gpu_num.desired == *MAX_BONSAI_GPU_NUM
&& current_gpu_num.pending == 0
{
Ok(())
} else {
auto_scaler.set_bonsai_gpu_num(*MAX_BONSAI_GPU_NUM).await?;
auto_scaler.wait_for_bonsai_config_active(300).await
}
}

pub(crate) async fn shutdown_bonsai() -> Result<()> {
let mut auto_scaler =
BonsaiAutoScaler::new(BONSAI_API_URL.to_string(), BONSAI_API_KEY.to_string());
let current_gpu_num = auto_scaler.get_bonsai_gpu_num().await?;
if current_gpu_num.current == 0 && current_gpu_num.pending == 0 && current_gpu_num.desired == 0
{
Ok(())
} else {
auto_scaler.set_bonsai_gpu_num(0).await?;
// wait few minute for the bonsai to cool down
auto_scaler.wait_for_bonsai_config_active(30).await
}
}

#[cfg(test)]
mod test {
use super::*;
use std::env;
use tokio;

#[ignore]
#[tokio::test]
async fn test_bonsai_auto_scaler_get() {
let bonsai_url = env::var("BONSAI_API_URL").expect("BONSAI_API_URL must be set");
let bonsai_key = env::var("BONSAI_API_KEY").expect("BONSAI_API_KEY must be set");
let max_bonsai_gpu: u32 = env::var("MAX_BONSAI_GPU_NUM")
.unwrap_or_else(|_| "15".to_string())
.parse()
.unwrap();
let auto_scaler = BonsaiAutoScaler::new(bonsai_url, bonsai_key);
let scalar_status = auto_scaler.get_bonsai_gpu_num().await.unwrap();
assert!(scalar_status.current <= max_bonsai_gpu);
assert_eq!(
scalar_status.desired,
scalar_status.current + scalar_status.pending
);
}

#[ignore]
#[tokio::test]
async fn test_bonsai_auto_scaler_set() {
let bonsai_url = env::var("BONSAI_API_URL").expect("BONSAI_API_URL must be set");
let bonsai_key = env::var("BONSAI_API_KEY").expect("BONSAI_API_KEY must be set");
let mut auto_scaler = BonsaiAutoScaler::new(bonsai_url, bonsai_key);

auto_scaler
.set_bonsai_gpu_num(7)
.await
.expect("Failed to set bonsai gpu num");
auto_scaler
.wait_for_bonsai_config_active(300)
.await
.unwrap();
let current_gpu_num = auto_scaler.get_bonsai_gpu_num().await.unwrap().current;
assert_eq!(current_gpu_num, 7);

auto_scaler
.set_bonsai_gpu_num(0)
.await
.expect("Failed to set bonsai gpu num");
auto_scaler
.wait_for_bonsai_config_active(300)
.await
.unwrap();
let current_gpu_num = auto_scaler.get_bonsai_gpu_num().await.unwrap().current;
assert_eq!(current_gpu_num, 0);
}
}
22 changes: 18 additions & 4 deletions provers/risc0/driver/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use std::fmt::Debug;
use tracing::{debug, info as traicing_info};

use crate::{
bonsai::auto_scaling::{maxpower_bonsai, shutdown_bonsai},
methods::risc0_guest::{RISC0_GUEST_ELF, RISC0_GUEST_ID},
snarks::verify_groth16_snark,
};
Expand Down Expand Up @@ -70,6 +71,13 @@ impl Prover for Risc0Prover {
debug!("elf code length: {}", RISC0_GUEST_ELF.len());
let encoded_input = to_vec(&input).expect("Could not serialize proving input!");

if config.bonsai {
// make max speed bonsai
maxpower_bonsai()
.await
.expect("Failed to set max power on Bonsai");
}

let result = maybe_prove::<GuestInput, B256>(
&config,
encoded_input,
Expand All @@ -83,8 +91,8 @@ impl Prover for Risc0Prover {

let journal: String = result.clone().unwrap().1.journal.encode_hex();

// Create/verify Groth16 SNARK
let snark_proof = if config.snark {
// Create/verify Groth16 SNARK in bonsai
let snark_proof = if config.snark && config.bonsai {
let Some((stark_uuid, stark_receipt)) = result else {
return Err(ProverError::GuestError(
"No STARK data to snarkify!".to_owned(),
Expand All @@ -108,6 +116,13 @@ impl Prover for Risc0Prover {
journal
};

if config.bonsai {
// shutdown max speed bonsai
shutdown_bonsai()
.await
.map_err(|e| ProverError::GuestError(e.to_string()))?;
}

Ok(Risc0Response { proof: snark_proof }.into())
}

Expand All @@ -125,8 +140,7 @@ impl Prover for Risc0Prover {
cancel_proof(uuid)
.await
.map_err(|e| ProverError::GuestError(e.to_string()))?;
id_store.remove_id(key).await?;
Ok(())
id_store.remove_id(key).await
}
}

Expand Down
2 changes: 1 addition & 1 deletion provers/risc0/driver/src/methods/ecdsa.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pub const ECDSA_ELF: &[u8] =
include_bytes!("../../../guest/target/riscv32im-risc0-zkvm-elf/release/ecdsa");
pub const ECDSA_ID: [u32; 8] = [
3314277365, 903638368, 2823387338, 975292771, 2962241176, 3386670094, 1262198564, 423457744,
1166688769, 1407190737, 3347938864, 1261472884, 3997842354, 3752365982, 4108615966, 2506107654,
];
2 changes: 1 addition & 1 deletion provers/risc0/driver/src/methods/sha256.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pub const SHA256_ELF: &[u8] =
include_bytes!("../../../guest/target/riscv32im-risc0-zkvm-elf/release/sha256");
pub const SHA256_ID: [u32; 8] = [
3506084161, 1146489446, 485833862, 3404354046, 3626029993, 1928006034, 3833244069, 3073098029,
1030743442, 3697463329, 2083175350, 1726292372, 629109085, 444583534, 849554126, 3148184953,
];
4 changes: 2 additions & 2 deletions provers/risc0/driver/src/methods/test_risc0_guest.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
pub const TEST_RISC0_GUEST_ELF: &[u8] = include_bytes!(
"../../../guest/target/riscv32im-risc0-zkvm-elf/release/deps/risc0_guest-4b4f18d42a260659"
"../../../guest/target/riscv32im-risc0-zkvm-elf/release/deps/risc0_guest-3bef88267f07d7e2"
);
pub const TEST_RISC0_GUEST_ID: [u32; 8] = [
3216516244, 2583889163, 799150854, 107525368, 1015178806, 1451965571, 3377528142, 1073775,
947177299, 3433149683, 3077752115, 1716500464, 3011459317, 622725533, 247263939, 1661915565,
];
5 changes: 0 additions & 5 deletions provers/risc0/driver/src/snarks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,6 @@ abigen!(
]"#
);

// /// ABI encoding of the seal.
// pub fn abi_encode(seal: Vec<u8>) -> Result<Vec<u8>> {
// Ok(encode(seal)?.abi_encode())
// }

/// encoding of the seal with selector.
pub fn encode(seal: Vec<u8>) -> Result<Vec<u8>> {
let verifier_parameters_digest = Groth16ReceiptVerifierParameters::default().digest();
Expand Down

0 comments on commit dc89e60

Please sign in to comment.