Skip to content

Commit

Permalink
feat: make register_csv accept a list of paths
Browse files Browse the repository at this point in the history
  • Loading branch information
mesejo committed Sep 25, 2024
1 parent f6261b0 commit b3faa7c
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 11 deletions.
11 changes: 8 additions & 3 deletions python/datafusion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ def register_parquet(
def register_csv(
self,
name: str,
path: str | pathlib.Path,
path: str | pathlib.Path | list[str | pathlib.Path],
schema: pyarrow.Schema | None = None,
has_header: bool = True,
delimiter: str = ",",
Expand All @@ -728,7 +728,7 @@ def register_csv(
Args:
name: Name of the table to register.
path: Path to the CSV file.
path: Path to the CSV file. It also accepts a list of Paths.
schema: An optional schema representing the CSV file. If None, the
CSV reader will try to infer it based on data in file.
has_header: Whether the CSV file have a header. If schema inference
Expand All @@ -741,9 +741,14 @@ def register_csv(
selected for data input.
file_compression_type: File compression type.
"""
if isinstance(path, list):
path = [str(p) for p in path]
else:
path = str(path)

self.ctx.register_csv(
name,
str(path),
path,
schema,
has_header,
delimiter,
Expand Down
35 changes: 35 additions & 0 deletions python/datafusion/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,41 @@ def test_register_csv(ctx, tmp_path):
ctx.register_csv("csv4", path, file_compression_type="rar")


def test_register_csv_list(ctx, tmp_path):
path = tmp_path / "test.csv"

int_values = [1, 2, 3, 4]
table = pa.Table.from_arrays(
[
int_values,
["a", "b", "c", "d"],
[1.1, 2.2, 3.3, 4.4],
],
names=["int", "str", "float"],
)
write_csv(table, path)
ctx.register_csv("csv", path)

csv_df = ctx.table("csv")
expected_count = csv_df.count() * 2
ctx.register_csv(
"double_csv",
path=[
path,
path,
],
)

double_csv_df = ctx.table("double_csv")
actual_count = double_csv_df.count()
assert actual_count == expected_count

int_sum = ctx.sql("select sum(int) from double_csv").to_pydict()[
"sum(double_csv.int)"
][0]
assert int_sum == 2 * sum(int_values)


def test_register_parquet(ctx, tmp_path):
path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data())
ctx.register_parquet("t", path)
Expand Down
64 changes: 56 additions & 8 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,21 @@ use crate::utils::{get_tokio_runtime, wait_for_future};
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
use datafusion::arrow::pyarrow::PyArrowType;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::common::ScalarValue;
use datafusion::catalog_common::TableReference;
use datafusion::common::{exec_err, ScalarValue};
use datafusion::datasource::file_format::file_compression_type::FileCompressionType;
use datafusion::datasource::file_format::parquet::ParquetFormat;
use datafusion::datasource::listing::{
ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
};
use datafusion::datasource::MemTable;
use datafusion::datasource::TableProvider;
use datafusion::execution::context::{SQLOptions, SessionConfig, SessionContext, TaskContext};
use datafusion::execution::context::{
DataFilePaths, SQLOptions, SessionConfig, SessionContext, TaskContext,
};
use datafusion::execution::disk_manager::DiskManagerConfig;
use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool};
use datafusion::execution::options::ReadOptions;
use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use datafusion::physical_plan::SendableRecordBatchStream;
use datafusion::prelude::{
Expand Down Expand Up @@ -621,7 +625,7 @@ impl PySessionContext {
pub fn register_csv(
&mut self,
name: &str,
path: PathBuf,
path: &Bound<'_, PyAny>,
schema: Option<PyArrowType<Schema>>,
has_header: bool,
delimiter: &str,
Expand All @@ -630,9 +634,6 @@ impl PySessionContext {
file_compression_type: Option<String>,
py: Python,
) -> PyResult<()> {
let path = path
.to_str()
.ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
let delimiter = delimiter.as_bytes();
if delimiter.len() != 1 {
return Err(PyValueError::new_err(
Expand All @@ -648,8 +649,15 @@ impl PySessionContext {
.file_compression_type(parse_file_compression_type(file_compression_type)?);
options.schema = schema.as_ref().map(|x| &x.0);

let result = self.ctx.register_csv(name, path, options);
wait_for_future(py, result).map_err(DataFusionError::from)?;
if path.is_instance_of::<PyList>() {
let paths = path.extract::<Vec<String>>()?;
let result = self.register_csv_from_multiple_paths(name, paths, options);
wait_for_future(py, result).map_err(DataFusionError::from)?;
} else {
let path = path.extract::<String>()?;
let result = self.ctx.register_csv(name, &path, options);
wait_for_future(py, result).map_err(DataFusionError::from)?;
}

Ok(())
}
Expand Down Expand Up @@ -981,6 +989,46 @@ impl PySessionContext {
async fn _table(&self, name: &str) -> datafusion::common::Result<DataFrame> {
self.ctx.table(name).await
}

async fn register_csv_from_multiple_paths(
&self,
name: &str,
table_paths: Vec<String>,
options: CsvReadOptions<'_>,
) -> datafusion::common::Result<()> {
let table_paths = table_paths.to_urls()?;
let session_config = self.ctx.copied_config();
let listing_options =
options.to_listing_options(&session_config, self.ctx.copied_table_options());

let option_extension = listing_options.file_extension.clone();

if table_paths.is_empty() {
return exec_err!("No table paths were provided");
}

// check if the file extension matches the expected extension
for path in &table_paths {
let file_path = path.as_str();
if !file_path.ends_with(option_extension.clone().as_str()) && !path.is_collection() {
return exec_err!(
"File path '{file_path}' does not match the expected extension '{option_extension}'"
);
}
}

let resolved_schema = options
.get_resolved_schema(&session_config, self.ctx.state(), table_paths[0].clone())
.await?;

let config = ListingTableConfig::new_with_multi_paths(table_paths)
.with_listing_options(listing_options)
.with_schema(resolved_schema);
let table = ListingTable::try_new(config)?;
self.ctx
.register_table(TableReference::Bare { table: name.into() }, Arc::new(table))?;
Ok(())
}
}

pub fn convert_table_partition_cols(
Expand Down

0 comments on commit b3faa7c

Please sign in to comment.