Skip to content

Commit

Permalink
from_parquet: use virtual filesystem to preserve partition information
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry committed Dec 24, 2024
1 parent 1b34bc0 commit 90f1b7c
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
15 changes: 14 additions & 1 deletion src/datachain/lib/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Any, Optional

import fsspec.implementations.reference
import orjson
import pyarrow as pa
from fsspec.core import split_protocol
from pyarrow.dataset import CsvFileFormat, dataset
from tqdm import tqdm

Expand All @@ -25,6 +27,14 @@
DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY = b"DataChain SignalSchema"


class ReferenceFileSystem(fsspec.implementations.reference.ReferenceFileSystem):
def _open(self, path, mode="rb", *args, **kwargs):
# `fsspec`'s `ReferenceFileSystem` reads the whole file in-memory.
uri = self.references[path]
protocol, _ = split_protocol(uri)
return self.fss[protocol]._open(uri, mode, *args, **kwargs)


class ArrowGenerator(Generator):
def __init__(
self,
Expand Down Expand Up @@ -56,7 +66,10 @@ def process(self, file: File):
if file._caching_enabled:
file.ensure_cached()
path = file.get_local_path()
ds = dataset(path, schema=self.input_schema, **self.kwargs)
fs = ReferenceFileSystem({file.path: path})
ds = dataset(
file.path, filesystem=fs, schema=self.input_schema, **self.kwargs
)
elif self.nrows:
path = _nrows_file(file, self.nrows)
ds = dataset(path, schema=self.input_schema, **self.kwargs)
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/lib/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,32 @@ def test_arrow_generator_hf(tmp_path, catalog):
assert isinstance(obj[1].col, HFClassLabel)


@pytest.mark.parametrize("cache", [True, False])
def test_arrow_generator_partitioned(tmp_path, catalog, cache):
pq_path = tmp_path / "parquets"
pylist = [
{"first_name": "Alice", "age": 25, "city": "New York"},
{"first_name": "Bob", "age": 30, "city": "Los Angeles"},
{"first_name": "Charlie", "age": 35, "city": "Chicago"},
]
table = pa.Table.from_pylist(pylist)
pq.write_to_dataset(table, pq_path, partition_cols=["first_name"])

output, original_names = schema_to_output(table.schema)
output_schema = dict_to_data_model("", output, original_names)
func = ArrowGenerator(
table.schema, output_schema=output_schema, partitioning="hive"
)

for path in pq_path.rglob("*.parquet"):
stream = File(path=path.as_posix(), source="file://")
stream._set_stream(catalog, caching_enabled=cache)

(o,) = list(func.process(stream))
assert isinstance(o[0], ArrowRow)
assert dict(o[1]) in pylist


@pytest.mark.parametrize(
"col_type,expected",
(
Expand Down

0 comments on commit 90f1b7c

Please sign in to comment.