Skip to content

Commit

Permalink
feat: add register_avro and read_table (#461)
Browse files Browse the repository at this point in the history
  • Loading branch information
mesejo committed Aug 22, 2023
1 parent 9c643bf commit e24dc75
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 0 deletions.
10 changes: 10 additions & 0 deletions datafusion/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,16 @@ def test_register_table(ctx, database):
assert public.names() == {"csv", "csv1", "csv2", "csv3"}


def test_read_table(ctx, database):
default = ctx.catalog()
public = default.database("public")
assert public.names() == {"csv", "csv1", "csv2"}

table = public.table("csv")
table_df = ctx.read_table(table)
table_df.show()


def test_deregister_table(ctx, database):
default = ctx.catalog()
public = default.database("public")
Expand Down
26 changes: 26 additions & 0 deletions datafusion/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,32 @@ def test_register_json(ctx, tmp_path):
ctx.register_json("json4", gzip_path, file_compression_type="rar")


def test_register_avro(ctx):
path = "testing/data/avro/alltypes_plain.avro"
ctx.register_avro("alltypes_plain", path)
result = ctx.sql(
"SELECT SUM(tinyint_col) as tinyint_sum FROM alltypes_plain"
).collect()
result = pa.Table.from_batches(result).to_pydict()
assert result["tinyint_sum"][0] > 0

alternative_schema = pa.schema(
[
pa.field("id", pa.int64()),
]
)

ctx.register_avro(
"alltypes_plain_schema",
path,
schema=alternative_schema,
infinite=False,
)
result = ctx.sql("SELECT * FROM alltypes_plain_schema").collect()
result = pa.Table.from_batches(result)
assert result.schema == alternative_schema


def test_execute(ctx, tmp_path):
data = [1, 1, 2, 2, 3, 11, 12]

Expand Down
41 changes: 41 additions & 0 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,39 @@ impl PySessionContext {
Ok(())
}

#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (name,
path,
schema=None,
file_extension=".avro",
table_partition_cols=vec![],
infinite=false))]
fn register_avro(
&mut self,
name: &str,
path: PathBuf,
schema: Option<PyArrowType<Schema>>,
file_extension: &str,
table_partition_cols: Vec<(String, String)>,
infinite: bool,
py: Python,
) -> PyResult<()> {
let path = path
.to_str()
.ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;

let mut options = AvroReadOptions::default()
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
.mark_infinite(infinite);
options.file_extension = file_extension;
options.schema = schema.as_ref().map(|x| &x.0);

let result = self.ctx.register_avro(name, path, options);
wait_for_future(py, result).map_err(DataFusionError::from)?;

Ok(())
}

// Registers a PyArrow.Dataset
fn register_dataset(&self, name: &str, dataset: &PyAny, py: Python) -> PyResult<()> {
let table: Arc<dyn TableProvider> = Arc::new(Dataset::new(dataset, py)?);
Expand Down Expand Up @@ -734,6 +767,14 @@ impl PySessionContext {
Ok(PyDataFrame::new(df))
}

fn read_table(&self, table: &PyTable) -> PyResult<PyDataFrame> {
let df = self
.ctx
.read_table(table.table())
.map_err(DataFusionError::from)?;
Ok(PyDataFrame::new(df))
}

fn __repr__(&self) -> PyResult<String> {
let config = self.ctx.copied_config();
let mut config_entries = config
Expand Down

0 comments on commit e24dc75

Please sign in to comment.