From 9a2afb653a62c33ab09aa0480a7f5655712153f6 Mon Sep 17 00:00:00 2001 From: samtvlabs <112424909+samtvlabs@users.noreply.github.com> Date: Tue, 14 Nov 2023 17:11:19 +0400 Subject: [PATCH] feat: hub auth (#605) --- .github/workflows/rust.yml | 4 +-- src/bin/ezkl.rs | 4 +-- src/commands.rs | 12 ++++++++ src/execute.rs | 59 +++++++++++++++++++++++++++++++------- src/python.rs | 31 ++++++++++++++------ 5 files changed, 87 insertions(+), 23 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 4077168b7..58d9fd2b4 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -666,8 +666,8 @@ jobs: run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_::tests_27_expects - name: KZG Vis demo run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_::tests_26_expects - # - name: Simple hub demo - # run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_::tests_25_expects + - name: Simple hub demo + run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_::tests_25_expects - name: Hashed DA tutorial run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_::tests_24_expects - name: Little transformer tutorial diff --git a/src/bin/ezkl.rs b/src/bin/ezkl.rs index b2e284767..ba4b210ab 100644 --- a/src/bin/ezkl.rs +++ b/src/bin/ezkl.rs @@ -15,10 +15,10 @@ use log::{error, info}; #[cfg(not(target_arch = "wasm32"))] use rand::prelude::SliceRandom; #[cfg(not(target_arch = "wasm32"))] -use std::error::Error; -#[cfg(not(target_arch = "wasm32"))] #[cfg(feature = "icicle")] use std::env; +#[cfg(not(target_arch = "wasm32"))] +use std::error::Error; #[tokio::main(flavor = "current_thread")] #[cfg(not(target_arch = "wasm32"))] diff --git a/src/commands.rs b/src/commands.rs index f61ee9519..6a4a0b2cb 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -672,6 +672,9 @@ pub enum Commands { #[command(name = "get-hub-credentials", arg_required_else_help = true)] #[cfg(not(target_arch = "wasm32"))] GetHubCredentials { + /// The user's api key + #[arg(short = 'K', long)] + api_key: Option, /// The path to the model file #[arg(short = 'N', long)] username: String, @@ -684,6 +687,9 @@ pub enum Commands { #[command(name = "create-hub-artifact", arg_required_else_help = true)] #[cfg(not(target_arch = "wasm32"))] CreateHubArtifact { + /// The user's api key + #[arg(short = 'K', long)] + api_key: Option, /// The path to the model file #[arg(short = 'M', long)] uncompiled_circuit: PathBuf, @@ -711,6 +717,9 @@ pub enum Commands { #[command(name = "prove-hub", arg_required_else_help = true)] #[cfg(not(target_arch = "wasm32"))] ProveHub { + /// The user's api key + #[arg(short = 'K', long)] + api_key: Option, /// The path to the model file #[arg(short = 'A', long)] artifact_id: String, @@ -727,6 +736,9 @@ pub enum Commands { #[command(name = "get-hub-proof", arg_required_else_help = true)] #[cfg(not(target_arch = "wasm32"))] GetHubProof { + /// The user's api key + #[arg(short = 'K', long)] + api_key: Option, /// The path to the model file #[arg(short = 'A', long)] artifact_id: String, diff --git a/src/execute.rs b/src/execute.rs index d9ff5a8e0..177dc3599 100644 --- a/src/execute.rs +++ b/src/execute.rs @@ -358,14 +358,17 @@ pub async fn run(cli: Cli) -> Result<(), Box> { } => verify_evm(proof_path, addr_verifier, rpc_url, addr_da).await, Commands::PrintProofHex { proof_path } => print_proof_hex(proof_path), #[cfg(not(target_arch = "wasm32"))] - Commands::GetHubCredentials { username, url } => { - get_hub_credentials(url.as_deref(), &username) - .await - .map(|_| ()) - } + Commands::GetHubCredentials { + api_key, + username, + url, + } => get_hub_credentials(api_key.as_deref(), url.as_deref(), &username) + .await + .map(|_| ()), #[cfg(not(target_arch = "wasm32"))] Commands::CreateHubArtifact { + api_key, uncompiled_circuit, data, organization_id, @@ -374,6 +377,7 @@ pub async fn run(cli: Cli) -> Result<(), Box> { args, target, } => deploy_model( + api_key.as_deref(), url.as_deref(), &uncompiled_circuit, &data, @@ -385,16 +389,22 @@ pub async fn run(cli: Cli) -> Result<(), Box> { .await .map(|_| ()), #[cfg(not(target_arch = "wasm32"))] - Commands::GetHubProof { artifact_id, url } => get_hub_proof(url.as_deref(), &artifact_id) + Commands::GetHubProof { + api_key, + artifact_id, + url, + } => get_hub_proof(api_key.as_deref(), url.as_deref(), &artifact_id) .await .map(|_| ()), #[cfg(not(target_arch = "wasm32"))] Commands::ProveHub { + api_key, artifact_id, data, transcript_type, url, } => prove_hub( + api_key.as_deref(), url.as_deref(), &artifact_id, &data, @@ -1753,6 +1763,7 @@ pub(crate) fn verify_aggr( /// Retrieves the user's credentials from the hub pub(crate) async fn get_hub_credentials( + api_key: Option<&str>, url: Option<&str>, username: &str, ) -> Result> { @@ -1771,8 +1782,15 @@ pub(crate) async fn get_hub_credentials( } }); let url = url.unwrap_or("https://hub-staging.ezkl.xyz/graphql"); + let api_key = api_key.unwrap_or("ed896983-2ec3-4aaf-afa7-f01299f3d61f"); + + let response = client + .post(url) + .header("Authorization", format!("Bearer {}", api_key)) + .json(&request_body) + .send() + .await?; - let response = client.post(url).json(&request_body).send().await?; let response_body = response.json::().await?; let organizations: crate::hub::Organizations = @@ -1787,6 +1805,7 @@ pub(crate) async fn get_hub_credentials( /// Deploy a model pub(crate) async fn deploy_model( + api_key: Option<&str>, url: Option<&str>, model: &Path, input: &Path, @@ -1862,8 +1881,14 @@ pub(crate) async fn deploy_model( let client = reqwest::Client::new(); let url = url.unwrap_or("https://hub-staging.ezkl.xyz/graphql"); + let api_key = api_key.unwrap_or("ed896983-2ec3-4aaf-afa7-f01299f3d61f"); //send request - let response = client.post(url).multipart(form).send().await?; + let response = client + .post(url) + .header("Authorization", format!("Bearer {}", api_key)) + .multipart(form) + .send() + .await?; let response_body = response.json::().await?; println!("{}", response_body.to_string()); let artifact_id: crate::hub::Artifact = @@ -1877,6 +1902,7 @@ pub(crate) async fn deploy_model( /// Generates proofs on the hub pub async fn prove_hub( + api_key: Option<&str>, url: Option<&str>, id: &str, input: &Path, @@ -1916,8 +1942,14 @@ pub async fn prove_hub( .text("map", map) .part("input", input_file); let url = url.unwrap_or("https://hub-staging.ezkl.xyz/graphql"); + let api_key = api_key.unwrap_or("ed896983-2ec3-4aaf-afa7-f01299f3d61f"); let client = reqwest::Client::new(); - let response = client.post(url).multipart(form).send().await?; + let response = client + .post(url) + .header("Authorization", format!("Bearer {}", api_key)) + .multipart(form) + .send() + .await?; let response_body = response.json::().await?; let proof_id: crate::hub::Proof = serde_json::from_value(response_body["data"]["initiateProof"].clone())?; @@ -1927,6 +1959,7 @@ pub async fn prove_hub( /// Fetches proofs from the hub pub(crate) async fn get_hub_proof( + api_key: Option<&str>, url: Option<&str>, id: &str, ) -> Result> { @@ -1947,8 +1980,14 @@ pub(crate) async fn get_hub_proof( "#, id), }); let url = url.unwrap_or("https://hub-staging.ezkl.xyz/graphql"); + let api_key = api_key.unwrap_or("ed896983-2ec3-4aaf-afa7-f01299f3d61f"); - let response = client.post(url).json(&request_body).send().await?; + let response = client + .post(url) + .header("Authorization", format!("Bearer {:?}", api_key)) + .json(&request_body) + .send() + .await?; let response_body = response.json::().await?; let proof: crate::hub::Proof = diff --git a/src/python.rs b/src/python.rs index 80c72708f..684aa6146 100644 --- a/src/python.rs +++ b/src/python.rs @@ -1163,12 +1163,13 @@ fn print_proof_hex(proof_path: PathBuf) -> Result { } /// deploys a model to the hub -#[pyfunction(signature = (model, input, name, organization_id, target=None, py_run_args=None, url=None))] +#[pyfunction(signature = (model, input, name, organization_id, api_key=None,target=None, py_run_args=None, url=None))] fn create_hub_artifact( model: PathBuf, input: PathBuf, name: String, organization_id: String, + api_key: Option<&str>, target: Option, py_run_args: Option, url: Option<&str>, @@ -1180,6 +1181,7 @@ fn create_hub_artifact( let output = Runtime::new() .unwrap() .block_on(crate::execute::deploy_model( + api_key, url, &model, &input, @@ -1196,16 +1198,23 @@ fn create_hub_artifact( } /// Generate a proof on the hub. -#[pyfunction(signature = (id, input, url=None, transcript_type=None))] +#[pyfunction(signature = ( id, input,api_key=None, url=None, transcript_type=None))] fn prove_hub( id: &str, input: PathBuf, + api_key: Option<&str>, url: Option<&str>, transcript_type: Option<&str>, ) -> PyResult { let output = Runtime::new() .unwrap() - .block_on(crate::execute::prove_hub(url, id, &input, transcript_type)) + .block_on(crate::execute::prove_hub( + api_key, + url, + id, + &input, + transcript_type, + )) .map_err(|e| { let err_str = format!("Failed to generate proof on hub: {}", e); PyRuntimeError::new_err(err_str) @@ -1214,11 +1223,11 @@ fn prove_hub( } /// Fetches proof from hub -#[pyfunction(signature = (id, url=None))] -fn get_hub_proof(id: &str, url: Option<&str>) -> PyResult { +#[pyfunction(signature = ( id, api_key=None,url=None))] +fn get_hub_proof(id: &str, api_key: Option<&str>, url: Option<&str>) -> PyResult { let output = Runtime::new() .unwrap() - .block_on(crate::execute::get_hub_proof(url, id)) + .block_on(crate::execute::get_hub_proof(api_key, url, id)) .map_err(|e| { let err_str = format!("Failed to get proof from hub: {}", e); PyRuntimeError::new_err(err_str) @@ -1227,11 +1236,15 @@ fn get_hub_proof(id: &str, url: Option<&str>) -> PyResult { } /// Gets hub credentials -#[pyfunction(signature = (username, url=None))] -fn get_hub_credentials(username: &str, url: Option<&str>) -> PyResult { +#[pyfunction(signature = (username,api_key=None, url=None))] +fn get_hub_credentials( + username: &str, + api_key: Option<&str>, + url: Option<&str>, +) -> PyResult { let output = Runtime::new() .unwrap() - .block_on(crate::execute::get_hub_credentials(url, username)) + .block_on(crate::execute::get_hub_credentials(api_key, url, username)) .map_err(|e| { let err_str = format!("Failed to get hub credentials: {}", e); PyRuntimeError::new_err(err_str)