diff --git a/scylla-cql/src/frame/response/result.rs b/scylla-cql/src/frame/response/result.rs index 1d20bb7b0a..f516d6e510 100644 --- a/scylla-cql/src/frame/response/result.rs +++ b/scylla-cql/src/frame/response/result.rs @@ -116,6 +116,31 @@ pub enum CqlValue { Varint(BigInt), } +impl ColumnType { + // Returns true if the type allows a special, empty value in addition to its + // natural representation. For example, bigint represents a 32-bit integer, + // but it can also hold a 0-bit empty value. + // + // It looks like Cassandra 4.1.3 rejects empty values for some more types than + // Scylla: date, time, smallint and tinyint. We will only check against + // Scylla's set of types supported for empty values as it's smaller; + // with Cassandra, some rejects will just have to be rejected on the db side. + pub(crate) fn supports_special_empty_value(&self) -> bool { + #[allow(clippy::match_like_matches_macro)] + match self { + ColumnType::Counter + | ColumnType::Duration + | ColumnType::List(_) + | ColumnType::Map(_, _) + | ColumnType::Set(_) + | ColumnType::UserDefinedType { .. } + | ColumnType::Custom(_) => false, + + _ => true, + } + } +} + impl CqlValue { pub fn as_ascii(&self) -> Option<&String> { match self { diff --git a/scylla-cql/src/frame/value.rs b/scylla-cql/src/frame/value.rs index 4faa8df501..a5fa8462f4 100644 --- a/scylla-cql/src/frame/value.rs +++ b/scylla-cql/src/frame/value.rs @@ -31,6 +31,10 @@ pub trait Value { #[error("Value too big to be sent in a request - max 2GiB allowed")] pub struct ValueTooBig; +#[derive(Debug, Error, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[error("Value is too large to fit in the CQL type")] +pub struct ValueOverflow; + /// Represents an unset value pub struct Unset; @@ -40,7 +44,7 @@ pub struct Counter(pub i64); /// Enum providing a way to represent a value that might be unset #[derive(Clone, Copy)] -pub enum MaybeUnset { +pub enum MaybeUnset { Unset, Set(V), } @@ -78,7 +82,7 @@ impl From for CqlDate { #[cfg(feature = "chrono")] impl TryInto for CqlDate { - type Error = ValueTooBig; + type Error = ValueOverflow; fn try_into(self) -> Result { let days_since_unix_epoch = self.0 as i64 - (1 << 31); @@ -90,7 +94,7 @@ impl TryInto for CqlDate { NaiveDate::from_yo_opt(1970, 1) .unwrap() .checked_add_signed(duration_since_unix_epoch) - .ok_or(ValueTooBig) + .ok_or(ValueOverflow) } } @@ -103,19 +107,19 @@ impl From> for CqlTimestamp { #[cfg(feature = "chrono")] impl TryInto> for CqlTimestamp { - type Error = ValueTooBig; + type Error = ValueOverflow; fn try_into(self) -> Result, Self::Error> { match Utc.timestamp_millis_opt(self.0) { chrono::LocalResult::Single(datetime) => Ok(datetime), - _ => Err(ValueTooBig), + _ => Err(ValueOverflow), } } } #[cfg(feature = "chrono")] impl TryFrom for CqlTime { - type Error = ValueTooBig; + type Error = ValueOverflow; fn try_from(value: NaiveTime) -> Result { let nanos = value @@ -127,23 +131,23 @@ impl TryFrom for CqlTime { if nanos <= 86399999999999 { Ok(Self(nanos)) } else { - Err(ValueTooBig) + Err(ValueOverflow) } } } #[cfg(feature = "chrono")] impl TryInto for CqlTime { - type Error = ValueTooBig; + type Error = ValueOverflow; fn try_into(self) -> Result { let secs = (self.0 / 1_000_000_000) .try_into() - .map_err(|_| ValueTooBig)?; + .map_err(|_| ValueOverflow)?; let nanos = (self.0 % 1_000_000_000) .try_into() - .map_err(|_| ValueTooBig)?; - NaiveTime::from_num_seconds_from_midnight_opt(secs, nanos).ok_or(ValueTooBig) + .map_err(|_| ValueOverflow)?; + NaiveTime::from_num_seconds_from_midnight_opt(secs, nanos).ok_or(ValueOverflow) } } @@ -167,7 +171,7 @@ impl From for CqlDate { #[cfg(feature = "time")] impl TryInto for CqlDate { - type Error = ValueTooBig; + type Error = ValueOverflow; fn try_into(self) -> Result { const JULIAN_DAY_OFFSET: i64 = @@ -175,9 +179,9 @@ impl TryInto for CqlDate { let julian_days = (self.0 as i64 - JULIAN_DAY_OFFSET) .try_into() - .map_err(|_| ValueTooBig)?; + .map_err(|_| ValueOverflow)?; - time::Date::from_julian_day(julian_days).map_err(|_| ValueTooBig) + time::Date::from_julian_day(julian_days).map_err(|_| ValueOverflow) } } @@ -209,11 +213,11 @@ impl From for CqlTimestamp { #[cfg(feature = "time")] impl TryInto for CqlTimestamp { - type Error = ValueTooBig; + type Error = ValueOverflow; fn try_into(self) -> Result { time::OffsetDateTime::from_unix_timestamp_nanos(self.0 as i128 * 1_000_000) - .map_err(|_| ValueTooBig) + .map_err(|_| ValueOverflow) } } @@ -231,7 +235,7 @@ impl From for CqlTime { #[cfg(feature = "time")] impl TryInto for CqlTime { - type Error = ValueTooBig; + type Error = ValueOverflow; fn try_into(self) -> Result { let h = self.0 / 3_600_000_000_000; @@ -240,12 +244,12 @@ impl TryInto for CqlTime { let n = self.0 % 1_000_000_000; time::Time::from_hms_nano( - h.try_into().map_err(|_| ValueTooBig)?, + h.try_into().map_err(|_| ValueOverflow)?, m as u8, s as u8, n as u32, ) - .map_err(|_| ValueTooBig) + .map_err(|_| ValueOverflow) } } @@ -631,7 +635,9 @@ impl Value for time::OffsetDateTime { #[cfg(feature = "chrono")] impl Value for NaiveTime { fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { - CqlTime::try_from(*self)?.serialize(buf) + CqlTime::try_from(*self) + .map_err(|_| ValueTooBig)? + .serialize(buf) } } diff --git a/scylla-cql/src/frame/value_tests.rs b/scylla-cql/src/frame/value_tests.rs index 003ff0116a..5f692d53b6 100644 --- a/scylla-cql/src/frame/value_tests.rs +++ b/scylla-cql/src/frame/value_tests.rs @@ -1,74 +1,246 @@ -use crate::frame::{types::RawValue, value::BatchValuesIterator}; +use crate::frame::{response::result::CqlValue, types::RawValue, value::BatchValuesIterator}; +use crate::types::serialize::row::{RowSerializationContext, SerializeRow}; +use crate::types::serialize::value::SerializeCql; +use crate::types::serialize::{BufBackedCellWriter, BufBackedRowWriter}; +use super::response::result::{ColumnSpec, ColumnType, TableSpec}; use super::value::{ - BatchValues, CqlDate, CqlTime, CqlTimestamp, MaybeUnset, SerializeValuesError, + BatchValues, CqlDate, CqlDuration, CqlTime, CqlTimestamp, MaybeUnset, SerializeValuesError, SerializedValues, Unset, Value, ValueList, ValueTooBig, }; +use bigdecimal::BigDecimal; use bytes::BufMut; +use num_bigint::BigInt; +use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; +use std::hash::{BuildHasherDefault, Hasher}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::{borrow::Cow, convert::TryInto}; use uuid::Uuid; -fn serialized(val: impl Value) -> Vec { +fn serialized(val: T, typ: ColumnType) -> Vec +where + T: Value + SerializeCql, +{ let mut result: Vec = Vec::new(); - val.serialize(&mut result).unwrap(); + Value::serialize(&val, &mut result).unwrap(); + + T::preliminary_type_check(&typ).unwrap(); + + let mut new_result: Vec = Vec::new(); + let writer = BufBackedCellWriter::new(&mut new_result); + SerializeCql::serialize(&val, &typ, writer).unwrap(); + + assert_eq!(result, new_result); + + result +} + +fn serialized_only_new(val: T, typ: ColumnType) -> Vec { + let mut result: Vec = Vec::new(); + let writer = BufBackedCellWriter::new(&mut result); + SerializeCql::serialize(&val, &typ, writer).unwrap(); result } #[test] -fn basic_serialization() { - assert_eq!(serialized(8_i8), vec![0, 0, 0, 1, 8]); - assert_eq!(serialized(16_i16), vec![0, 0, 0, 2, 0, 16]); - assert_eq!(serialized(32_i32), vec![0, 0, 0, 4, 0, 0, 0, 32]); +fn boolean_serialization() { + assert_eq!(serialized(true, ColumnType::Boolean), vec![0, 0, 0, 1, 1]); + assert_eq!(serialized(false, ColumnType::Boolean), vec![0, 0, 0, 1, 0]); +} + +#[test] +fn fixed_integral_serialization() { + assert_eq!(serialized(8_i8, ColumnType::TinyInt), vec![0, 0, 0, 1, 8]); assert_eq!( - serialized(64_i64), + serialized(16_i16, ColumnType::SmallInt), + vec![0, 0, 0, 2, 0, 16] + ); + assert_eq!( + serialized(32_i32, ColumnType::Int), + vec![0, 0, 0, 4, 0, 0, 0, 32] + ); + assert_eq!( + serialized(64_i64, ColumnType::BigInt), vec![0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 64] ); +} + +#[test] +fn counter_serialization() { + assert_eq!( + serialized(0x0123456789abcdef_i64, ColumnType::BigInt), + vec![0, 0, 0, 8, 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef] + ); +} + +#[test] +fn bigint_serialization() { + let cases_from_the_spec: &[(i64, Vec)] = &[ + (0, vec![0x00]), + (1, vec![0x01]), + (127, vec![0x7F]), + (128, vec![0x00, 0x80]), + (129, vec![0x00, 0x81]), + (-1, vec![0xFF]), + (-128, vec![0x80]), + (-129, vec![0xFF, 0x7F]), + ]; + + for (i, b) in cases_from_the_spec { + let x = BigInt::from(*i); + let b_with_len = (b.len() as i32) + .to_be_bytes() + .iter() + .chain(b) + .cloned() + .collect::>(); + assert_eq!(serialized(x, ColumnType::Varint), b_with_len); + } +} + +#[test] +fn bigdecimal_serialization() { + // Bigint cases + let cases_from_the_spec: &[(i64, Vec)] = &[ + (0, vec![0x00]), + (1, vec![0x01]), + (127, vec![0x7F]), + (128, vec![0x00, 0x80]), + (129, vec![0x00, 0x81]), + (-1, vec![0xFF]), + (-128, vec![0x80]), + (-129, vec![0xFF, 0x7F]), + ]; + + for exponent in -10_i32..10_i32 { + for (digits, serialized_digits) in cases_from_the_spec { + let repr = ((serialized_digits.len() + 4) as i32) + .to_be_bytes() + .iter() + .chain(&exponent.to_be_bytes()) + .chain(serialized_digits) + .cloned() + .collect::>(); + let digits = BigInt::from(*digits); + let x = BigDecimal::new(digits, exponent as i64); + assert_eq!(serialized(x, ColumnType::Decimal), repr); + } + } +} + +#[test] +fn floating_point_serialization() { + assert_eq!( + serialized(123.456f32, ColumnType::Float), + [0, 0, 0, 4] + .into_iter() + .chain((123.456f32).to_be_bytes()) + .collect::>() + ); + assert_eq!( + serialized(123.456f64, ColumnType::Double), + [0, 0, 0, 8] + .into_iter() + .chain((123.456f64).to_be_bytes()) + .collect::>() + ); +} - assert_eq!(serialized("abc"), vec![0, 0, 0, 3, 97, 98, 99]); - assert_eq!(serialized("abc".to_string()), vec![0, 0, 0, 3, 97, 98, 99]); +#[test] +fn text_serialization() { + assert_eq!( + serialized("abc", ColumnType::Text), + vec![0, 0, 0, 3, 97, 98, 99] + ); + assert_eq!( + serialized("abc".to_string(), ColumnType::Ascii), + vec![0, 0, 0, 3, 97, 98, 99] + ); } #[test] fn u8_array_serialization() { let val = [1u8; 4]; - assert_eq!(serialized(val), vec![0, 0, 0, 4, 1, 1, 1, 1]); + assert_eq!( + serialized(val, ColumnType::Blob), + vec![0, 0, 0, 4, 1, 1, 1, 1] + ); } #[test] fn u8_slice_serialization() { let val = vec![1u8, 1, 1, 1]; - assert_eq!(serialized(val.as_slice()), vec![0, 0, 0, 4, 1, 1, 1, 1]); + assert_eq!( + serialized(val.as_slice(), ColumnType::Blob), + vec![0, 0, 0, 4, 1, 1, 1, 1] + ); } #[test] fn cql_date_serialization() { - assert_eq!(serialized(CqlDate(0)), vec![0, 0, 0, 4, 0, 0, 0, 0]); assert_eq!( - serialized(CqlDate(u32::MAX)), + serialized(CqlDate(0), ColumnType::Date), + vec![0, 0, 0, 4, 0, 0, 0, 0] + ); + assert_eq!( + serialized(CqlDate(u32::MAX), ColumnType::Date), vec![0, 0, 0, 4, 255, 255, 255, 255] ); } +#[test] +fn vec_u8_slice_serialization() { + let val = vec![1u8, 1, 1, 1]; + assert_eq!( + serialized(val, ColumnType::Blob), + vec![0, 0, 0, 4, 1, 1, 1, 1] + ); +} + +#[test] +fn ipaddr_serialization() { + let ipv4 = IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)); + assert_eq!( + serialized(ipv4, ColumnType::Inet), + vec![0, 0, 0, 4, 1, 2, 3, 4] + ); + + let ipv6 = IpAddr::V6(Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8)); + assert_eq!( + serialized(ipv6, ColumnType::Inet), + vec![ + 0, 0, 0, 16, // serialized size + 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8, // contents + ] + ); +} + #[cfg(feature = "chrono")] #[test] fn naive_date_serialization() { use chrono::NaiveDate; // 1970-01-31 is 2^31 let unix_epoch: NaiveDate = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); - assert_eq!(serialized(unix_epoch), vec![0, 0, 0, 4, 128, 0, 0, 0]); + assert_eq!( + serialized(unix_epoch, ColumnType::Date), + vec![0, 0, 0, 4, 128, 0, 0, 0] + ); assert_eq!(2_u32.pow(31).to_be_bytes(), [128, 0, 0, 0]); // 1969-12-02 is 2^31 - 30 let before_epoch: NaiveDate = NaiveDate::from_ymd_opt(1969, 12, 2).unwrap(); assert_eq!( - serialized(before_epoch), + serialized(before_epoch, ColumnType::Date), vec![0, 0, 0, 4, 127, 255, 255, 226] ); assert_eq!((2_u32.pow(31) - 30).to_be_bytes(), [127, 255, 255, 226]); // 1970-01-31 is 2^31 + 30 let after_epoch: NaiveDate = NaiveDate::from_ymd_opt(1970, 1, 31).unwrap(); - assert_eq!(serialized(after_epoch), vec![0, 0, 0, 4, 128, 0, 0, 30]); + assert_eq!( + serialized(after_epoch, ColumnType::Date), + vec![0, 0, 0, 4, 128, 0, 0, 30] + ); assert_eq!((2_u32.pow(31) + 30).to_be_bytes(), [128, 0, 0, 30]); } @@ -77,20 +249,26 @@ fn naive_date_serialization() { fn date_serialization() { // 1970-01-31 is 2^31 let unix_epoch = time::Date::from_ordinal_date(1970, 1).unwrap(); - assert_eq!(serialized(unix_epoch), vec![0, 0, 0, 4, 128, 0, 0, 0]); + assert_eq!( + serialized(unix_epoch, ColumnType::Date), + vec![0, 0, 0, 4, 128, 0, 0, 0] + ); assert_eq!(2_u32.pow(31).to_be_bytes(), [128, 0, 0, 0]); // 1969-12-02 is 2^31 - 30 let before_epoch = time::Date::from_calendar_date(1969, time::Month::December, 2).unwrap(); assert_eq!( - serialized(before_epoch), + serialized(before_epoch, ColumnType::Date), vec![0, 0, 0, 4, 127, 255, 255, 226] ); assert_eq!((2_u32.pow(31) - 30).to_be_bytes(), [127, 255, 255, 226]); // 1970-01-31 is 2^31 + 30 let after_epoch = time::Date::from_calendar_date(1970, time::Month::January, 31).unwrap(); - assert_eq!(serialized(after_epoch), vec![0, 0, 0, 4, 128, 0, 0, 30]); + assert_eq!( + serialized(after_epoch, ColumnType::Date), + vec![0, 0, 0, 4, 128, 0, 0, 30] + ); assert_eq!((2_u32.pow(31) + 30).to_be_bytes(), [128, 0, 0, 30]); // Min date represented by time::Date (without large-dates feature) @@ -101,7 +279,7 @@ fn date_serialization() { [127, 189, 75, 125] ); assert_eq!( - serialized(long_before_epoch), + serialized(long_before_epoch, ColumnType::Date), vec![0, 0, 0, 4, 127, 189, 75, 125] ); @@ -113,7 +291,7 @@ fn date_serialization() { [128, 44, 192, 160] ); assert_eq!( - serialized(long_after_epoch), + serialized(long_after_epoch, ColumnType::Date), vec![0, 0, 0, 4, 128, 44, 192, 160] ); } @@ -130,7 +308,7 @@ fn cql_time_serialization() { // Invalid values are also serialized correctly - database will respond with an error for test_val in [0, 1, 15, 18463, max_time, -1, -324234, max_time + 16].into_iter() { let test_time: CqlTime = CqlTime(test_val); - let bytes: Vec = serialized(test_time); + let bytes: Vec = serialized(test_time, ColumnType::Time); let mut expected_bytes: Vec = vec![0, 0, 0, 8]; expected_bytes.extend_from_slice(&test_val.to_be_bytes()); @@ -160,7 +338,7 @@ fn naive_time_serialization() { ), ]; for (time, expected) in test_cases { - let bytes = serialized(time); + let bytes = serialized(time, ColumnType::Time); let mut expected_bytes: Vec = vec![0, 0, 0, 8]; expected_bytes.extend_from_slice(&expected); @@ -171,7 +349,10 @@ fn naive_time_serialization() { // Leap second must return error on serialize let leap_second = NaiveTime::from_hms_nano_opt(23, 59, 59, 1_500_000_000).unwrap(); let mut buffer = Vec::new(); - assert_eq!(leap_second.serialize(&mut buffer), Err(ValueTooBig)) + assert_eq!( + <_ as Value>::serialize(&leap_second, &mut buffer), + Err(ValueTooBig) + ) } #[cfg(feature = "time")] @@ -192,7 +373,7 @@ fn time_serialization() { ), ]; for (time, expected) in test_cases { - let bytes = serialized(time); + let bytes = serialized(time, ColumnType::Time); let mut expected_bytes: Vec = vec![0, 0, 0, 8]; expected_bytes.extend_from_slice(&expected); @@ -207,7 +388,7 @@ fn cql_timestamp_serialization() { for test_val in &[0, -1, 1, -45345346, 453451, i64::MIN, i64::MAX] { let test_timestamp: CqlTimestamp = CqlTimestamp(*test_val); - let bytes: Vec = serialized(test_timestamp); + let bytes: Vec = serialized(test_timestamp, ColumnType::Timestamp); let mut expected_bytes: Vec = vec![0, 0, 0, 8]; expected_bytes.extend_from_slice(&test_val.to_be_bytes()); @@ -260,7 +441,7 @@ fn naive_date_time_serialization() { ]; for (datetime, expected) in test_cases { let test_datetime = datetime.and_utc(); - let bytes: Vec = serialized(test_datetime); + let bytes: Vec = serialized(test_datetime, ColumnType::Timestamp); let mut expected_bytes: Vec = vec![0, 0, 0, 8]; expected_bytes.extend_from_slice(&expected); @@ -322,7 +503,7 @@ fn offset_date_time_serialization() { ), ]; for (datetime, expected) in test_cases { - let bytes: Vec = serialized(datetime); + let bytes: Vec = serialized(datetime, ColumnType::Timestamp); let mut expected_bytes: Vec = vec![0, 0, 0, 8]; expected_bytes.extend_from_slice(&expected); @@ -349,7 +530,7 @@ fn timeuuid_serialization() { for uuid_bytes in &tests { let uuid = Uuid::from_slice(uuid_bytes.as_ref()).unwrap(); - let uuid_serialized: Vec = serialized(uuid); + let uuid_serialized: Vec = serialized(uuid, ColumnType::Uuid); let mut expected_serialized: Vec = vec![0, 0, 0, 16]; expected_serialized.extend_from_slice(uuid_bytes.as_ref()); @@ -358,22 +539,279 @@ fn timeuuid_serialization() { } } +#[test] +fn cqlduration_serialization() { + let duration = CqlDuration { + months: 1, + days: 2, + nanoseconds: 3, + }; + assert_eq!( + serialized(duration, ColumnType::Duration), + vec![0, 0, 0, 3, 2, 4, 6] + ); +} + +#[test] +fn box_serialization() { + let x: Box = Box::new(123); + assert_eq!( + serialized(x, ColumnType::Int), + vec![0, 0, 0, 4, 0, 0, 0, 123] + ); +} + +#[test] +fn vec_set_serialization() { + let m = vec!["ala", "ma", "kota"]; + assert_eq!( + serialized(m, ColumnType::Set(Box::new(ColumnType::Text))), + vec![ + 0, 0, 0, 25, // 25 bytes + 0, 0, 0, 3, // 3 items + 0, 0, 0, 3, 97, 108, 97, // ala + 0, 0, 0, 2, 109, 97, // ma + 0, 0, 0, 4, 107, 111, 116, 97, // kota + ] + ) +} + +#[test] +fn slice_set_serialization() { + let m = ["ala", "ma", "kota"]; + assert_eq!( + serialized(m.as_ref(), ColumnType::Set(Box::new(ColumnType::Text))), + vec![ + 0, 0, 0, 25, // 25 bytes + 0, 0, 0, 3, // 3 items + 0, 0, 0, 3, 97, 108, 97, // ala + 0, 0, 0, 2, 109, 97, // ma + 0, 0, 0, 4, 107, 111, 116, 97, // kota + ] + ) +} + +// A deterministic hasher just for the tests. +#[derive(Default)] +struct DumbHasher { + state: u8, +} + +impl Hasher for DumbHasher { + fn finish(&self) -> u64 { + self.state as u64 + } + + fn write(&mut self, bytes: &[u8]) { + for b in bytes { + self.state ^= b; + } + } +} + +type DumbBuildHasher = BuildHasherDefault; + +#[test] +fn hashset_serialization() { + let m: HashSet<&'static str, DumbBuildHasher> = ["ala", "ma", "kota"].into_iter().collect(); + assert_eq!( + serialized(m, ColumnType::Set(Box::new(ColumnType::Text))), + vec![ + 0, 0, 0, 25, // 25 bytes + 0, 0, 0, 3, // 3 items + 0, 0, 0, 2, 109, 97, // ma + 0, 0, 0, 4, 107, 111, 116, 97, // kota + 0, 0, 0, 3, 97, 108, 97, // ala + ] + ) +} + +#[test] +fn hashmap_serialization() { + let m: HashMap<&'static str, i32, DumbBuildHasher> = + [("ala", 1), ("ma", 2), ("kota", 3)].into_iter().collect(); + assert_eq!( + serialized( + m, + ColumnType::Map(Box::new(ColumnType::Text), Box::new(ColumnType::Int)) + ), + vec![ + 0, 0, 0, 49, // 49 bytes + 0, 0, 0, 3, // 3 items + 0, 0, 0, 2, 109, 97, // ma + 0, 0, 0, 4, 0, 0, 0, 2, // 2 + 0, 0, 0, 4, 107, 111, 116, 97, // kota + 0, 0, 0, 4, 0, 0, 0, 3, // 3 + 0, 0, 0, 3, 97, 108, 97, // ala + 0, 0, 0, 4, 0, 0, 0, 1, // 1 + ] + ) +} + +#[test] +fn btreeset_serialization() { + let m: BTreeSet<&'static str> = ["ala", "ma", "kota"].into_iter().collect(); + assert_eq!( + serialized(m, ColumnType::Set(Box::new(ColumnType::Text))), + vec![ + 0, 0, 0, 25, // 25 bytes + 0, 0, 0, 3, // 3 items + 0, 0, 0, 3, 97, 108, 97, // ala + 0, 0, 0, 4, 107, 111, 116, 97, // kota + 0, 0, 0, 2, 109, 97, // ma + ] + ) +} + +#[test] +fn btreemap_serialization() { + let m: BTreeMap<&'static str, i32> = [("ala", 1), ("ma", 2), ("kota", 3)].into_iter().collect(); + assert_eq!( + serialized( + m, + ColumnType::Map(Box::new(ColumnType::Text), Box::new(ColumnType::Int)) + ), + vec![ + 0, 0, 0, 49, // 49 bytes + 0, 0, 0, 3, // 3 items + 0, 0, 0, 3, 97, 108, 97, // ala + 0, 0, 0, 4, 0, 0, 0, 1, // 1 + 0, 0, 0, 4, 107, 111, 116, 97, // kota + 0, 0, 0, 4, 0, 0, 0, 3, // 3 + 0, 0, 0, 2, 109, 97, // ma + 0, 0, 0, 4, 0, 0, 0, 2, // 2 + ] + ) +} + +#[test] +fn cqlvalue_serialization() { + // We only check those variants here which have some custom logic, + // e.g. UDTs or tuples. + + // Empty + assert_eq!( + serialized(CqlValue::Empty, ColumnType::Int), + vec![0, 0, 0, 0], + ); + + // UDTs + let udt = CqlValue::UserDefinedType { + keyspace: "ks".to_string(), + type_name: "t".to_string(), + fields: vec![ + ("foo".to_string(), Some(CqlValue::Int(123))), + ("bar".to_string(), None), + ], + }; + let typ = ColumnType::UserDefinedType { + type_name: "t".to_string(), + keyspace: "ks".to_string(), + field_types: vec![ + ("foo".to_string(), ColumnType::Int), + ("bar".to_string(), ColumnType::Text), + ], + }; + + assert_eq!( + serialized(udt, typ.clone()), + vec![ + 0, 0, 0, 12, // size of the whole thing + 0, 0, 0, 4, 0, 0, 0, 123, // foo: 123_i32 + 255, 255, 255, 255, // bar: null + ] + ); + + // Unlike the legacy Value trait, SerializeCql takes case of reordering + // the fields + let udt = CqlValue::UserDefinedType { + keyspace: "ks".to_string(), + type_name: "t".to_string(), + fields: vec![ + ("bar".to_string(), None), + ("foo".to_string(), Some(CqlValue::Int(123))), + ], + }; + + assert_eq!( + serialized_only_new(udt, typ.clone()), + vec![ + 0, 0, 0, 12, // size of the whole thing + 0, 0, 0, 4, 0, 0, 0, 123, // foo: 123_i32 + 255, 255, 255, 255, // bar: null + ] + ); + + // Tuples + let tup = CqlValue::Tuple(vec![Some(CqlValue::Int(123)), None]); + let typ = ColumnType::Tuple(vec![ColumnType::Int, ColumnType::Text]); + assert_eq!( + serialized(tup, typ), + vec![ + 0, 0, 0, 12, // size of the whole thing + 0, 0, 0, 4, 0, 0, 0, 123, // 123_i32 + 255, 255, 255, 255, // null + ] + ); + + // It's not required to specify all the values for the tuple, + // only some prefix is sufficient. The rest will be treated by the DB + // as nulls. + // TODO: Need a database test for that + let tup = CqlValue::Tuple(vec![Some(CqlValue::Int(123)), None]); + let typ = ColumnType::Tuple(vec![ColumnType::Int, ColumnType::Text, ColumnType::Counter]); + assert_eq!( + serialized(tup, typ), + vec![ + 0, 0, 0, 12, // size of the whole thing + 0, 0, 0, 4, 0, 0, 0, 123, // 123_i32 + 255, 255, 255, 255, // null + ] + ); +} + +#[cfg(feature = "secret")] +#[test] +fn secret_serialization() { + use secrecy::Secret; + let secret = Secret::new(987654i32); + assert_eq!( + serialized(secret, ColumnType::Int), + vec![0, 0, 0, 4, 0x00, 0x0f, 0x12, 0x06] + ); +} + #[test] fn option_value() { - assert_eq!(serialized(Some(32_i32)), vec![0, 0, 0, 4, 0, 0, 0, 32]); + assert_eq!( + serialized(Some(32_i32), ColumnType::Int), + vec![0, 0, 0, 4, 0, 0, 0, 32] + ); let null_i32: Option = None; - assert_eq!(serialized(null_i32), &(-1_i32).to_be_bytes()[..]); + assert_eq!( + serialized(null_i32, ColumnType::Int), + &(-1_i32).to_be_bytes()[..] + ); } #[test] fn unset_value() { - assert_eq!(serialized(Unset), &(-2_i32).to_be_bytes()[..]); + assert_eq!( + serialized(Unset, ColumnType::Int), + &(-2_i32).to_be_bytes()[..] + ); let unset_i32: MaybeUnset = MaybeUnset::Unset; - assert_eq!(serialized(unset_i32), &(-2_i32).to_be_bytes()[..]); + assert_eq!( + serialized(unset_i32, ColumnType::Int), + &(-2_i32).to_be_bytes()[..] + ); let set_i32: MaybeUnset = MaybeUnset::Set(32); - assert_eq!(serialized(set_i32), vec![0, 0, 0, 4, 0, 0, 0, 32]); + assert_eq!( + serialized(set_i32, ColumnType::Int), + vec![0, 0, 0, 4, 0, 0, 0, 32] + ); } #[test] @@ -500,9 +938,12 @@ fn empty_array_value_list() { #[test] fn slice_value_list() { let values: &[i32] = &[1, 2, 3]; - let serialized: SerializedValues = <&[i32] as ValueList>::serialized(&values) - .unwrap() - .into_owned(); + let cols = &[ + col_spec("ala", ColumnType::Int), + col_spec("ma", ColumnType::Int), + col_spec("kota", ColumnType::Int), + ]; + let serialized = serialize_values(values, cols); assert_eq!( serialized.iter().collect::>(), @@ -517,9 +958,12 @@ fn slice_value_list() { #[test] fn vec_value_list() { let values: Vec = vec![1, 2, 3]; - let serialized: SerializedValues = as ValueList>::serialized(&values) - .unwrap() - .into_owned(); + let cols = &[ + col_spec("ala", ColumnType::Int), + col_spec("ma", ColumnType::Int), + col_spec("kota", ColumnType::Int), + ]; + let serialized = serialize_values(values, cols); assert_eq!( serialized.iter().collect::>(), @@ -531,10 +975,63 @@ fn vec_value_list() { ); } +fn col_spec(name: &str, typ: ColumnType) -> ColumnSpec { + ColumnSpec { + table_spec: TableSpec { + ks_name: "ks".to_string(), + table_name: "tbl".to_string(), + }, + name: name.to_string(), + typ, + } +} + +fn serialize_values( + vl: T, + columns: &[ColumnSpec], +) -> SerializedValues { + let serialized = ::serialized(&vl).unwrap().into_owned(); + let mut old_serialized = Vec::new(); + serialized.write_to_request(&mut old_serialized); + + let ctx = RowSerializationContext { columns }; + ::preliminary_type_check(&ctx).unwrap(); + let mut new_serialized = vec![0, 0]; + let mut writer = BufBackedRowWriter::new(&mut new_serialized); + ::serialize(&vl, &ctx, &mut writer).unwrap(); + let value_count: u16 = writer.value_count().try_into().unwrap(); + + // Prepend with value count, like `ValueList` does + new_serialized[0..2].copy_from_slice(&value_count.to_be_bytes()); + + assert_eq!(old_serialized, new_serialized); + + serialized +} + +fn serialize_values_only_new(vl: T, columns: &[ColumnSpec]) -> Vec { + let ctx = RowSerializationContext { columns }; + ::preliminary_type_check(&ctx).unwrap(); + let mut serialized = vec![0, 0]; + let mut writer = BufBackedRowWriter::new(&mut serialized); + ::serialize(&vl, &ctx, &mut writer).unwrap(); + let value_count: u16 = writer.value_count().try_into().unwrap(); + + // Prepend with value count, like `ValueList` does + serialized[0..2].copy_from_slice(&value_count.to_be_bytes()); + + serialized +} + #[test] fn tuple_value_list() { - fn check_i8_tuple(tuple: impl ValueList, expected: core::ops::Range) { - let serialized: SerializedValues = tuple.serialized().unwrap().into_owned(); + fn check_i8_tuple(tuple: impl ValueList + SerializeRow, expected: core::ops::Range) { + let typs = expected + .clone() + .enumerate() + .map(|(i, _)| col_spec(&format!("col_{i}"), ColumnType::TinyInt)) + .collect::>(); + let serialized = serialize_values(tuple, &typs); assert_eq!(serialized.len() as usize, expected.len()); let serialized_vals: Vec = serialized @@ -547,6 +1044,7 @@ fn tuple_value_list() { assert_eq!(serialized_vals, expected); } + check_i8_tuple((), 1..1); check_i8_tuple((1_i8,), 1..2); check_i8_tuple((1_i8, 2_i8), 1..3); check_i8_tuple((1_i8, 2_i8, 3_i8), 1..4); @@ -603,12 +1101,48 @@ fn tuple_value_list() { ); } +#[test] +fn map_value_list() { + // The legacy ValueList would serialize this as a list of named values, + // whereas the new SerializeRow will order the values by their names. + + // Note that the alphabetical order of the keys is "ala", "kota", "ma", + // but the impl sorts properly. + let row = BTreeMap::from_iter([("ala", 1), ("ma", 2), ("kota", 3)]); + let cols = &[ + col_spec("ala", ColumnType::Int), + col_spec("ma", ColumnType::Int), + col_spec("kota", ColumnType::Int), + ]; + let new_values = serialize_values_only_new(row.clone(), cols); + assert_eq!( + new_values, + vec![ + 0, 3, // value count: 3 + 0, 0, 0, 4, 0, 0, 0, 1, // ala: 1 + 0, 0, 0, 4, 0, 0, 0, 2, // ma: 2 + 0, 0, 0, 4, 0, 0, 0, 3, // kota: 3 + ] + ); + + // While ValueList will serialize differently, the fallback SerializeRow impl + // should convert it to how serialized BTreeMap would look like if serialized + // directly through SerializeRow. + let ser = <_ as ValueList>::serialized(&row).unwrap(); + let fallbacked = serialize_values_only_new(ser, cols); + + assert_eq!(new_values, fallbacked); +} + #[test] fn ref_value_list() { let values: &[i32] = &[1, 2, 3]; - let serialized: SerializedValues = <&&[i32] as ValueList>::serialized(&&values) - .unwrap() - .into_owned(); + let typs = &[ + col_spec("col_1", ColumnType::Int), + col_spec("col_2", ColumnType::Int), + col_spec("col_3", ColumnType::Int), + ]; + let serialized = serialize_values::<&&[i32]>(&values, typs); assert_eq!( serialized.iter().collect::>(), diff --git a/scylla-cql/src/types/serialize/mod.rs b/scylla-cql/src/types/serialize/mod.rs index 8e6c91983c..617fbc5f88 100644 --- a/scylla-cql/src/types/serialize/mod.rs +++ b/scylla-cql/src/types/serialize/mod.rs @@ -8,11 +8,19 @@ pub mod writers; pub use writers::{ BufBackedCellValueBuilder, BufBackedCellWriter, BufBackedRowWriter, CellValueBuilder, - CellWriter, CountingWriter, RowWriter, + CellWriter, CountingCellWriter, RowWriter, }; #[derive(Debug, Clone, Error)] pub struct SerializationError(Arc); +impl SerializationError { + /// Constructs a new `SerializationError`. + #[inline] + pub fn new(err: impl Error + Send + Sync + 'static) -> SerializationError { + SerializationError(Arc::new(err)) + } +} + impl Display for SerializationError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "SerializationError: {}", self.0) diff --git a/scylla-cql/src/types/serialize/row.rs b/scylla-cql/src/types/serialize/row.rs index e58bd6d32e..c2d8a45246 100644 --- a/scylla-cql/src/types/serialize/row.rs +++ b/scylla-cql/src/types/serialize/row.rs @@ -1,15 +1,20 @@ +use std::borrow::Cow; +use std::collections::{BTreeMap, HashSet}; +use std::fmt::Display; +use std::hash::BuildHasher; use std::{collections::HashMap, sync::Arc}; use thiserror::Error; -use crate::frame::value::ValueList; +use crate::frame::value::{SerializedValues, ValueList}; use crate::frame::{response::result::ColumnSpec, types::RawValue}; +use super::value::SerializeCql; use super::{CellWriter, RowWriter, SerializationError}; /// Contains information needed to serialize a row. pub struct RowSerializationContext<'a> { - columns: &'a [ColumnSpec], + pub(crate) columns: &'a [ColumnSpec], } impl<'a> RowSerializationContext<'a> { @@ -50,11 +55,200 @@ pub trait SerializeRow { ) -> Result<(), SerializationError>; } -impl SerializeRow for T { - fn preliminary_type_check( - _ctx: &RowSerializationContext<'_>, - ) -> Result<(), SerializationError> { - Ok(()) +macro_rules! fallback_impl_contents { + () => { + fn preliminary_type_check( + _ctx: &RowSerializationContext<'_>, + ) -> Result<(), SerializationError> { + Ok(()) + } + fn serialize( + &self, + ctx: &RowSerializationContext<'_>, + writer: &mut W, + ) -> Result<(), SerializationError> { + serialize_legacy_row(self, ctx, writer) + } + }; +} + +macro_rules! impl_serialize_row_for_unit { + () => { + fn preliminary_type_check( + ctx: &RowSerializationContext<'_>, + ) -> Result<(), SerializationError> { + if !ctx.columns().is_empty() { + return Err(mk_typck_err::( + BuiltinTypeCheckErrorKind::WrongColumnCount { + actual: 0, + asked_for: ctx.columns().len(), + }, + )); + } + Ok(()) + } + + fn serialize( + &self, + _ctx: &RowSerializationContext<'_>, + _writer: &mut W, + ) -> Result<(), SerializationError> { + // Row is empty - do nothing + Ok(()) + } + }; +} + +impl SerializeRow for () { + impl_serialize_row_for_unit!(); +} + +impl SerializeRow for [u8; 0] { + impl_serialize_row_for_unit!(); +} + +macro_rules! impl_serialize_row_for_slice { + () => { + fn preliminary_type_check( + ctx: &RowSerializationContext<'_>, + ) -> Result<(), SerializationError> { + // While we don't know how many columns will be there during serialization, + // we can at least check that all provided columns match T. + for col in ctx.columns() { + ::preliminary_type_check(&col.typ).map_err(|err| { + mk_typck_err::(BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed { + name: col.name.clone(), + err, + }) + })?; + } + Ok(()) + } + + fn serialize( + &self, + ctx: &RowSerializationContext<'_>, + writer: &mut W, + ) -> Result<(), SerializationError> { + if ctx.columns().len() != self.len() { + return Err(mk_typck_err::( + BuiltinTypeCheckErrorKind::WrongColumnCount { + actual: self.len(), + asked_for: ctx.columns().len(), + }, + )); + } + for (col, val) in ctx.columns().iter().zip(self.iter()) { + ::serialize(val, &col.typ, writer.make_cell_writer()).map_err( + |err| { + mk_ser_err::( + BuiltinSerializationErrorKind::ColumnSerializationFailed { + name: col.name.clone(), + err, + }, + ) + }, + )?; + } + Ok(()) + } + }; +} + +impl<'a, T: SerializeCql + 'a> SerializeRow for &'a [T] { + impl_serialize_row_for_slice!(); +} + +impl SerializeRow for Vec { + impl_serialize_row_for_slice!(); +} + +macro_rules! impl_serialize_row_for_map { + () => { + fn preliminary_type_check( + ctx: &RowSerializationContext<'_>, + ) -> Result<(), SerializationError> { + // While we don't know the column count or their names, + // we can go over all columns and check that their types match T. + for col in ctx.columns() { + ::preliminary_type_check(&col.typ).map_err(|err| { + mk_typck_err::(BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed { + name: col.name.clone(), + err, + }) + })?; + } + Ok(()) + } + + fn serialize( + &self, + ctx: &RowSerializationContext<'_>, + writer: &mut W, + ) -> Result<(), SerializationError> { + // Unfortunately, column names aren't guaranteed to be unique. + // We need to track not-yet-used columns in order to see + // whether some values were not used at the end, and report an error. + let mut unused_columns: HashSet<&str> = self.keys().map(|k| k.as_ref()).collect(); + + for col in ctx.columns.iter() { + match self.get(col.name.as_str()) { + None => { + return Err(mk_typck_err::( + BuiltinTypeCheckErrorKind::MissingValueForColumn { + name: col.name.clone(), + }, + )) + } + Some(v) => { + ::serialize(v, &col.typ, writer.make_cell_writer()) + .map_err(|err| { + mk_typck_err::( + BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed { + name: col.name.clone(), + err, + }, + ) + })?; + let _ = unused_columns.remove(col.name.as_str()); + } + } + } + + if !unused_columns.is_empty() { + // Report the lexicographically first value for deterministic error messages + let name = unused_columns.iter().min().unwrap(); + return Err(mk_typck_err::( + BuiltinTypeCheckErrorKind::ColumnMissingForValue { + name: name.to_string(), + }, + )); + } + + Ok(()) + } + }; +} + +impl SerializeRow for BTreeMap { + impl_serialize_row_for_map!(); +} + +impl SerializeRow for BTreeMap<&str, T> { + impl_serialize_row_for_map!(); +} + +impl SerializeRow for HashMap { + impl_serialize_row_for_map!(); +} + +impl SerializeRow for HashMap<&str, T, S> { + impl_serialize_row_for_map!(); +} + +impl SerializeRow for &T { + fn preliminary_type_check(ctx: &RowSerializationContext<'_>) -> Result<(), SerializationError> { + ::preliminary_type_check(ctx) } fn serialize( @@ -62,10 +256,200 @@ impl SerializeRow for T { ctx: &RowSerializationContext<'_>, writer: &mut W, ) -> Result<(), SerializationError> { - serialize_legacy_row(self, ctx, writer) + ::serialize(self, ctx, writer) } } +impl SerializeRow for SerializedValues { + fallback_impl_contents!(); +} + +impl<'b> SerializeRow for Cow<'b, SerializedValues> { + fallback_impl_contents!(); +} + +macro_rules! impl_tuple { + ( + $($typs:ident),*; + $($fidents:ident),*; + $($tidents:ident),*; + $length:expr + ) => { + impl<$($typs: SerializeCql),*> SerializeRow for ($($typs,)*) { + fn preliminary_type_check( + ctx: &RowSerializationContext<'_>, + ) -> Result<(), SerializationError> { + match ctx.columns() { + [$($tidents),*] => { + $( + <$typs as SerializeCql>::preliminary_type_check(&$tidents.typ).map_err(|err| { + mk_typck_err::(BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed { + name: $tidents.name.clone(), + err, + }) + })?; + )* + } + _ => return Err(mk_typck_err::( + BuiltinTypeCheckErrorKind::WrongColumnCount { + actual: $length, + asked_for: ctx.columns().len(), + }, + )), + }; + Ok(()) + } + + fn serialize( + &self, + ctx: &RowSerializationContext<'_>, + writer: &mut W, + ) -> Result<(), SerializationError> { + let ($($tidents,)*) = match ctx.columns() { + [$($tidents),*] => ($($tidents,)*), + _ => return Err(mk_typck_err::( + BuiltinTypeCheckErrorKind::WrongColumnCount { + actual: $length, + asked_for: ctx.columns().len(), + }, + )), + }; + let ($($fidents,)*) = self; + $( + <$typs as SerializeCql>::serialize($fidents, &$tidents.typ, writer.make_cell_writer()).map_err(|err| { + mk_ser_err::(BuiltinSerializationErrorKind::ColumnSerializationFailed { + name: $tidents.name.clone(), + err, + }) + })?; + )* + Ok(()) + } + } + }; +} + +macro_rules! impl_tuples { + (;;;$length:expr) => {}; + ( + $typ:ident$(, $($typs:ident),*)?; + $fident:ident$(, $($fidents:ident),*)?; + $tident:ident$(, $($tidents:ident),*)?; + $length:expr + ) => { + impl_tuples!( + $($($typs),*)?; + $($($fidents),*)?; + $($($tidents),*)?; + $length - 1 + ); + impl_tuple!( + $typ$(, $($typs),*)?; + $fident$(, $($fidents),*)?; + $tident$(, $($tidents),*)?; + $length + ); + }; +} + +impl_tuples!( + T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15; + f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15; + t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15; + 16 +); + +/// Implements the [`SerializeRow`] trait for a type, provided that the type +/// already implements the legacy +/// [`ValueList`](crate::frame::value::ValueList) trait. +/// +/// # Note +/// +/// The translation from one trait to another encounters a performance penalty +/// and does not utilize the stronger guarantees of `SerializeRow`. Before +/// resorting to this macro, you should consider other options instead: +/// +/// - If the impl was generated using the `ValueList` procedural macro, you +/// should switch to the `SerializeRow` procedural macro. *The new macro +/// behaves differently by default, so please read its documentation first!* +/// - If the impl was written by hand, it is still preferable to rewrite it +/// manually. You have an opportunity to make your serialization logic +/// type-safe and potentially improve performance. +/// +/// Basically, you should consider using the macro if you have a hand-written +/// impl and the moment it is not easy/not desirable to rewrite it. +/// +/// # Example +/// +/// ```rust +/// # use std::borrow::Cow; +/// # use scylla_cql::frame::value::{Value, ValueList, SerializedResult, SerializedValues}; +/// # use scylla_cql::impl_serialize_row_via_value_list; +/// struct NoGenerics {} +/// impl ValueList for NoGenerics { +/// fn serialized(&self) -> SerializedResult<'_> { +/// ().serialized() +/// } +/// } +/// impl_serialize_row_via_value_list!(NoGenerics); +/// +/// // Generic types are also supported. You must specify the bounds if the +/// // struct/enum contains any. +/// struct WithGenerics(T, U); +/// impl ValueList for WithGenerics { +/// fn serialized(&self) -> SerializedResult<'_> { +/// let mut values = SerializedValues::new(); +/// values.add_value(&self.0); +/// values.add_value(&self.1.clone()); +/// Ok(Cow::Owned(values)) +/// } +/// } +/// impl_serialize_row_via_value_list!(WithGenerics); +/// ``` +#[macro_export] +macro_rules! impl_serialize_row_via_value_list { + ($t:ident$(<$($targ:tt $(: $tbound:tt)?),*>)?) => { + impl $(<$($targ $(: $tbound)?),*>)? $crate::types::serialize::row::SerializeRow + for $t$(<$($targ),*>)? + where + Self: $crate::frame::value::ValueList, + { + fn preliminary_type_check( + _ctx: &$crate::types::serialize::row::RowSerializationContext<'_>, + ) -> ::std::result::Result<(), $crate::types::serialize::SerializationError> { + // No-op - the old interface didn't offer type safety + ::std::result::Result::Ok(()) + } + + fn serialize( + &self, + ctx: &$crate::types::serialize::row::RowSerializationContext<'_>, + writer: &mut W, + ) -> ::std::result::Result<(), $crate::types::serialize::SerializationError> { + $crate::types::serialize::row::serialize_legacy_row(self, ctx, writer) + } + } + }; +} + +/// Serializes an object implementing [`ValueList`] by using the [`RowWriter`] +/// interface. +/// +/// The function first serializes the value with [`ValueList::serialized`], then +/// parses the result and serializes it again with given `RowWriter`. In case +/// or serialized values with names, they are converted to serialized values +/// without names, based on the information about the bind markers provided +/// in the [`RowSerializationContext`]. +/// +/// It is a lazy and inefficient way to implement `RowWriter` via an existing +/// `ValueList` impl. +/// +/// Returns an error if `ValueList::serialized` call failed or, in case of +/// named serialized values, some bind markers couldn't be matched to a +/// named value. +/// +/// See [`impl_serialize_row_via_value_list`] which generates a boilerplate +/// [`SerializeRow`] implementation that uses this function. pub fn serialize_legacy_row( r: &T, ctx: &RowSerializationContext<'_>, @@ -79,7 +463,9 @@ pub fn serialize_legacy_row( let _proof = match value { RawValue::Null => cell_writer.set_null(), RawValue::Unset => cell_writer.set_unset(), - RawValue::Value(v) => cell_writer.set_value(v), + // The unwrap below will succeed because the value was successfully + // deserialized from the CQL format, so it must have + RawValue::Value(v) => cell_writer.set_value(v).unwrap(), }; }; @@ -106,6 +492,125 @@ pub fn serialize_legacy_row( Ok(()) } +/// Failed to type check values for a statement, represented by one of the types +/// built into the driver. +#[derive(Debug, Error, Clone)] +#[error("Failed to type check query arguments {rust_name}: {kind}")] +pub struct BuiltinTypeCheckError { + /// Name of the Rust type used to represent the values. + pub rust_name: &'static str, + + /// Detailed information about the failure. + pub kind: BuiltinTypeCheckErrorKind, +} + +fn mk_typck_err(kind: impl Into) -> SerializationError { + mk_typck_err_named(std::any::type_name::(), kind) +} + +fn mk_typck_err_named( + name: &'static str, + kind: impl Into, +) -> SerializationError { + SerializationError::new(BuiltinTypeCheckError { + rust_name: name, + kind: kind.into(), + }) +} + +/// Failed to serialize values for a statement, represented by one of the types +/// built into the driver. +#[derive(Debug, Error, Clone)] +#[error("Failed to serialize query arguments {rust_name}: {kind}")] +pub struct BuiltinSerializationError { + /// Name of the Rust type used to represent the values. + pub rust_name: &'static str, + + /// Detailed information about the failure. + pub kind: BuiltinSerializationErrorKind, +} + +fn mk_ser_err(kind: impl Into) -> SerializationError { + mk_ser_err_named(std::any::type_name::(), kind) +} + +fn mk_ser_err_named( + name: &'static str, + kind: impl Into, +) -> SerializationError { + SerializationError::new(BuiltinSerializationError { + rust_name: name, + kind: kind.into(), + }) +} + +/// Describes why type checking values for a statement failed. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum BuiltinTypeCheckErrorKind { + /// The Rust type expects `asked_for` column, but the query requires `actual`. + WrongColumnCount { actual: usize, asked_for: usize }, + + /// The Rust type provides a value for some column, but that column is not + /// present in the statement. + MissingValueForColumn { name: String }, + + /// A value required by the statement is not provided by the Rust type. + ColumnMissingForValue { name: String }, + + /// One of the columns failed to type check. + ColumnTypeCheckFailed { + name: String, + err: SerializationError, + }, +} + +impl Display for BuiltinTypeCheckErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BuiltinTypeCheckErrorKind::WrongColumnCount { actual, asked_for } => { + write!(f, "wrong column count: the query requires {asked_for} columns, but {actual} were provided") + } + BuiltinTypeCheckErrorKind::MissingValueForColumn { name } => { + write!( + f, + "value for column {name} was not provided, but the query requires it" + ) + } + BuiltinTypeCheckErrorKind::ColumnMissingForValue { name } => { + write!( + f, + "value for column {name} was provided, but there is no bind marker for this column in the query" + ) + } + BuiltinTypeCheckErrorKind::ColumnTypeCheckFailed { name, err } => { + write!(f, "failed to check column {name}: {err}") + } + } + } +} + +/// Describes why serializing values for a statement failed. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum BuiltinSerializationErrorKind { + /// One of the columns failed to serialize. + ColumnSerializationFailed { + name: String, + err: SerializationError, + }, +} + +impl Display for BuiltinSerializationErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BuiltinSerializationErrorKind::ColumnSerializationFailed { name, err } => { + write!(f, "failed to serialize column {name}: {err}") + } + } + } +} + #[derive(Error, Debug)] pub enum ValueListToSerializeRowAdapterError { #[error("There is no bind marker with name {name}, but a value for it was provided")] @@ -145,7 +650,14 @@ mod tests { let mut new_data = Vec::new(); let mut new_data_writer = BufBackedRowWriter::new(&mut new_data); - let ctx = RowSerializationContext { columns: &[] }; + let ctx = RowSerializationContext { + columns: &[ + col_spec("a", ColumnType::Int), + col_spec("b", ColumnType::Text), + col_spec("c", ColumnType::BigInt), + col_spec("b", ColumnType::Ascii), + ], + }; <_ as SerializeRow>::serialize(&row, &ctx, &mut new_data_writer).unwrap(); assert_eq!(new_data_writer.value_count(), 4); diff --git a/scylla-cql/src/types/serialize/value.rs b/scylla-cql/src/types/serialize/value.rs index fde7067265..5d81cdb938 100644 --- a/scylla-cql/src/types/serialize/value.rs +++ b/scylla-cql/src/types/serialize/value.rs @@ -1,11 +1,30 @@ +use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; +use std::fmt::Display; +use std::hash::BuildHasher; +use std::net::IpAddr; use std::sync::Arc; +use bigdecimal::BigDecimal; +use num_bigint::BigInt; use thiserror::Error; +use uuid::Uuid; -use crate::frame::response::result::ColumnType; -use crate::frame::value::Value; +#[cfg(feature = "chrono")] +use chrono::{DateTime, NaiveDate, NaiveTime, Utc}; -use super::{CellWriter, SerializationError}; +#[cfg(feature = "secret")] +use secrecy::{ExposeSecret, Secret, Zeroize}; + +use crate::frame::response::result::{ColumnType, CqlValue}; +use crate::frame::types::vint_encode; +use crate::frame::value::{ + Counter, CqlDate, CqlDuration, CqlTime, CqlTimestamp, MaybeUnset, Unset, Value, +}; + +#[cfg(feature = "chrono")] +use crate::frame::value::ValueOverflow; + +use super::{CellValueBuilder, CellWriter, SerializationError}; pub trait SerializeCql { /// Given a CQL type, checks if it _might_ be possible to serialize to that type. @@ -29,20 +48,1027 @@ pub trait SerializeCql { ) -> Result; } -impl SerializeCql for T { +macro_rules! impl_exact_preliminary_type_check { + ($($cql:tt),*) => { + fn preliminary_type_check(typ: &ColumnType) -> Result<(), SerializationError> { + match typ { + $(ColumnType::$cql)|* => Ok(()), + _ => Err(mk_typck_err::( + typ, + BuiltinTypeCheckErrorKind::MismatchedType { + expected: &[$(ColumnType::$cql),*], + } + )) + } + } + }; +} + +macro_rules! impl_serialize_via_writer { + (|$me:ident, $writer:ident| $e:expr) => { + impl_serialize_via_writer!(|$me, _typ, $writer| $e); + }; + (|$me:ident, $typ:ident, $writer:ident| $e:expr) => { + fn serialize( + &self, + typ: &ColumnType, + writer: W, + ) -> Result { + let $writer = writer; + let $typ = typ; + let $me = self; + let proof = $e; + Ok(proof) + } + }; +} + +impl SerializeCql for i8 { + impl_exact_preliminary_type_check!(TinyInt); + impl_serialize_via_writer!(|me, writer| writer.set_value(me.to_be_bytes().as_slice()).unwrap()); +} +impl SerializeCql for i16 { + impl_exact_preliminary_type_check!(SmallInt); + impl_serialize_via_writer!(|me, writer| writer.set_value(me.to_be_bytes().as_slice()).unwrap()); +} +impl SerializeCql for i32 { + impl_exact_preliminary_type_check!(Int); + impl_serialize_via_writer!(|me, writer| writer.set_value(me.to_be_bytes().as_slice()).unwrap()); +} +impl SerializeCql for i64 { + impl_exact_preliminary_type_check!(BigInt); + impl_serialize_via_writer!(|me, writer| writer.set_value(me.to_be_bytes().as_slice()).unwrap()); +} +impl SerializeCql for BigDecimal { + impl_exact_preliminary_type_check!(Decimal); + impl_serialize_via_writer!(|me, typ, writer| { + let mut builder = writer.into_value_builder(); + let (value, scale) = me.as_bigint_and_exponent(); + let scale: i32 = scale + .try_into() + .map_err(|_| mk_ser_err::(typ, BuiltinSerializationErrorKind::ValueOverflow))?; + builder.append_bytes(&scale.to_be_bytes()); + builder.append_bytes(&value.to_signed_bytes_be()); + builder + .finish() + .map_err(|_| mk_ser_err::(typ, BuiltinSerializationErrorKind::SizeOverflow))? + }); +} +impl SerializeCql for CqlDate { + impl_exact_preliminary_type_check!(Date); + impl_serialize_via_writer!(|me, writer| { + writer.set_value(me.0.to_be_bytes().as_slice()).unwrap() + }); +} +impl SerializeCql for CqlTimestamp { + impl_exact_preliminary_type_check!(Timestamp); + impl_serialize_via_writer!(|me, writer| { + writer.set_value(me.0.to_be_bytes().as_slice()).unwrap() + }); +} +impl SerializeCql for CqlTime { + impl_exact_preliminary_type_check!(Time); + impl_serialize_via_writer!(|me, writer| { + writer.set_value(me.0.to_be_bytes().as_slice()).unwrap() + }); +} +#[cfg(feature = "chrono")] +impl SerializeCql for NaiveDate { + impl_exact_preliminary_type_check!(Date); + impl_serialize_via_writer!(|me, typ, writer| { + ::serialize(&(*me).into(), typ, writer)? + }); +} +#[cfg(feature = "chrono")] +impl SerializeCql for DateTime { + impl_exact_preliminary_type_check!(Timestamp); + impl_serialize_via_writer!(|me, typ, writer| { + ::serialize(&(*me).into(), typ, writer)? + }); +} +#[cfg(feature = "chrono")] +impl SerializeCql for NaiveTime { + impl_exact_preliminary_type_check!(Time); + impl_serialize_via_writer!(|me, typ, writer| { + let cql_time = CqlTime::try_from(*me).map_err(|_: ValueOverflow| { + mk_ser_err::(typ, BuiltinSerializationErrorKind::ValueOverflow) + })?; + ::serialize(&cql_time, typ, writer)? + }); +} +#[cfg(feature = "chrono")] +impl SerializeCql for time::Date { + impl_exact_preliminary_type_check!(Date); + impl_serialize_via_writer!(|me, typ, writer| { + ::serialize(&(*me).into(), typ, writer)? + }); +} +#[cfg(feature = "chrono")] +impl SerializeCql for time::OffsetDateTime { + impl_exact_preliminary_type_check!(Timestamp); + impl_serialize_via_writer!(|me, typ, writer| { + ::serialize(&(*me).into(), typ, writer)? + }); +} +#[cfg(feature = "chrono")] +impl SerializeCql for time::Time { + impl_exact_preliminary_type_check!(Time); + impl_serialize_via_writer!(|me, typ, writer| { + ::serialize(&(*me).into(), typ, writer)? + }); +} +#[cfg(feature = "secret")] +impl SerializeCql for Secret { + fn preliminary_type_check(typ: &ColumnType) -> Result<(), SerializationError> { + V::preliminary_type_check(typ) + } + fn serialize( + &self, + typ: &ColumnType, + writer: W, + ) -> Result { + V::serialize(self.expose_secret(), typ, writer) + } +} +impl SerializeCql for bool { + impl_exact_preliminary_type_check!(Boolean); + impl_serialize_via_writer!(|me, writer| writer.set_value(&[*me as u8]).unwrap()); +} +impl SerializeCql for f32 { + impl_exact_preliminary_type_check!(Float); + impl_serialize_via_writer!(|me, writer| writer.set_value(me.to_be_bytes().as_slice()).unwrap()); +} +impl SerializeCql for f64 { + impl_exact_preliminary_type_check!(Double); + impl_serialize_via_writer!(|me, writer| writer.set_value(me.to_be_bytes().as_slice()).unwrap()); +} +impl SerializeCql for Uuid { + impl_exact_preliminary_type_check!(Uuid, Timeuuid); + impl_serialize_via_writer!(|me, writer| writer.set_value(me.as_bytes().as_ref()).unwrap()); +} +impl SerializeCql for BigInt { + impl_exact_preliminary_type_check!(Varint); + impl_serialize_via_writer!(|me, typ, writer| { + // TODO: The allocation here can be avoided and we can reimplement + // `to_signed_bytes_be` by using `to_u64_digits` and a bit of custom + // logic. Need better tests in order to do this. + writer + .set_value(me.to_signed_bytes_be().as_slice()) + .map_err(|_| mk_ser_err::(typ, BuiltinSerializationErrorKind::SizeOverflow))? + }); +} +impl SerializeCql for &str { + impl_exact_preliminary_type_check!(Ascii, Text); + impl_serialize_via_writer!(|me, typ, writer| { + writer + .set_value(me.as_bytes()) + .map_err(|_| mk_ser_err::(typ, BuiltinSerializationErrorKind::SizeOverflow))? + }); +} +impl SerializeCql for Vec { + impl_exact_preliminary_type_check!(Blob); + impl_serialize_via_writer!(|me, typ, writer| { + writer + .set_value(me.as_ref()) + .map_err(|_| mk_ser_err::(typ, BuiltinSerializationErrorKind::SizeOverflow))? + }); +} +impl SerializeCql for &[u8] { + impl_exact_preliminary_type_check!(Blob); + impl_serialize_via_writer!(|me, typ, writer| { + writer + .set_value(me) + .map_err(|_| mk_ser_err::(typ, BuiltinSerializationErrorKind::SizeOverflow))? + }); +} +impl SerializeCql for [u8; N] { + impl_exact_preliminary_type_check!(Blob); + impl_serialize_via_writer!(|me, typ, writer| { + writer + .set_value(me.as_ref()) + .map_err(|_| mk_ser_err::(typ, BuiltinSerializationErrorKind::SizeOverflow))? + }); +} +impl SerializeCql for IpAddr { + impl_exact_preliminary_type_check!(Inet); + impl_serialize_via_writer!(|me, writer| { + match me { + IpAddr::V4(ip) => writer.set_value(&ip.octets()).unwrap(), + IpAddr::V6(ip) => writer.set_value(&ip.octets()).unwrap(), + } + }); +} +impl SerializeCql for String { + impl_exact_preliminary_type_check!(Ascii, Text); + impl_serialize_via_writer!(|me, typ, writer| { + writer + .set_value(me.as_bytes()) + .map_err(|_| mk_ser_err::(typ, BuiltinSerializationErrorKind::SizeOverflow))? + }); +} +impl SerializeCql for Option { + fn preliminary_type_check(typ: &ColumnType) -> Result<(), SerializationError> { + T::preliminary_type_check(typ) + } + fn serialize( + &self, + typ: &ColumnType, + writer: W, + ) -> Result { + match self { + Some(v) => v.serialize(typ, writer), + None => Ok(writer.set_null()), + } + } +} +impl SerializeCql for Unset { fn preliminary_type_check(_typ: &ColumnType) -> Result<(), SerializationError> { - Ok(()) + Ok(()) // Fits everything + } + impl_serialize_via_writer!(|_me, writer| writer.set_unset()); +} +impl SerializeCql for Counter { + impl_exact_preliminary_type_check!(Counter); + impl_serialize_via_writer!(|me, writer| { + writer.set_value(me.0.to_be_bytes().as_slice()).unwrap() + }); +} +impl SerializeCql for CqlDuration { + impl_exact_preliminary_type_check!(Duration); + impl_serialize_via_writer!(|me, writer| { + // TODO: adjust vint_encode to use CellValueBuilder or something like that + let mut buf = Vec::with_capacity(27); // worst case size is 27 + vint_encode(me.months as i64, &mut buf); + vint_encode(me.days as i64, &mut buf); + vint_encode(me.nanoseconds, &mut buf); + writer.set_value(buf.as_slice()).unwrap() + }); +} +impl SerializeCql for MaybeUnset { + fn preliminary_type_check(typ: &ColumnType) -> Result<(), SerializationError> { + V::preliminary_type_check(typ) + } + fn serialize( + &self, + typ: &ColumnType, + writer: W, + ) -> Result { + match self { + MaybeUnset::Set(v) => v.serialize(typ, writer), + MaybeUnset::Unset => Ok(writer.set_unset()), + } + } +} +impl SerializeCql for &T { + fn preliminary_type_check(typ: &ColumnType) -> Result<(), SerializationError> { + T::preliminary_type_check(typ) + } + fn serialize( + &self, + typ: &ColumnType, + writer: W, + ) -> Result { + T::serialize(*self, typ, writer) + } +} +impl SerializeCql for Box { + fn preliminary_type_check(typ: &ColumnType) -> Result<(), SerializationError> { + T::preliminary_type_check(typ) + } + fn serialize( + &self, + typ: &ColumnType, + writer: W, + ) -> Result { + T::serialize(&**self, typ, writer) + } +} +impl SerializeCql for HashSet { + fn preliminary_type_check(typ: &ColumnType) -> Result<(), SerializationError> { + match typ { + ColumnType::Set(elt) => V::preliminary_type_check(elt).map_err(|err| { + mk_typck_err::( + typ, + SetOrListTypeCheckErrorKind::ElementTypeCheckFailed(err), + ) + }), + _ => Err(mk_typck_err::( + typ, + SetOrListTypeCheckErrorKind::NotSetOrList, + )), + } + } + + fn serialize( + &self, + typ: &ColumnType, + writer: W, + ) -> Result { + serialize_sequence( + std::any::type_name::(), + self.len(), + self.iter(), + typ, + writer, + ) + } +} +impl SerializeCql for HashMap { + fn preliminary_type_check(typ: &ColumnType) -> Result<(), SerializationError> { + match typ { + ColumnType::Map(k, v) => { + K::preliminary_type_check(k).map_err(|err| { + mk_typck_err::(typ, MapTypeCheckErrorKind::KeyTypeCheckFailed(err)) + })?; + V::preliminary_type_check(v).map_err(|err| { + mk_typck_err::(typ, MapTypeCheckErrorKind::ValueTypeCheckFailed(err)) + })?; + Ok(()) + } + _ => Err(mk_typck_err::(typ, MapTypeCheckErrorKind::NotMap)), + } + } + + fn serialize( + &self, + typ: &ColumnType, + writer: W, + ) -> Result { + serialize_mapping( + std::any::type_name::(), + self.len(), + self.iter(), + typ, + writer, + ) + } +} +impl SerializeCql for BTreeSet { + fn preliminary_type_check(typ: &ColumnType) -> Result<(), SerializationError> { + match typ { + ColumnType::Set(elt) => V::preliminary_type_check(elt).map_err(|err| { + mk_typck_err::( + typ, + SetOrListTypeCheckErrorKind::ElementTypeCheckFailed(err), + ) + }), + _ => Err(mk_typck_err::( + typ, + SetOrListTypeCheckErrorKind::NotSetOrList, + )), + } + } + + fn serialize( + &self, + typ: &ColumnType, + writer: W, + ) -> Result { + serialize_sequence( + std::any::type_name::(), + self.len(), + self.iter(), + typ, + writer, + ) + } +} +impl SerializeCql for BTreeMap { + fn preliminary_type_check(typ: &ColumnType) -> Result<(), SerializationError> { + match typ { + ColumnType::Map(k, v) => { + K::preliminary_type_check(k).map_err(|err| { + mk_typck_err::(typ, MapTypeCheckErrorKind::KeyTypeCheckFailed(err)) + })?; + V::preliminary_type_check(v).map_err(|err| { + mk_typck_err::(typ, MapTypeCheckErrorKind::ValueTypeCheckFailed(err)) + })?; + Ok(()) + } + _ => Err(mk_typck_err::(typ, MapTypeCheckErrorKind::NotMap)), + } + } + + fn serialize( + &self, + typ: &ColumnType, + writer: W, + ) -> Result { + serialize_mapping( + std::any::type_name::(), + self.len(), + self.iter(), + typ, + writer, + ) + } +} +impl SerializeCql for Vec { + fn preliminary_type_check(typ: &ColumnType) -> Result<(), SerializationError> { + match typ { + ColumnType::List(elt) | ColumnType::Set(elt) => { + T::preliminary_type_check(elt).map_err(|err| { + mk_typck_err::( + typ, + SetOrListTypeCheckErrorKind::ElementTypeCheckFailed(err), + ) + }) + } + _ => Err(mk_typck_err::( + typ, + SetOrListTypeCheckErrorKind::NotSetOrList, + )), + } + } + + fn serialize( + &self, + typ: &ColumnType, + writer: W, + ) -> Result { + serialize_sequence( + std::any::type_name::(), + self.len(), + self.iter(), + typ, + writer, + ) + } +} +impl<'a, T: SerializeCql + 'a> SerializeCql for &'a [T] { + fn preliminary_type_check(typ: &ColumnType) -> Result<(), SerializationError> { + match typ { + ColumnType::List(elt) | ColumnType::Set(elt) => { + T::preliminary_type_check(elt).map_err(|err| { + mk_typck_err::( + typ, + SetOrListTypeCheckErrorKind::ElementTypeCheckFailed(err), + ) + }) + } + _ => Err(mk_typck_err::( + typ, + SetOrListTypeCheckErrorKind::NotSetOrList, + )), + } + } + + fn serialize( + &self, + typ: &ColumnType, + writer: W, + ) -> Result { + serialize_sequence( + std::any::type_name::(), + self.len(), + self.iter(), + typ, + writer, + ) + } +} +impl SerializeCql for CqlValue { + fn preliminary_type_check(typ: &ColumnType) -> Result<(), SerializationError> { + match typ { + ColumnType::Custom(_) => Err(mk_typck_err::( + typ, + BuiltinTypeCheckErrorKind::CustomTypeUnsupported, + )), + _ => Ok(()), + } } fn serialize( &self, - _typ: &ColumnType, + typ: &ColumnType, writer: W, ) -> Result { - serialize_legacy_value(self, writer) + serialize_cql_value(self, typ, writer).map_err(fix_cql_value_name_in_err) + } +} + +fn serialize_cql_value( + value: &CqlValue, + typ: &ColumnType, + writer: W, +) -> Result { + match value { + CqlValue::Ascii(a) => check_and_serialize(a, typ, writer), + CqlValue::Boolean(b) => check_and_serialize(b, typ, writer), + CqlValue::Blob(b) => check_and_serialize(b, typ, writer), + CqlValue::Counter(c) => check_and_serialize(c, typ, writer), + CqlValue::Decimal(d) => check_and_serialize(d, typ, writer), + CqlValue::Date(d) => check_and_serialize(d, typ, writer), + CqlValue::Double(d) => check_and_serialize(d, typ, writer), + CqlValue::Duration(d) => check_and_serialize(d, typ, writer), + CqlValue::Empty => { + if !typ.supports_special_empty_value() { + return Err(mk_typck_err::( + typ, + BuiltinTypeCheckErrorKind::NotEmptyable, + )); + } + Ok(writer.set_value(&[]).unwrap()) + } + CqlValue::Float(f) => check_and_serialize(f, typ, writer), + CqlValue::Int(i) => check_and_serialize(i, typ, writer), + CqlValue::BigInt(b) => check_and_serialize(b, typ, writer), + CqlValue::Text(t) => check_and_serialize(t, typ, writer), + CqlValue::Timestamp(t) => check_and_serialize(t, typ, writer), + CqlValue::Inet(i) => check_and_serialize(i, typ, writer), + CqlValue::List(l) => check_and_serialize(l, typ, writer), + CqlValue::Map(m) => serialize_mapping( + std::any::type_name::(), + m.len(), + m.iter().map(|(ref k, ref v)| (k, v)), + typ, + writer, + ), + CqlValue::Set(s) => check_and_serialize(s, typ, writer), + CqlValue::UserDefinedType { + keyspace, + type_name, + fields, + } => serialize_udt(typ, keyspace, type_name, fields, writer), + CqlValue::SmallInt(s) => check_and_serialize(s, typ, writer), + CqlValue::TinyInt(t) => check_and_serialize(t, typ, writer), + CqlValue::Time(t) => check_and_serialize(t, typ, writer), + CqlValue::Timeuuid(t) => check_and_serialize(t, typ, writer), + CqlValue::Tuple(t) => { + // We allow serializing tuples that have less fields + // than the database tuple, but not the other way around. + let fields = match typ { + ColumnType::Tuple(fields) => { + if fields.len() < t.len() { + return Err(mk_typck_err::( + typ, + TupleTypeCheckErrorKind::WrongElementCount { + actual: t.len(), + asked_for: fields.len(), + }, + )); + } + fields + } + _ => { + return Err(mk_typck_err::( + typ, + TupleTypeCheckErrorKind::NotTuple, + )) + } + }; + serialize_tuple_like(typ, fields.iter(), t.iter(), writer) + } + CqlValue::Uuid(u) => check_and_serialize(u, typ, writer), + CqlValue::Varint(v) => check_and_serialize(v, typ, writer), + } +} + +fn fix_cql_value_name_in_err(mut err: SerializationError) -> SerializationError { + // The purpose of this function is to change the `rust_name` field + // in the error to CqlValue. Most of the time, the `err` given to the + // function here will be the sole owner of the data, so theoretically + // we could fix this in place. + + let rust_name = std::any::type_name::(); + + match Arc::get_mut(&mut err.0) { + Some(err_mut) => { + if let Some(err) = err_mut.downcast_mut::() { + err.rust_name = rust_name; + } else if let Some(err) = err_mut.downcast_mut::() { + err.rust_name = rust_name; + } + } + None => { + // The `None` case shouldn't happen consisdering how we are using + // the function in the code now, but let's provide it here anyway + // for correctness. + if let Some(err) = err.0.downcast_ref::() { + if err.rust_name != rust_name { + return SerializationError::new(BuiltinTypeCheckError { + rust_name, + ..err.clone() + }); + } + } + if let Some(err) = err.0.downcast_ref::() { + if err.rust_name != rust_name { + return SerializationError::new(BuiltinSerializationError { + rust_name, + ..err.clone() + }); + } + } + } + }; + + err +} + +fn check_and_serialize( + v: &V, + typ: &ColumnType, + writer: W, +) -> Result { + V::preliminary_type_check(typ)?; + v.serialize(typ, writer) +} + +fn serialize_udt( + typ: &ColumnType, + keyspace: &str, + type_name: &str, + values: &[(String, Option)], + writer: W, +) -> Result { + let (dst_type_name, dst_keyspace, field_types) = match typ { + ColumnType::UserDefinedType { + type_name, + keyspace, + field_types, + } => (type_name, keyspace, field_types), + _ => return Err(mk_typck_err::(typ, UdtTypeCheckErrorKind::NotUdt)), + }; + + if keyspace != dst_keyspace || type_name != dst_type_name { + return Err(mk_typck_err::( + typ, + UdtTypeCheckErrorKind::NameMismatch { + keyspace: dst_keyspace.clone(), + type_name: dst_type_name.clone(), + }, + )); + } + + // Allow columns present in the CQL type which are not present in CqlValue, + // but not the other way around + let mut indexed_fields: HashMap<_, _> = values.iter().map(|(k, v)| (k.as_str(), v)).collect(); + + let mut builder = writer.into_value_builder(); + for (fname, ftyp) in field_types { + // Take a value from the original list. + // If a field is missing, write null instead. + let fvalue = indexed_fields + .remove(fname.as_str()) + .and_then(|x| x.as_ref()); + + let writer = builder.make_sub_writer(); + match fvalue { + None => writer.set_null(), + Some(v) => serialize_cql_value(v, ftyp, writer).map_err(|err| { + let err = fix_cql_value_name_in_err(err); + mk_ser_err::( + typ, + UdtSerializationErrorKind::FieldSerializationFailed { + field_name: fname.clone(), + err, + }, + ) + })?, + }; + } + + // If there are some leftover fields, it's an error. + if !indexed_fields.is_empty() { + // In order to have deterministic errors, return an error about + // the lexicographically smallest field. + let fname = indexed_fields.keys().min().unwrap(); + return Err(mk_typck_err::( + typ, + UdtTypeCheckErrorKind::UnexpectedFieldInDestination { + field_name: fname.to_string(), + }, + )); + } + + builder + .finish() + .map_err(|_| mk_ser_err::(typ, BuiltinSerializationErrorKind::SizeOverflow)) +} + +fn serialize_tuple_like<'t, W: CellWriter>( + typ: &ColumnType, + field_types: impl Iterator, + field_values: impl Iterator>, + writer: W, +) -> Result { + let mut builder = writer.into_value_builder(); + + for (index, (el, typ)) in field_values.zip(field_types).enumerate() { + let sub = builder.make_sub_writer(); + match el { + None => sub.set_null(), + Some(el) => serialize_cql_value(el, typ, sub).map_err(|err| { + let err = fix_cql_value_name_in_err(err); + mk_ser_err::( + typ, + TupleSerializationErrorKind::ElementSerializationFailed { index, err }, + ) + })?, + }; + } + + builder + .finish() + .map_err(|_| mk_ser_err::(typ, BuiltinSerializationErrorKind::SizeOverflow)) +} + +macro_rules! impl_tuple { + ( + $($typs:ident),*; + $($fidents:ident),*; + $($tidents:ident),*; + $length:expr + ) => { + impl<$($typs: SerializeCql),*> SerializeCql for ($($typs,)*) { + fn preliminary_type_check(typ: &ColumnType) -> Result<(), SerializationError> { + match typ { + ColumnType::Tuple(typs) => match typs.as_slice() { + [$($tidents),*, ..] => { + let index = 0; + $( + <$typs as SerializeCql>::preliminary_type_check($tidents) + .map_err(|err| + mk_typck_err::( + typ, + TupleTypeCheckErrorKind::ElementTypeCheckFailed { + index, + err, + } + ) + )?; + let index = index + 1; + )* + let _ = index; + } + _ => return Err(mk_typck_err::( + typ, + TupleTypeCheckErrorKind::WrongElementCount { + actual: $length, + asked_for: typs.len(), + } + )) + } + _ => return Err(mk_typck_err::( + typ, + TupleTypeCheckErrorKind::NotTuple + )), + }; + Ok(()) + } + + fn serialize( + &self, + typ: &ColumnType, + writer: W, + ) -> Result { + let ($($tidents,)*) = match typ { + ColumnType::Tuple(typs) => match typs.as_slice() { + [$($tidents),*] => ($($tidents,)*), + _ => return Err(mk_typck_err::( + typ, + TupleTypeCheckErrorKind::WrongElementCount { + actual: $length, + asked_for: typs.len(), + } + )) + } + _ => return Err(mk_typck_err::( + typ, + TupleTypeCheckErrorKind::NotTuple, + )) + }; + let ($($fidents,)*) = self; + let mut builder = writer.into_value_builder(); + let index = 0; + $( + <$typs as SerializeCql>::serialize($fidents, $tidents, builder.make_sub_writer()) + .map_err(|err| mk_ser_err::( + typ, + TupleSerializationErrorKind::ElementSerializationFailed { + index, + err, + } + ))?; + let index = index + 1; + )* + let _ = index; + builder + .finish() + .map_err(|_| mk_ser_err::(typ, BuiltinSerializationErrorKind::SizeOverflow)) + } + } + }; +} + +macro_rules! impl_tuples { + (;;;$length:expr) => {}; + ( + $typ:ident$(, $($typs:ident),*)?; + $fident:ident$(, $($fidents:ident),*)?; + $tident:ident$(, $($tidents:ident),*)?; + $length:expr + ) => { + impl_tuples!( + $($($typs),*)?; + $($($fidents),*)?; + $($($tidents),*)?; + $length - 1 + ); + impl_tuple!( + $typ$(, $($typs),*)?; + $fident$(, $($fidents),*)?; + $tident$(, $($tidents),*)?; + $length + ); + }; +} + +impl_tuples!( + T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15; + f0, f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12, f13, f14, f15; + t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15; + 16 +); + +fn serialize_sequence<'t, T: SerializeCql + 't, W: CellWriter>( + rust_name: &'static str, + len: usize, + iter: impl Iterator, + typ: &ColumnType, + writer: W, +) -> Result { + let elt = match typ { + ColumnType::List(elt) | ColumnType::Set(elt) => elt, + _ => { + return Err(mk_typck_err_named( + rust_name, + typ, + SetOrListTypeCheckErrorKind::NotSetOrList, + )); + } + }; + + let mut builder = writer.into_value_builder(); + + let element_count: i32 = len.try_into().map_err(|_| { + mk_ser_err_named( + rust_name, + typ, + SetOrListSerializationErrorKind::TooManyElements, + ) + })?; + builder.append_bytes(&element_count.to_be_bytes()); + + for el in iter { + T::serialize(el, elt, builder.make_sub_writer()).map_err(|err| { + mk_ser_err_named( + rust_name, + typ, + SetOrListSerializationErrorKind::ElementSerializationFailed(err), + ) + })?; } + + builder + .finish() + .map_err(|_| mk_ser_err_named(rust_name, typ, BuiltinSerializationErrorKind::SizeOverflow)) } +fn serialize_mapping<'t, K: SerializeCql + 't, V: SerializeCql + 't, W: CellWriter>( + rust_name: &'static str, + len: usize, + iter: impl Iterator, + typ: &ColumnType, + writer: W, +) -> Result { + let (ktyp, vtyp) = match typ { + ColumnType::Map(k, v) => (k, v), + _ => { + return Err(mk_typck_err_named( + rust_name, + typ, + MapTypeCheckErrorKind::NotMap, + )); + } + }; + + let mut builder = writer.into_value_builder(); + + let element_count: i32 = len.try_into().map_err(|_| { + mk_ser_err_named(rust_name, typ, MapSerializationErrorKind::TooManyElements) + })?; + builder.append_bytes(&element_count.to_be_bytes()); + + for (k, v) in iter { + K::serialize(k, ktyp, builder.make_sub_writer()).map_err(|err| { + mk_ser_err_named( + rust_name, + typ, + MapSerializationErrorKind::KeySerializationFailed(err), + ) + })?; + V::serialize(v, vtyp, builder.make_sub_writer()).map_err(|err| { + mk_ser_err_named( + rust_name, + typ, + MapSerializationErrorKind::ValueSerializationFailed(err), + ) + })?; + } + + builder + .finish() + .map_err(|_| mk_ser_err_named(rust_name, typ, BuiltinSerializationErrorKind::SizeOverflow)) +} + +/// Implements the [`SerializeCql`] trait for a type, provided that the type +/// already implements the legacy [`Value`](crate::frame::value::Value) trait. +/// +/// # Note +/// +/// The translation from one trait to another encounters a performance penalty +/// and does not utilize the stronger guarantees of `SerializeCql`. Before +/// resorting to this macro, you should consider other options instead: +/// +/// - If the impl was generated using the `Value` procedural macro, you should +/// switch to the `SerializeCql` procedural macro. *The new macro behaves +/// differently by default, so please read its documentation first!* +/// - If the impl was written by hand, it is still preferable to rewrite it +/// manually. You have an opportunity to make your serialization logic +/// type-safe and potentially improve performance. +/// +/// Basically, you should consider using the macro if you have a hand-written +/// impl and the moment it is not easy/not desirable to rewrite it. +/// +/// # Example +/// +/// ```rust +/// # use scylla_cql::frame::value::{Value, ValueTooBig}; +/// # use scylla_cql::impl_serialize_cql_via_value; +/// struct NoGenerics {} +/// impl Value for NoGenerics { +/// fn serialize(&self, _buf: &mut Vec) -> Result<(), ValueTooBig> { +/// Ok(()) +/// } +/// } +/// impl_serialize_cql_via_value!(NoGenerics); +/// +/// // Generic types are also supported. You must specify the bounds if the +/// // struct/enum contains any. +/// struct WithGenerics(T, U); +/// impl Value for WithGenerics { +/// fn serialize(&self, buf: &mut Vec) -> Result<(), ValueTooBig> { +/// self.0.serialize(buf)?; +/// self.1.clone().serialize(buf)?; +/// Ok(()) +/// } +/// } +/// impl_serialize_cql_via_value!(WithGenerics); +/// ``` +#[macro_export] +macro_rules! impl_serialize_cql_via_value { + ($t:ident$(<$($targ:tt $(: $tbound:tt)?),*>)?) => { + impl $(<$($targ $(: $tbound)?),*>)? $crate::types::serialize::value::SerializeCql + for $t$(<$($targ),*>)? + where + Self: $crate::frame::value::Value, + { + fn preliminary_type_check( + _typ: &$crate::frame::response::result::ColumnType, + ) -> ::std::result::Result<(), $crate::types::serialize::SerializationError> { + // No-op - the old interface didn't offer type safety + ::std::result::Result::Ok(()) + } + + fn serialize( + &self, + _typ: &$crate::frame::response::result::ColumnType, + writer: W, + ) -> ::std::result::Result< + W::WrittenCellProof, + $crate::types::serialize::SerializationError, + > { + $crate::types::serialize::value::serialize_legacy_value(self, writer) + } + } + }; +} + +/// Serializes a value implementing [`Value`] by using the [`CellWriter`] +/// interface. +/// +/// The function first serializes the value with [`Value::serialize`], then +/// parses the result and serializes it again with given `CellWriter`. It is +/// a lazy and inefficient way to implement `CellWriter` via an existing `Value` +/// impl. +/// +/// Returns an error if the result of the `Value::serialize` call was not +/// a properly encoded `[value]` as defined in the CQL protocol spec. +/// +/// See [`impl_serialize_cql_via_value`] which generates a boilerplate +/// [`SerializeCql`] implementation that uses this function. pub fn serialize_legacy_value( v: &T, writer: W, @@ -73,7 +1099,7 @@ pub fn serialize_legacy_value( }, ))) } else { - Ok(writer.set_value(contents)) + Ok(writer.set_value(contents).unwrap()) // len <= i32::MAX, so unwrap will succeed } } _ => Err(SerializationError(Arc::new( @@ -82,6 +1108,455 @@ pub fn serialize_legacy_value( } } +/// Type checking of one of the built-in types failed. +#[derive(Debug, Error, Clone)] +#[error("Failed to type check Rust type {rust_name} against CQL type {got:?}: {kind}")] +pub struct BuiltinTypeCheckError { + /// Name of the Rust type being serialized. + pub rust_name: &'static str, + + /// The CQL type that the Rust type was being serialized to. + pub got: ColumnType, + + /// Detailed information about the failure. + pub kind: BuiltinTypeCheckErrorKind, +} + +fn mk_typck_err( + got: &ColumnType, + kind: impl Into, +) -> SerializationError { + mk_typck_err_named(std::any::type_name::(), got, kind) +} + +fn mk_typck_err_named( + name: &'static str, + got: &ColumnType, + kind: impl Into, +) -> SerializationError { + SerializationError::new(BuiltinTypeCheckError { + rust_name: name, + got: got.clone(), + kind: kind.into(), + }) +} + +/// Serialization of one of the built-in types failed. +#[derive(Debug, Error, Clone)] +#[error("Failed to serialize Rust type {rust_name} into CQL type {got:?}: {kind}")] +pub struct BuiltinSerializationError { + /// Name of the Rust type being serialized. + pub rust_name: &'static str, + + /// The CQL type that the Rust type was being serialized to. + pub got: ColumnType, + + /// Detailed information about the failure. + pub kind: BuiltinSerializationErrorKind, +} + +fn mk_ser_err( + got: &ColumnType, + kind: impl Into, +) -> SerializationError { + mk_ser_err_named(std::any::type_name::(), got, kind) +} + +fn mk_ser_err_named( + name: &'static str, + got: &ColumnType, + kind: impl Into, +) -> SerializationError { + SerializationError::new(BuiltinSerializationError { + rust_name: name, + got: got.clone(), + kind: kind.into(), + }) +} + +/// Describes why type checking some of the built-in types has failed. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum BuiltinTypeCheckErrorKind { + /// Expected one from a list of particular types. + MismatchedType { expected: &'static [ColumnType] }, + + /// Expected a type that can be empty. + NotEmptyable, + + /// A type check failure specific to a CQL set or list. + SetOrListError(SetOrListTypeCheckErrorKind), + + /// A type check failure specific to a CQL map. + MapError(MapTypeCheckErrorKind), + + /// A type check failure specific to a CQL tuple. + TupleError(TupleTypeCheckErrorKind), + + /// A type check failure specific to a CQL UDT. + UdtError(UdtTypeCheckErrorKind), + + /// Custom CQL type - unsupported + // TODO: Should we actually support it? Counters used to be implemented like that. + CustomTypeUnsupported, +} + +impl From for BuiltinTypeCheckErrorKind { + fn from(value: SetOrListTypeCheckErrorKind) -> Self { + BuiltinTypeCheckErrorKind::SetOrListError(value) + } +} + +impl From for BuiltinTypeCheckErrorKind { + fn from(value: MapTypeCheckErrorKind) -> Self { + BuiltinTypeCheckErrorKind::MapError(value) + } +} + +impl From for BuiltinTypeCheckErrorKind { + fn from(value: TupleTypeCheckErrorKind) -> Self { + BuiltinTypeCheckErrorKind::TupleError(value) + } +} + +impl From for BuiltinTypeCheckErrorKind { + fn from(value: UdtTypeCheckErrorKind) -> Self { + BuiltinTypeCheckErrorKind::UdtError(value) + } +} + +impl Display for BuiltinTypeCheckErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BuiltinTypeCheckErrorKind::MismatchedType { expected } => { + write!(f, "expected one of the CQL types: {expected:?}") + } + BuiltinTypeCheckErrorKind::NotEmptyable => { + write!( + f, + "the separate empty representation is not valid for this type" + ) + } + BuiltinTypeCheckErrorKind::SetOrListError(err) => err.fmt(f), + BuiltinTypeCheckErrorKind::MapError(err) => err.fmt(f), + BuiltinTypeCheckErrorKind::TupleError(err) => err.fmt(f), + BuiltinTypeCheckErrorKind::UdtError(err) => err.fmt(f), + BuiltinTypeCheckErrorKind::CustomTypeUnsupported => { + write!(f, "custom CQL types are unsupported") + } + } + } +} + +/// Describes why serialization of some of the built-in types has failed. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum BuiltinSerializationErrorKind { + /// The size of the Rust value is too large to fit in the CQL serialization + /// format (over i32::MAX bytes). + SizeOverflow, + + /// The Rust value is out of range supported by the CQL type. + ValueOverflow, + + /// A serialization failure specific to a CQL set or list. + SetOrListError(SetOrListSerializationErrorKind), + + /// A serialization failure specific to a CQL map. + MapError(MapSerializationErrorKind), + + /// A serialization failure specific to a CQL tuple. + TupleError(TupleSerializationErrorKind), + + /// A serialization failure specific to a CQL UDT. + UdtError(UdtSerializationErrorKind), +} + +impl From for BuiltinSerializationErrorKind { + fn from(value: SetOrListSerializationErrorKind) -> Self { + BuiltinSerializationErrorKind::SetOrListError(value) + } +} + +impl From for BuiltinSerializationErrorKind { + fn from(value: MapSerializationErrorKind) -> Self { + BuiltinSerializationErrorKind::MapError(value) + } +} + +impl From for BuiltinSerializationErrorKind { + fn from(value: TupleSerializationErrorKind) -> Self { + BuiltinSerializationErrorKind::TupleError(value) + } +} + +impl From for BuiltinSerializationErrorKind { + fn from(value: UdtSerializationErrorKind) -> Self { + BuiltinSerializationErrorKind::UdtError(value) + } +} + +impl Display for BuiltinSerializationErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BuiltinSerializationErrorKind::SizeOverflow => { + write!( + f, + "the Rust value is too big to be serialized in the CQL protocol format" + ) + } + BuiltinSerializationErrorKind::ValueOverflow => { + write!( + f, + "the Rust value is out of range supported by the CQL type" + ) + } + BuiltinSerializationErrorKind::SetOrListError(err) => err.fmt(f), + BuiltinSerializationErrorKind::MapError(err) => err.fmt(f), + BuiltinSerializationErrorKind::TupleError(err) => err.fmt(f), + BuiltinSerializationErrorKind::UdtError(err) => err.fmt(f), + } + } +} + +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum MapTypeCheckErrorKind { + /// The CQL type is not a map. + NotMap, + + /// Checking the map key type failed. + KeyTypeCheckFailed(SerializationError), + + /// Checking the map value type failed. + ValueTypeCheckFailed(SerializationError), +} + +impl Display for MapTypeCheckErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + MapTypeCheckErrorKind::NotMap => { + write!( + f, + "the CQL type the map was attempted to be serialized to was not map" + ) + } + MapTypeCheckErrorKind::KeyTypeCheckFailed(err) => { + write!(f, "failed to type check one of the keys: {}", err) + } + MapTypeCheckErrorKind::ValueTypeCheckFailed(err) => { + write!(f, "failed to type check one of the values: {}", err) + } + } + } +} + +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum MapSerializationErrorKind { + /// The many contains too many items, exceeding the protocol limit (i32::MAX). + TooManyElements, + + /// One of the keys in the map failed to serialize. + KeySerializationFailed(SerializationError), + + /// One of the values in the map failed to serialize. + ValueSerializationFailed(SerializationError), +} + +impl Display for MapSerializationErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + MapSerializationErrorKind::TooManyElements => { + write!( + f, + "the map contains too many elements to fit in CQL representation" + ) + } + MapSerializationErrorKind::KeySerializationFailed(err) => { + write!(f, "failed to serialize one of the keys: {}", err) + } + MapSerializationErrorKind::ValueSerializationFailed(err) => { + write!(f, "failed to serialize one of the values: {}", err) + } + } + } +} + +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum SetOrListTypeCheckErrorKind { + /// The CQL type is neither a set not a list. + NotSetOrList, + + /// Checking the type of the set/list element failed. + ElementTypeCheckFailed(SerializationError), +} + +impl Display for SetOrListTypeCheckErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SetOrListTypeCheckErrorKind::NotSetOrList => { + write!( + f, + "the CQL type the tuple was attempted to was neither a set or a list" + ) + } + SetOrListTypeCheckErrorKind::ElementTypeCheckFailed(err) => { + write!(f, "failed to type check one of the elements: {err}") + } + } + } +} + +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum SetOrListSerializationErrorKind { + /// The set/list contains too many items, exceeding the protocol limit (i32::MAX). + TooManyElements, + + /// One of the elements of the set/list failed to serialize. + ElementSerializationFailed(SerializationError), +} + +impl Display for SetOrListSerializationErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SetOrListSerializationErrorKind::TooManyElements => { + write!( + f, + "the collection contains too many elements to fit in CQL representation" + ) + } + SetOrListSerializationErrorKind::ElementSerializationFailed(err) => { + write!(f, "failed to serialize one of the elements: {err}") + } + } + } +} + +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum TupleTypeCheckErrorKind { + /// The CQL type is not a tuple. + NotTuple, + + /// The tuple has the wrong element count. + /// + /// Note that it is allowed to write a Rust tuple with less elements + /// than the corresponding CQL type, but not more. The additional, unknown + /// elements will be set to null. + WrongElementCount { actual: usize, asked_for: usize }, + + /// One of the tuple elements failed to type check. + ElementTypeCheckFailed { + index: usize, + err: SerializationError, + }, +} + +impl Display for TupleTypeCheckErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TupleTypeCheckErrorKind::NotTuple => write!( + f, + "the CQL type the tuple was attempted to be serialized to is not a tuple" + ), + TupleTypeCheckErrorKind::WrongElementCount { actual, asked_for } => write!( + f, + "wrong tuple element count: CQL type has {asked_for}, the Rust tuple has {actual}" + ), + TupleTypeCheckErrorKind::ElementTypeCheckFailed { index, err } => { + write!(f, "element no. {index} failed to type check: {err}") + } + } + } +} + +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum TupleSerializationErrorKind { + /// One of the tuple elements failed to serialize. + ElementSerializationFailed { + index: usize, + err: SerializationError, + }, +} + +impl Display for TupleSerializationErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TupleSerializationErrorKind::ElementSerializationFailed { index, err } => { + write!(f, "element no. {index} failed to serialize: {err}") + } + } + } +} + +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum UdtTypeCheckErrorKind { + /// The CQL type is not a user defined type. + NotUdt, + + /// The name of the UDT being serialized to does not match. + NameMismatch { keyspace: String, type_name: String }, + + /// The Rust data contains a field that is not present in the UDT + UnexpectedFieldInDestination { field_name: String }, + + /// One of the fields failed to type check. + FieldTypeCheckFailed { + field_name: String, + err: SerializationError, + }, +} + +impl Display for UdtTypeCheckErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + UdtTypeCheckErrorKind::NotUdt => write!( + f, + "the CQL type the tuple was attempted to be type checked against is not a UDT" + ), + UdtTypeCheckErrorKind::NameMismatch { + keyspace, + type_name, + } => write!( + f, + "the Rust UDT name does not match the actual CQL UDT name ({keyspace}.{type_name})" + ), + UdtTypeCheckErrorKind::UnexpectedFieldInDestination { field_name } => write!( + f, + "the field {field_name} present in the Rust data is not present in the CQL type" + ), + UdtTypeCheckErrorKind::FieldTypeCheckFailed { field_name, err } => { + write!(f, "field {field_name} failed to type check: {err}") + } + } + } +} + +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum UdtSerializationErrorKind { + /// One of the fields failed to serialize. + FieldSerializationFailed { + field_name: String, + err: SerializationError, + }, +} + +impl Display for UdtSerializationErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + UdtSerializationErrorKind::FieldSerializationFailed { field_name, err } => { + write!(f, "field {field_name} failed to serialize: {err}") + } + } + } +} + #[derive(Error, Debug)] pub enum ValueToSerializeCqlAdapterError { #[error("Output produced by the Value trait is too short to be considered a value: {size} < 4 minimum bytes")] diff --git a/scylla-cql/src/types/serialize/writers.rs b/scylla-cql/src/types/serialize/writers.rs index cafd5442fc..ecb8a1fcc1 100644 --- a/scylla-cql/src/types/serialize/writers.rs +++ b/scylla-cql/src/types/serialize/writers.rs @@ -1,5 +1,7 @@ //! Contains types and traits used for safe serialization of values for a CQL statement. +use thiserror::Error; + /// An interface that facilitates writing values for a CQL query. pub trait RowWriter { type CellWriter<'a>: CellWriter @@ -62,7 +64,10 @@ pub trait CellWriter { /// Prefer this to [`into_value_builder`](CellWriter::into_value_builder) /// if you have all of the contents of the value ready up front (e.g. for /// fixed size types). - fn set_value(self, contents: &[u8]) -> Self::WrittenCellProof; + /// + /// Fails if the contents size overflows the maximum allowed CQL cell size + /// (which is i32::MAX). + fn set_value(self, contents: &[u8]) -> Result; /// Turns this writter into a [`CellValueBuilder`] which can be used /// to gradually initialize the CQL value. @@ -94,16 +99,24 @@ pub trait CellValueBuilder { fn make_sub_writer(&mut self) -> Self::SubCellWriter<'_>; /// Finishes serializing the value. - fn finish(self) -> Self::WrittenCellProof; + /// + /// Fails if the constructed cell size overflows the maximum allowed + /// CQL cell size (which is i32::MAX). + fn finish(self) -> Result; } +/// There was an attempt to produce a CQL value over the maximum size limit (i32::MAX) +#[derive(Debug, Clone, Copy, Error)] +#[error("CQL cell overflowed the maximum allowed size of 2^31 - 1")] +pub struct CellOverflowError; + /// A row writer backed by a buffer (vec). pub struct BufBackedRowWriter<'buf> { // Buffer that this value should be serialized to. buf: &'buf mut Vec, // Number of values written so far. - value_count: u16, + value_count: usize, } impl<'buf> BufBackedRowWriter<'buf> { @@ -119,8 +132,11 @@ impl<'buf> BufBackedRowWriter<'buf> { } /// Returns the number of values that were written so far. + /// + /// Note that the protocol allows at most u16::MAX to be written into a query, + /// but the writer's interface allows more to be written. #[inline] - pub fn value_count(&self) -> u16 { + pub fn value_count(&self) -> usize { self.value_count } } @@ -130,10 +146,7 @@ impl<'buf> RowWriter for BufBackedRowWriter<'buf> { #[inline] fn make_cell_writer(&mut self) -> Self::CellWriter<'_> { - self.value_count = self - .value_count - .checked_add(1) - .expect("tried to serialize too many values for a query (more than u16::MAX)"); + self.value_count += 1; BufBackedCellWriter::new(self.buf) } } @@ -169,13 +182,11 @@ impl<'buf> CellWriter for BufBackedCellWriter<'buf> { } #[inline] - fn set_value(self, bytes: &[u8]) { - let value_len: i32 = bytes - .len() - .try_into() - .expect("value is too big to fit into a CQL [bytes] object (larger than i32::MAX)"); + fn set_value(self, bytes: &[u8]) -> Result<(), CellOverflowError> { + let value_len: i32 = bytes.len().try_into().map_err(|_| CellOverflowError)?; self.buf.extend_from_slice(&value_len.to_be_bytes()); self.buf.extend_from_slice(bytes); + Ok(()) } #[inline] @@ -226,49 +237,55 @@ impl<'buf> CellValueBuilder for BufBackedCellValueBuilder<'buf> { } #[inline] - fn finish(self) { - // TODO: Should this panic, or should we catch this error earlier? - // Vec will panic anyway if we overflow isize, so at least this - // behavior is consistent with what the stdlib does. + fn finish(self) -> Result<(), CellOverflowError> { let value_len: i32 = (self.buf.len() - self.starting_pos - 4) .try_into() - .expect("value is too big to fit into a CQL [bytes] object (larger than i32::MAX)"); + .map_err(|_| CellOverflowError)?; self.buf[self.starting_pos..self.starting_pos + 4] .copy_from_slice(&value_len.to_be_bytes()); + Ok(()) } } -/// A writer that does not actually write anything, just counts the bytes. -/// -/// It can serve as a: -/// -/// - [`RowWriter`] -/// - [`CellWriter`] -/// - [`CellValueBuilder`] -pub struct CountingWriter<'buf> { +/// A row writer that does not actually write anything, just counts the bytes. +pub struct CountingRowWriter<'buf> { buf: &'buf mut usize, } -impl<'buf> CountingWriter<'buf> { +impl<'buf> CountingRowWriter<'buf> { /// Creates a new writer which increments the counter under given reference /// when bytes are appended. #[inline] - fn new(buf: &'buf mut usize) -> Self { - CountingWriter { buf } + pub fn new(buf: &'buf mut usize) -> Self { + CountingRowWriter { buf } } } -impl<'buf> RowWriter for CountingWriter<'buf> { - type CellWriter<'a> = CountingWriter<'a> where Self: 'a; +impl<'buf> RowWriter for CountingRowWriter<'buf> { + type CellWriter<'a> = CountingCellWriter<'a> where Self: 'a; #[inline] fn make_cell_writer(&mut self) -> Self::CellWriter<'_> { - CountingWriter::new(self.buf) + CountingCellWriter::new(self.buf) + } +} + +/// A cell writer that does not actually write anything, just counts the bytes. +pub struct CountingCellWriter<'buf> { + buf: &'buf mut usize, +} + +impl<'buf> CountingCellWriter<'buf> { + /// Creates a new writer which increments the counter under given reference + /// when bytes are appended. + #[inline] + fn new(buf: &'buf mut usize) -> Self { + CountingCellWriter { buf } } } -impl<'buf> CellWriter for CountingWriter<'buf> { - type ValueBuilder = CountingWriter<'buf>; +impl<'buf> CellWriter for CountingCellWriter<'buf> { + type ValueBuilder = CountingCellValueBuilder<'buf>; type WrittenCellProof = (); @@ -283,19 +300,39 @@ impl<'buf> CellWriter for CountingWriter<'buf> { } #[inline] - fn set_value(self, contents: &[u8]) { + fn set_value(self, contents: &[u8]) -> Result<(), CellOverflowError> { + if contents.len() > i32::MAX as usize { + return Err(CellOverflowError); + } *self.buf += 4 + contents.len(); + Ok(()) } #[inline] fn into_value_builder(self) -> Self::ValueBuilder { *self.buf += 4; - CountingWriter::new(self.buf) + CountingCellValueBuilder::new(self.buf) } } -impl<'buf> CellValueBuilder for CountingWriter<'buf> { - type SubCellWriter<'a> = CountingWriter<'a> +pub struct CountingCellValueBuilder<'buf> { + buf: &'buf mut usize, + + starting_pos: usize, +} + +impl<'buf> CountingCellValueBuilder<'buf> { + /// Creates a new builder which increments the counter under given reference + /// when bytes are appended. + #[inline] + fn new(buf: &'buf mut usize) -> Self { + let starting_pos = *buf; + CountingCellValueBuilder { buf, starting_pos } + } +} + +impl<'buf> CellValueBuilder for CountingCellValueBuilder<'buf> { + type SubCellWriter<'a> = CountingCellWriter<'a> where Self: 'a; @@ -308,17 +345,24 @@ impl<'buf> CellValueBuilder for CountingWriter<'buf> { #[inline] fn make_sub_writer(&mut self) -> Self::SubCellWriter<'_> { - CountingWriter::new(self.buf) + CountingCellWriter::new(self.buf) } #[inline] - fn finish(self) -> Self::WrittenCellProof {} + fn finish(self) -> Result { + if *self.buf - self.starting_pos > i32::MAX as usize { + return Err(CellOverflowError); + } + Ok(()) + } } #[cfg(test)] mod tests { + use crate::types::serialize::writers::CountingRowWriter; + use super::{ - BufBackedCellWriter, BufBackedRowWriter, CellValueBuilder, CellWriter, CountingWriter, + BufBackedCellWriter, BufBackedRowWriter, CellValueBuilder, CellWriter, CountingCellWriter, RowWriter, }; @@ -335,7 +379,7 @@ mod tests { c.check(writer); let mut byte_count = 0usize; - let counting_writer = CountingWriter::new(&mut byte_count); + let counting_writer = CountingCellWriter::new(&mut byte_count); c.check(counting_writer); assert_eq!(data.len(), byte_count); @@ -349,9 +393,12 @@ mod tests { fn check(&self, writer: W) { let mut sub_writer = writer.into_value_builder(); sub_writer.make_sub_writer().set_null(); - sub_writer.make_sub_writer().set_value(&[1, 2, 3, 4]); + sub_writer + .make_sub_writer() + .set_value(&[1, 2, 3, 4]) + .unwrap(); sub_writer.make_sub_writer().set_unset(); - sub_writer.finish(); + sub_writer.finish().unwrap(); } } @@ -395,7 +442,7 @@ mod tests { c.check(&mut writer); let mut byte_count = 0usize; - let mut counting_writer = CountingWriter::new(&mut byte_count); + let mut counting_writer = CountingRowWriter::new(&mut byte_count); c.check(&mut counting_writer); assert_eq!(data.len(), byte_count); @@ -408,7 +455,7 @@ mod tests { impl RowSerializeCheck for Check { fn check(&self, writer: &mut W) { writer.make_cell_writer().set_null(); - writer.make_cell_writer().set_value(&[1, 2, 3, 4]); + writer.make_cell_writer().set_value(&[1, 2, 3, 4]).unwrap(); writer.make_cell_writer().set_unset(); } }