Skip to content

Commit

Permalink
feat: allow manager to interact with different rav and receipt
Browse files Browse the repository at this point in the history
Signed-off-by: Gustavo Inacio <[email protected]>
  • Loading branch information
gusinacio committed Jan 17, 2025
1 parent eaef977 commit 28baab5
Show file tree
Hide file tree
Showing 17 changed files with 351 additions and 201 deletions.
6 changes: 3 additions & 3 deletions tap_core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::result::Result as StdResult;
use alloy::primitives::{Address, SignatureError};
use thiserror::Error as ThisError;

use crate::{rav::ReceiptAggregateVoucher, receipt::ReceiptError};
use crate::receipt::ReceiptError;

/// Error type for the TAP protocol
#[derive(ThisError, Debug)]
Expand Down Expand Up @@ -38,8 +38,8 @@ pub enum Error {
/// Error when the received RAV does not match the expected RAV
#[error("Received RAV does not match expexted RAV")]
InvalidReceivedRAV {
received_rav: ReceiptAggregateVoucher,
expected_rav: ReceiptAggregateVoucher,
received_rav: String,
expected_rav: String,
},
/// Generic error from the adapter
#[error("Error from adapter.\n Caused by: {source_error}")]
Expand Down
15 changes: 8 additions & 7 deletions tap_core/src/manager/adapters/escrow.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
// Copyright 2023-, Semiotic AI, Inc.
// SPDX-License-Identifier: Apache-2.0

use alloy::{dyn_abi::Eip712Domain, primitives::Address};
use alloy::{dyn_abi::Eip712Domain, primitives::Address, sol_types::SolStruct};
use async_trait::async_trait;

use crate::{
rav::SignedRAV,
manager::WithValueAndTimestamp,
receipt::{state::AwaitingReserve, ReceiptError, ReceiptResult, ReceiptWithState},
signed_message::EIP712SignedMessage,
Error,
};

Expand Down Expand Up @@ -41,9 +42,9 @@ pub trait EscrowHandler: Send + Sync {
async fn verify_signer(&self, signer_address: Address) -> Result<bool, Self::AdapterError>;

/// Checks and reserves escrow for the received receipt
async fn check_and_reserve_escrow(
async fn check_and_reserve_escrow<T: SolStruct + WithValueAndTimestamp + Sync>(
&self,
received_receipt: &ReceiptWithState<AwaitingReserve>,
received_receipt: &ReceiptWithState<AwaitingReserve, T>,
domain_separator: &Eip712Domain,
) -> ReceiptResult<()> {
let signed_receipt = &received_receipt.signed_receipt;
Expand All @@ -55,7 +56,7 @@ pub trait EscrowHandler: Send + Sync {
})?;

if self
.subtract_escrow(receipt_signer_address, signed_receipt.message.value)
.subtract_escrow(receipt_signer_address, signed_receipt.message.value())
.await
.is_err()
{
Expand All @@ -66,9 +67,9 @@ pub trait EscrowHandler: Send + Sync {
}

/// Checks the signature of the RAV
async fn check_rav_signature(
async fn check_rav_signature<R: SolStruct + Sync>(
&self,
signed_rav: &SignedRAV,
signed_rav: &EIP712SignedMessage<R>,
domain_separator: &Eip712Domain,
) -> Result<(), Error> {
let recovered_address = signed_rav.recover_signer(domain_separator)?;
Expand Down
17 changes: 12 additions & 5 deletions tap_core/src/manager/adapters/rav.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
// Copyright 2023-, Semiotic AI, Inc.
// SPDX-License-Identifier: Apache-2.0

use alloy::sol_types::SolStruct;
use async_trait::async_trait;

use crate::rav::SignedRAV;
use crate::signed_message::EIP712SignedMessage;

/// Stores the latest RAV in the storage.
///
Expand All @@ -12,7 +13,10 @@ use crate::rav::SignedRAV;
/// For example code see [crate::manager::context::memory::RAVStorage]
#[async_trait]
pub trait RAVStore {
pub trait RAVStore<T>
where
T: SolStruct,
{
/// Defines the user-specified error type.
///
/// This error type should implement the `Error` and `Debug` traits from
Expand All @@ -25,7 +29,7 @@ pub trait RAVStore {
/// This method should be implemented to store the most recent validated
/// `SignedRAV` into your chosen storage system. Any errors that occur
/// during this process should be captured and returned as an `AdapterError`.
async fn update_last_rav(&self, rav: SignedRAV) -> Result<(), Self::AdapterError>;
async fn update_last_rav(&self, rav: EIP712SignedMessage<T>) -> Result<(), Self::AdapterError>;
}

/// Reads the RAV from storage
Expand All @@ -35,7 +39,10 @@ pub trait RAVStore {
/// For example code see [crate::manager::context::memory::RAVStorage]
#[async_trait]
pub trait RAVRead {
pub trait RAVRead<T>
where
T: SolStruct,
{
/// Defines the user-specified error type.
///
/// This error type should implement the `Error` and `Debug` traits from
Expand All @@ -46,5 +53,5 @@ pub trait RAVRead {
/// Retrieves the latest `SignedRAV` from the storage.
///
/// If no `SignedRAV` is available, this method should return `None`.
async fn last_rav(&self) -> Result<Option<SignedRAV>, Self::AdapterError>;
async fn last_rav(&self) -> Result<Option<EIP712SignedMessage<T>>, Self::AdapterError>;
}
39 changes: 24 additions & 15 deletions tap_core/src/manager/adapters/receipt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@

use std::ops::RangeBounds;

use alloy::sol_types::SolStruct;
use async_trait::async_trait;

use crate::receipt::{
state::{Checking, ReceiptState},
ReceiptWithState,
use crate::{
manager::WithValueAndTimestamp,
receipt::{
state::{Checking, ReceiptState},
ReceiptWithState,
},
};

/// Stores receipts in the storage.
Expand All @@ -16,7 +20,10 @@ use crate::receipt::{
///
/// For example code see [crate::manager::context::memory::ReceiptStorage]
#[async_trait]
pub trait ReceiptStore {
pub trait ReceiptStore<T>
where
T: SolStruct,
{
/// Defines the user-specified error type.
///
/// This error type should implement the `Error` and `Debug` traits from the standard library.
Expand All @@ -29,7 +36,7 @@ pub trait ReceiptStore {
/// this process should be captured and returned as an `AdapterError`.
async fn store_receipt(
&self,
receipt: ReceiptWithState<Checking>,
receipt: ReceiptWithState<Checking, T>,
) -> Result<u64, Self::AdapterError>;
}

Expand Down Expand Up @@ -62,7 +69,10 @@ pub trait ReceiptDelete {
///
/// For example code see [crate::manager::context::memory::ReceiptStorage]
#[async_trait]
pub trait ReceiptRead {
pub trait ReceiptRead<T>
where
T: SolStruct,
{
/// Defines the user-specified error type.
///
/// This error type should implement the `Error` and `Debug` traits from
Expand Down Expand Up @@ -92,15 +102,15 @@ pub trait ReceiptRead {
&self,
timestamp_range_ns: R,
limit: Option<u64>,
) -> Result<Vec<ReceiptWithState<Checking>>, Self::AdapterError>;
) -> Result<Vec<ReceiptWithState<Checking, T>>, Self::AdapterError>;
}

/// See [`ReceiptRead::retrieve_receipts_in_timestamp_range()`] for details.
///
/// WARNING: Will sort the receipts by timestamp using
/// [vec::sort_unstable](https://doc.rust-lang.org/std/vec/struct.Vec.html#method.sort_unstable).
pub fn safe_truncate_receipts<T: ReceiptState>(
receipts: &mut Vec<ReceiptWithState<T>>,
pub fn safe_truncate_receipts<T: ReceiptState, R: SolStruct + WithValueAndTimestamp>(
receipts: &mut Vec<ReceiptWithState<T, R>>,
limit: u64,
) {
if receipts.len() <= limit as usize {
Expand All @@ -110,27 +120,26 @@ pub fn safe_truncate_receipts<T: ReceiptState>(
return;
}

receipts.sort_unstable_by_key(|rx_receipt| rx_receipt.signed_receipt().message.timestamp_ns);
receipts.sort_unstable_by_key(|rx_receipt| rx_receipt.signed_receipt().message.timestamp());

// This one will be the last timestamp in `receipts` after naive truncation
let last_timestamp = receipts[limit as usize - 1]
.signed_receipt()
.message
.timestamp_ns;
.timestamp();
// This one is the timestamp that comes just after the one above
let after_last_timestamp = receipts[limit as usize]
.signed_receipt()
.message
.timestamp_ns;
.timestamp();

receipts.truncate(limit as usize);

if last_timestamp == after_last_timestamp {
// If the last timestamp is the same as the one that came after it, we need to
// remove all the receipts with the same timestamp as the last one, because
// otherwise we would leave behind part of the receipts for that timestamp.
receipts.retain(|rx_receipt| {
rx_receipt.signed_receipt().message.timestamp_ns != last_timestamp
});
receipts
.retain(|rx_receipt| rx_receipt.signed_receipt().message.timestamp() != last_timestamp);
}
}
46 changes: 27 additions & 19 deletions tap_core/src/manager/context/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ use async_trait::async_trait;

use crate::{
manager::adapters::*,
rav::SignedRAV,
receipt::{checks::StatefulTimestampCheck, state::Checking, ReceiptWithState},
rav::{ReceiptAggregateVoucher, SignedRAV},
receipt::{checks::StatefulTimestampCheck, state::Checking, Receipt, ReceiptWithState},
signed_message::MessageId,
};

pub type EscrowStorage = Arc<RwLock<HashMap<Address, u128>>>;
pub type QueryAppraisals = Arc<RwLock<HashMap<MessageId, u128>>>;
pub type ReceiptStorage = Arc<RwLock<HashMap<u64, ReceiptWithState<Checking>>>>;
pub type ReceiptStorage = Arc<RwLock<HashMap<u64, ReceiptWithState<Checking, Receipt>>>>;
pub type RAVStorage = Arc<RwLock<Option<SignedRAV>>>;

use thiserror::Error;
Expand Down Expand Up @@ -71,7 +71,7 @@ impl InMemoryContext {
pub async fn retrieve_receipt_by_id(
&self,
receipt_id: u64,
) -> Result<ReceiptWithState<Checking>, InMemoryError> {
) -> Result<ReceiptWithState<Checking, Receipt>, InMemoryError> {
let receipt_storage = self.receipt_storage.read().unwrap();

receipt_storage
Expand All @@ -85,7 +85,7 @@ impl InMemoryContext {
pub async fn retrieve_receipts_by_timestamp(
&self,
timestamp_ns: u64,
) -> Result<Vec<(u64, ReceiptWithState<Checking>)>, InMemoryError> {
) -> Result<Vec<(u64, ReceiptWithState<Checking, Receipt>)>, InMemoryError> {
let receipt_storage = self.receipt_storage.read().unwrap();
Ok(receipt_storage
.iter()
Expand All @@ -99,7 +99,7 @@ impl InMemoryContext {
pub async fn retrieve_receipts_upto_timestamp(
&self,
timestamp_ns: u64,
) -> Result<Vec<ReceiptWithState<Checking>>, InMemoryError> {
) -> Result<Vec<ReceiptWithState<Checking, Receipt>>, InMemoryError> {
self.retrieve_receipts_in_timestamp_range(..=timestamp_ns, None)
.await
}
Expand All @@ -125,7 +125,7 @@ impl InMemoryContext {
}

#[async_trait]
impl RAVStore for InMemoryContext {
impl RAVStore<ReceiptAggregateVoucher> for InMemoryContext {
type AdapterError = InMemoryError;

async fn update_last_rav(&self, rav: SignedRAV) -> Result<(), Self::AdapterError> {
Expand All @@ -138,7 +138,7 @@ impl RAVStore for InMemoryContext {
}

#[async_trait]
impl RAVRead for InMemoryContext {
impl RAVRead<ReceiptAggregateVoucher> for InMemoryContext {
type AdapterError = InMemoryError;

async fn last_rav(&self) -> Result<Option<SignedRAV>, Self::AdapterError> {
Expand All @@ -147,12 +147,12 @@ impl RAVRead for InMemoryContext {
}

#[async_trait]
impl ReceiptStore for InMemoryContext {
impl ReceiptStore<Receipt> for InMemoryContext {
type AdapterError = InMemoryError;

async fn store_receipt(
&self,
receipt: ReceiptWithState<Checking>,
receipt: ReceiptWithState<Checking, Receipt>,
) -> Result<u64, Self::AdapterError> {
let mut id_pointer = self.unique_id.write().unwrap();
let id_previous = *id_pointer;
Expand All @@ -179,15 +179,15 @@ impl ReceiptDelete for InMemoryContext {
}
}
#[async_trait]
impl ReceiptRead for InMemoryContext {
impl ReceiptRead<Receipt> for InMemoryContext {
type AdapterError = InMemoryError;
async fn retrieve_receipts_in_timestamp_range<R: RangeBounds<u64> + std::marker::Send>(
&self,
timestamp_range_ns: R,
limit: Option<u64>,
) -> Result<Vec<ReceiptWithState<Checking>>, Self::AdapterError> {
) -> Result<Vec<ReceiptWithState<Checking, Receipt>>, Self::AdapterError> {
let receipt_storage = self.receipt_storage.read().unwrap();
let mut receipts_in_range: Vec<ReceiptWithState<Checking>> = receipt_storage
let mut receipts_in_range: Vec<ReceiptWithState<Checking, Receipt>> = receipt_storage
.iter()
.filter(|(_, rx_receipt)| {
timestamp_range_ns.contains(&rx_receipt.signed_receipt().message.timestamp_ns)
Expand Down Expand Up @@ -274,7 +274,7 @@ pub mod checks {
receipt::{
checks::{Check, CheckError, CheckResult, ReceiptCheck},
state::Checking,
Context, ReceiptError, ReceiptWithState,
Context, Receipt, ReceiptError, ReceiptWithState,
},
signed_message::MessageId,
};
Expand All @@ -284,7 +284,7 @@ pub mod checks {
valid_signers: HashSet<Address>,
allocation_ids: Arc<RwLock<HashSet<Address>>>,
_query_appraisals: Arc<RwLock<HashMap<MessageId, u128>>>,
) -> Vec<ReceiptCheck> {
) -> Vec<ReceiptCheck<Receipt>> {
vec![
// Arc::new(UniqueCheck ),
// Arc::new(ValueCheck { query_appraisals }),
Expand All @@ -301,8 +301,12 @@ pub mod checks {
}

#[async_trait::async_trait]
impl Check for AllocationIdCheck {
async fn check(&self, _: &Context, receipt: &ReceiptWithState<Checking>) -> CheckResult {
impl Check<Receipt> for AllocationIdCheck {
async fn check(
&self,
_: &Context,
receipt: &ReceiptWithState<Checking, Receipt>,
) -> CheckResult {
let received_allocation_id = receipt.signed_receipt().message.allocation_id;
if self
.allocation_ids
Expand All @@ -328,8 +332,12 @@ pub mod checks {
}

#[async_trait::async_trait]
impl Check for SignatureCheck {
async fn check(&self, _: &Context, receipt: &ReceiptWithState<Checking>) -> CheckResult {
impl Check<Receipt> for SignatureCheck {
async fn check(
&self,
_: &Context,
receipt: &ReceiptWithState<Checking, Receipt>,
) -> CheckResult {
let recovered_address = receipt
.signed_receipt()
.recover_signer(&self.domain_separator)
Expand Down
Loading

0 comments on commit 28baab5

Please sign in to comment.