From 4c013e87b01c02fb5cd6f2bccd467fb70e674900 Mon Sep 17 00:00:00 2001 From: samuel orji Date: Fri, 6 Oct 2023 22:09:21 +0100 Subject: [PATCH] changed the type of the maximum number of statements in a batch query from an i16 to a u16 according to the CQL protocol spec --- scylla-cql/src/frame/request/batch.rs | 4 ++-- scylla-cql/src/frame/types.rs | 26 +++++++++++++++++++++++++- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/scylla-cql/src/frame/request/batch.rs b/scylla-cql/src/frame/request/batch.rs index 3c0bad3931..92b8b61ec4 100644 --- a/scylla-cql/src/frame/request/batch.rs +++ b/scylla-cql/src/frame/request/batch.rs @@ -81,7 +81,7 @@ where buf.put_u8(self.batch_type as u8); // Serializing queries - types::write_short(self.statements.len().try_into()?, buf); + types::write_u16(self.statements.len().try_into()?, buf); let counts_mismatch_err = |n_values: usize, n_statements: usize| { ParseError::BadDataToSerialize(format!( @@ -190,7 +190,7 @@ impl<'b> DeserializableRequest for Batch<'b, BatchStatement<'b>, Vec Result { let batch_type = buf.get_u8().try_into()?; - let statements_count: usize = types::read_short(buf)?.try_into()?; + let statements_count: usize = types::read_u16(buf)?.try_into()?; let statements_with_values = (0..statements_count) .map(|_| { let batch_statement = BatchStatement::deserialize(buf)?; diff --git a/scylla-cql/src/frame/types.rs b/scylla-cql/src/frame/types.rs index fd2254c8b0..6af08e217f 100644 --- a/scylla-cql/src/frame/types.rs +++ b/scylla-cql/src/frame/types.rs @@ -5,7 +5,7 @@ use byteorder::{BigEndian, ReadBytesExt}; use bytes::{Buf, BufMut}; use num_enum::TryFromPrimitive; use std::collections::HashMap; -use std::convert::TryFrom; +use std::convert::{Infallible, TryFrom}; use std::convert::TryInto; use std::net::IpAddr; use std::net::SocketAddr; @@ -98,6 +98,12 @@ impl From for ParseError { } } +impl From for ParseError { + fn from(_: Infallible) -> Self { + ParseError::BadIncomingData("Unexpected Infallible Error".to_string()) + } +} + impl From for ParseError { fn from(_err: std::array::TryFromSliceError) -> Self { ParseError::BadIncomingData("array try from slice failed".to_string()) @@ -174,10 +180,19 @@ pub fn read_short(buf: &mut &[u8]) -> Result { Ok(v) } +pub fn read_u16(buf: &mut &[u8]) -> Result { + let v = buf.read_u16::()?; + Ok(v) +} + pub fn write_short(v: i16, buf: &mut impl BufMut) { buf.put_i16(v); } +pub fn write_u16(v: u16, buf: &mut impl BufMut) { + buf.put_u16(v); +} + pub(crate) fn read_short_length(buf: &mut &[u8]) -> Result { let v = read_short(buf)?; let v: usize = v.try_into()?; @@ -200,6 +215,15 @@ fn type_short() { } } +#[test] +fn type_u16() { + let vals = [0, 1, u16::MAX]; + for val in vals.iter() { + let mut buf = Vec::new(); + write_u16(*val, &mut buf); + assert_eq!(read_u16(&mut &buf[..]).unwrap(), *val); + } +} // https://github.com/apache/cassandra/blob/trunk/doc/native_protocol_v4.spec#L208 pub fn read_bytes_opt<'a>(buf: &mut &'a [u8]) -> Result, ParseError> { let len = read_int(buf)?;