diff --git a/rpc-client-api/src/filter.rs b/rpc-client-api/src/filter.rs index 4c65f4249e..cf9d85de4d 100644 --- a/rpc-client-api/src/filter.rs +++ b/rpc-client-api/src/filter.rs @@ -17,6 +17,7 @@ pub enum RpcFilterType { DataSize(u64), Memcmp(Memcmp), TokenAccountState, + ValueCmp(ValueCmp), } impl RpcFilterType { @@ -57,6 +58,7 @@ impl RpcFilterType { } } RpcFilterType::TokenAccountState => Ok(()), + RpcFilterType::ValueCmp(_) => Ok(()), } } @@ -69,6 +71,9 @@ impl RpcFilterType { RpcFilterType::DataSize(size) => account.data().len() as u64 == *size, RpcFilterType::Memcmp(compare) => compare.bytes_match(account.data()), RpcFilterType::TokenAccountState => Account::valid_account_data(account.data()), + RpcFilterType::ValueCmp(compare) => { + compare.values_match(account.data()).unwrap_or(false) + } } } } @@ -81,6 +86,8 @@ pub enum RpcFilterError { Base58DecodeError(#[from] bs58::decode::Error), #[error("base64 decode error")] Base64DecodeError(#[from] base64::DecodeError), + #[error("invalid ValueCmp filter")] + InvalidValueCmp, } #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize)] @@ -222,6 +229,178 @@ impl Memcmp { } } +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct ValueCmp { + pub left: Operand, + comparator: Comparator, + pub right: Operand, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub enum Operand { + Mem { + offset: usize, + value_type: ValueType, + }, + Constant(String), +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub enum ValueType { + U8, + U16, + U32, + U64, + U128, +} + +enum WrappedValueType { + U8(u8), + U16(u16), + U32(u32), + U64(u64), + U128(u128), +} + +impl ValueCmp { + fn parse_mem_into_value_type( + o: &Operand, + data: &[u8], + ) -> Result { + match o { + Operand::Mem { offset, value_type } => match value_type { + ValueType::U8 => { + if *offset >= data.len() { + return Err(RpcFilterError::InvalidValueCmp); + } + + Ok(WrappedValueType::U8(data[*offset])) + } + ValueType::U16 => { + if *offset + 1 >= data.len() { + return Err(RpcFilterError::InvalidValueCmp); + } + Ok(WrappedValueType::U16(u16::from_le_bytes( + data[*offset..*offset + 2].try_into().unwrap(), + ))) + } + ValueType::U32 => { + if *offset + 3 >= data.len() { + return Err(RpcFilterError::InvalidValueCmp); + } + Ok(WrappedValueType::U32(u32::from_le_bytes( + data[*offset..*offset + 4].try_into().unwrap(), + ))) + } + ValueType::U64 => { + if *offset + 7 >= data.len() { + return Err(RpcFilterError::InvalidValueCmp); + } + Ok(WrappedValueType::U64(u64::from_le_bytes( + data[*offset..*offset + 8].try_into().unwrap(), + ))) + } + ValueType::U128 => { + if *offset + 15 >= data.len() { + return Err(RpcFilterError::InvalidValueCmp); + } + Ok(WrappedValueType::U128(u128::from_le_bytes( + data[*offset..*offset + 16].try_into().unwrap(), + ))) + } + }, + _ => Err(RpcFilterError::InvalidValueCmp), + } + } + + pub fn values_match(&self, data: &[u8]) -> Result { + match (&self.left, &self.right) { + (left @ Operand::Mem { .. }, right @ Operand::Mem { .. }) => { + let left = Self::parse_mem_into_value_type(left, data)?; + let right = Self::parse_mem_into_value_type(right, data)?; + + match (left, right) { + (WrappedValueType::U8(left), WrappedValueType::U8(right)) => { + Ok(self.comparator.compare(left, right)) + } + (WrappedValueType::U16(left), WrappedValueType::U16(right)) => { + Ok(self.comparator.compare(left, right)) + } + (WrappedValueType::U32(left), WrappedValueType::U32(right)) => { + Ok(self.comparator.compare(left, right)) + } + (WrappedValueType::U64(left), WrappedValueType::U64(right)) => { + Ok(self.comparator.compare(left, right)) + } + (WrappedValueType::U128(left), WrappedValueType::U128(right)) => { + Ok(self.comparator.compare(left, right)) + } + _ => Err(RpcFilterError::InvalidValueCmp), + } + } + (left @ Operand::Mem { .. }, Operand::Constant(constant)) => { + match Self::parse_mem_into_value_type(left, data)? { + WrappedValueType::U8(left) => { + let right = constant + .parse::() + .map_err(|_| RpcFilterError::InvalidValueCmp)?; + Ok(self.comparator.compare(left, right)) + } + WrappedValueType::U16(left) => { + let right = constant + .parse::() + .map_err(|_| RpcFilterError::InvalidValueCmp)?; + Ok(self.comparator.compare(left, right)) + } + WrappedValueType::U32(left) => { + let right = constant + .parse::() + .map_err(|_| RpcFilterError::InvalidValueCmp)?; + Ok(self.comparator.compare(left, right)) + } + WrappedValueType::U64(left) => { + let right = constant + .parse::() + .map_err(|_| RpcFilterError::InvalidValueCmp)?; + Ok(self.comparator.compare(left, right)) + } + WrappedValueType::U128(left) => { + let right = constant + .parse::() + .map_err(|_| RpcFilterError::InvalidValueCmp)?; + Ok(self.comparator.compare(left, right)) + } + } + } + _ => Err(RpcFilterError::InvalidValueCmp), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub enum Comparator { + Eq = 0, + Ne, + Gt, + Ge, + Lt, + Le, +} + +impl Comparator { + // write a generic function to compare two values + pub fn compare(&self, left: T, right: T) -> bool { + match self { + Comparator::Eq => left == right, + Comparator::Ne => left != right, + Comparator::Gt => left > right, + Comparator::Ge => left >= right, + Comparator::Lt => left < right, + Comparator::Le => left <= right, + } + } +} + #[cfg(test)] mod tests { use { @@ -455,4 +634,56 @@ mod tests { serde_json::from_str::(BYTES_FILTER_WITH_ENCODING).unwrap() ); } + + #[test] + fn test_values_match() { + // test all the ValueCmp cases + let data = vec![1, 2, 3, 4, 5]; + + let filter = ValueCmp { + left: Operand::Mem { + offset: 1, + value_type: ValueType::U8, + }, + comparator: Comparator::Eq, + right: Operand::Constant("2".to_string()), + }; + + assert!(ValueCmp { + left: Operand::Mem { + offset: 1, + value_type: ValueType::U8 + }, + comparator: Comparator::Eq, + right: Operand::Constant("2".to_string()) + } + .values_match(&data) + .unwrap()); + + assert!(ValueCmp { + left: Operand::Mem { + offset: 1, + value_type: ValueType::U8 + }, + comparator: Comparator::Lt, + right: Operand::Constant("3".to_string()) + } + .values_match(&data) + .unwrap()); + + assert!(ValueCmp { + left: Operand::Mem { + offset: 0, + value_type: ValueType::U32 + }, + comparator: Comparator::Eq, + right: Operand::Constant("67305985".to_string()) + } + .values_match(&data) + .unwrap()); + + // serialize + let s = serde_json::to_string(&filter).unwrap(); + println!("{}", s); + } } diff --git a/rpc/src/filter.rs b/rpc/src/filter.rs index 81cebd2f47..4c106d86cf 100644 --- a/rpc/src/filter.rs +++ b/rpc/src/filter.rs @@ -9,5 +9,6 @@ pub fn filter_allows(filter: &RpcFilterType, account: &AccountSharedData) -> boo RpcFilterType::DataSize(size) => account.data().len() as u64 == *size, RpcFilterType::Memcmp(compare) => compare.bytes_match(account.data()), RpcFilterType::TokenAccountState => Account::valid_account_data(account.data()), + RpcFilterType::ValueCmp(compare) => compare.values_match(account.data()).unwrap_or(false), } } diff --git a/rpc/src/rpc.rs b/rpc/src/rpc.rs index 87bb355524..3f10aa7875 100644 --- a/rpc/src/rpc.rs +++ b/rpc/src/rpc.rs @@ -2442,6 +2442,7 @@ fn get_spl_token_owner_filter(program_id: &Pubkey, filters: &[RpcFilterType]) -> } } RpcFilterType::TokenAccountState => token_account_state_filter = true, + RpcFilterType::ValueCmp(_) => {} } } if data_size_filter == Some(account_packed_len as u64) @@ -2493,6 +2494,7 @@ fn get_spl_token_mint_filter(program_id: &Pubkey, filters: &[RpcFilterType]) -> } } RpcFilterType::TokenAccountState => token_account_state_filter = true, + RpcFilterType::ValueCmp(_) => {} } } if data_size_filter == Some(account_packed_len as u64)