Skip to content

Commit 7fe01bb

Browse files
authored
Allow overriding the inferred parquet scheme root (#5814)
1 parent 09e58a4 commit 7fe01bb

File tree

3 files changed

+41
-16
lines changed

3 files changed

+41
-16
lines changed

parquet/src/arrow/arrow_writer/mod.rs

+17-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ use arrow_array::{ArrayRef, RecordBatch, RecordBatchWriter};
3131
use arrow_schema::{ArrowError, DataType as ArrowDataType, Field, IntervalUnit, SchemaRef};
3232

3333
use super::schema::{
34-
add_encoded_arrow_schema_to_metadata, arrow_to_parquet_schema, decimal_length_from_precision,
34+
add_encoded_arrow_schema_to_metadata, arrow_to_parquet_schema,
35+
arrow_to_parquet_schema_with_root, decimal_length_from_precision,
3536
};
3637

3738
use crate::arrow::arrow_writer::byte_array::ByteArrayEncoder;
@@ -160,7 +161,10 @@ impl<W: Write + Send> ArrowWriter<W> {
160161
arrow_schema: SchemaRef,
161162
options: ArrowWriterOptions,
162163
) -> Result<Self> {
163-
let schema = arrow_to_parquet_schema(&arrow_schema)?;
164+
let schema = match options.schema_root {
165+
Some(s) => arrow_to_parquet_schema_with_root(&arrow_schema, &s)?,
166+
None => arrow_to_parquet_schema(&arrow_schema)?,
167+
};
164168
let mut props = options.properties;
165169
if !options.skip_arrow_metadata {
166170
// add serialized arrow schema
@@ -323,6 +327,7 @@ impl<W: Write + Send> RecordBatchWriter for ArrowWriter<W> {
323327
pub struct ArrowWriterOptions {
324328
properties: WriterProperties,
325329
skip_arrow_metadata: bool,
330+
schema_root: Option<String>,
326331
}
327332

328333
impl ArrowWriterOptions {
@@ -346,6 +351,16 @@ impl ArrowWriterOptions {
346351
..self
347352
}
348353
}
354+
355+
/// Overrides the name of the root parquet schema element
356+
///
357+
/// Defaults to `"arrow_schema"`
358+
pub fn with_schema_root(self, name: String) -> Self {
359+
Self {
360+
schema_root: Some(name),
361+
..self
362+
}
363+
}
349364
}
350365

351366
/// A single column chunk produced by [`ArrowColumnWriter`]

parquet/src/arrow/schema/mod.rs

+9-3
Original file line numberDiff line numberDiff line change
@@ -223,15 +223,21 @@ pub(crate) fn add_encoded_arrow_schema_to_metadata(schema: &Schema, props: &mut
223223
}
224224

225225
/// Convert arrow schema to parquet schema
226+
///
227+
/// The name of the root schema element defaults to `"arrow_schema"`, this can be
228+
/// overridden with [`arrow_to_parquet_schema_with_root`]
226229
pub fn arrow_to_parquet_schema(schema: &Schema) -> Result<SchemaDescriptor> {
230+
arrow_to_parquet_schema_with_root(schema, "arrow_schema")
231+
}
232+
233+
/// Convert arrow schema to parquet schema specifying the name of the root schema element
234+
pub fn arrow_to_parquet_schema_with_root(schema: &Schema, root: &str) -> Result<SchemaDescriptor> {
227235
let fields = schema
228236
.fields()
229237
.iter()
230238
.map(|field| arrow_to_parquet_type(field).map(Arc::new))
231239
.collect::<Result<_>>()?;
232-
let group = Type::group_type_builder("arrow_schema")
233-
.with_fields(fields)
234-
.build()?;
240+
let group = Type::group_type_builder(root).with_fields(fields).build()?;
235241
Ok(SchemaDescriptor::new(Arc::new(group)))
236242
}
237243

parquet/src/bin/parquet-fromcsv.rs

+15-11
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ use std::{
8181
use arrow_csv::ReaderBuilder;
8282
use arrow_schema::{ArrowError, Schema};
8383
use clap::{Parser, ValueEnum};
84+
use parquet::arrow::arrow_writer::ArrowWriterOptions;
8485
use parquet::{
8586
arrow::{parquet_to_arrow_schema, ArrowWriter},
8687
basic::Compression,
@@ -333,21 +334,16 @@ fn configure_reader_builder(args: &Args, arrow_schema: Arc<Schema>) -> ReaderBui
333334
builder
334335
}
335336

336-
fn arrow_schema_from_string(schema: &str) -> Result<Arc<Schema>, ParquetFromCsvError> {
337-
let schema = Arc::new(parse_message_type(schema)?);
338-
let desc = SchemaDescriptor::new(schema);
339-
let arrow_schema = Arc::new(parquet_to_arrow_schema(&desc, None)?);
340-
Ok(arrow_schema)
341-
}
342-
343337
fn convert_csv_to_parquet(args: &Args) -> Result<(), ParquetFromCsvError> {
344338
let schema = read_to_string(args.schema_path()).map_err(|e| {
345339
ParquetFromCsvError::with_context(
346340
e,
347341
&format!("Failed to open schema file {:#?}", args.schema_path()),
348342
)
349343
})?;
350-
let arrow_schema = arrow_schema_from_string(&schema)?;
344+
let parquet_schema = Arc::new(parse_message_type(&schema)?);
345+
let desc = SchemaDescriptor::new(parquet_schema);
346+
let arrow_schema = Arc::new(parquet_to_arrow_schema(&desc, None)?);
351347

352348
// create output parquet writer
353349
let parquet_file = File::create(&args.output_file).map_err(|e| {
@@ -357,9 +353,12 @@ fn convert_csv_to_parquet(args: &Args) -> Result<(), ParquetFromCsvError> {
357353
)
358354
})?;
359355

360-
let writer_properties = Some(configure_writer_properties(args));
356+
let options = ArrowWriterOptions::new()
357+
.with_properties(configure_writer_properties(args))
358+
.with_schema_root(desc.name().to_string());
359+
361360
let mut arrow_writer =
362-
ArrowWriter::try_new(parquet_file, arrow_schema.clone(), writer_properties)
361+
ArrowWriter::try_new_with_options(parquet_file, arrow_schema.clone(), options)
363362
.map_err(|e| ParquetFromCsvError::with_context(e, "Failed to create ArrowWriter"))?;
364363

365364
// open input file
@@ -426,6 +425,7 @@ mod tests {
426425
use clap::{CommandFactory, Parser};
427426
use flate2::write::GzEncoder;
428427
use parquet::basic::{BrotliLevel, GzipLevel, ZstdLevel};
428+
use parquet::file::reader::{FileReader, SerializedFileReader};
429429
use snap::write::FrameEncoder;
430430
use tempfile::NamedTempFile;
431431

@@ -647,7 +647,7 @@ mod tests {
647647

648648
fn test_convert_compressed_csv_to_parquet(csv_compression: Compression) {
649649
let schema = NamedTempFile::new().unwrap();
650-
let schema_text = r"message schema {
650+
let schema_text = r"message my_amazing_schema {
651651
optional int32 id;
652652
optional binary name (STRING);
653653
}";
@@ -728,6 +728,10 @@ mod tests {
728728
help: None,
729729
};
730730
convert_csv_to_parquet(&args).unwrap();
731+
732+
let file = SerializedFileReader::new(output_parquet.into_file()).unwrap();
733+
let schema_name = file.metadata().file_metadata().schema().name();
734+
assert_eq!(schema_name, "my_amazing_schema");
731735
}
732736

733737
#[test]

0 commit comments

Comments
 (0)