diff --git a/datafusion/input/location.py b/datafusion/input/location.py index efbc82f2..939c7f41 100644 --- a/datafusion/input/location.py +++ b/datafusion/input/location.py @@ -16,6 +16,7 @@ # under the License. import os +import glob from typing import Any from datafusion.common import DataTypeMap, SqlTable @@ -41,14 +42,12 @@ def build_table( format = extension.lstrip(".").lower() num_rows = 0 # Total number of rows in the file. Used for statistics columns = [] - if format == "parquet": import pyarrow.parquet as pq # Read the Parquet metadata metadata = pq.read_metadata(input_file) num_rows = metadata.num_rows - # Iterate through the schema and build the SqlTable for col in metadata.schema: columns.append( @@ -57,7 +56,6 @@ def build_table( DataTypeMap.from_parquet_type_str(col.physical_type), ) ) - elif format == "csv": import csv @@ -73,7 +71,6 @@ def build_table( print(header_row) for _ in reader: num_rows += 1 - # TODO: Need to actually consume this row into resonable columns raise RuntimeError( "TODO: Currently unable to support CSV input files." @@ -84,4 +81,7 @@ def build_table( Only Parquet and CSV." ) - return SqlTable(table_name, columns, num_rows, input_file) + # Input could possibly be multiple files. Create a list if so + input_files = glob.glob(input_file) + + return SqlTable(table_name, columns, num_rows, input_files) diff --git a/datafusion/tests/test_input.py b/datafusion/tests/test_input.py index 1e2ef416..5b1decf2 100644 --- a/datafusion/tests/test_input.py +++ b/datafusion/tests/test_input.py @@ -30,4 +30,4 @@ def test_location_input(): tbl = location_input.build_table(input_file, table_name) assert "blog" == tbl.name assert 3 == len(tbl.columns) - assert "blogs.parquet" in tbl.filepath + assert "blogs.parquet" in tbl.filepaths[0] diff --git a/src/common/schema.rs b/src/common/schema.rs index a003d0ca..77b0ce2b 100644 --- a/src/common/schema.rs +++ b/src/common/schema.rs @@ -56,7 +56,7 @@ pub struct SqlTable { #[pyo3(get, set)] pub statistics: SqlStatistics, #[pyo3(get, set)] - pub filepath: Option, + pub filepaths: Option>, } #[pymethods] @@ -66,7 +66,7 @@ impl SqlTable { table_name: String, columns: Vec<(String, DataTypeMap)>, row_count: f64, - filepath: Option, + filepaths: Option>, ) -> Self { Self { name: table_name, @@ -76,7 +76,7 @@ impl SqlTable { indexes: Vec::new(), constraints: Vec::new(), statistics: SqlStatistics::new(row_count), - filepath, + filepaths, } } } @@ -124,7 +124,7 @@ impl SqlSchema { pub struct SqlTableSource { schema: SchemaRef, statistics: Option, - filepath: Option, + filepaths: Option>, } impl SqlTableSource { @@ -132,12 +132,12 @@ impl SqlTableSource { pub fn new( schema: SchemaRef, statistics: Option, - filepath: Option, + filepaths: Option>, ) -> Self { Self { schema, statistics, - filepath, + filepaths, } } @@ -148,8 +148,8 @@ impl SqlTableSource { /// Access optional filepath associated with this table source #[allow(dead_code)] - pub fn filepath(&self) -> Option<&String> { - self.filepath.as_ref() + pub fn filepaths(&self) -> Option<&Vec> { + self.filepaths.as_ref() } }