diff --git a/crates/network-scheduler/Cargo.toml b/crates/network-scheduler/Cargo.toml index 8e54274..d3ca963 100644 --- a/crates/network-scheduler/Cargo.toml +++ b/crates/network-scheduler/Cargo.toml @@ -44,6 +44,8 @@ sqd-contract-client = { workspace = true } sqd-messages = { workspace = true, features = ["semver"] } sqd-network-transport = { workspace = true, features = ["scheduler", "metrics"] } chrono = "0.4.38" +bs58 = "0.5.1" +crypto_box = "0.9.1" [target.'cfg(not(target_env = "msvc"))'.dependencies] tikv-jemallocator = "0.6" diff --git a/crates/network-scheduler/src/assignment.rs b/crates/network-scheduler/src/assignment.rs index a655d4a..f73bb22 100644 --- a/crates/network-scheduler/src/assignment.rs +++ b/crates/network-scheduler/src/assignment.rs @@ -1,6 +1,12 @@ use std::collections::HashMap; +use aws_config::identity; +use crypto_box::{ + aead::{Aead, AeadCore, OsRng}, + SalsaBox, PublicKey, SecretKey +}; use serde::{Deserialize, Serialize}; +use sha3::digest::generic_array::GenericArray; use crate::signature::timed_hmac_now; @@ -23,11 +29,19 @@ pub struct Dataset { #[derive(Debug, Default, Serialize, Deserialize)] #[serde(rename_all = "kebab-case")] -struct EncryptedHeaders { +struct Headers { worker_id: String, worker_signature: String, } +#[derive(Debug, Default, Serialize, Deserialize, Clone)] +#[serde(rename_all = "camelCase")] +struct EncryptedHeaders { + identity: Vec, + nonce: Vec, + ciphertext: Vec, +} + #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] struct WorkerAssignment { @@ -118,18 +132,27 @@ impl Assignment { Some(result) } - pub fn headers_for_peer_id(&self, peer_id: String) -> Option> { - let local_assignment = match self.worker_assignments.get(&peer_id) { - Some(worker_assignment) => worker_assignment, - None => { - return None - } + pub fn headers_for_peer_id(&self, peer_id: String, secret_key: Vec) -> Option> { + let Some(local_assignment) = self.worker_assignments.get(&peer_id) else { + return None }; - let headers = match serde_json::to_value(&local_assignment.encrypted_headers) { - Ok(v) => v, - Err(_) => { - return None; - } + let EncryptedHeaders {identity, nonce, ciphertext,} = local_assignment.encrypted_headers.clone(); + let Ok(alice_public_key) = PublicKey::from_slice(identity.as_slice()) else { + return None + }; + let Ok(bob_secret_key) = SecretKey::from_slice(secret_key.as_slice()) else { + return None + }; + let bob_box = SalsaBox::new(&alice_public_key, &bob_secret_key); + let generic_nonce = GenericArray::clone_from_slice(&nonce); + let Ok(decrypted_plaintext) = bob_box.decrypt(&generic_nonce, &ciphertext[..]) else { + return None + }; + let Ok(plaintext_headers) = std::str::from_utf8(&decrypted_plaintext) else { + return None; + }; + let Ok(headers) = serde_json::to_value(&plaintext_headers) else { + return None; }; let mut result: HashMap = Default::default(); for (k,v) in headers.as_object().unwrap() { @@ -154,15 +177,34 @@ impl Assignment { } pub fn regenerate_headers(&mut self, cloudflare_storage_secret: String) { + let alice_secret_key = SecretKey::generate(&mut OsRng); + let alice_public_key_bytes = alice_secret_key.public_key().as_bytes().clone(); + for (worker_id, worker_assignment) in &mut self.worker_assignments { let worker_signature = timed_hmac_now( worker_id, &cloudflare_storage_secret, ); - worker_assignment.encrypted_headers = EncryptedHeaders { + + let headers = Headers { worker_id: worker_id.to_string(), worker_signature, - } + }; + + let pub_key = &bs58::decode(worker_id).into_vec().unwrap()[6..]; + let bob_public_key = PublicKey::from_slice(pub_key).unwrap(); + + let alice_box = SalsaBox::new(&bob_public_key, &alice_secret_key); + let nonce = SalsaBox::generate_nonce(&mut OsRng); + let plaintext = serde_json::to_vec(&headers).unwrap(); + let ciphertext = alice_box.encrypt(&nonce, &plaintext[..]).unwrap(); + + + worker_assignment.encrypted_headers = EncryptedHeaders { + identity: alice_public_key_bytes.to_vec(), + nonce: nonce.to_vec(), + ciphertext, + }; } } } \ No newline at end of file