diff --git a/scylla-cql/src/errors.rs b/scylla-cql/src/errors.rs index 40587cfef6..c7cc85d233 100644 --- a/scylla-cql/src/errors.rs +++ b/scylla-cql/src/errors.rs @@ -348,6 +348,10 @@ pub enum BadQuery { #[error("Passed invalid keyspace name to use: {0}")] BadKeyspaceName(#[from] BadKeyspaceName), + /// Too many queries in the batch statement + #[error("Number of Queries in Batch Statement has exceeded the max value of 65,536")] + TooManyQueriesInBatchStatement, + /// Other reasons of bad query #[error("{0}")] Other(String), diff --git a/scylla-cql/src/frame/types.rs b/scylla-cql/src/frame/types.rs index 6af08e217f..fa964e9478 100644 --- a/scylla-cql/src/frame/types.rs +++ b/scylla-cql/src/frame/types.rs @@ -5,8 +5,8 @@ use byteorder::{BigEndian, ReadBytesExt}; use bytes::{Buf, BufMut}; use num_enum::TryFromPrimitive; use std::collections::HashMap; -use std::convert::{Infallible, TryFrom}; use std::convert::TryInto; +use std::convert::{Infallible, TryFrom}; use std::net::IpAddr; use std::net::SocketAddr; use std::str; diff --git a/scylla/Cargo.toml b/scylla/Cargo.toml index 2460e020b9..3408d10330 100644 --- a/scylla/Cargo.toml +++ b/scylla/Cargo.toml @@ -60,6 +60,7 @@ criterion = "0.4" # Note: v0.5 needs at least rust 1.70.0 tracing-subscriber = { version = "0.3.14", features = ["env-filter"] } assert_matches = "1.5.0" rand_chacha = "0.3.1" +bcs = "0.1.5" [[bench]] name = "benchmark" diff --git a/scylla/src/transport/large_batch_statements_test.rs b/scylla/src/transport/large_batch_statements_test.rs new file mode 100644 index 0000000000..6195de30df --- /dev/null +++ b/scylla/src/transport/large_batch_statements_test.rs @@ -0,0 +1,106 @@ +use bcs::serialize_into; +use scylla_cql::errors::{BadQuery, QueryError}; + +use crate::batch::BatchType; +use crate::query::Query; +use crate::{ + batch::Batch, + prepared_statement::PreparedStatement, + test_utils::{create_new_session_builder, unique_keyspace_name}, + IntoTypedRows, QueryResult, Session, +}; + +#[tokio::test] +async fn test_large_batch_statements() { + let mut session = create_new_session_builder().build().await.unwrap(); + let ks = unique_keyspace_name(); + session = create_test_session(session, &ks).await; + + let max_number_of_queries = u16::MAX as usize; + write_batch(&session, max_number_of_queries).await; + + let key_prefix = vec![0]; + let keys = find_keys_by_prefix(&session, key_prefix.clone()).await; + assert_eq!(keys.len(), max_number_of_queries); + + let too_many_queries = u16::MAX as usize + 1; + + let err = write_batch(&session, too_many_queries).await; + + assert!(err.is_err()); +} + +async fn create_test_session(session: Session, ks: &String) -> Session { + session + .query( + format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{ 'class' : 'SimpleStrategy', 'replication_factor' : 1 }}",ks), + &[], + ) + .await.unwrap(); + session + .query("DROP TABLE IF EXISTS kv.pairs;", &[]) + .await + .unwrap(); + session + .query( + "CREATE TABLE IF NOT EXISTS kv.pairs (dummy int, k blob, v blob, primary key (dummy, k))", + &[], + ) + .await.unwrap(); + session +} + +async fn write_batch(session: &Session, n: usize) -> Result { + let mut batch_query = Batch::new(BatchType::Logged); + let mut batch_values = Vec::new(); + for i in 0..n { + let mut key = vec![0]; + serialize_into(&mut key, &(i as usize)).unwrap(); + let value = key.clone(); + let query = "INSERT INTO kv.pairs (dummy, k, v) VALUES (0, ?, ?)"; + let values = vec![key, value]; + batch_values.push(values); + let query = Query::new(query); + batch_query.append_statement(query); + } + session.batch(&batch_query, batch_values).await +} + +async fn find_keys_by_prefix(session: &Session, key_prefix: Vec) -> Vec> { + let len = key_prefix.len(); + let rows = match get_upper_bound_option(&key_prefix) { + None => { + let values = (key_prefix,); + let query = "SELECT k FROM kv.pairs WHERE dummy = 0 AND k >= ? ALLOW FILTERING"; + session.query(query, values).await.unwrap() + } + Some(upper_bound) => { + let values = (key_prefix, upper_bound); + let query = + "SELECT k FROM kv.pairs WHERE dummy = 0 AND k >= ? AND k < ? ALLOW FILTERING"; + session.query(query, values).await.unwrap() + } + }; + let mut keys = Vec::new(); + if let Some(rows) = rows.rows { + for row in rows.into_typed::<(Vec,)>() { + let key = row.unwrap(); + let short_key = key.0[len..].to_vec(); + keys.push(short_key); + } + } + keys +} + +fn get_upper_bound_option(key_prefix: &[u8]) -> Option> { + let len = key_prefix.len(); + for i in (0..len).rev() { + let val = key_prefix[i]; + if val < u8::MAX { + let mut upper_bound = key_prefix[0..i + 1].to_vec(); + upper_bound[i] += 1; + return Some(upper_bound); + } + } + None +} diff --git a/scylla/src/transport/mod.rs b/scylla/src/transport/mod.rs index 939983cfc4..a33943645d 100644 --- a/scylla/src/transport/mod.rs +++ b/scylla/src/transport/mod.rs @@ -35,6 +35,8 @@ mod silent_prepare_batch_test; mod cql_types_test; #[cfg(test)] mod cql_value_test; +#[cfg(test)] +mod large_batch_statements_test; pub use cluster::ClusterData; pub use node::{KnownNode, Node, NodeAddr, NodeRef}; diff --git a/scylla/src/transport/session.rs b/scylla/src/transport/session.rs index 35ff25475f..f92067363d 100644 --- a/scylla/src/transport/session.rs +++ b/scylla/src/transport/session.rs @@ -76,6 +76,7 @@ pub use crate::transport::connection_pool::PoolSize; use crate::authentication::AuthenticatorProvider; #[cfg(feature = "ssl")] use openssl::ssl::SslContext; +use scylla_cql::errors::BadQuery; /// Translates IP addresses received from ScyllaDB nodes into locally reachable addresses. /// @@ -1143,6 +1144,12 @@ impl Session { // Shard-awareness behavior for batch will be to pick shard based on first batch statement's shard // If users batch statements by shard, they will be rewarded with full shard awareness + // check to ensure that we don't send a batch statement with more than u16::MAX queries + if batch.statements.len() > u16::MAX as usize { + return Err(QueryError::BadQuery( + BadQuery::TooManyQueriesInBatchStatement, + )); + } // Extract first serialized_value let first_serialized_value = values.batch_values_iter().next_serialized().transpose()?; let first_serialized_value = first_serialized_value.as_deref();