Skip to content

Commit

Permalink
refactor Wrapper type
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai committed Sep 10, 2024
1 parent 837d5d6 commit 42597ae
Show file tree
Hide file tree
Showing 30 changed files with 229 additions and 174 deletions.
4 changes: 2 additions & 2 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ This assists us in knowing when to make the next release a breaking release and

### shotover rust API

* `Transform::transform` now takes `&mut Wrapper` instead of `Wrapper`.
* `Wrapper` is renamed to ChainState.
`Transform::transform` previously took a `Wrapper` type as an argument.
That has now been split into 2 separate types: `&mut ChainState` and `DownChainTransforms`.

## 0.4.0

Expand Down
10 changes: 6 additions & 4 deletions custom-transforms-example/src/redis_get_rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use serde::{Deserialize, Serialize};
use shotover::frame::{Frame, MessageType, RedisFrame};
use shotover::message::{MessageIdSet, Messages};
use shotover::transforms::{
ChainState, Transform, TransformBuilder, TransformConfig, TransformContextConfig,
ChainState, DownChainTransforms, Transform, TransformBuilder, TransformConfig,
TransformContextConfig,
};
use shotover::transforms::{DownChainProtocol, TransformContextBuilder, UpChainProtocol};

Expand Down Expand Up @@ -64,9 +65,10 @@ impl Transform for RedisGetRewrite {
NAME
}

async fn transform<'shorter, 'longer: 'shorter>(
async fn transform(
&mut self,
chain_state: &'shorter mut ChainState<'longer>,
chain_state: &mut ChainState,
down_chain: DownChainTransforms<'_>,
) -> Result<Messages> {
for message in chain_state.requests.iter_mut() {
if let Some(frame) = message.frame() {
Expand All @@ -75,7 +77,7 @@ impl Transform for RedisGetRewrite {
}
}
}
let mut responses = chain_state.call_next_transform().await?;
let mut responses = down_chain.call_next_transform(chain_state).await?;

for response in responses.iter_mut() {
if response
Expand Down
10 changes: 5 additions & 5 deletions shotover/benches/benches/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,14 +341,14 @@ fn cassandra_parsed_query(query: &str) -> ChainState {
)
}

struct BenchInput<'a> {
struct BenchInput {
chain: TransformChain,
chain_state: ChainState<'a>,
chain_state: ChainState,
}

impl<'a> BenchInput<'a> {
impl BenchInput {
// Setup the bench such that the chain is completely fresh
fn new_fresh(chain: &TransformChainBuilder, chain_state: &ChainState<'a>) -> Self {
fn new_fresh(chain: &TransformChainBuilder, chain_state: &ChainState) -> Self {
BenchInput {
chain: chain.build(TransformContextBuilder::new_test()),
chain_state: chain_state.clone(),
Expand All @@ -358,7 +358,7 @@ impl<'a> BenchInput<'a> {
// Setup the bench such that the chain has already had the test chain_state passed through it.
// This ensures that any adhoc setup for that message type has been performed.
// This is a more realistic bench for typical usage.
fn new_pre_used(chain: &TransformChainBuilder, chain_state: &ChainState<'a>) -> Self {
fn new_pre_used(chain: &TransformChainBuilder, chain_state: &ChainState) -> Self {
let mut chain = chain.build(TransformContextBuilder::new_test());

// Run the chain once so we are measuring the chain once each transform has been fully initialized
Expand Down
9 changes: 5 additions & 4 deletions shotover/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -727,10 +727,11 @@ impl<C: CodecBuilder + 'static> Handler<C> {
out_tx: &mpsc::UnboundedSender<Messages>,
requests: Messages,
) -> Result<Option<CloseReason>> {
let mut wrapper = ChainState::new_with_addr(requests, local_addr);
let mut chain_state = ChainState::new_with_addr(requests, local_addr);

self.pending_requests.process_requests(&wrapper.requests);
let responses = match self.chain.process_request(&mut wrapper).await {
self.pending_requests
.process_requests(&chain_state.requests);
let responses = match self.chain.process_request(&mut chain_state).await {
Ok(x) => x,
Err(err) => {
let err = err.context("Chain failed to send and/or receive messages, the connection will now be closed.");
Expand All @@ -752,7 +753,7 @@ impl<C: CodecBuilder + 'static> Handler<C> {
}

// if requested by a transform, close connection AFTER sending any responses back to the client
if wrapper.close_client_connection {
if chain_state.close_client_connection {
return Ok(Some(CloseReason::TransformRequested));
}

Expand Down
11 changes: 6 additions & 5 deletions shotover/src/transforms/cassandra/peers_rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use crate::frame::MessageType;
use crate::message::{Message, MessageIdMap, Messages};
use crate::transforms::cassandra::peers_rewrite::CassandraOperation::Event;
use crate::transforms::{
ChainState, DownChainProtocol, Transform, TransformBuilder, TransformConfig,
TransformContextBuilder, UpChainProtocol,
ChainState, DownChainProtocol, DownChainTransforms, Transform, TransformBuilder,
TransformConfig, TransformContextBuilder, UpChainProtocol,
};
use crate::{
frame::{
Expand Down Expand Up @@ -79,9 +79,10 @@ impl Transform for CassandraPeersRewrite {
NAME
}

async fn transform<'shorter, 'longer: 'shorter>(
async fn transform(
&mut self,
chain_state: &'shorter mut ChainState<'longer>,
chain_state: &mut ChainState,
down_chain: DownChainTransforms<'_>,
) -> Result<Messages> {
// Find the indices of queries to system.peers & system.peers_v2
// we need to know which columns in which CQL queries in which messages have system peers
Expand All @@ -90,7 +91,7 @@ impl Transform for CassandraPeersRewrite {
self.column_names_to_rewrite.insert(request.id(), sys_peers);
}

let mut responses = chain_state.call_next_transform().await?;
let mut responses = down_chain.call_next_transform(chain_state).await?;

for response in &mut responses {
if let Some(Frame::Cassandra(frame)) = response.frame() {
Expand Down
9 changes: 5 additions & 4 deletions shotover/src/transforms/cassandra/sink_cluster/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame, M
use crate::message::{Message, MessageIdMap, Messages, Metadata};
use crate::tls::{TlsConnector, TlsConnectorConfig};
use crate::transforms::{
ChainState, DownChainProtocol, Transform, TransformBuilder, TransformConfig,
TransformContextBuilder, TransformContextConfig, UpChainProtocol,
ChainState, DownChainProtocol, DownChainTransforms, Transform, TransformBuilder,
TransformConfig, TransformContextBuilder, TransformContextConfig, UpChainProtocol,
};
use anyhow::{anyhow, Context, Result};
use async_trait::async_trait;
Expand Down Expand Up @@ -761,9 +761,10 @@ impl Transform for CassandraSinkCluster {
NAME
}

async fn transform<'shorter, 'longer: 'shorter>(
async fn transform(
&mut self,
chain_state: &'shorter mut ChainState<'longer>,
chain_state: &mut ChainState,
_down_chain: DownChainTransforms<'_>,
) -> Result<Messages> {
self.send_message(std::mem::take(&mut chain_state.requests))
.await
Expand Down
9 changes: 5 additions & 4 deletions shotover/src/transforms/cassandra/sink_single.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use crate::frame::MessageType;
use crate::message::{Messages, Metadata};
use crate::tls::{TlsConnector, TlsConnectorConfig};
use crate::transforms::{
ChainState, DownChainProtocol, Transform, TransformBuilder, TransformConfig,
TransformContextBuilder, TransformContextConfig, UpChainProtocol,
ChainState, DownChainProtocol, DownChainTransforms, Transform, TransformBuilder,
TransformConfig, TransformContextBuilder, TransformContextConfig, UpChainProtocol,
};
use anyhow::{anyhow, Result};
use async_trait::async_trait;
Expand Down Expand Up @@ -212,9 +212,10 @@ impl Transform for CassandraSinkSingle {
NAME
}

async fn transform<'shorter, 'longer: 'shorter>(
async fn transform(
&mut self,
chain_state: &'shorter mut ChainState<'longer>,
chain_state: &mut ChainState,
_down_chain: DownChainTransforms<'_>,
) -> Result<Messages> {
self.send_message(std::mem::take(&mut chain_state.requests))
.await
Expand Down
26 changes: 11 additions & 15 deletions shotover/src/transforms/chain.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::TransformContextBuilder;
use super::{DownChainTransforms, TransformContextBuilder};
use crate::message::Messages;
use crate::transforms::{ChainState, Transform, TransformBuilder};
use anyhow::{anyhow, Result};
Expand Down Expand Up @@ -72,7 +72,7 @@ pub struct BufferedChain {
impl BufferedChain {
pub async fn process_request(
&mut self,
chain_state: ChainState<'_>,
chain_state: ChainState,
buffer_timeout_micros: Option<u64>,
) -> Result<Messages> {
self.process_request_with_receiver(chain_state, buffer_timeout_micros)
Expand All @@ -82,7 +82,7 @@ impl BufferedChain {

async fn process_request_with_receiver(
&mut self,
chain_state: ChainState<'_>,
chain_state: ChainState,
buffer_timeout_micros: Option<u64>,
) -> Result<oneshot::Receiver<Result<Messages>>> {
let (one_tx, one_rx) = oneshot::channel::<Result<Messages>>();
Expand Down Expand Up @@ -119,7 +119,7 @@ impl BufferedChain {

pub async fn process_request_no_return(
&mut self,
chain_state: ChainState<'_>,
chain_state: ChainState,
buffer_timeout_micros: Option<u64>,
) -> Result<()> {
if chain_state.flush {
Expand Down Expand Up @@ -158,16 +158,12 @@ impl BufferedChain {
}

impl TransformChain {
pub async fn process_request<'shorter, 'longer: 'shorter>(
&'longer mut self,
chain_state: &'shorter mut ChainState<'longer>,
) -> Result<Messages> {
pub async fn process_request(&mut self, state: &mut ChainState) -> Result<Messages> {
let start = Instant::now();
chain_state.reset(&mut self.chain);
let down_chain = DownChainTransforms::new(&mut self.chain);

self.chain_batch_size
.record(chain_state.requests.len() as f64);
let result = chain_state.call_next_transform().await;
self.chain_batch_size.record(state.requests.len() as f64);
let result = down_chain.call_next_transform(state).await;
self.chain_total.increment(1);
if result.is_err() {
self.chain_failures.increment(1);
Expand Down Expand Up @@ -322,9 +318,9 @@ impl TransformChainBuilder {
count_clone.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}

let mut chain_state = ChainState::new_with_addr(messages, local_addr);
chain_state.flush = flush;
let chain_response = chain.process_request(&mut chain_state).await;
let mut wrapper = ChainState::new_with_addr(messages, local_addr);
wrapper.flush = flush;
let chain_response = chain.process_request(&mut wrapper).await;

if let Err(e) = &chain_response {
error!("Internal error in buffered chain: {e:?}");
Expand Down
22 changes: 15 additions & 7 deletions shotover/src/transforms/coalesce.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use super::{DownChainProtocol, TransformContextBuilder, TransformContextConfig, UpChainProtocol};
use super::{
DownChainProtocol, DownChainTransforms, TransformContextBuilder, TransformContextConfig,
UpChainProtocol,
};
use crate::message::Messages;
use crate::transforms::{ChainState, Transform, TransformBuilder, TransformConfig};
use anyhow::Result;
Expand Down Expand Up @@ -81,9 +84,10 @@ impl Transform for Coalesce {
NAME
}

async fn transform<'shorter, 'longer: 'shorter>(
async fn transform(
&mut self,
chain_state: &'shorter mut ChainState<'longer>,
chain_state: &mut ChainState,
down_chain: DownChainTransforms<'_>,
) -> Result<Messages> {
self.buffer.append(&mut chain_state.requests);

Expand All @@ -102,7 +106,7 @@ impl Transform for Coalesce {
self.last_write = Instant::now()
}
std::mem::swap(&mut self.buffer, &mut chain_state.requests);
chain_state.call_next_transform().await
down_chain.call_next_transform(chain_state).await
} else {
Ok(vec![])
}
Expand All @@ -116,7 +120,7 @@ mod test {
use crate::transforms::chain::TransformAndMetrics;
use crate::transforms::coalesce::Coalesce;
use crate::transforms::loopback::Loopback;
use crate::transforms::{ChainState, Transform};
use crate::transforms::{ChainState, DownChainTransforms, Transform};
use pretty_assertions::assert_eq;
use std::time::{Duration, Instant};

Expand Down Expand Up @@ -199,9 +203,13 @@ mod test {
expected_len: usize,
) {
let mut wrapper = ChainState::new_test(requests.to_vec());
wrapper.reset(chain);
let transforms = DownChainTransforms::new(chain);
assert_eq!(
coalesce.transform(&mut wrapper).await.unwrap().len(),
coalesce
.transform(&mut wrapper, transforms)
.await
.unwrap()
.len(),
expected_len
);
}
Expand Down
8 changes: 5 additions & 3 deletions shotover/src/transforms/debug/force_parse.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::message::Messages;
use crate::transforms::DownChainTransforms;
/// This transform will by default parse requests and responses that pass through it.
/// request and response parsing can be individually disabled if desired.
///
Expand Down Expand Up @@ -105,9 +106,10 @@ impl Transform for DebugForceParse {
NAME
}

async fn transform<'shorter, 'longer: 'shorter>(
async fn transform(
&mut self,
chain_state: &'shorter mut ChainState<'longer>,
chain_state: &mut ChainState,
down_chain: DownChainTransforms<'_>,
) -> Result<Messages> {
for message in &mut chain_state.requests {
if self.parse_requests {
Expand All @@ -118,7 +120,7 @@ impl Transform for DebugForceParse {
}
}

let mut response = chain_state.call_next_transform().await;
let mut response = down_chain.call_next_transform(chain_state).await;

if let Ok(response) = response.as_mut() {
for message in response {
Expand Down
11 changes: 7 additions & 4 deletions shotover/src/transforms/debug/log_to_file.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use crate::message::{Encodable, Message};
use crate::transforms::{ChainState, Transform, TransformBuilder, TransformContextBuilder};
use crate::transforms::{
ChainState, DownChainTransforms, Transform, TransformBuilder, TransformContextBuilder,
};
#[cfg(feature = "alpha-transforms")]
use crate::transforms::{DownChainProtocol, UpChainProtocol};
use anyhow::{Context, Result};
Expand Down Expand Up @@ -89,9 +91,10 @@ impl Transform for DebugLogToFile {
NAME
}

async fn transform<'shorter, 'longer: 'shorter>(
async fn transform(
&mut self,
chain_state: &'shorter mut ChainState<'longer>,
chain_state: &mut ChainState,
down_chain: DownChainTransforms<'_>,
) -> Result<Vec<Message>> {
for message in &chain_state.requests {
self.request_counter += 1;
Expand All @@ -101,7 +104,7 @@ impl Transform for DebugLogToFile {
log_message(message, path.as_path()).await?;
}

let response = chain_state.call_next_transform().await?;
let response = down_chain.call_next_transform(chain_state).await?;

for message in &response {
self.response_counter += 1;
Expand Down
11 changes: 6 additions & 5 deletions shotover/src/transforms/debug/printer.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::message::Messages;
use crate::transforms::{
ChainState, DownChainProtocol, Transform, TransformBuilder, TransformConfig,
TransformContextBuilder, TransformContextConfig, UpChainProtocol,
ChainState, DownChainProtocol, DownChainTransforms, Transform, TransformBuilder,
TransformConfig, TransformContextBuilder, TransformContextConfig, UpChainProtocol,
};
use anyhow::Result;
use async_trait::async_trait;
Expand Down Expand Up @@ -65,16 +65,17 @@ impl Transform for DebugPrinter {
NAME
}

async fn transform<'shorter, 'longer: 'shorter>(
async fn transform(
&mut self,
chain_state: &'shorter mut ChainState<'longer>,
chain_state: &mut ChainState,
down_chain: DownChainTransforms<'_>,
) -> Result<Messages> {
for request in &mut chain_state.requests {
info!("Request: {}", request.to_high_level_string());
}

self.counter += 1;
let mut responses = chain_state.call_next_transform().await?;
let mut responses = down_chain.call_next_transform(chain_state).await?;

for response in &mut responses {
info!("Response: {}", response.to_high_level_string());
Expand Down
Loading

0 comments on commit 42597ae

Please sign in to comment.