Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding FSSpec Export for CSV and Parquet #516

Merged
merged 7 commits into from
Oct 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 57 additions & 4 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import copy
import os
import os.path
import re
import sys
from collections.abc import Iterator, Sequence
from functools import wraps
from typing import (
Expand Down Expand Up @@ -1887,21 +1889,48 @@
path: Union[str, os.PathLike[str], BinaryIO],
partition_cols: Optional[Sequence[str]] = None,
chunk_size: int = DEFAULT_PARQUET_CHUNK_SIZE,
fs_kwargs: Optional[dict[str, Any]] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a concern here that we don't use fs_kwargs in all other places (e.g. anon=True, or from_parquet I think reads from a file object - file.get_fs() or something). can we do a bit of research on that end and unify or get rid of this additional kwargs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These kwargs are optional and are combined with the Catalog's client_config for any unified use cases, such as configuration that applies to read and write. I added this optional kwargs parameter here to be used if users need to specify a write-only custom configuration, such as an access token to be used on write, etc. that is not (or may not apply) on read or for the whole application / chain. For example, a token can be specified for Hugging Face filesystems on write, as described here: https://huggingface.co/docs/huggingface_hub/en/guides/hf_file_system#authentication but users may only want to specify this token on write to Hugging Face, not for other clouds or on read. I can rename this or change as desired, but it seems like having extra write-only kwargs can be useful in some cases.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kk. It seems to me that it should be symmetrical (people might need to provide something extra on reads as well eventually) and then the question is - will we be able to do the same easy and w/o changing some logic in those methods (like from_parquet, etc).

should we update the docs here btw?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the docs should be updated - I updated them in the latest commit. And shared (read and write) kwargs can be provided in client_config on Session or Catalog (as well as in environment variables) and these configuration settings will automatically be used for read and write. This fs_kwargs option just provides a way to provide write-only configuration, or override the shared configuration, only if necessary.

**kwargs,
) -> None:
"""Save chain to parquet file with SignalSchema metadata.

Parameters:
path : Path or a file-like binary object to save the file.
path : Path or a file-like binary object to save the file. This supports
local paths as well as remote paths, such as s3:// or hf:// with fsspec.
partition_cols : Column names by which to partition the dataset.
chunk_size : The chunk size of results to read and convert to columnar
data, to avoid running out of memory.
fs_kwargs : Optional kwargs to pass to the fsspec filesystem, used only for
write, for fsspec-type URLs, such as s3:// or hf:// when
provided as the destination path.
"""
import pyarrow as pa
import pyarrow.parquet as pq

from datachain.lib.arrow import DATACHAIN_SIGNAL_SCHEMA_PARQUET_KEY

fsspec_fs = None

if isinstance(path, str) and "://" in path:
from datachain.client.fsspec import Client

fs_kwargs = {
**self._query.catalog.client_config,
**(fs_kwargs or {}),
}

client = Client.get_implementation(path)

if path.startswith("file://"):
# pyarrow does not handle file:// uris, and needs a direct path instead.
from urllib.parse import urlparse

path = urlparse(path).path
if sys.platform == "win32":
path = os.path.normpath(path.lstrip("/"))

Check warning on line 1930 in src/datachain/lib/dc.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/lib/dc.py#L1930

Added line #L1930 was not covered by tests

fsspec_fs = client.create_fs(**fs_kwargs)

_partition_cols = list(partition_cols) if partition_cols else None
signal_schema_metadata = orjson.dumps(
self._effective_signals_schema.serialize()
Expand Down Expand Up @@ -1936,12 +1965,15 @@
table,
root_path=path,
partition_cols=_partition_cols,
filesystem=fsspec_fs,
**kwargs,
)
else:
if first_chunk:
# Write to a single parquet file.
parquet_writer = pq.ParquetWriter(path, parquet_schema, **kwargs)
parquet_writer = pq.ParquetWriter(
path, parquet_schema, filesystem=fsspec_fs, **kwargs
)
first_chunk = False

assert parquet_writer
Expand All @@ -1954,22 +1986,43 @@
self,
path: Union[str, os.PathLike[str]],
delimiter: str = ",",
fs_kwargs: Optional[dict[str, Any]] = None,
**kwargs,
) -> None:
"""Save chain to a csv (comma-separated values) file.

Parameters:
path : Path to save the file.
path : Path to save the file. This supports local paths as well as
remote paths, such as s3:// or hf:// with fsspec.
delimiter : Delimiter to use for the resulting file.
fs_kwargs : Optional kwargs to pass to the fsspec filesystem, used only for
write, for fsspec-type URLs, such as s3:// or hf:// when
provided as the destination path.
"""
import csv

opener = open

if isinstance(path, str) and "://" in path:
from datachain.client.fsspec import Client

fs_kwargs = {
**self._query.catalog.client_config,
**(fs_kwargs or {}),
}

client = Client.get_implementation(path)

fsspec_fs = client.create_fs(**fs_kwargs)

opener = fsspec_fs.open

headers, _ = self._effective_signals_schema.get_headers_with_length()
column_names = [".".join(filter(None, header)) for header in headers]

results_iter = self.collect_flatten()

with open(path, "w", newline="") as f:
with opener(path, "w", newline="") as f:
writer = csv.writer(f, delimiter=delimiter, **kwargs)
writer.writerow(column_names)

Expand Down
20 changes: 20 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,26 @@ def cloud_test_catalog(
)


@pytest.fixture
def cloud_test_catalog_upload(cloud_test_catalog):
"""This returns a version of the cloud_test_catalog that is suitable for uploading
files, and will perform the necessary cleanup of any uploaded files."""
from datachain.client.fsspec import Client

src = cloud_test_catalog.src_uri
client = Client.get_implementation(src)
fsspec_fs = client.create_fs(**cloud_test_catalog.client_config)
original_paths = set(fsspec_fs.ls(src))

yield cloud_test_catalog

# Cleanup any written files
new_paths = set(fsspec_fs.ls(src))
cleanup_paths = new_paths - original_paths
for p in cleanup_paths:
fsspec_fs.rm(p, recursive=True)


@pytest.fixture
def cloud_test_catalog_tmpfile(
cloud_server,
Expand Down
50 changes: 50 additions & 0 deletions tests/func/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@
text_embedding,
)

DF_DATA = {
"first_name": ["Alice", "Bob", "Charlie", "David", "Eva"],
"age": [25, 30, 35, 40, 45],
"city": ["New York", "Los Angeles", "Chicago", "Houston", "Phoenix"],
}


def _get_listing_datasets(session):
return sorted(
Expand Down Expand Up @@ -1366,3 +1372,47 @@ def file_info(file: File) -> DataModel:
],
"file_info__path",
)


def test_to_from_csv_remote(cloud_test_catalog_upload):
ctc = cloud_test_catalog_upload
path = f"{ctc.src_uri}/test.csv"

df = pd.DataFrame(DF_DATA)
dc_to = DataChain.from_pandas(df, session=ctc.session)
dc_to.to_csv(path)

dc_from = DataChain.from_csv(path, session=ctc.session)
df1 = dc_from.select("first_name", "age", "city").to_pandas()
assert df1.equals(df)


@pytest.mark.parametrize("chunk_size", (1000, 2))
@pytest.mark.parametrize("kwargs", ({}, {"compression": "gzip"}))
def test_to_from_parquet_remote(cloud_test_catalog_upload, chunk_size, kwargs):
ctc = cloud_test_catalog_upload
path = f"{ctc.src_uri}/test.parquet"

df = pd.DataFrame(DF_DATA)
dc_to = DataChain.from_pandas(df, session=ctc.session)
dc_to.to_parquet(path, chunk_size=chunk_size, **kwargs)

dc_from = DataChain.from_parquet(path, session=ctc.session)
df1 = dc_from.select("first_name", "age", "city").to_pandas()

assert df1.equals(df)


@pytest.mark.parametrize("chunk_size", (1000, 2))
def test_to_from_parquet_partitioned_remote(cloud_test_catalog_upload, chunk_size):
ctc = cloud_test_catalog_upload
path = f"{ctc.src_uri}/parquets"

df = pd.DataFrame(DF_DATA)
dc_to = DataChain.from_pandas(df, session=ctc.session)
dc_to.to_parquet(path, partition_cols=["first_name"], chunk_size=chunk_size)

dc_from = DataChain.from_parquet(path, session=ctc.session)
df1 = dc_from.select("first_name", "age", "city").to_pandas()
df1 = df1.sort_values("first_name").reset_index(drop=True)
assert df1.equals(df)