Skip to content

Adding JSON / JSON Lines Export Support #538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 74 additions & 1 deletion src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from datachain.query.schema import DEFAULT_DELIMITER, Column, ColumnMeta
from datachain.sql.functions import path as pathfunc
from datachain.telemetry import telemetry
from datachain.utils import batched_it, inside_notebook
from datachain.utils import batched_it, inside_notebook, row_to_nested_dict

if TYPE_CHECKING:
from pyarrow import DataType as ArrowDataType
Expand Down Expand Up @@ -2051,6 +2051,79 @@ def to_csv(
for row in results_iter:
writer.writerow(row)

def to_json(
self,
path: Union[str, os.PathLike[str]],
fs_kwargs: Optional[dict[str, Any]] = None,
include_outer_list: bool = True,
) -> None:
"""Save chain to a JSON file.

Parameters:
path : Path to save the file. This supports local paths as well as
remote paths, such as s3:// or hf:// with fsspec.
fs_kwargs : Optional kwargs to pass to the fsspec filesystem, used only for
write, for fsspec-type URLs, such as s3:// or hf:// when
provided as the destination path.
include_outer_list : Sets whether to include an outer list for all rows.
Setting this to True makes the file valid JSON, while False instead
writes in the JSON lines format.
"""
opener = open

if isinstance(path, str) and "://" in path:
from datachain.client.fsspec import Client

fs_kwargs = {
**self._query.catalog.client_config,
**(fs_kwargs or {}),
}

client = Client.get_implementation(path)

fsspec_fs = client.create_fs(**fs_kwargs)

opener = fsspec_fs.open

headers, _ = self._effective_signals_schema.get_headers_with_length()
headers = [list(filter(None, header)) for header in headers]

is_first = True

with opener(path, "wb") as f:
if include_outer_list:
# This makes the file JSON instead of JSON lines.
f.write(b"[\n")
for row in self.collect_flatten():
if not is_first:
if include_outer_list:
# This makes the file JSON instead of JSON lines.
f.write(b",\n")
else:
f.write(b"\n")
else:
is_first = False
f.write(orjson.dumps(row_to_nested_dict(headers, row)))
if include_outer_list:
# This makes the file JSON instead of JSON lines.
f.write(b"\n]\n")

def to_jsonl(
Copy link
Contributor

@ilongin ilongin Oct 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems little redundant to have this method if the same can be achieved with to_json() by setting include_outer_list flag. I'm not sure what is the optimal solution though ... maybe move all the logic from to_json to some non public private method and make public to_json() and to_jsonl() just wrappers around it by calling correct include_outer_list flag (similar to current to_jsonl()). Another solution is to just remove to_jsonl().
But again, this is also not that super important and neither solution seems perfect as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I wasn't sure what the best approach is either, but this one seems to have the least extra code (as making both to_json and to_jsonl wrappers adds extra wrapper code), and also the most consistency / user-friendliness as we have from_json and from_jsonl so I wanted to keep the to functions consistent.

self,
path: Union[str, os.PathLike[str]],
fs_kwargs: Optional[dict[str, Any]] = None,
) -> None:
"""Save chain to a JSON lines file.

Parameters:
path : Path to save the file. This supports local paths as well as
remote paths, such as s3:// or hf:// with fsspec.
fs_kwargs : Optional kwargs to pass to the fsspec filesystem, used only for
write, for fsspec-type URLs, such as s3:// or hf:// when
provided as the destination path.
"""
self.to_json(path, fs_kwargs, include_outer_list=False)

@classmethod
def from_records(
cls,
Expand Down
24 changes: 24 additions & 0 deletions src/datachain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,3 +455,27 @@ def env2bool(var, undefined=False):
if var is None:
return undefined
return bool(re.search("1|y|yes|true", var, flags=re.IGNORECASE))


def nested_dict_path_set(
data: dict[str, Any], path: Sequence[str], value: Any
) -> dict[str, Any]:
"""Sets a value inside a nested dict based on the list of dict keys as a path,
and will create sub-dicts as needed to set the value."""
sub_data = data
for element in path[:-1]:
if element not in sub_data:
sub_data[element] = {}
sub_data = sub_data[element]
sub_data[path[len(path) - 1]] = value
return data


def row_to_nested_dict(
headers: Iterable[Sequence[str]], row: Iterable[Any]
) -> dict[str, Any]:
"""Converts a row to a nested dict based on the provided headers."""
result: dict[str, Any] = {}
for h, v in zip(headers, row):
nested_dict_path_set(result, h, v)
return result
32 changes: 32 additions & 0 deletions tests/func/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -1505,3 +1505,35 @@ def test_to_from_parquet_partitioned_remote(cloud_test_catalog_upload, chunk_siz
df1 = dc_from.select("first_name", "age", "city").to_pandas()
df1 = df1.sort_values("first_name").reset_index(drop=True)
assert df1.equals(df)


# These deprecation warnings occur in the datamodel-code-generator package.
@pytest.mark.filterwarnings("ignore::pydantic.warnings.PydanticDeprecatedSince20")
def test_to_from_json_remote(cloud_test_catalog_upload):
ctc = cloud_test_catalog_upload
path = f"{ctc.src_uri}/test.json"

df = pd.DataFrame(DF_DATA)
dc_to = DataChain.from_pandas(df, session=ctc.session)
dc_to.to_json(path)

dc_from = DataChain.from_json(path, session=ctc.session)
df1 = dc_from.select("json.first_name", "json.age", "json.city").to_pandas()
df1 = df1["json"]
assert df1.equals(df)


# These deprecation warnings occur in the datamodel-code-generator package.
@pytest.mark.filterwarnings("ignore::pydantic.warnings.PydanticDeprecatedSince20")
def test_to_from_jsonl_remote(cloud_test_catalog_upload):
ctc = cloud_test_catalog_upload
path = f"{ctc.src_uri}/test.jsonl"

df = pd.DataFrame(DF_DATA)
dc_to = DataChain.from_pandas(df, session=ctc.session)
dc_to.to_jsonl(path)

dc_from = DataChain.from_jsonl(path, session=ctc.session)
df1 = dc_from.select("jsonl.first_name", "jsonl.age", "jsonl.city").to_pandas()
df1 = df1["jsonl"]
assert df1.equals(df)
139 changes: 139 additions & 0 deletions tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import json
import math
import os
import re
Expand Down Expand Up @@ -1275,6 +1276,144 @@ def test_to_csv_features_nested(tmp_dir, test_session):
]


# These deprecation warnings occur in the datamodel-code-generator package.
@pytest.mark.filterwarnings("ignore::pydantic.warnings.PydanticDeprecatedSince20")
def test_to_from_json(tmp_dir, test_session):
df = pd.DataFrame(DF_DATA)
dc_to = DataChain.from_pandas(df, session=test_session)
path = tmp_dir / "test.json"
dc_to.to_json(path)

with open(path) as f:
values = json.load(f)
assert values == [
{"first_name": n, "age": a, "city": c}
for n, a, c in zip(DF_DATA["first_name"], DF_DATA["age"], DF_DATA["city"])
]

dc_from = DataChain.from_json(path.as_uri(), session=test_session)
df1 = dc_from.select("json.first_name", "json.age", "json.city").to_pandas()
df1 = df1["json"]
assert df1.equals(df)


# These deprecation warnings occur in the datamodel-code-generator package.
@pytest.mark.filterwarnings("ignore::pydantic.warnings.PydanticDeprecatedSince20")
def test_from_json_jmespath(tmp_dir, test_session):
df = pd.DataFrame(DF_DATA)
values = [
{"first_name": n, "age": a, "city": c}
for n, a, c in zip(DF_DATA["first_name"], DF_DATA["age"], DF_DATA["city"])
]
path = tmp_dir / "test.json"
with open(path, "w") as f:
json.dump({"author": "Test User", "version": 5, "values": values}, f)

dc_from = DataChain.from_json(
path.as_uri(), jmespath="values", session=test_session
)
df1 = dc_from.select("values.first_name", "values.age", "values.city").to_pandas()
df1 = df1["values"]
assert df1.equals(df)


def test_to_json_features(tmp_dir, test_session):
dc_to = DataChain.from_values(
f1=features, num=range(len(features)), session=test_session
)
path = tmp_dir / "test.json"
dc_to.to_json(path)
with open(path) as f:
values = json.load(f)
assert values == [
{"f1": {"nnn": f.nnn, "count": f.count}, "num": n}
for n, f in enumerate(features)
]


def test_to_json_features_nested(tmp_dir, test_session):
dc_to = DataChain.from_values(sign1=features_nested, session=test_session)
path = tmp_dir / "test.json"
dc_to.to_json(path)
with open(path) as f:
values = json.load(f)
assert values == [
{"sign1": {"label": f"label_{n}", "fr": {"nnn": f.nnn, "count": f.count}}}
for n, f in enumerate(features)
]


# These deprecation warnings occur in the datamodel-code-generator package.
@pytest.mark.filterwarnings("ignore::pydantic.warnings.PydanticDeprecatedSince20")
def test_to_from_jsonl(tmp_dir, test_session):
df = pd.DataFrame(DF_DATA)
dc_to = DataChain.from_pandas(df, session=test_session)
path = tmp_dir / "test.jsonl"
dc_to.to_jsonl(path)

with open(path) as f:
values = [json.loads(line) for line in f.read().split("\n")]
assert values == [
{"first_name": n, "age": a, "city": c}
for n, a, c in zip(DF_DATA["first_name"], DF_DATA["age"], DF_DATA["city"])
]

dc_from = DataChain.from_jsonl(path.as_uri(), session=test_session)
df1 = dc_from.select("jsonl.first_name", "jsonl.age", "jsonl.city").to_pandas()
df1 = df1["jsonl"]
assert df1.equals(df)


# These deprecation warnings occur in the datamodel-code-generator package.
@pytest.mark.filterwarnings("ignore::pydantic.warnings.PydanticDeprecatedSince20")
def test_from_jsonl_jmespath(tmp_dir, test_session):
df = pd.DataFrame(DF_DATA)
values = [
{"first_name": n, "age": a, "city": c}
for n, a, c in zip(DF_DATA["first_name"], DF_DATA["age"], DF_DATA["city"])
]
path = tmp_dir / "test.jsonl"
with open(path, "w") as f:
for v in values:
f.write(
json.dumps({"data": "Contained Within", "row_version": 5, "value": v})
)
f.write("\n")

dc_from = DataChain.from_jsonl(
path.as_uri(), jmespath="value", session=test_session
)
df1 = dc_from.select("value.first_name", "value.age", "value.city").to_pandas()
df1 = df1["value"]
assert df1.equals(df)


def test_to_jsonl_features(tmp_dir, test_session):
dc_to = DataChain.from_values(
f1=features, num=range(len(features)), session=test_session
)
path = tmp_dir / "test.json"
dc_to.to_jsonl(path)
with open(path) as f:
values = [json.loads(line) for line in f.read().split("\n")]
assert values == [
{"f1": {"nnn": f.nnn, "count": f.count}, "num": n}
for n, f in enumerate(features)
]


def test_to_jsonl_features_nested(tmp_dir, test_session):
dc_to = DataChain.from_values(sign1=features_nested, session=test_session)
path = tmp_dir / "test.json"
dc_to.to_jsonl(path)
with open(path) as f:
values = [json.loads(line) for line in f.read().split("\n")]
assert values == [
{"sign1": {"label": f"label_{n}", "fr": {"nnn": f.nnn, "count": f.count}}}
for n, f in enumerate(features)
]


def test_from_parquet(tmp_dir, test_session):
df = pd.DataFrame(DF_DATA)
path = tmp_dir / "test.parquet"
Expand Down
47 changes: 47 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from datachain.utils import (
datachain_paths_join,
determine_processes,
nested_dict_path_set,
retry_with_backoff,
row_to_nested_dict,
sizeof_fmt,
sql_escape_like,
suffix_to_number,
Expand Down Expand Up @@ -170,3 +172,48 @@ def test_determine_processes(parallel, settings, expected):
)
def test_uses_glob(path, expected):
assert uses_glob(path) is expected


@pytest.mark.parametrize(
"data,path,value,expected",
(
({}, ["test"], True, {"test": True}),
({"extra": False}, ["test"], True, {"extra": False, "test": True}),
(
{"extra": False},
["test", "nested"],
True,
{"extra": False, "test": {"nested": True}},
),
(
{"extra": False},
["test", "nested", "deep"],
True,
{"extra": False, "test": {"nested": {"deep": True}}},
),
(
{"extra": False, "test": {"test2": 5, "nested": {}}},
["test", "nested", "deep"],
True,
{"extra": False, "test": {"test2": 5, "nested": {"deep": True}}},
),
),
)
def test_nested_dict_path_set(data, path, value, expected):
assert nested_dict_path_set(data, path, value) == expected


@pytest.mark.parametrize(
"headers,row,expected",
(
([["a"], ["b"]], (3, 7), {"a": 3, "b": 7}),
([["a"], ["b", "c"]], (3, 7), {"a": 3, "b": {"c": 7}}),
(
[["a", "b"], ["a", "c"], ["d"], ["a", "e", "f"]],
(1, 5, "test", 11),
{"a": {"b": 1, "c": 5, "e": {"f": 11}}, "d": "test"},
),
),
)
def test_row_to_nested_dict(headers, row, expected):
assert row_to_nested_dict(headers, row) == expected