From 5770599a6c0398a4a750857da0fdd108b6b50fff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Tue, 24 Dec 2024 21:04:20 +0545 Subject: [PATCH] from_parquet: use virtual filesystem to preserve partition information Also did: * minor refactor, * removes `nrows` _hack_ and, * disables prefetching when `nrows` is set, so that we don't download the whole dataset. --- src/datachain/lib/arrow.py | 139 ++++++++++++++++++++--------------- src/datachain/lib/dc.py | 11 ++- tests/unit/lib/test_arrow.py | 26 +++++++ 3 files changed, 114 insertions(+), 62 deletions(-) diff --git a/src/datachain/lib/arrow.py b/src/datachain/lib/arrow.py index 58ea2c535..7d7c910db 100644 --- a/src/datachain/lib/arrow.py +++ b/src/datachain/lib/arrow.py @@ -1,9 +1,11 @@ from collections.abc import Sequence -from tempfile import NamedTemporaryFile +from itertools import islice from typing import TYPE_CHECKING, Any, Optional +import fsspec.implementations.reference import orjson import pyarrow as pa +from fsspec.core import split_protocol from pyarrow.dataset import CsvFileFormat, dataset from tqdm import tqdm @@ -25,7 +27,17 @@ DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY = b"DataChain SignalSchema" +class ReferenceFileSystem(fsspec.implementations.reference.ReferenceFileSystem): + def _open(self, path, mode="rb", *args, **kwargs): + # `fsspec`'s `ReferenceFileSystem._open` reads the whole file in-memory. + (uri,) = self.references[path] + protocol, _ = split_protocol(uri) + return self.fss[protocol]._open(uri, mode, *args, **kwargs) + + class ArrowGenerator(Generator): + DEFAULT_BATCH_SIZE = 2**17 # same as `pyarrow.dataset.DEFAULT_BATCH_SIZE` + def __init__( self, input_schema: Optional["pa.Schema"] = None, @@ -55,57 +67,80 @@ def __init__( def process(self, file: File): if file._caching_enabled: file.ensure_cached() - path = file.get_local_path() - ds = dataset(path, schema=self.input_schema, **self.kwargs) - elif self.nrows: - path = _nrows_file(file, self.nrows) - ds = dataset(path, schema=self.input_schema, **self.kwargs) + cache_path = file.get_local_path() + fs_path = file.path + fs = ReferenceFileSystem({fs_path: [cache_path]}) else: - path = file.get_path() - ds = dataset( - path, filesystem=file.get_fs(), schema=self.input_schema, **self.kwargs - ) + fs, fs_path = file.get_fs(), file.get_path() + + ds = dataset(fs_path, schema=self.input_schema, filesystem=fs, **self.kwargs) + hf_schema = _get_hf_schema(ds.schema) use_datachain_schema = ( bool(ds.schema.metadata) and DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY in ds.schema.metadata ) - index = 0 - with tqdm(desc="Parsed by pyarrow", unit=" rows") as pbar: - for record_batch in ds.to_batches(): - for record in record_batch.to_pylist(): - if use_datachain_schema and self.output_schema: - vals = [_nested_model_instantiate(record, self.output_schema)] - else: - vals = list(record.values()) - if self.output_schema: - fields = self.output_schema.model_fields - vals_dict = {} - for i, ((field, field_info), val) in enumerate( - zip(fields.items(), vals) - ): - anno = field_info.annotation - if hf_schema: - from datachain.lib.hf import convert_feature - - feat = list(hf_schema[0].values())[i] - vals_dict[field] = convert_feature(val, feat, anno) - elif ModelStore.is_pydantic(anno): - vals_dict[field] = anno(**val) # type: ignore[misc] - else: - vals_dict[field] = val - vals = [self.output_schema(**vals_dict)] - if self.source: - kwargs: dict = self.kwargs - # Can't serialize CsvFileFormat; may lose formatting options. - if isinstance(kwargs.get("format"), CsvFileFormat): - kwargs["format"] = "csv" - arrow_file = ArrowRow(file=file, index=index, kwargs=kwargs) - yield [arrow_file, *vals] - else: - yield vals - index += 1 - pbar.update(len(record_batch)) + + kw = {} + if self.nrows: + kw = {"batch_size": min(self.DEFAULT_BATCH_SIZE, self.nrows)} + + def iter_records(): + for record in ds.to_batches(**kw): + yield from record.to_pylist() + + it = islice(iter_records(), self.nrows) + with tqdm(it, desc="Parsed by pyarrow", unit="rows", total=self.nrows) as pbar: + for index, record in enumerate(pbar): + yield self._process_record( + record, file, index, hf_schema, use_datachain_schema + ) + + def _process_record( + self, + record: dict[str, Any], + file: File, + index: int, + hf_schema: Optional[tuple["Features", dict[str, "DataType"]]], + use_datachain_schema: bool, + ): + if use_datachain_schema and self.output_schema: + vals = [_nested_model_instantiate(record, self.output_schema)] + else: + vals = self._process_non_datachain_record(record, hf_schema) + + if self.source: + kwargs: dict = self.kwargs + # Can't serialize CsvFileFormat; may lose formatting options. + if isinstance(kwargs.get("format"), CsvFileFormat): + kwargs["format"] = "csv" + arrow_file = ArrowRow(file=file, index=index, kwargs=kwargs) + return [arrow_file, *vals] + return vals + + def _process_non_datachain_record( + self, + record: dict[str, Any], + hf_schema: Optional[tuple["Features", dict[str, "DataType"]]], + ): + vals = list(record.values()) + if not self.output_schema: + return vals + + fields = self.output_schema.model_fields + vals_dict = {} + for i, ((field, field_info), val) in enumerate(zip(fields.items(), vals)): + anno = field_info.annotation + if hf_schema: + from datachain.lib.hf import convert_feature + + feat = list(hf_schema[0].values())[i] + vals_dict[field] = convert_feature(val, feat, anno) + elif ModelStore.is_pydantic(anno): + vals_dict[field] = anno(**val) # type: ignore[misc] + else: + vals_dict[field] = val + return [self.output_schema(**vals_dict)] def infer_schema(chain: "DataChain", **kwargs) -> pa.Schema: @@ -190,18 +225,6 @@ def arrow_type_mapper(col_type: pa.DataType, column: str = "") -> type: # noqa: raise TypeError(f"{col_type!r} datatypes not supported, column: {column}") -def _nrows_file(file: File, nrows: int) -> str: - tf = NamedTemporaryFile(delete=False) # noqa: SIM115 - with file.open(mode="r") as reader: - with open(tf.name, "a") as writer: - for row, line in enumerate(reader): - if row >= nrows: - break - writer.write(line) - writer.write("\n") - return tf.name - - def _get_hf_schema( schema: "pa.Schema", ) -> Optional[tuple["Features", dict[str, "DataType"]]]: diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index 0900ce697..303b769ad 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -557,7 +557,9 @@ def jmespath_to_name(s: str): nrows=nrows, ) } - return chain.gen(**signal_dict) # type: ignore[misc, arg-type] + # disable prefetch if nrows is set + settings = {"prefetch": 0} if nrows else {} + return chain.settings(**settings).gen(**signal_dict) # type: ignore[misc, arg-type] def explode( self, @@ -1894,7 +1896,10 @@ def parse_tabular( if source: output = {"source": ArrowRow} | output # type: ignore[assignment,operator] - return self.gen( + + # disable prefetch if nrows is set + settings = {"prefetch": 0} if nrows else {} + return self.settings(**settings).gen( # type: ignore[arg-type] ArrowGenerator(schema, model, source, nrows, **kwargs), output=output ) @@ -1976,8 +1981,6 @@ def from_csv( else: msg = f"error parsing csv - incompatible output type {type(output)}" raise DatasetPrepareError(chain.name, msg) - elif nrows: - nrows += 1 parse_options = ParseOptions(delimiter=delimiter) read_options = ReadOptions(column_names=column_names) diff --git a/tests/unit/lib/test_arrow.py b/tests/unit/lib/test_arrow.py index 3c24e44e1..422c9ba8d 100644 --- a/tests/unit/lib/test_arrow.py +++ b/tests/unit/lib/test_arrow.py @@ -106,6 +106,32 @@ def test_arrow_generator_hf(tmp_path, catalog): assert isinstance(obj[1].col, HFClassLabel) +@pytest.mark.parametrize("cache", [True, False]) +def test_arrow_generator_partitioned(tmp_path, catalog, cache): + pq_path = tmp_path / "parquets" + pylist = [ + {"first_name": "Alice", "age": 25, "city": "New York"}, + {"first_name": "Bob", "age": 30, "city": "Los Angeles"}, + {"first_name": "Charlie", "age": 35, "city": "Chicago"}, + ] + table = pa.Table.from_pylist(pylist) + pq.write_to_dataset(table, pq_path, partition_cols=["first_name"]) + + output, original_names = schema_to_output(table.schema) + output_schema = dict_to_data_model("", output, original_names) + func = ArrowGenerator( + table.schema, output_schema=output_schema, partitioning="hive" + ) + + for path in pq_path.rglob("*.parquet"): + stream = File(path=path.as_posix(), source="file://") + stream._set_stream(catalog, caching_enabled=cache) + + (o,) = list(func.process(stream)) + assert isinstance(o[0], ArrowRow) + assert dict(o[1]) in pylist + + @pytest.mark.parametrize( "col_type,expected", (