From 2d669f9f5c63c1837a094c73589876a381119b23 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Mon, 22 May 2023 17:17:39 +0100 Subject: [PATCH] Add safe zero-copy converion from bytes::Bytes (#4254) --- arrow-buffer/Cargo.toml | 1 + arrow-buffer/src/bytes.rs | 28 ++++++++++++++++++++++++++++ arrow-flight/src/decode.rs | 3 ++- arrow-flight/src/sql/client.rs | 2 +- 4 files changed, 32 insertions(+), 2 deletions(-) diff --git a/arrow-buffer/Cargo.toml b/arrow-buffer/Cargo.toml index 1db388db8398..746045cc8dde 100644 --- a/arrow-buffer/Cargo.toml +++ b/arrow-buffer/Cargo.toml @@ -34,6 +34,7 @@ path = "src/lib.rs" bench = false [dependencies] +bytes = { version = "1.4" } num = { version = "0.4", default-features = false, features = ["std"] } half = { version = "2.1", default-features = false } diff --git a/arrow-buffer/src/bytes.rs b/arrow-buffer/src/bytes.rs index b3105ed5a3b4..8f5019d5a4cc 100644 --- a/arrow-buffer/src/bytes.rs +++ b/arrow-buffer/src/bytes.rs @@ -148,3 +148,31 @@ impl Debug for Bytes { write!(f, " }}") } } + +impl From for Bytes { + fn from(value: bytes::Bytes) -> Self { + Self { + len: value.len(), + ptr: NonNull::new(value.as_ptr() as _).unwrap(), + deallocation: Deallocation::Custom(std::sync::Arc::new(value)), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_from_bytes() { + let bytes = bytes::Bytes::from(vec![1, 2, 3, 4]); + let arrow_bytes: Bytes = bytes.clone().into(); + + assert_eq!(bytes.as_ptr(), arrow_bytes.as_ptr()); + + drop(bytes); + drop(arrow_bytes); + + let _ = Bytes::from(bytes::Bytes::new()); + } +} diff --git a/arrow-flight/src/decode.rs b/arrow-flight/src/decode.rs index fe132e3e8448..df74923332e3 100644 --- a/arrow-flight/src/decode.rs +++ b/arrow-flight/src/decode.rs @@ -17,6 +17,7 @@ use crate::{utils::flight_data_to_arrow_batch, FlightData}; use arrow_array::{ArrayRef, RecordBatch}; +use arrow_buffer::Buffer; use arrow_schema::{Schema, SchemaRef}; use bytes::Bytes; use futures::{ready, stream::BoxStream, Stream, StreamExt}; @@ -258,7 +259,7 @@ impl FlightDataDecoder { )); }; - let buffer: arrow_buffer::Buffer = data.data_body.into(); + let buffer = Buffer::from_bytes(data.data_body.into()); let dictionary_batch = message.header_as_dictionary_batch().ok_or_else(|| { FlightError::protocol( diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs index c9adc2b98b12..d661c9640908 100644 --- a/arrow-flight/src/sql/client.rs +++ b/arrow-flight/src/sql/client.rs @@ -538,7 +538,7 @@ pub fn arrow_data_from_flight_data( let dictionaries_by_field = HashMap::new(); let record_batch = read_record_batch( - &Buffer::from(&flight_data.data_body), + &Buffer::from_bytes(flight_data.data_body.into()), ipc_record_batch, arrow_schema_ref.clone(), &dictionaries_by_field,