From e6ab9916ec5d536ac62b9b1cf723e93f8c244eb2 Mon Sep 17 00:00:00 2001 From: Xun Li Date: Tue, 22 Oct 2024 19:04:52 -0700 Subject: [PATCH] Support multiple sponsor addresses --- Cargo.lock | 1 + Cargo.toml | 3 +- src/command.rs | 6 +- src/config.rs | 8 +- src/gas_pool/gas_pool_core.rs | 48 ++++---- src/gas_pool/mod.rs | 14 +-- src/gas_pool_initializer.rs | 98 +++++++++------- src/storage/mod.rs | 204 ++++++++++++++++++++++------------ src/storage/redis/mod.rs | 187 ++++++++++++++++++------------- src/test_env.rs | 2 +- src/tx_signer.rs | 85 +++++++++----- 11 files changed, 409 insertions(+), 247 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e2a5abe..ab5c0c9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9451,6 +9451,7 @@ dependencies = [ "test-cluster", "tokio", "tokio-retry", + "tokio-util 0.7.11", "tracing", ] diff --git a/Cargo.toml b/Cargo.toml index 2d33c7c..a9b0432 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,7 @@ telemetry-subscribers = { git = "https://github.com/MystenLabs/sui", branch = "t anyhow = "1.0.75" async-trait = "0.1.51" -axum = {version = "0.6.6", features = ["headers"]} +axum = { version = "0.6.6", features = ["headers"] } bcs = "0.1.6" clap = "4.4.10" chrono = "0.4.19" @@ -53,6 +53,7 @@ tracing = "0.1.40" tokio = { version = "1.36.0", features = ["full"] } tokio-retry = "0.3.0" serde_json = "1.0.108" +tokio-util = "0.7.10" [dev-dependencies] rand = "0.8.5" diff --git a/src/command.rs b/src/command.rs index 1bf957c..001440b 100644 --- a/src/command.rs +++ b/src/command.rs @@ -53,9 +53,9 @@ impl Command { let signer = signer_config.new_signer().await; let storage_metrics = StorageMetrics::new(&prometheus_registry); - let sponsor_address = signer.get_address(); - info!("Sponsor address: {:?}", sponsor_address); - let storage = connect_storage(&gas_pool_config, sponsor_address, storage_metrics).await; + let sponsor_addresses = signer.get_all_addresses(); + info!("Sponsor addresses: {:?}", sponsor_addresses); + let storage = connect_storage(&gas_pool_config, sponsor_addresses, storage_metrics).await; let sui_client = SuiClient::new(&fullnode_url, fullnode_basic_auth).await; let _coin_init_task = if let Some(coin_init_config) = coin_init_config { let task = GasPoolInitializer::start( diff --git a/src/config.rs b/src/config.rs index 88b42c2..c159e84 100644 --- a/src/config.rs +++ b/src/config.rs @@ -82,6 +82,7 @@ impl Default for GasPoolStorageConfig { pub enum TxSignerConfig { Local { keypair: SuiKeyPair }, Sidecar { sidecar_url: String }, + MultiSidecar { sidecar_urls: Vec }, } impl Default for TxSignerConfig { @@ -97,7 +98,12 @@ impl TxSignerConfig { pub async fn new_signer(self) -> Arc { match self { TxSignerConfig::Local { keypair } => TestTxSigner::new(keypair), - TxSignerConfig::Sidecar { sidecar_url } => SidecarTxSigner::new(sidecar_url).await, + TxSignerConfig::Sidecar { sidecar_url } => { + SidecarTxSigner::new(vec![sidecar_url]).await + } + TxSignerConfig::MultiSidecar { sidecar_urls } => { + SidecarTxSigner::new(sidecar_urls).await + } } } } diff --git a/src/gas_pool/gas_pool_core.rs b/src/gas_pool/gas_pool_core.rs index a142bc9..2a46dd5 100644 --- a/src/gas_pool/gas_pool_core.rs +++ b/src/gas_pool/gas_pool_core.rs @@ -20,6 +20,7 @@ use sui_types::transaction::{ }; use tap::TapFallible; use tokio::task::JoinHandle; +use tokio_util::sync::CancellationToken; use tracing::{debug, error, info}; use super::gas_usage_cap::GasUsageCap; @@ -28,9 +29,8 @@ const EXPIRATION_JOB_INTERVAL: Duration = Duration::from_secs(1); pub struct GasPoolContainer { inner: Arc, - _coin_unlocker_task: JoinHandle<()>, - // This is always Some. It is None only after the drop method is called. - cancel_sender: Option>, + _coin_unlocker_tasks: Vec>, + cancel: CancellationToken, } pub struct GasPool { @@ -66,10 +66,10 @@ impl GasPool { ) -> anyhow::Result<(SuiAddress, ReservationID, Vec)> { let cur_time = std::time::Instant::now(); self.gas_usage_cap.check_usage().await?; - let sponsor = self.signer.get_address(); + let sponsor = self.signer.get_one_address(); let (reservation_id, gas_coins) = self .gas_pool_store - .reserve_gas_coins(gas_budget, duration.as_millis() as u64) + .reserve_gas_coins(sponsor, gas_budget, duration.as_millis() as u64) .await?; let elapsed = cur_time.elapsed().as_millis(); self.metrics.reserve_gas_latency_ms.observe(elapsed as u64); @@ -106,7 +106,7 @@ impl GasPool { "Payment coins in transaction: {:?}", payment ); self.gas_pool_store - .ready_for_execution(reservation_id) + .ready_for_execution(sponsor, reservation_id) .await?; debug!(?reservation_id, "Reservation is ready for execution"); @@ -161,7 +161,7 @@ impl GasPool { // Regardless of whether the transaction succeeded, we need to release the coins. // Otherwise, we lose track of them. This is because `ready_for_execution` already takes // the coins out of the pool and will not be covered by the auto-release mechanism. - self.release_gas_coins(updated_coins).await; + self.release_gas_coins(sponsor, updated_coins).await; if smashed_coin_count > 0 { info!( ?reservation_id, @@ -260,11 +260,11 @@ impl GasPool { } /// Release gas coins back to the gas pool, by adding them to the storage. - async fn release_gas_coins(&self, gas_coins: Vec) { + async fn release_gas_coins(&self, sponsor: SuiAddress, gas_coins: Vec) { debug!("Trying to release gas coins: {:?}", gas_coins); retry_forever!(async { self.gas_pool_store - .add_new_coins(gas_coins.clone()) + .add_new_coins(sponsor, gas_coins.clone()) .await .tap_err(|err| error!("Failed to call update_gas_coins on storage: {:?}", err)) }) @@ -293,11 +293,12 @@ impl GasPool { async fn start_coin_unlock_task( self: Arc, - mut cancel_receiver: tokio::sync::oneshot::Receiver<()>, + sponsor: SuiAddress, + cancel: CancellationToken, ) -> JoinHandle<()> { tokio::task::spawn(async move { loop { - let expire_results = self.gas_pool_store.expire_coins().await; + let expire_results = self.gas_pool_store.expire_coins(sponsor).await; let unlocked_coins = expire_results.unwrap_or_else(|err| { error!("Failed to call expire_coins to the storage: {:?}", err); vec![] @@ -312,12 +313,12 @@ impl GasPool { .flatten() .collect(); let count = latest_coins.len(); - self.release_gas_coins(latest_coins).await; + self.release_gas_coins(sponsor, latest_coins).await; info!("Released {:?} coins after expiration", count); } tokio::select! { _ = tokio::time::sleep(EXPIRATION_JOB_INTERVAL) => {} - _ = &mut cancel_receiver => { + _ = cancel.cancelled() => { info!("Coin unlocker task is cancelled"); break; } @@ -326,9 +327,9 @@ impl GasPool { }) } - pub async fn query_pool_available_coin_count(&self) -> usize { + pub async fn query_pool_available_coin_count(&self, sponsor: SuiAddress) -> usize { self.gas_pool_store - .get_available_coin_count() + .get_available_coin_count(sponsor) .await .unwrap() } @@ -342,6 +343,7 @@ impl GasPoolContainer { gas_usage_daily_cap: u64, metrics: Arc, ) -> Self { + let sponsor_addresses = signer.get_all_addresses(); let inner = GasPool::new( signer, gas_pool_store, @@ -350,13 +352,19 @@ impl GasPoolContainer { Arc::new(GasUsageCap::new(gas_usage_daily_cap)), ) .await; - let (cancel_sender, cancel_receiver) = tokio::sync::oneshot::channel(); - let _coin_unlocker_task = inner.clone().start_coin_unlock_task(cancel_receiver).await; + let cancel = CancellationToken::new(); + + let mut _coin_unlocker_tasks = vec![]; + for sponsor in sponsor_addresses { + let inner = inner.clone(); + let task = inner.start_coin_unlock_task(sponsor, cancel.clone()).await; + _coin_unlocker_tasks.push(task); + } Self { inner, - _coin_unlocker_task, - cancel_sender: Some(cancel_sender), + _coin_unlocker_tasks, + cancel, } } @@ -367,6 +375,6 @@ impl GasPoolContainer { impl Drop for GasPoolContainer { fn drop(&mut self) { - self.cancel_sender.take().unwrap().send(()).unwrap(); + self.cancel.cancel(); } } diff --git a/src/gas_pool/mod.rs b/src/gas_pool/mod.rs index 0b887aa..5769eb0 100644 --- a/src/gas_pool/mod.rs +++ b/src/gas_pool/mod.rs @@ -27,14 +27,14 @@ mod tests { .await .unwrap(); assert_eq!(gas_coins.len(), 3); - assert_eq!(station.query_pool_available_coin_count().await, 7); + assert_eq!(station.query_pool_available_coin_count(sponsor1).await, 7); let (sponsor2, _res_id2, gas_coins) = station .reserve_gas(MIST_PER_SUI * 7, Duration::from_secs(10)) .await .unwrap(); assert_eq!(gas_coins.len(), 7); assert_eq!(sponsor1, sponsor2); - assert_eq!(station.query_pool_available_coin_count().await, 0); + assert_eq!(station.query_pool_available_coin_count(sponsor2).await, 0); assert!(station .reserve_gas(1, Duration::from_secs(10)) .await @@ -55,7 +55,7 @@ mod tests { .await .unwrap(); assert_eq!(gas_coins.len(), 1); - assert_eq!(station.query_pool_available_coin_count().await, 0); + assert_eq!(station.query_pool_available_coin_count(sponsor).await, 0); assert!(station .reserve_gas(1, Duration::from_secs(10)) .await @@ -67,7 +67,7 @@ mod tests { .await .unwrap(); assert!(effects.status().is_ok()); - assert_eq!(station.query_pool_available_coin_count().await, 1); + assert_eq!(station.query_pool_available_coin_count(sponsor).await, 1); } #[tokio::test] @@ -93,7 +93,7 @@ mod tests { .await; println!("{:?}", result); assert!(result.is_err()); - assert_eq!(station.query_pool_available_coin_count().await, 1); + assert_eq!(station.query_pool_available_coin_count(sponsor).await, 1); } #[tokio::test] @@ -106,14 +106,14 @@ mod tests { .await .unwrap(); assert_eq!(gas_coins.len(), 1); - assert_eq!(station.query_pool_available_coin_count().await, 0); + assert_eq!(station.query_pool_available_coin_count(sponsor).await, 0); assert!(station .reserve_gas(1, Duration::from_secs(1)) .await .is_err()); // Sleep a little longer to give it enough time to expire. tokio::time::sleep(Duration::from_secs(5)).await; - assert_eq!(station.query_pool_available_coin_count().await, 1); + assert_eq!(station.query_pool_available_coin_count(sponsor).await, 1); let (tx_data, user_sig) = create_test_transaction(&test_cluster, sponsor, gas_coins).await; assert!(station .execute_transaction(reservation_id, tx_data, user_sig) diff --git a/src/gas_pool_initializer.rs b/src/gas_pool_initializer.rs index a2d5702..3f6eb08 100644 --- a/src/gas_pool_initializer.rs +++ b/src/gas_pool_initializer.rs @@ -161,7 +161,7 @@ enum RunMode { } pub struct GasPoolInitializer { - _task_handle: JoinHandle<()>, + _fund_task_handle: JoinHandle<()>, // This is always Some. It is None only after the drop method is called. cancel_sender: Option>, } @@ -179,19 +179,22 @@ impl GasPoolInitializer { coin_init_config: CoinInitConfig, signer: Arc, ) -> Self { - if !storage.is_initialized().await.unwrap() { - // If the pool has never been initialized, always run once at the beginning to make sure we have enough coins. - Self::run_once( - sui_client.clone(), - &storage, - RunMode::Init, - coin_init_config.target_init_balance, - &signer, - ) - .await; + for address in signer.get_all_addresses() { + if !storage.is_initialized(address).await.unwrap() { + // If the pool has never been initialized, always run once at the beginning to make sure we have enough coins. + Self::run_once( + address, + sui_client.clone(), + &storage, + RunMode::Init, + coin_init_config.target_init_balance, + &signer, + ) + .await; + } } let (cancel_sender, cancel_receiver) = tokio::sync::oneshot::channel(); - let _task_handle = tokio::spawn(Self::run( + let _fund_task_handle = tokio::spawn(Self::run( sui_client, storage, coin_init_config, @@ -199,7 +202,7 @@ impl GasPoolInitializer { cancel_receiver, )); Self { - _task_handle, + _fund_task_handle, cancel_sender: Some(cancel_sender), } } @@ -220,38 +223,50 @@ impl GasPoolInitializer { } } info!("Coin init task waking up and looking for new coins to initialize"); - Self::run_once( - sui_client.clone(), - &storage, - RunMode::Refresh, - coin_init_config.target_init_balance, - &signer, - ) - .await; + for address in signer.get_all_addresses() { + Self::run_once( + address, + sui_client.clone(), + &storage, + RunMode::Refresh, + coin_init_config.target_init_balance, + &signer, + ) + .await; + } } } async fn run_once( + sponsor_address: SuiAddress, sui_client: SuiClient, storage: &Arc, mode: RunMode, target_init_coin_balance: u64, signer: &Arc, ) { - let sponsor_address = signer.get_address(); if storage - .acquire_init_lock(MAX_INIT_DURATION_SEC) + .acquire_init_lock(sponsor_address, MAX_INIT_DURATION_SEC) .await .unwrap() { - info!("Acquired init lock. Starting new coin initialization"); + info!( + ?sponsor_address, + "Acquired init lock. Starting new coin initialization" + ); } else { - info!("Another task is already initializing the pool. Skipping this round"); + info!( + ?sponsor_address, + "Another task is already initializing the pool. Skipping this round" + ); return; } let start = Instant::now(); let balance_threshold = if matches!(mode, RunMode::Init) { - info!("The pool has never been initialized. Initializing it for the first time"); + info!( + ?sponsor_address, + "The pool has never been initialized. Initializing it for the first time" + ); 0 } else { target_init_coin_balance * NEW_COIN_BALANCE_FACTOR_THRESHOLD @@ -261,10 +276,11 @@ impl GasPoolInitializer { .await; if coins.is_empty() { info!( + ?sponsor_address, "No coins with balance above {} found. Skipping new coin initialization", balance_threshold ); - storage.release_init_lock().await.unwrap(); + storage.release_init_lock(sponsor_address).await.unwrap(); return; } let total_coin_count = Arc::new(AtomicUsize::new(coins.len())); @@ -288,10 +304,14 @@ impl GasPoolInitializer { ) .await; for chunk in result.chunks(5000) { - storage.add_new_coins(chunk.to_vec()).await.unwrap(); + storage + .add_new_coins(sponsor_address, chunk.to_vec()) + .await + .unwrap(); } - storage.release_init_lock().await.unwrap(); + storage.release_init_lock(sponsor_address).await.unwrap(); info!( + ?sponsor_address, "New coin initialization took {:?}s", start.elapsed().as_secs() ); @@ -343,7 +363,8 @@ mod tests { telemetry_subscribers::init_for_testing(); let (cluster, signer) = start_sui_cluster(vec![1000 * MIST_PER_SUI]).await; let fullnode_url = cluster.fullnode_handle.rpc_url; - let storage = connect_storage_for_testing(signer.get_address()).await; + let sponsor = signer.get_one_address(); + let storage = connect_storage_for_testing(sponsor).await; let sui_client = SuiClient::new(&fullnode_url, None).await; let _ = GasPoolInitializer::start( sui_client, @@ -355,15 +376,16 @@ mod tests { signer, ) .await; - assert!(storage.get_available_coin_count().await.unwrap() > 900); + assert!(storage.get_available_coin_count(sponsor).await.unwrap() > 900); } #[tokio::test] async fn test_init_non_even_split() { telemetry_subscribers::init_for_testing(); let (cluster, signer) = start_sui_cluster(vec![10000000 * MIST_PER_SUI]).await; + let sponsor = signer.get_one_address(); let fullnode_url = cluster.fullnode_handle.rpc_url; - let storage = connect_storage_for_testing(signer.get_address()).await; + let storage = connect_storage_for_testing(sponsor).await; let target_init_balance = 12345 * MIST_PER_SUI; let sui_client = SuiClient::new(&fullnode_url, None).await; let _ = GasPoolInitializer::start( @@ -376,16 +398,16 @@ mod tests { signer, ) .await; - assert!(storage.get_available_coin_count().await.unwrap() > 800); + assert!(storage.get_available_coin_count(sponsor).await.unwrap() > 800); } #[tokio::test] async fn test_add_new_funds_to_pool() { telemetry_subscribers::init_for_testing(); let (cluster, signer) = start_sui_cluster(vec![1000 * MIST_PER_SUI]).await; - let sponsor = signer.get_address(); + let sponsor = signer.get_one_address(); let fullnode_url = cluster.fullnode_handle.rpc_url.clone(); - let storage = connect_storage_for_testing(signer.get_address()).await; + let storage = connect_storage_for_testing(sponsor).await; let sui_client = SuiClient::new(&fullnode_url, None).await; let _init_task = GasPoolInitializer::start( sui_client, @@ -397,8 +419,8 @@ mod tests { signer, ) .await; - assert!(storage.is_initialized().await.unwrap()); - let available_coin_count = storage.get_available_coin_count().await.unwrap(); + assert!(storage.is_initialized(sponsor).await.unwrap()); + let available_coin_count = storage.get_available_coin_count(sponsor).await.unwrap(); tracing::debug!("Available coin count: {}", available_coin_count); // Transfer some new SUI into the sponsor account. @@ -420,7 +442,7 @@ mod tests { // Give it some time for the task to pick up the new coin and split it. tokio::time::sleep(std::time::Duration::from_secs(30)).await; - let new_available_coin_count = storage.get_available_coin_count().await.unwrap(); + let new_available_coin_count = storage.get_available_coin_count(sponsor).await.unwrap(); assert!( // In an ideal world we should have NEW_COIN_BALANCE_FACTOR_THRESHOLD more coins // since we just send a new coin with balance NEW_COIN_BALANCE_FACTOR_THRESHOLD and split diff --git a/src/storage/mod.rs b/src/storage/mod.rs index 7918179..8314a23 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -26,15 +26,24 @@ pub trait Storage: Sync + Send { /// 3. It should never return more than 256 coins at a time since that's the upper bound of gas. async fn reserve_gas_coins( &self, + sponsor_address: SuiAddress, target_budget: u64, reserved_duration_ms: u64, ) -> anyhow::Result<(ReservationID, Vec)>; - async fn ready_for_execution(&self, reservation_id: ReservationID) -> anyhow::Result<()>; + async fn ready_for_execution( + &self, + sponsor_address: SuiAddress, + reservation_id: ReservationID, + ) -> anyhow::Result<()>; - async fn add_new_coins(&self, new_coins: Vec) -> anyhow::Result<()>; + async fn add_new_coins( + &self, + sponsor_address: SuiAddress, + new_coins: Vec, + ) -> anyhow::Result<()>; - async fn expire_coins(&self) -> anyhow::Result>; + async fn expire_coins(&self, sponsor_address: SuiAddress) -> anyhow::Result>; /// Initialize some of the gas pool statistics at the startup. /// Such as the total number of gas coins and the total balance. @@ -43,48 +52,57 @@ pub trait Storage: Sync + Send { /// We only need this once ever though. /// 2. To make sure we start reporting the correct metrics from the beginning. /// Returns the total number of gas coins and the total balance. - async fn init_coin_stats_at_startup(&self) -> anyhow::Result<(u64, u64)>; + async fn init_coin_stats_at_startup( + &self, + sponsor_address: SuiAddress, + ) -> anyhow::Result<(u64, u64)>; /// Whether the gas pool for the given sponsor address is initialized. - async fn is_initialized(&self) -> anyhow::Result; + async fn is_initialized(&self, sponsor_address: SuiAddress) -> anyhow::Result; /// Acquire a lock to initialize the gas pool for the given sponsor address for a certain duration. /// Returns true if the lock is acquired, false otherwise. /// Once the lock is acquired, until it expires, no other caller can acquire the lock. /// The reason we use a lock duration is such that in case the server crashed while holding the lock, /// the lock will be automatically considered as released after the lock duration. - async fn acquire_init_lock(&self, lock_duration_sec: u64) -> anyhow::Result; + async fn acquire_init_lock( + &self, + sponsor_address: SuiAddress, + lock_duration_sec: u64, + ) -> anyhow::Result; - async fn release_init_lock(&self) -> anyhow::Result<()>; + async fn release_init_lock(&self, sponsor_address: SuiAddress) -> anyhow::Result<()>; async fn check_health(&self) -> anyhow::Result<()>; #[cfg(test)] async fn flush_db(&self); - async fn get_available_coin_count(&self) -> anyhow::Result; + async fn get_available_coin_count(&self, sponsor_address: SuiAddress) -> anyhow::Result; - async fn get_available_coin_total_balance(&self) -> u64; + async fn get_available_coin_total_balance(&self, sponsor_address: SuiAddress) -> u64; #[cfg(test)] - async fn get_reserved_coin_count(&self) -> usize; + async fn get_reserved_coin_count(&self, sponsor_address: SuiAddress) -> usize; } pub async fn connect_storage( config: &GasPoolStorageConfig, - sponsor_address: SuiAddress, + sponsor_addresses: Vec, metrics: Arc, ) -> Arc { let storage: Arc = match config { GasPoolStorageConfig::Redis { redis_url } => { - Arc::new(RedisStorage::new(redis_url, sponsor_address, metrics).await) + Arc::new(RedisStorage::new(redis_url, metrics).await) } }; storage .check_health() .await .expect("Unable to connect to the storage layer"); - storage.init_coin_stats_at_startup().await.unwrap(); + for address in sponsor_addresses { + storage.init_coin_stats_at_startup(address).await.unwrap(); + } storage } @@ -98,12 +116,20 @@ pub async fn connect_storage_for_testing_with_config( static IS_FIRST_CALL: AtomicBool = AtomicBool::new(true); let is_first_call = IS_FIRST_CALL.fetch_and(false, Ordering::SeqCst); - let storage = connect_storage(config, sponsor_address, StorageMetrics::new_for_testing()).await; + let storage = connect_storage( + config, + vec![sponsor_address], + StorageMetrics::new_for_testing(), + ) + .await; if is_first_call { // Make sure that we only flush the DB once at the beginning of each test run. storage.flush_db().await; // Re-init coin stats again since we just flushed. - storage.init_coin_stats_at_startup().await.unwrap(); + storage + .init_coin_stats_at_startup(sponsor_address) + .await + .unwrap(); } storage } @@ -124,9 +150,17 @@ mod tests { use sui_types::base_types::{random_object_ref, ObjectID, SequenceNumber, SuiAddress}; use sui_types::digests::ObjectDigest; - async fn assert_coin_count(storage: &Arc, available: usize, reserved: usize) { - assert_eq!(storage.get_available_coin_count().await.unwrap(), available); - assert_eq!(storage.get_reserved_coin_count().await, reserved); + async fn assert_coin_count( + sponsor: SuiAddress, + storage: &Arc, + available: usize, + reserved: usize, + ) { + assert_eq!( + storage.get_available_coin_count(sponsor).await.unwrap(), + available + ); + assert_eq!(storage.get_reserved_coin_count(sponsor).await, reserved); } async fn setup(sponsor: SuiAddress, init_balances: Vec) -> Arc { @@ -143,7 +177,10 @@ mod tests { }) .collect::>(); for chunk in gas_coins.chunks(5000) { - storage.add_new_coins(chunk.to_vec()).await.unwrap(); + storage + .add_new_coins(sponsor, chunk.to_vec()) + .await + .unwrap(); } storage } @@ -152,18 +189,21 @@ mod tests { async fn test_gas_pool_init() { let sponsor = SuiAddress::random_for_testing_only(); let storage = connect_storage_for_testing(sponsor).await; - assert!(!storage.is_initialized().await.unwrap()); - storage.add_new_coins(vec![]).await.unwrap(); + assert!(!storage.is_initialized(sponsor).await.unwrap()); + storage.add_new_coins(sponsor, vec![]).await.unwrap(); // Still not initialized because we are not adding any coins. - assert!(!storage.is_initialized().await.unwrap()); + assert!(!storage.is_initialized(sponsor).await.unwrap()); storage - .add_new_coins(vec![GasCoin { - object_ref: random_object_ref(), - balance: 1, - }]) + .add_new_coins( + sponsor, + vec![GasCoin { + object_ref: random_object_ref(), + balance: 1, + }], + ) .await .unwrap(); - assert!(storage.is_initialized().await.unwrap()); + assert!(storage.is_initialized(sponsor).await.unwrap()); } #[tokio::test] @@ -171,18 +211,20 @@ mod tests { // Create a gas pool of 100000 coins, each with balance of 1. let sponsor = SuiAddress::random_for_testing_only(); let storage = setup(sponsor, vec![1; 100000]).await; - assert_coin_count(&storage, 100000, 0).await; + assert_coin_count(sponsor, &storage, 100000, 0).await; let mut cur_available = 100000; let mut expected_res_id = 1; for i in 1..=MAX_GAS_PER_QUERY { - let (res_id, reserved_gas_coins) = - storage.reserve_gas_coins(i as u64, 1000).await.unwrap(); + let (res_id, reserved_gas_coins) = storage + .reserve_gas_coins(sponsor, i as u64, 1000) + .await + .unwrap(); assert_eq!(expected_res_id, res_id); assert_eq!(reserved_gas_coins.len(), i); expected_res_id += 1; cur_available -= i; } - assert_coin_count(&storage, cur_available, 100000 - cur_available).await; + assert_coin_count(sponsor, &storage, cur_available, 100000 - cur_available).await; } #[tokio::test] @@ -190,18 +232,18 @@ mod tests { let sponsor = SuiAddress::random_for_testing_only(); let storage = setup(sponsor, vec![1; MAX_GAS_PER_QUERY + 1]).await; assert!(storage - .reserve_gas_coins((MAX_GAS_PER_QUERY + 1) as u64, 1000) + .reserve_gas_coins(sponsor, (MAX_GAS_PER_QUERY + 1) as u64, 1000) .await .is_err()); - assert_coin_count(&storage, MAX_GAS_PER_QUERY + 1, 0).await; + assert_coin_count(sponsor, &storage, MAX_GAS_PER_QUERY + 1, 0).await; } #[tokio::test] async fn test_insufficient_pool_budget() { let sponsor = SuiAddress::random_for_testing_only(); let storage = setup(sponsor, vec![1; 100]).await; - assert!(storage.reserve_gas_coins(101, 1000).await.is_err()); - assert_coin_count(&storage, 100, 0).await; + assert!(storage.reserve_gas_coins(sponsor, 101, 1000).await.is_err()); + assert_coin_count(sponsor, &storage, 100, 0).await; } #[tokio::test] @@ -211,12 +253,16 @@ mod tests { for _ in 0..100 { // Keep reserving and putting them back. // Should be able to repeat this process indefinitely if balance are not changed. - let (res_id, reserved_gas_coins) = storage.reserve_gas_coins(99, 1000).await.unwrap(); + let (res_id, reserved_gas_coins) = + storage.reserve_gas_coins(sponsor, 99, 1000).await.unwrap(); assert_eq!(reserved_gas_coins.len(), 99); - assert_coin_count(&storage, 1, 99).await; - storage.ready_for_execution(res_id).await.unwrap(); - storage.add_new_coins(reserved_gas_coins).await.unwrap(); - assert_coin_count(&storage, 100, 0).await; + assert_coin_count(sponsor, &storage, 1, 99).await; + storage.ready_for_execution(sponsor, res_id).await.unwrap(); + storage + .add_new_coins(sponsor, reserved_gas_coins) + .await + .unwrap(); + assert_coin_count(sponsor, &storage, 100, 0).await; } } @@ -226,7 +272,7 @@ mod tests { let storage = setup(sponsor, vec![1; 100]).await; for _ in 0..10 { let (res_id, mut reserved_gas_coins) = - storage.reserve_gas_coins(10, 1000).await.unwrap(); + storage.reserve_gas_coins(sponsor, 10, 1000).await.unwrap(); assert_eq!( reserved_gas_coins.iter().map(|c| c.balance).sum::(), 10 @@ -236,46 +282,56 @@ mod tests { reserved_gas_coin.balance -= 1; } } - storage.ready_for_execution(res_id).await.unwrap(); - storage.add_new_coins(reserved_gas_coins).await.unwrap(); + storage.ready_for_execution(sponsor, res_id).await.unwrap(); + storage + .add_new_coins(sponsor, reserved_gas_coins) + .await + .unwrap(); } - assert_coin_count(&storage, 100, 0).await; - assert_eq!(storage.get_available_coin_total_balance().await, 0); - assert!(storage.reserve_gas_coins(1, 1000).await.is_err()); + assert_coin_count(sponsor, &storage, 100, 0).await; + assert_eq!(storage.get_available_coin_total_balance(sponsor).await, 0); + assert!(storage.reserve_gas_coins(sponsor, 1, 1000).await.is_err()); } #[tokio::test] async fn test_deleted_objects() { let sponsor = SuiAddress::random_for_testing_only(); let storage = setup(sponsor, vec![1; 100]).await; - let (res_id, mut reserved_gas_coins) = storage.reserve_gas_coins(100, 1000).await.unwrap(); + let (res_id, mut reserved_gas_coins) = + storage.reserve_gas_coins(sponsor, 100, 1000).await.unwrap(); assert_eq!(reserved_gas_coins.len(), 100); - storage.ready_for_execution(res_id).await.unwrap(); + storage.ready_for_execution(sponsor, res_id).await.unwrap(); reserved_gas_coins.drain(0..50); - storage.add_new_coins(reserved_gas_coins).await.unwrap(); - assert_coin_count(&storage, 50, 0).await; + storage + .add_new_coins(sponsor, reserved_gas_coins) + .await + .unwrap(); + assert_coin_count(sponsor, &storage, 50, 0).await; } #[tokio::test] async fn test_coin_expiration() { let sponsor = SuiAddress::random_for_testing_only(); let storage = setup(sponsor, vec![1; 100]).await; - let (_res_id1, reserved_gas_coins1) = storage.reserve_gas_coins(10, 900).await.unwrap(); + let (_res_id1, reserved_gas_coins1) = + storage.reserve_gas_coins(sponsor, 10, 900).await.unwrap(); assert_eq!(reserved_gas_coins1.len(), 10); - let (_res_id2, reserved_gas_coins2) = storage.reserve_gas_coins(30, 1900).await.unwrap(); + let (_res_id2, reserved_gas_coins2) = + storage.reserve_gas_coins(sponsor, 30, 1900).await.unwrap(); assert_eq!(reserved_gas_coins2.len(), 30); // Just to make sure these two reservations will have a different expiration timestamp. tokio::time::sleep(Duration::from_millis(1)).await; - let (_res_id3, reserved_gas_coins3) = storage.reserve_gas_coins(50, 1900).await.unwrap(); + let (_res_id3, reserved_gas_coins3) = + storage.reserve_gas_coins(sponsor, 50, 1900).await.unwrap(); assert_eq!(reserved_gas_coins3.len(), 50); - assert_coin_count(&storage, 10, 90).await; + assert_coin_count(sponsor, &storage, 10, 90).await; - assert!(storage.expire_coins().await.unwrap().is_empty()); - assert_coin_count(&storage, 10, 90).await; + assert!(storage.expire_coins(sponsor).await.unwrap().is_empty()); + assert_coin_count(sponsor, &storage, 10, 90).await; tokio::time::sleep(Duration::from_secs(1)).await; - let expired1 = storage.expire_coins().await.unwrap(); + let expired1 = storage.expire_coins(sponsor).await.unwrap(); assert_eq!(expired1.len(), 10); assert_eq!( expired1.iter().cloned().collect::>(), @@ -284,13 +340,13 @@ mod tests { .map(|coin| coin.object_ref.0) .collect::>() ); - assert_coin_count(&storage, 10, 80).await; + assert_coin_count(sponsor, &storage, 10, 80).await; - assert!(storage.expire_coins().await.unwrap().is_empty()); - assert_coin_count(&storage, 10, 80).await; + assert!(storage.expire_coins(sponsor).await.unwrap().is_empty()); + assert_coin_count(sponsor, &storage, 10, 80).await; tokio::time::sleep(Duration::from_secs(1)).await; - let expired2 = storage.expire_coins().await.unwrap(); + let expired2 = storage.expire_coins(sponsor).await.unwrap(); assert_eq!(expired2.len(), 80); assert_eq!( expired2.iter().cloned().collect::>(), @@ -300,7 +356,7 @@ mod tests { .map(|coin| coin.object_ref.0) .collect::>() ); - assert_coin_count(&storage, 10, 0).await; + assert_coin_count(sponsor, &storage, 10, 0).await; } #[tokio::test] @@ -309,13 +365,13 @@ mod tests { .map(|_| SuiAddress::random_for_testing_only()) .collect::>(); let mut storages = vec![]; - for sponsor in sponsors { - storages.push(setup(sponsor, vec![1; 100]).await); + for sponsor in &sponsors { + storages.push(setup(*sponsor, vec![1; 100]).await); } - for storage in storages { - let (_, gas_coins) = storage.reserve_gas_coins(50, 1000).await.unwrap(); + for (storage, sponsor) in storages.into_iter().zip(sponsors) { + let (_, gas_coins) = storage.reserve_gas_coins(sponsor, 50, 1000).await.unwrap(); assert_eq!(gas_coins.len(), 50); - assert_coin_count(&storage, 50, 50).await; + assert_coin_count(sponsor, &storage, 50, 50).await; } } @@ -329,7 +385,8 @@ mod tests { handles.push(tokio::spawn(async move { let mut reserved_gas_coins = vec![]; for _ in 0..100 { - let (_, newly_reserved) = storage.reserve_gas_coins(3, 1000).await.unwrap(); + let (_, newly_reserved) = + storage.reserve_gas_coins(sponsor, 3, 1000).await.unwrap(); reserved_gas_coins.extend(newly_reserved); } reserved_gas_coins @@ -344,17 +401,17 @@ mod tests { reserved_gas_coins.sort_by_key(|c| c.object_ref.0); reserved_gas_coins.dedup_by_key(|c| c.object_ref.0); assert_eq!(reserved_gas_coins.len(), count); - assert_coin_count(&storage, 100000 - count, count).await; + assert_coin_count(sponsor, &storage, 100000 - count, count).await; } #[tokio::test] async fn test_acquire_init_lock() { let sponsor = SuiAddress::random_for_testing_only(); let storage = setup(sponsor, vec![1; 100]).await; - assert!(storage.acquire_init_lock(5).await.unwrap()); - assert!(!storage.acquire_init_lock(1).await.unwrap()); + assert!(storage.acquire_init_lock(sponsor, 5).await.unwrap()); + assert!(!storage.acquire_init_lock(sponsor, 1).await.unwrap()); tokio::time::sleep(Duration::from_secs(6)).await; - assert!(storage.acquire_init_lock(5).await.unwrap()); + assert!(storage.acquire_init_lock(sponsor, 5).await.unwrap()); } #[tokio::test] @@ -363,7 +420,8 @@ mod tests { let storage = setup(sponsor, vec![1; 100]).await; // init_coin_stats_at_startup has already been called in setup. // Calling it again should not change anything. - let (coin_count, total_balance) = storage.init_coin_stats_at_startup().await.unwrap(); + let (coin_count, total_balance) = + storage.init_coin_stats_at_startup(sponsor).await.unwrap(); assert_eq!(coin_count, 100); assert_eq!(total_balance, 100); } diff --git a/src/storage/redis/mod.rs b/src/storage/redis/mod.rs index fa9f15a..904d034 100644 --- a/src/storage/redis/mod.rs +++ b/src/storage/redis/mod.rs @@ -18,22 +18,15 @@ use tracing::{debug, info}; pub struct RedisStorage { conn_manager: ConnectionManager, - // String format of the sponsor address to avoid converting it to string multiple times. - sponsor_str: String, metrics: Arc, } impl RedisStorage { - pub async fn new( - redis_url: &str, - sponsor_address: SuiAddress, - metrics: Arc, - ) -> Self { + pub async fn new(redis_url: &str, metrics: Arc) -> Self { let client = redis::Client::open(redis_url).unwrap(); let conn_manager = ConnectionManager::new(client).await.unwrap(); Self { conn_manager, - sponsor_str: sponsor_address.to_string(), metrics, } } @@ -43,10 +36,12 @@ impl RedisStorage { impl Storage for RedisStorage { async fn reserve_gas_coins( &self, + sponsor_address: SuiAddress, target_budget: u64, reserved_duration_ms: u64, ) -> anyhow::Result<(ReservationID, Vec)> { self.metrics.num_reserve_gas_coins_requests.inc(); + let sponsor_str = sponsor_address.to_string(); let expiration_time = Utc::now() .add(Duration::from_millis(reserved_duration_ms)) @@ -58,7 +53,7 @@ impl Storage for RedisStorage { i64, i64, ) = ScriptManager::reserve_gas_coins_script() - .arg(self.sponsor_str.clone()) + .arg(sponsor_str.clone()) .arg(target_budget) .arg(expiration_time) .invoke_async(&mut conn) @@ -89,22 +84,27 @@ impl Storage for RedisStorage { self.metrics .gas_pool_available_gas_coin_count - .with_label_values(&[&self.sponsor_str]) + .with_label_values(&[&sponsor_str]) .set(new_coin_count); self.metrics .gas_pool_available_gas_total_balance - .with_label_values(&[&self.sponsor_str]) + .with_label_values(&[&sponsor_str]) .set(new_total_balance); self.metrics.num_successful_reserve_gas_coins_requests.inc(); Ok((reservation_id, gas_coins)) } - async fn ready_for_execution(&self, reservation_id: ReservationID) -> anyhow::Result<()> { + async fn ready_for_execution( + &self, + sponsor_address: SuiAddress, + reservation_id: ReservationID, + ) -> anyhow::Result<()> { self.metrics.num_ready_for_execution_requests.inc(); + let sponsor_str = sponsor_address.to_string(); let mut conn = self.conn_manager.clone(); ScriptManager::ready_for_execution_script() - .arg(self.sponsor_str.clone()) + .arg(sponsor_str) .arg(reservation_id) .invoke_async::<_, ()>(&mut conn) .await?; @@ -115,8 +115,13 @@ impl Storage for RedisStorage { Ok(()) } - async fn add_new_coins(&self, new_coins: Vec) -> anyhow::Result<()> { + async fn add_new_coins( + &self, + sponsor_address: SuiAddress, + new_coins: Vec, + ) -> anyhow::Result<()> { self.metrics.num_add_new_coins_requests.inc(); + let sponsor_str = sponsor_address.to_string(); let formatted_coins = new_coins .iter() .map(|c| { @@ -135,7 +140,7 @@ impl Storage for RedisStorage { let mut conn = self.conn_manager.clone(); let (new_total_balance, new_coin_count): (i64, i64) = ScriptManager::add_new_coins_script() - .arg(self.sponsor_str.clone()) + .arg(sponsor_str.clone()) .arg(serde_json::to_string(&formatted_coins)?) .invoke_async(&mut conn) .await?; @@ -146,23 +151,24 @@ impl Storage for RedisStorage { ); self.metrics .gas_pool_available_gas_coin_count - .with_label_values(&[&self.sponsor_str]) + .with_label_values(&[&sponsor_str]) .set(new_coin_count); self.metrics .gas_pool_available_gas_total_balance - .with_label_values(&[&self.sponsor_str]) + .with_label_values(&[&sponsor_str]) .set(new_total_balance); self.metrics.num_successful_add_new_coins_requests.inc(); Ok(()) } - async fn expire_coins(&self) -> anyhow::Result> { + async fn expire_coins(&self, sponsor_address: SuiAddress) -> anyhow::Result> { self.metrics.num_expire_coins_requests.inc(); + let sponsor_str = sponsor_address.to_string(); let now = Utc::now().timestamp_millis() as u64; let mut conn = self.conn_manager.clone(); let expired_coin_strings: Vec = ScriptManager::expire_coins_script() - .arg(self.sponsor_str.clone()) + .arg(sponsor_str) .arg(now) .invoke_async(&mut conn) .await?; @@ -176,26 +182,30 @@ impl Storage for RedisStorage { Ok(expired_coin_ids) } - async fn init_coin_stats_at_startup(&self) -> anyhow::Result<(u64, u64)> { + async fn init_coin_stats_at_startup( + &self, + sponsor_address: SuiAddress, + ) -> anyhow::Result<(u64, u64)> { + let sponsor_str = sponsor_address.to_string(); let mut conn = self.conn_manager.clone(); let (available_coin_count, available_coin_total_balance): (i64, i64) = ScriptManager::init_coin_stats_at_startup_script() - .arg(self.sponsor_str.clone()) + .arg(sponsor_str.clone()) .invoke_async(&mut conn) .await?; info!( - sponsor_address=?self.sponsor_str, + sponsor_address=?sponsor_str, "Number of available gas coins in the pool: {}, total balance: {}", available_coin_count, available_coin_total_balance ); self.metrics .gas_pool_available_gas_coin_count - .with_label_values(&[&self.sponsor_str]) + .with_label_values(&[&sponsor_str]) .set(available_coin_count); self.metrics .gas_pool_available_gas_total_balance - .with_label_values(&[&self.sponsor_str]) + .with_label_values(&[&sponsor_str]) .set(available_coin_total_balance); Ok(( available_coin_count as u64, @@ -203,16 +213,22 @@ impl Storage for RedisStorage { )) } - async fn is_initialized(&self) -> anyhow::Result { + async fn is_initialized(&self, sponsor_address: SuiAddress) -> anyhow::Result { + let sponsor_str = sponsor_address.to_string(); let mut conn = self.conn_manager.clone(); let result = ScriptManager::get_is_initialized_script() - .arg(self.sponsor_str.clone()) + .arg(sponsor_str) .invoke_async::<_, bool>(&mut conn) .await?; Ok(result) } - async fn acquire_init_lock(&self, lock_duration_sec: u64) -> anyhow::Result { + async fn acquire_init_lock( + &self, + sponsor_address: SuiAddress, + lock_duration_sec: u64, + ) -> anyhow::Result { + let sponsor_str = sponsor_address.to_string(); let mut conn = self.conn_manager.clone(); let cur_timestamp = Utc::now().timestamp() as u64; debug!( @@ -220,7 +236,7 @@ impl Storage for RedisStorage { cur_timestamp, lock_duration_sec ); let result = ScriptManager::acquire_init_lock_script() - .arg(self.sponsor_str.clone()) + .arg(sponsor_str) .arg(cur_timestamp) .arg(lock_duration_sec) .invoke_async::<_, bool>(&mut conn) @@ -228,11 +244,12 @@ impl Storage for RedisStorage { Ok(result) } - async fn release_init_lock(&self) -> anyhow::Result<()> { + async fn release_init_lock(&self, sponsor_address: SuiAddress) -> anyhow::Result<()> { debug!("Releasing the init lock."); + let sponsor_str = sponsor_address.to_string(); let mut conn = self.conn_manager.clone(); ScriptManager::release_init_lock_script() - .arg(self.sponsor_str.clone()) + .arg(sponsor_str) .invoke_async::<_, ()>(&mut conn) .await?; Ok(()) @@ -253,29 +270,32 @@ impl Storage for RedisStorage { .unwrap(); } - async fn get_available_coin_count(&self) -> anyhow::Result { + async fn get_available_coin_count(&self, sponsor_address: SuiAddress) -> anyhow::Result { + let sponsor_str = sponsor_address.to_string(); let mut conn = self.conn_manager.clone(); let count = ScriptManager::get_available_coin_count_script() - .arg(self.sponsor_str.clone()) + .arg(sponsor_str) .invoke_async::<_, usize>(&mut conn) .await?; Ok(count) } - async fn get_available_coin_total_balance(&self) -> u64 { + async fn get_available_coin_total_balance(&self, sponsor_address: SuiAddress) -> u64 { + let sponsor_str = sponsor_address.to_string(); let mut conn = self.conn_manager.clone(); ScriptManager::get_available_coin_total_balance_script() - .arg(self.sponsor_str.clone()) + .arg(sponsor_str) .invoke_async::<_, u64>(&mut conn) .await .unwrap() } #[cfg(test)] - async fn get_reserved_coin_count(&self) -> usize { + async fn get_reserved_coin_count(&self, sponsor_address: SuiAddress) -> usize { + let sponsor_str = sponsor_address.to_string(); let mut conn = self.conn_manager.clone(); ScriptManager::get_reserved_coin_count_script() - .arg(self.sponsor_str.clone()) + .arg(sponsor_str) .invoke_async::<_, usize>(&mut conn) .await .unwrap() @@ -295,72 +315,83 @@ mod tests { #[tokio::test] async fn test_init_coin_stats_at_startup() { let storage = setup_storage().await; + let sponsor = SuiAddress::ZERO; storage - .add_new_coins(vec![ - GasCoin { - balance: 100, - object_ref: random_object_ref(), - }, - GasCoin { - balance: 200, - object_ref: random_object_ref(), - }, - ]) + .add_new_coins( + sponsor, + vec![ + GasCoin { + balance: 100, + object_ref: random_object_ref(), + }, + GasCoin { + balance: 200, + object_ref: random_object_ref(), + }, + ], + ) .await .unwrap(); - let (coin_count, total_balance) = storage.init_coin_stats_at_startup().await.unwrap(); + let (coin_count, total_balance) = + storage.init_coin_stats_at_startup(sponsor).await.unwrap(); assert_eq!(coin_count, 2); assert_eq!(total_balance, 300); } #[tokio::test] async fn test_add_new_coins() { + let sponsor = SuiAddress::ZERO; let storage = setup_storage().await; storage - .add_new_coins(vec![ - GasCoin { - balance: 100, - object_ref: random_object_ref(), - }, - GasCoin { - balance: 200, - object_ref: random_object_ref(), - }, - ]) + .add_new_coins( + sponsor, + vec![ + GasCoin { + balance: 100, + object_ref: random_object_ref(), + }, + GasCoin { + balance: 200, + object_ref: random_object_ref(), + }, + ], + ) .await .unwrap(); - let coin_count = storage.get_available_coin_count().await.unwrap(); + let coin_count = storage.get_available_coin_count(sponsor).await.unwrap(); assert_eq!(coin_count, 2); - let total_balance = storage.get_available_coin_total_balance().await; + let total_balance = storage.get_available_coin_total_balance(sponsor).await; assert_eq!(total_balance, 300); storage - .add_new_coins(vec![ - GasCoin { - balance: 300, - object_ref: random_object_ref(), - }, - GasCoin { - balance: 400, - object_ref: random_object_ref(), - }, - ]) + .add_new_coins( + sponsor, + vec![ + GasCoin { + balance: 300, + object_ref: random_object_ref(), + }, + GasCoin { + balance: 400, + object_ref: random_object_ref(), + }, + ], + ) .await .unwrap(); - let coin_count = storage.get_available_coin_count().await.unwrap(); + let coin_count = storage.get_available_coin_count(sponsor).await.unwrap(); assert_eq!(coin_count, 4); - let total_balance = storage.get_available_coin_total_balance().await; + let total_balance = storage.get_available_coin_total_balance(sponsor).await; assert_eq!(total_balance, 1000); } async fn setup_storage() -> RedisStorage { - let storage = RedisStorage::new( - "redis://127.0.0.1:6379", - SuiAddress::ZERO, - StorageMetrics::new_for_testing(), - ) - .await; + let storage = + RedisStorage::new("redis://127.0.0.1:6379", StorageMetrics::new_for_testing()).await; storage.flush_db().await; - let (coin_count, total_balance) = storage.init_coin_stats_at_startup().await.unwrap(); + let (coin_count, total_balance) = storage + .init_coin_stats_at_startup(SuiAddress::ZERO) + .await + .unwrap(); assert_eq!(coin_count, 0); assert_eq!(total_balance, 0); storage diff --git a/src/test_env.rs b/src/test_env.rs index a6d4bc6..d713858 100644 --- a/src/test_env.rs +++ b/src/test_env.rs @@ -47,7 +47,7 @@ pub async fn start_gas_station( debug!("Starting Sui cluster.."); let (test_cluster, signer) = start_sui_cluster(init_gas_amounts).await; let fullnode_url = test_cluster.fullnode_handle.rpc_url.clone(); - let sponsor_address = signer.get_address(); + let sponsor_address = signer.get_one_address(); debug!("Starting storage. Sponsor address: {:?}", sponsor_address); let storage = connect_storage_for_testing(sponsor_address).await; let sui_client = SuiClient::new(&fullnode_url, None).await; diff --git a/src/tx_signer.rs b/src/tx_signer.rs index 1c3ba8f..dadf418 100644 --- a/src/tx_signer.rs +++ b/src/tx_signer.rs @@ -7,21 +7,22 @@ use reqwest::Client; use serde::Deserialize; use serde_json::json; use shared_crypto::intent::{Intent, IntentMessage}; +use std::collections::HashMap; use std::str::FromStr; -use std::sync::Arc; +use std::sync::atomic::AtomicUsize; +use std::sync::{atomic, Arc}; use sui_types::base_types::SuiAddress; use sui_types::crypto::{Signature, SuiKeyPair}; use sui_types::signature::GenericSignature; -use sui_types::transaction::TransactionData; +use sui_types::transaction::{TransactionData, TransactionDataAPI}; #[async_trait::async_trait] pub trait TxSigner: Send + Sync { async fn sign_transaction(&self, tx_data: &TransactionData) -> anyhow::Result; - fn get_address(&self) -> SuiAddress; - fn is_valid_address(&self, address: &SuiAddress) -> bool { - self.get_address() == *address - } + fn get_one_address(&self) -> SuiAddress; + fn get_all_addresses(&self) -> Vec; + fn is_valid_address(&self, address: &SuiAddress) -> bool; } #[derive(Deserialize)] @@ -36,29 +37,38 @@ struct SuiAddressResponse { sui_pubkey_address: SuiAddress, } +// TODO: Add a mock side car server with tests for multi-address support. pub struct SidecarTxSigner { - sidecar_url: String, client: Client, - sui_address: SuiAddress, + sidecar_url_map: HashMap, + sui_addresses: Vec, + next_address_idx: AtomicUsize, } impl SidecarTxSigner { - pub async fn new(sidecar_url: String) -> Arc { + pub async fn new(sidecar_urls: Vec) -> Arc { let client = Client::new(); - let resp = client - .get(format!("{}/{}", sidecar_url, "get-pubkey-address")) - .send() - .await - .unwrap_or_else(|err| panic!("Failed to get pubkey address: {}", err)); - let sui_address = resp - .json::() - .await - .unwrap_or_else(|err| panic!("Failed to parse address response: {}", err)) - .sui_pubkey_address; + let mut sidecar_url_map = HashMap::new(); + let mut sui_addresses = vec![]; + for sidecar_url in sidecar_urls { + let resp = client + .get(format!("{}/{}", &sidecar_url, "get-pubkey-address")) + .send() + .await + .unwrap_or_else(|err| panic!("Failed to get pubkey address: {}", err)); + let sui_address = resp + .json::() + .await + .unwrap_or_else(|err| panic!("Failed to parse address response: {}", err)) + .sui_pubkey_address; + sui_addresses.push(sui_address); + sidecar_url_map.insert(sui_address, sidecar_url); + } Arc::new(Self { - sidecar_url, client, - sui_address, + sidecar_url_map, + sui_addresses, + next_address_idx: AtomicUsize::new(0), }) } } @@ -69,10 +79,15 @@ impl TxSigner for SidecarTxSigner { &self, tx_data: &TransactionData, ) -> anyhow::Result { + let sponsor_address = tx_data.gas_data().owner; + let sidecar_url = self + .sidecar_url_map + .get(&sponsor_address) + .ok_or_else(|| anyhow!("Address is not a valid sponsor: {:?}", sponsor_address))?; let bytes = Base64::encode(bcs::to_bytes(&tx_data)?); let resp = self .client - .post(format!("{}/{}", self.sidecar_url, "sign-transaction")) + .post(format!("{}/{}", sidecar_url, "sign-transaction")) .header("Content-Type", "application/json") .json(&json!({"txBytes": bytes})) .send() @@ -83,8 +98,20 @@ impl TxSigner for SidecarTxSigner { Ok(sig) } - fn get_address(&self) -> SuiAddress { - self.sui_address + fn get_one_address(&self) -> SuiAddress { + // Round robin the address we are using. + let idx = self + .next_address_idx + .fetch_add(1, atomic::Ordering::Relaxed); + self.sui_addresses[idx % self.sui_addresses.len()] + } + + fn get_all_addresses(&self) -> Vec { + self.sui_addresses.clone() + } + + fn is_valid_address(&self, address: &SuiAddress) -> bool { + self.sidecar_url_map.contains_key(address) } } @@ -109,7 +136,15 @@ impl TxSigner for TestTxSigner { Ok(sponsor_sig) } - fn get_address(&self) -> SuiAddress { + fn get_one_address(&self) -> SuiAddress { (&self.keypair.public()).into() } + + fn get_all_addresses(&self) -> Vec { + vec![self.get_one_address()] + } + + fn is_valid_address(&self, address: &SuiAddress) -> bool { + address == &self.get_one_address() + } }