Skip to content

Commit 217ede8

Browse files
authored
feat: add register_json (#458)
1 parent 0b22c97 commit 217ede8

File tree

2 files changed

+88
-1
lines changed

2 files changed

+88
-1
lines changed

datafusion/tests/test_sql.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
import gzip
18+
import os
1719

1820
import numpy as np
1921
import pyarrow as pa
2022
import pyarrow.dataset as ds
2123
import pytest
22-
import gzip
2324

2425
from datafusion import udf
2526

@@ -154,6 +155,56 @@ def test_register_dataset(ctx, tmp_path):
154155
assert result.to_pydict() == {"cnt": [100]}
155156

156157

158+
def test_register_json(ctx, tmp_path):
159+
path = os.path.dirname(os.path.abspath(__file__))
160+
test_data_path = os.path.join(path, "data_test_context", "data.json")
161+
gzip_path = tmp_path / "data.json.gz"
162+
163+
with open(test_data_path, "rb") as json_file:
164+
with gzip.open(gzip_path, "wb") as gzipped_file:
165+
gzipped_file.writelines(json_file)
166+
167+
ctx.register_json("json", test_data_path)
168+
ctx.register_json("json1", str(test_data_path))
169+
ctx.register_json(
170+
"json2",
171+
test_data_path,
172+
schema_infer_max_records=10,
173+
)
174+
ctx.register_json(
175+
"json_gzip",
176+
gzip_path,
177+
file_extension="gz",
178+
file_compression_type="gzip",
179+
)
180+
181+
alternative_schema = pa.schema(
182+
[
183+
("some_int", pa.int16()),
184+
("some_bytes", pa.string()),
185+
("some_floats", pa.float32()),
186+
]
187+
)
188+
ctx.register_json("json3", path, schema=alternative_schema)
189+
190+
assert ctx.tables() == {"json", "json1", "json2", "json3", "json_gzip"}
191+
192+
for table in ["json", "json1", "json2", "json_gzip"]:
193+
result = ctx.sql(f'SELECT COUNT("B") AS cnt FROM {table}').collect()
194+
result = pa.Table.from_batches(result)
195+
assert result.to_pydict() == {"cnt": [3]}
196+
197+
result = ctx.sql("SELECT * FROM json3").collect()
198+
result = pa.Table.from_batches(result)
199+
assert result.schema == alternative_schema
200+
201+
with pytest.raises(
202+
ValueError,
203+
match="file_compression_type must one of: gzip, bz2, xz, zstd",
204+
):
205+
ctx.register_json("json4", gzip_path, file_compression_type="rar")
206+
207+
157208
def test_execute(ctx, tmp_path):
158209
data = [1, 1, 2, 2, 3, 11, 12]
159210

src/context.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,42 @@ impl PySessionContext {
509509
Ok(())
510510
}
511511

512+
#[allow(clippy::too_many_arguments)]
513+
#[pyo3(signature = (name,
514+
path,
515+
schema=None,
516+
schema_infer_max_records=1000,
517+
file_extension=".json",
518+
table_partition_cols=vec![],
519+
file_compression_type=None))]
520+
fn register_json(
521+
&mut self,
522+
name: &str,
523+
path: PathBuf,
524+
schema: Option<PyArrowType<Schema>>,
525+
schema_infer_max_records: usize,
526+
file_extension: &str,
527+
table_partition_cols: Vec<(String, String)>,
528+
file_compression_type: Option<String>,
529+
py: Python,
530+
) -> PyResult<()> {
531+
let path = path
532+
.to_str()
533+
.ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
534+
535+
let mut options = NdJsonReadOptions::default()
536+
.file_compression_type(parse_file_compression_type(file_compression_type)?)
537+
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
538+
options.schema_infer_max_records = schema_infer_max_records;
539+
options.file_extension = file_extension;
540+
options.schema = schema.as_ref().map(|x| &x.0);
541+
542+
let result = self.ctx.register_json(name, path, options);
543+
wait_for_future(py, result).map_err(DataFusionError::from)?;
544+
545+
Ok(())
546+
}
547+
512548
// Registers a PyArrow.Dataset
513549
fn register_dataset(&self, name: &str, dataset: &PyAny, py: Python) -> PyResult<()> {
514550
let table: Arc<dyn TableProvider> = Arc::new(Dataset::new(dataset, py)?);

0 commit comments

Comments
 (0)