Skip to content

Commit 2cf52bf

Browse files
committed
feat: add register_avro
1 parent 217ede8 commit 2cf52bf

File tree

2 files changed

+59
-0
lines changed

2 files changed

+59
-0
lines changed

datafusion/tests/test_sql.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,32 @@ def test_register_json(ctx, tmp_path):
205205
ctx.register_json("json4", gzip_path, file_compression_type="rar")
206206

207207

208+
def test_register_avro(ctx):
209+
path = "testing/data/avro/alltypes_plain.avro"
210+
ctx.register_avro("alltypes_plain", path)
211+
result = ctx.sql(
212+
"SELECT SUM(tinyint_col) as tinyint_sum FROM alltypes_plain"
213+
).collect()
214+
result = pa.Table.from_batches(result).to_pydict()
215+
assert result["tinyint_sum"][0] > 0
216+
217+
alternative_schema = pa.schema(
218+
[
219+
pa.field("id", pa.int64()),
220+
]
221+
)
222+
223+
ctx.register_avro(
224+
"alltypes_plain_schema",
225+
path,
226+
schema=alternative_schema,
227+
infinite=False,
228+
)
229+
result = ctx.sql("SELECT * FROM alltypes_plain_schema").collect()
230+
result = pa.Table.from_batches(result)
231+
assert result.schema == alternative_schema
232+
233+
208234
def test_execute(ctx, tmp_path):
209235
data = [1, 1, 2, 2, 3, 11, 12]
210236

src/context.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,39 @@ impl PySessionContext {
545545
Ok(())
546546
}
547547

548+
#[allow(clippy::too_many_arguments)]
549+
#[pyo3(signature = (name,
550+
path,
551+
schema=None,
552+
file_extension=".avro",
553+
table_partition_cols=vec![],
554+
infinite=false))]
555+
fn register_avro(
556+
&mut self,
557+
name: &str,
558+
path: PathBuf,
559+
schema: Option<PyArrowType<Schema>>,
560+
file_extension: &str,
561+
table_partition_cols: Vec<(String, String)>,
562+
infinite: bool,
563+
py: Python,
564+
) -> PyResult<()> {
565+
let path = path
566+
.to_str()
567+
.ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
568+
569+
let mut options = AvroReadOptions::default()
570+
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
571+
.mark_infinite(infinite);
572+
options.file_extension = file_extension;
573+
options.schema = schema.as_ref().map(|x| &x.0);
574+
575+
let result = self.ctx.register_avro(name, path, options);
576+
wait_for_future(py, result).map_err(DataFusionError::from)?;
577+
578+
Ok(())
579+
}
580+
548581
// Registers a PyArrow.Dataset
549582
fn register_dataset(&self, name: &str, dataset: &PyAny, py: Python) -> PyResult<()> {
550583
let table: Arc<dyn TableProvider> = Arc::new(Dataset::new(dataset, py)?);

0 commit comments

Comments
 (0)