Skip to content

Commit

Permalink
Merge pull request eqlabs#2258 from eqlabs/sistemd/starknet-subscribe…
Browse files Browse the repository at this point in the history
…-pending-transactions-1

feat(rpc): implement `starknet_subscribePendingTransactions`
  • Loading branch information
sistemd authored Sep 25, 2024
2 parents b5400d3 + 0cd34b1 commit d8cc94a
Show file tree
Hide file tree
Showing 6 changed files with 698 additions and 56 deletions.
112 changes: 73 additions & 39 deletions crates/rpc/src/jsonrpc/router/subscription.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::sync::Arc;
use std::time::Duration;

use axum::extract::ws::{Message, WebSocket};
use dashmap::DashMap;
Expand All @@ -11,7 +12,7 @@ use crate::context::RpcContext;
use crate::dto::serialize::SerializeForVersion;
use crate::dto::DeserializeForVersion;
use crate::error::ApplicationError;
use crate::jsonrpc::{RpcError, RpcRequest, RpcResponse};
use crate::jsonrpc::{RequestId, RpcError, RpcRequest, RpcResponse};
use crate::{RpcVersion, SubscriptionId};

/// See [`RpcSubscriptionFlow`].
Expand All @@ -20,11 +21,11 @@ pub(super) trait RpcSubscriptionEndpoint: Send + Sync {
// Start the subscription.
async fn invoke(
&self,
state: RpcContext,
router: RpcRouter,
input: serde_json::Value,
subscription_id: SubscriptionId,
subscriptions: Arc<DashMap<SubscriptionId, tokio::task::JoinHandle<()>>>,
version: RpcVersion,
req_id: RequestId,
tx: mpsc::Sender<Result<Message, RpcResponse>>,
) -> Result<(), RpcError>;
}
Expand Down Expand Up @@ -57,12 +58,12 @@ pub(super) trait RpcSubscriptionEndpoint: Send + Sync {
/// - Stream the first active update, and then keep streaming the rest.
#[axum::async_trait]
pub trait RpcSubscriptionFlow: Send + Sync {
type Request: crate::dto::DeserializeForVersion + Send + Sync + 'static;
type Request: crate::dto::DeserializeForVersion + Clone + Send + Sync + 'static;
type Notification: crate::dto::serialize::SerializeForVersion + Send + Sync + 'static;

/// The block to start streaming from. If the subscription endpoint does not
/// support catching up, the value returned by this method is
/// irrelevant.
/// support catching up, this method should always return
/// [`BlockId::Latest`].
fn starting_block(req: &Self::Request) -> BlockId;

/// Fetch historical data from the `from` block to the `to` block. The
Expand All @@ -78,6 +79,7 @@ pub trait RpcSubscriptionFlow: Send + Sync {
/// Subscribe to active updates.
async fn subscribe(
state: RpcContext,
req: Self::Request,
tx: mpsc::Sender<SubscriptionMessage<Self::Notification>>,
);
}
Expand All @@ -101,20 +103,20 @@ where
{
async fn invoke(
&self,
state: RpcContext,
router: RpcRouter,
input: serde_json::Value,
subscription_id: SubscriptionId,
subscriptions: Arc<DashMap<SubscriptionId, tokio::task::JoinHandle<()>>>,
version: RpcVersion,
tx: mpsc::Sender<Result<Message, RpcResponse>>,
req_id: RequestId,
ws_tx: mpsc::Sender<Result<Message, RpcResponse>>,
) -> Result<(), RpcError> {
let req = T::Request::deserialize(crate::dto::Value::new(input, version))
let req = T::Request::deserialize(crate::dto::Value::new(input, router.version))
.map_err(|e| RpcError::InvalidParams(e.to_string()))?;
let tx = SubscriptionSender {
subscription_id,
subscriptions,
tx,
version,
tx: ws_tx.clone(),
version: router.version,
_phantom: Default::default(),
};

Expand All @@ -128,13 +130,32 @@ where
}
BlockId::Latest => {
// No need to catch up. The code below will subscribe to new blocks.
// Only needs to send the subscription ID to the client.
if ws_tx
.send(Ok(Message::Text(
serde_json::to_string(&RpcResponse {
output: Ok(serde_json::to_value(&SubscriptionIdResult {
subscription_id,
})
.unwrap()),
id: req_id.clone(),
})
.unwrap(),
)))
.await
.is_err()
{
return Ok(());
}
BlockNumber::MAX
}
BlockId::Number(_) | BlockId::Hash(_) => {
// Catch up to the latest block in batches of BATCH_SIZE.

// Load the first block number, return an error if it's invalid.
let first_block = pathfinder_storage::BlockId::try_from(T::starting_block(&req))
.map_err(|e| RpcError::InvalidParams(e.to_string()))?;
let storage = state.storage.clone();
let storage = router.context.storage.clone();
let mut current_block =
tokio::task::spawn_blocking(move || -> Result<_, RpcError> {
let mut conn = storage.connection().map_err(RpcError::InternalError)?;
Expand All @@ -145,11 +166,34 @@ where
})
.await
.map_err(|e| RpcError::InternalError(e.into()))??;

// Send the subscription ID to the client.
if ws_tx
.send(Ok(Message::Text(
serde_json::to_string(&RpcResponse {
output: Ok(serde_json::to_value(&SubscriptionIdResult {
subscription_id,
})
.unwrap()),
id: req_id.clone(),
})
.unwrap(),
)))
.await
.is_err()
{
return Ok(());
}

const BATCH_SIZE: u64 = 64;
loop {
let messages =
T::catch_up(&state, &req, current_block, current_block + BATCH_SIZE)
.await?;
let messages = T::catch_up(
&router.context,
&req,
current_block,
current_block + BATCH_SIZE,
)
.await?;
if messages.is_empty() {
// Caught up.
break;
Expand All @@ -174,7 +218,10 @@ where

// Subscribe to new blocks. Receive the first subscription message.
let (tx1, mut rx1) = mpsc::channel::<SubscriptionMessage<T::Notification>>(1024);
tokio::spawn(T::subscribe(state.clone(), tx1));
{
let req = req.clone();
tokio::spawn(T::subscribe(router.context.clone(), req, tx1));
}
let first_msg = match rx1.recv().await {
Some(msg) => msg,
None => {
Expand All @@ -188,7 +235,7 @@ where
// blocks. Because the catch_up range is inclusive, we need to subtract 1 from
// the block number.
if let Some(block_number) = first_msg.block_number.parent() {
let messages = T::catch_up(&state, &req, current_block, block_number).await?;
let messages = T::catch_up(&router.context, &req, current_block, block_number).await?;
for msg in messages {
if tx
.send(msg.notification, msg.subscription_name)
Expand Down Expand Up @@ -415,36 +462,19 @@ pub fn handle_json_rpc_socket(
};

// Start the subscription.
let state = state.clone();
let subscription_id = SubscriptionId::next();
let context = state.context.clone();
let version = state.version;
let ws_tx = ws_tx.clone();
if ws_tx
.send(Ok(Message::Text(
serde_json::to_string(&RpcResponse {
output: Ok(
serde_json::to_value(&SubscriptionIdResult { subscription_id })
.unwrap(),
),
id: req_id.clone(),
})
.unwrap(),
)))
.await
.is_err()
{
break;
}
let handle = tokio::spawn({
let subscriptions = subscriptions.clone();
async move {
if let Err(e) = endpoint
.invoke(
context,
state,
params,
subscription_id,
subscriptions,
version,
subscriptions.clone(),
req_id.clone(),
ws_tx.clone(),
)
.await
Expand All @@ -456,6 +486,10 @@ pub fn handle_json_rpc_socket(
}))
.await
.ok();
while subscriptions.remove(&subscription_id).is_none() {
// Race condition, the insert has not yet happened.
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
}
});
Expand Down
1 change: 1 addition & 0 deletions crates/rpc/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pub mod get_transaction_receipt;
pub mod get_transaction_status;
pub mod simulate_transactions;
pub mod subscribe_new_heads;
pub mod subscribe_pending_transactions;
pub mod syncing;
pub mod trace_block_transactions;
pub mod trace_transaction;
Expand Down
Loading

0 comments on commit d8cc94a

Please sign in to comment.