diff --git a/Cargo.lock b/Cargo.lock index 4f236b51b..6e324f522 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4108,6 +4108,7 @@ dependencies = [ "serde", "serde_json", "serde_with", + "strum", "thiserror", "tokio", "tracing", diff --git a/crates/pool/src/mempool/mod.rs b/crates/pool/src/mempool/mod.rs index 98a8cedc3..0a8827570 100644 --- a/crates/pool/src/mempool/mod.rs +++ b/crates/pool/src/mempool/mod.rs @@ -37,7 +37,6 @@ use ethers::types::{Address, H256, U256}; use mockall::automock; use rundler_sim::{EntityInfos, MempoolConfig, PrecheckSettings, SimulationSettings}; use rundler_types::{Entity, EntityType, EntityUpdate, UserOperation, ValidTimeRange}; -use strum::IntoEnumIterator; use tonic::async_trait; pub(crate) use uo_pool::UoPool; @@ -215,9 +214,11 @@ pub struct PaymasterMetadata { impl PoolOperation { /// Returns true if the operation contains the given entity. pub fn contains_entity(&self, entity: &Entity) -> bool { - self.entity_address(entity.kind) - .map(|address| address == entity.address) - .unwrap_or(false) + if let Some(e) = self.entity_infos.get(entity.kind) { + e.address == entity.address + } else { + false + } } /// Returns true if the operation requires the given entity to stake. @@ -239,48 +240,31 @@ impl PoolOperation { /// Returns an iterator over all entities that are included in this operation. pub fn entities(&'_ self) -> impl Iterator + '_ { - EntityType::iter().filter_map(|entity| { - self.entity_address(entity) - .map(|address| Entity::new(entity, address)) - }) + self.entity_infos + .entities() + .map(|(t, entity)| Entity::new(t, entity.address)) } - /// Returns an iterator over all entities that need stake in this operation. + /// Returns an iterator over all entities that need stake in this operation. This can be a subset of entities that are staked in the operation. pub fn entities_requiring_stake(&'_ self) -> impl Iterator + '_ { - EntityType::iter() - .filter(|entity| self.requires_stake(*entity)) - .filter_map(|entity| { - self.entity_address(entity) - .map(|address| Entity::new(entity, address)) - }) + self.entity_infos.entities().filter_map(|(t, entity)| { + if self.requires_stake(t) { + Entity::new(t, entity.address).into() + } else { + None + } + }) } /// Return all the unstaked entities that are used in this operation. pub fn unstaked_entities(&'_ self) -> impl Iterator + '_ { - let mut unstaked_entities = vec![]; - if !self.entity_infos.sender.is_staked { - unstaked_entities.push(Entity::new( - EntityType::Account, - self.entity_infos.sender.address, - )) - } - if let Some(factory) = self.entity_infos.factory { - if !factory.is_staked { - unstaked_entities.push(Entity::new(EntityType::Factory, factory.address)) + self.entity_infos.entities().filter_map(|(t, entity)| { + if entity.is_staked { + None + } else { + Entity::new(t, entity.address).into() } - } - if let Some(paymaster) = self.entity_infos.paymaster { - if !paymaster.is_staked { - unstaked_entities.push(Entity::new(EntityType::Paymaster, paymaster.address)) - } - } - if let Some(aggregator) = self.entity_infos.aggregator { - if !aggregator.is_staked { - unstaked_entities.push(Entity::new(EntityType::Aggregator, aggregator.address)) - } - } - - unstaked_entities.into_iter() + }) } /// Compute the amount of heap memory the PoolOperation takes up. @@ -289,19 +273,12 @@ impl PoolOperation { + self.uo.heap_size() + self.entities_needing_stake.len() * std::mem::size_of::() } - - fn entity_address(&self, entity: EntityType) -> Option
{ - match entity { - EntityType::Account => Some(self.uo.sender), - EntityType::Paymaster => self.uo.paymaster(), - EntityType::Factory => self.uo.factory(), - EntityType::Aggregator => self.aggregator, - } - } } #[cfg(test)] mod tests { + use rundler_sim::EntityInfo; + use super::*; #[test] @@ -326,7 +303,24 @@ mod tests { sim_block_number: 0, entities_needing_stake: vec![EntityType::Account, EntityType::Aggregator], account_is_staked: true, - entity_infos: EntityInfos::default(), + entity_infos: EntityInfos { + factory: Some(EntityInfo { + address: factory, + is_staked: false, + }), + sender: EntityInfo { + address: sender, + is_staked: false, + }, + paymaster: Some(EntityInfo { + address: paymaster, + is_staked: false, + }), + aggregator: Some(EntityInfo { + address: aggregator, + is_staked: false, + }), + }, }; assert!(po.requires_stake(EntityType::Account)); @@ -334,11 +328,6 @@ mod tests { assert!(!po.requires_stake(EntityType::Factory)); assert!(po.requires_stake(EntityType::Aggregator)); - assert_eq!(po.entity_address(EntityType::Account), Some(sender)); - assert_eq!(po.entity_address(EntityType::Paymaster), Some(paymaster)); - assert_eq!(po.entity_address(EntityType::Factory), Some(factory)); - assert_eq!(po.entity_address(EntityType::Aggregator), Some(aggregator)); - let entities = po.entities().collect::>(); assert_eq!(entities.len(), 4); for e in entities { diff --git a/crates/pool/src/mempool/pool.rs b/crates/pool/src/mempool/pool.rs index 9b317b321..b93aeed4c 100644 --- a/crates/pool/src/mempool/pool.rs +++ b/crates/pool/src/mempool/pool.rs @@ -601,6 +601,8 @@ impl PoolMetrics { #[cfg(test)] mod tests { + use rundler_sim::{EntityInfo, EntityInfos}; + use super::*; #[test] @@ -784,6 +786,10 @@ mod tests { ]; for mut op in ops.into_iter() { op.aggregator = Some(agg); + op.entity_infos.aggregator = Some(EntityInfo { + address: agg, + is_staked: false, + }); pool.add_operation(op.clone(), None).unwrap(); } assert_eq!(pool.by_hash.len(), 3); @@ -805,6 +811,10 @@ mod tests { ]; for mut op in ops.into_iter() { op.uo.paymaster_and_data = paymaster.as_bytes().to_vec().into(); + op.entity_infos.paymaster = Some(EntityInfo { + address: op.uo.paymaster().unwrap(), + is_staked: false, + }); pool.add_operation(op.clone(), None).unwrap(); } assert_eq!(pool.by_hash.len(), 3); @@ -839,8 +849,20 @@ mod tests { let mut op = create_op(sender, 0, 1); op.uo.paymaster_and_data = paymaster.as_bytes().to_vec().into(); + op.entity_infos.paymaster = Some(EntityInfo { + address: op.uo.paymaster().unwrap(), + is_staked: false, + }); op.uo.init_code = factory.as_bytes().to_vec().into(); + op.entity_infos.factory = Some(EntityInfo { + address: op.uo.factory().unwrap(), + is_staked: false, + }); op.aggregator = Some(aggregator); + op.entity_infos.aggregator = Some(EntityInfo { + address: aggregator, + is_staked: false, + }); let count = 5; let mut hashes = vec![]; @@ -937,6 +959,10 @@ mod tests { let mut po1 = create_op(sender, 0, 10); po1.uo.max_priority_fee_per_gas = 10.into(); po1.uo.paymaster_and_data = paymaster1.as_bytes().to_vec().into(); + po1.entity_infos.paymaster = Some(EntityInfo { + address: po1.uo.paymaster().unwrap(), + is_staked: false, + }); let _ = pool.add_operation(po1, None).unwrap(); assert_eq!(pool.address_count(&paymaster1), 1); @@ -944,6 +970,10 @@ mod tests { let mut po2 = create_op(sender, 0, 11); po2.uo.max_priority_fee_per_gas = 11.into(); po2.uo.paymaster_and_data = paymaster2.as_bytes().to_vec().into(); + po2.entity_infos.paymaster = Some(EntityInfo { + address: po2.uo.paymaster().unwrap(), + is_staked: false, + }); let _ = pool.add_operation(po2.clone(), None).unwrap(); assert_eq!(pool.address_count(&sender), 1); @@ -1038,8 +1068,18 @@ mod tests { sender, nonce: nonce.into(), max_fee_per_gas: max_fee_per_gas.into(), + ..UserOperation::default() }, + entity_infos: EntityInfos { + factory: None, + sender: EntityInfo { + address: sender, + is_staked: false, + }, + paymaster: None, + aggregator: None, + }, ..PoolOperation::default() } } diff --git a/crates/sim/Cargo.toml b/crates/sim/Cargo.toml index 43423d71e..7e6ba4faf 100644 --- a/crates/sim/Cargo.toml +++ b/crates/sim/Cargo.toml @@ -27,6 +27,7 @@ reqwest.workspace = true tokio = { workspace = true, features = ["macros"] } tracing.workspace = true url.workspace = true +strum.workspace = true mockall = {workspace = true, optional = true } diff --git a/crates/sim/src/simulation/simulation.rs b/crates/sim/src/simulation/simulation.rs index dd2d7560b..19cb56c0c 100644 --- a/crates/sim/src/simulation/simulation.rs +++ b/crates/sim/src/simulation/simulation.rs @@ -32,6 +32,7 @@ use rundler_types::{ contracts::i_entry_point::FailedOp, Entity, EntityType, StorageSlot, UserOperation, ValidTimeRange, }; +use strum::IntoEnumIterator; use super::{ mempool::{match_mempools, AllowEntity, AllowRule, MempoolConfig, MempoolMatchResult}, @@ -798,6 +799,11 @@ impl EntityInfos { } } + /// Get iterator over the entities + pub fn entities(&'_ self) -> impl Iterator + '_ { + EntityType::iter().filter_map(|t| self.get(t).map(|info| (t, info))) + } + fn override_is_staked(&mut self, allow_unstaked_addresses: &HashSet
) { if let Some(mut factory) = self.factory { factory.override_is_staked(allow_unstaked_addresses) @@ -811,7 +817,8 @@ impl EntityInfos { } } - fn get(self, entity: EntityType) -> Option { + /// Get the EntityInfo of a specific entity + pub fn get(self, entity: EntityType) -> Option { match entity { EntityType::Factory => self.factory, EntityType::Account => Some(self.sender),