Skip to content

Commit f9e091b

Browse files
committed
Avoid use of flatbuffers::size_prefixed_root
1 parent d4b9482 commit f9e091b

File tree

2 files changed

+80
-28
lines changed

2 files changed

+80
-28
lines changed

arrow-ipc/src/convert.rs

+36-25
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ use std::fmt::{Debug, Formatter};
2828
use std::sync::Arc;
2929

3030
use crate::writer::DictionaryTracker;
31-
use crate::{size_prefixed_root_as_message, KeyValue, Message, CONTINUATION_MARKER};
31+
use crate::{KeyValue, Message, CONTINUATION_MARKER};
3232
use DataType::*;
3333

3434
/// Low level Arrow [Schema] to IPC bytes converter
@@ -255,32 +255,43 @@ pub fn try_schema_from_ipc_buffer(buffer: &[u8]) -> Result<Schema, ArrowError> {
255255
// 4 bytes - an optional IPC_CONTINUATION_TOKEN prefix
256256
// 4 bytes - the byte length of the payload
257257
// a flatbuffer Message whose header is the Schema
258-
if buffer.len() >= 4 {
259-
// check continuation marker
260-
let continuation_marker = &buffer[0..4];
261-
let begin_offset: usize = if continuation_marker.eq(&CONTINUATION_MARKER) {
262-
// 4 bytes: CONTINUATION_MARKER
263-
// 4 bytes: length
264-
// buffer
265-
4
266-
} else {
267-
// backward compatibility for buffer without the continuation marker
268-
// 4 bytes: length
269-
// buffer
270-
0
271-
};
272-
let msg = size_prefixed_root_as_message(&buffer[begin_offset..]).map_err(|err| {
273-
ArrowError::ParseError(format!("Unable to convert flight info to a message: {err}"))
274-
})?;
275-
let ipc_schema = msg.header_as_schema().ok_or_else(|| {
276-
ArrowError::ParseError("Unable to convert flight info to a schema".to_string())
277-
})?;
278-
Ok(fb_to_schema(ipc_schema))
279-
} else {
280-
Err(ArrowError::ParseError(
258+
if buffer.len() < 4 {
259+
return Err(ArrowError::ParseError(
281260
"The buffer length is less than 4 and missing the continuation marker or length of buffer".to_string()
282-
))
261+
));
262+
}
263+
264+
let (len, buffer) = if buffer[..4] == CONTINUATION_MARKER {
265+
if buffer.len() < 8 {
266+
return Err(ArrowError::ParseError(
267+
"The buffer length is less than 8 and missing the length of buffer".to_string(),
268+
));
269+
}
270+
buffer[4..].split_at(4)
271+
} else {
272+
buffer.split_at(4)
273+
};
274+
275+
let len = <i32>::from_le_bytes(len.try_into().unwrap());
276+
if len < 0 {
277+
return Err(ArrowError::ParseError(format!(
278+
"The encapsulated message's reported length is negative ({len})"
279+
)));
280+
}
281+
282+
if buffer.len() < len as usize {
283+
let actual_len = buffer.len();
284+
return Err(ArrowError::ParseError(
285+
format!("The buffer length ({actual_len}) is less than the encapsulated message's reported length ({len})")
286+
));
283287
}
288+
289+
let msg = crate::root_as_message(buffer)
290+
.map_err(|err| ArrowError::ParseError(format!("Unable to get root as message: {err:?}")))?;
291+
let ipc_schema = msg.header_as_schema().ok_or_else(|| {
292+
ArrowError::ParseError("Unable to convert flight info to a schema".to_string())
293+
})?;
294+
Ok(fb_to_schema(ipc_schema))
284295
}
285296

286297
/// Get the Arrow data type from the flatbuffer Field table

arrow-ipc/src/reader.rs

+44-3
Original file line numberDiff line numberDiff line change
@@ -1480,12 +1480,14 @@ impl<R: Read> RecordBatchReader for StreamReader<R> {
14801480

14811481
#[cfg(test)]
14821482
mod tests {
1483-
use crate::writer::{unslice_run_array, DictionaryTracker, IpcDataGenerator, IpcWriteOptions};
1483+
use crate::convert::fb_to_schema;
1484+
use crate::writer::{
1485+
unslice_run_array, write_message, DictionaryTracker, IpcDataGenerator, IpcWriteOptions,
1486+
};
14841487

14851488
use super::*;
14861489

1487-
use crate::convert::fb_to_schema;
1488-
use crate::{root_as_footer, root_as_message};
1490+
use crate::{root_as_footer, root_as_message, size_prefixed_root_as_message};
14891491
use arrow_array::builder::{PrimitiveRunBuilder, UnionBuilder};
14901492
use arrow_array::types::*;
14911493
use arrow_buffer::{NullBuffer, OffsetBuffer};
@@ -2617,4 +2619,43 @@ mod tests {
26172619
let err = read_ipc_with_decoder(buf).unwrap_err();
26182620
assert_eq!(err.to_string(), expected_err);
26192621
}
2622+
2623+
#[test]
2624+
fn test_roundtrip_schema() {
2625+
let schema = Schema::new(vec![
2626+
Field::new(
2627+
"a",
2628+
DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)),
2629+
false,
2630+
),
2631+
Field::new(
2632+
"b",
2633+
DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)),
2634+
false,
2635+
),
2636+
]);
2637+
2638+
let options = IpcWriteOptions::default();
2639+
let data_gen = IpcDataGenerator::default();
2640+
let mut dict_tracker = DictionaryTracker::new(false);
2641+
let encoded_data =
2642+
data_gen.schema_to_bytes_with_dictionary_tracker(&schema, &mut dict_tracker, &options);
2643+
let mut schema_bytes = vec![];
2644+
write_message(&mut schema_bytes, encoded_data, &options).expect("write_message");
2645+
2646+
let begin_offset: usize = if schema_bytes[0..4].eq(&CONTINUATION_MARKER) {
2647+
4
2648+
} else {
2649+
0
2650+
};
2651+
2652+
size_prefixed_root_as_message(&schema_bytes[begin_offset..])
2653+
.expect_err("size_prefixed_root_as_message");
2654+
2655+
let msg = parse_message(&schema_bytes).expect("parse_message");
2656+
let ipc_schema = msg.header_as_schema().expect("header_as_schema");
2657+
let new_schema = fb_to_schema(ipc_schema);
2658+
2659+
assert_eq!(schema, new_schema);
2660+
}
26202661
}

0 commit comments

Comments
 (0)