Skip to content

Commit

Permalink
Remove Entry class and use File instead
Browse files Browse the repository at this point in the history
  • Loading branch information
rlamy committed Sep 12, 2024
1 parent 21a80e9 commit 4f741e6
Show file tree
Hide file tree
Showing 16 changed files with 59 additions and 164 deletions.
13 changes: 4 additions & 9 deletions src/datachain/asyn.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@
import asyncio
from collections.abc import Awaitable, Coroutine, Iterable
from collections.abc import AsyncIterable, Awaitable, Coroutine, Iterable, Iterator
from concurrent.futures import ThreadPoolExecutor
from heapq import heappop, heappush
from typing import (
Any,
Callable,
Generic,
Optional,
TypeVar,
)
from typing import Any, Callable, Generic, Optional, TypeVar

from fsspec.asyn import get_loop

ASYNC_WORKERS = 20

InputT = TypeVar("InputT", contravariant=True) # noqa: PLC0105
ResultT = TypeVar("ResultT", covariant=True) # noqa: PLC0105
T = TypeVar("T")


class AsyncMapper(Generic[InputT, ResultT]):
Expand Down Expand Up @@ -226,7 +221,7 @@ async def _break_iteration(self) -> None:
self._push_result(self._next_yield, None)


def iter_over_async(ait, loop):
def iter_over_async(ait: AsyncIterable[T], loop) -> Iterator[T]:
"""Wrap an asynchronous iterator into a synchronous one"""
ait = ait.__aiter__()

Expand Down
14 changes: 1 addition & 13 deletions src/datachain/client/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from tqdm import tqdm

from datachain.lib.file import File
from datachain.node import Entry

from .fsspec import DELIMITER, Client, ResultQueue

Expand All @@ -14,17 +13,6 @@ class AzureClient(Client):
PREFIX = "az://"
protocol = "az"

def convert_info(self, v: dict[str, Any], path: str) -> Entry:
version_id = v.get("version_id")
return Entry.from_file(
path=path,
etag=v.get("etag", "").strip('"'),
version=version_id or "",
is_latest=version_id is None or bool(v.get("is_current_version")),
last_modified=v["last_modified"],
size=v.get("size", ""),
)

def info_to_file(self, v: dict[str, Any], path: str) -> File:
version_id = v.get("version_id")
return File(
Expand Down Expand Up @@ -57,7 +45,7 @@ async def _fetch_flat(self, start_prefix: str, result_queue: ResultQueue) -> Non
continue
info = (await self.fs._details([b]))[0]
entries.append(
self.convert_info(info, self.rel_path(info["name"]))
self.info_to_file(info, self.rel_path(info["name"]))
)
if entries:
await result_queue.put(entries)
Expand Down
14 changes: 7 additions & 7 deletions src/datachain/client/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from datachain.cache import DataChainCache, UniqueId
from datachain.client.fileslice import FileSlice, FileWrapper
from datachain.error import ClientError as DataChainClientError
from datachain.node import Entry
from datachain.lib.file import File
from datachain.nodes_fetcher import NodesFetcher
from datachain.nodes_thread_pool import NodeChunk
from datachain.storage import StorageURI
Expand All @@ -45,7 +45,7 @@

DATA_SOURCE_URI_PATTERN = re.compile(r"^[\w]+:\/\/.*$")

ResultQueue = asyncio.Queue[Optional[Sequence[Entry]]]
ResultQueue = asyncio.Queue[Optional[Sequence[File]]]


def _is_win_local_path(uri: str) -> bool:
Expand Down Expand Up @@ -188,7 +188,7 @@ def url(self, path: str, expires: int = 3600, **kwargs) -> str:

async def get_current_etag(self, uid: UniqueId) -> str:
info = await self.fs._info(self.get_full_path(uid.path))
return self.convert_info(info, "").etag
return self.info_to_file(info, "").etag

async def get_size(self, path: str) -> int:
return await self.fs._size(path)
Expand All @@ -198,7 +198,7 @@ async def get_file(self, lpath, rpath, callback):

async def scandir(
self, start_prefix: str, method: str = "default"
) -> AsyncIterator[Sequence[Entry]]:
) -> AsyncIterator[Sequence[File]]:
try:
impl = getattr(self, f"_fetch_{method}")
except AttributeError:
Expand Down Expand Up @@ -264,7 +264,7 @@ async def _fetch_default(
) -> None:
await self._fetch_nested(start_prefix, result_queue)

async def _fetch_dir(self, prefix, pbar, result_queue) -> set[str]:
async def _fetch_dir(self, prefix, pbar, result_queue: ResultQueue) -> set[str]:
path = f"{self.name}/{prefix}"
infos = await self.ls_dir(path)
files = []
Expand All @@ -277,7 +277,7 @@ async def _fetch_dir(self, prefix, pbar, result_queue) -> set[str]:
if info["type"] == "directory":
subdirs.add(subprefix)
else:
files.append(self.convert_info(info, subprefix))
files.append(self.info_to_file(info, subprefix))
if files:
await result_queue.put(files)
found_count = len(subdirs) + len(files)
Expand All @@ -303,7 +303,7 @@ def get_full_path(self, rel_path: str) -> str:
return f"{self.PREFIX}{self.name}/{rel_path}"

@abstractmethod
def convert_info(self, v: dict[str, Any], parent: str) -> Entry: ...
def info_to_file(self, v: dict[str, Any], parent: str) -> File: ...

def fetch_nodes(
self,
Expand Down
15 changes: 2 additions & 13 deletions src/datachain/client/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from tqdm import tqdm

from datachain.lib.file import File
from datachain.node import Entry

from .fsspec import DELIMITER, Client, ResultQueue

Expand Down Expand Up @@ -108,19 +107,9 @@ async def _get_pages(self, path: str, page_queue: PageQueue) -> None:
finally:
await page_queue.put(None)

def _entry_from_dict(self, d: dict[str, Any]) -> Entry:
def _entry_from_dict(self, d: dict[str, Any]) -> File:
info = self.fs._process_object(self.name, d)
return self.convert_info(info, self.rel_path(info["name"]))

def convert_info(self, v: dict[str, Any], path: str) -> Entry:
return Entry.from_file(
path=path,
etag=v.get("etag", ""),
version=v.get("generation", ""),
is_latest=not v.get("timeDeleted"),
last_modified=self.parse_timestamp(v["updated"]),
size=v.get("size", ""),
)
return self.info_to_file(info, self.rel_path(info["name"]))

def info_to_file(self, v: dict[str, Any], path: str) -> File:
return File(
Expand Down
10 changes: 0 additions & 10 deletions src/datachain/client/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from huggingface_hub import HfFileSystem

from datachain.lib.file import File
from datachain.node import Entry

from .fsspec import Client

Expand All @@ -22,15 +21,6 @@ def create_fs(cls, **kwargs) -> HfFileSystem:

return cast(HfFileSystem, super().create_fs(**kwargs))

def convert_info(self, v: dict[str, Any], path: str) -> Entry:
return Entry.from_file(
path=path,
size=v["size"],
version=v["last_commit"].oid,
etag=v.get("blob_id", ""),
last_modified=v["last_commit"].date,
)

def info_to_file(self, v: dict[str, Any], path: str) -> File:
return File(
path=path,
Expand Down
15 changes: 3 additions & 12 deletions src/datachain/client/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

from fsspec.implementations.local import LocalFileSystem

from datachain.cache import UniqueId
from datachain.lib.file import File
from datachain.node import Entry
from datachain.storage import StorageURI

from .fsspec import Client
Expand Down Expand Up @@ -114,9 +114,9 @@ def from_source(
use_symlinks=use_symlinks,
)

async def get_current_etag(self, uid) -> str:
async def get_current_etag(self, uid: UniqueId) -> str:
info = self.fs.info(self.get_full_path(uid.path))
return self.convert_info(info, "").etag
return self.info_to_file(info, "").etag

async def get_size(self, path: str) -> int:
return self.fs.size(path)
Expand All @@ -136,15 +136,6 @@ def get_full_path(self, rel_path):
full_path += "/"
return full_path

def convert_info(self, v: dict[str, Any], path: str) -> Entry:
return Entry.from_file(
path=path,
etag=v["mtime"].hex(),
is_latest=True,
last_modified=datetime.fromtimestamp(v["mtime"], timezone.utc),
size=v.get("size", ""),
)

def info_to_file(self, v: dict[str, Any], path: str) -> File:
return File(
source=self.uri,
Expand Down
28 changes: 9 additions & 19 deletions src/datachain/client/s3.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import asyncio
from typing import Any, cast
from typing import Any, Optional, cast

from botocore.exceptions import NoCredentialsError
from s3fs import S3FileSystem
from tqdm import tqdm

from datachain.lib.file import File
from datachain.node import Entry

from .fsspec import DELIMITER, Client, ResultQueue

Expand Down Expand Up @@ -111,8 +110,9 @@ async def _fetch_default(
) -> None:
await self._fetch_flat(start_prefix, result_queue)

def _entry_from_boto(self, v, bucket, versions=False):
return Entry.from_file(
def _entry_from_boto(self, v, bucket, versions=False) -> File:
return File(
source=self.uri,
path=v["Key"],
etag=v.get("ETag", "").strip('"'),
version=ClientS3.clean_s3_version(v.get("VersionId", "")),
Expand All @@ -125,8 +125,8 @@ async def _fetch_dir(
self,
prefix,
pbar,
result_queue,
):
result_queue: ResultQueue,
) -> set[str]:
if prefix:
prefix = prefix.lstrip(DELIMITER) + DELIMITER
files = []
Expand All @@ -141,7 +141,7 @@ async def _fetch_dir(
if info["type"] == "directory":
subdirs.add(subprefix)
else:
files.append(self.convert_info(info, subprefix))
files.append(self.info_to_file(info, subprefix))
pbar.update()
found = True
if not found:
Expand All @@ -152,18 +152,8 @@ async def _fetch_dir(
return subdirs

@staticmethod
def clean_s3_version(ver):
return ver if ver != "null" else ""

def convert_info(self, v: dict[str, Any], path: str) -> Entry:
return Entry.from_file(
path=path,
etag=v.get("ETag", "").strip('"'),
version=ClientS3.clean_s3_version(v.get("VersionId", "")),
is_latest=v.get("IsLatest", True),
last_modified=v.get("LastModified", ""),
size=v["size"],
)
def clean_s3_version(ver: Optional[str]) -> str:
return ver if (ver is not None and ver != "null") else ""

def info_to_file(self, v: dict[str, Any], path: str) -> File:
return File(
Expand Down
5 changes: 5 additions & 0 deletions src/datachain/data_storage/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.types import TypeEngine

from datachain.lib.file import File


logger = logging.getLogger("datachain")

Expand Down Expand Up @@ -708,6 +710,9 @@ def merge_dataset_rows(

self.db.execute(insert_query)

def prepare_entries(self, entries: "Iterable[File]") -> Iterable[dict[str, Any]]:
return (e.model_dump() for e in entries)

def insert_rows(self, table: Table, rows: Iterable[dict[str, Any]]) -> None:
rows = list(rows)
if not rows:
Expand Down
17 changes: 5 additions & 12 deletions src/datachain/data_storage/warehouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from datachain.data_storage.schema import convert_rows_custom_column_types
from datachain.data_storage.serializer import Serializable
from datachain.dataset import DatasetRecord
from datachain.node import DirType, DirTypeGroup, Entry, Node, NodeWithPath, get_path
from datachain.node import DirType, DirTypeGroup, Node, NodeWithPath, get_path
from datachain.sql.functions import path as pathfunc
from datachain.sql.types import Int, SQLType
from datachain.storage import StorageURI
Expand All @@ -34,6 +34,7 @@
from datachain.data_storage import AbstractIDGenerator, schema
from datachain.data_storage.db_engine import DatabaseEngine
from datachain.data_storage.schema import DataTable
from datachain.lib.file import File

try:
import numpy as np
Expand Down Expand Up @@ -410,17 +411,9 @@ def dataset_stats(
((nrows, *rest),) = self.db.execute(query)
return nrows, rest[0] if rest else 0

def prepare_entries(
self, uri: str, entries: Iterable[Entry]
) -> list[dict[str, Any]]:
"""
Prepares bucket listing entry (row) for inserting into database
"""

def _prepare_entry(entry: Entry):
return attrs.asdict(entry) | {"source": uri}

return [_prepare_entry(e) for e in entries]
@abstractmethod
def prepare_entries(self, entries: "Iterable[File]") -> Iterable[dict[str, Any]]:
"""Convert File entries so they can be passed on to `insert_rows()`"""

@abstractmethod
def insert_rows(self, table: Table, rows: Iterable[dict[str, Any]]) -> None:
Expand Down
3 changes: 1 addition & 2 deletions src/datachain/lib/listing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ def list_func() -> Iterator[File]:
config = client_config or {}
client, path = Client.parse_url(uri, None, **config) # type: ignore[arg-type]
for entries in iter_over_async(client.scandir(path.rstrip("/")), get_loop()):
for entry in entries:
yield entry.to_file(client.uri)
yield from entries

return list_func

Expand Down
14 changes: 6 additions & 8 deletions src/datachain/listing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from sqlalchemy.sql import func
from tqdm import tqdm

from datachain.node import DirType, Entry, Node, NodeWithPath
from datachain.lib.file import File
from datachain.node import DirType, Node, NodeWithPath
from datachain.sql.functions import path as pathfunc
from datachain.utils import suffix_to_number

Expand Down Expand Up @@ -80,16 +81,13 @@ async def _fetch(self, start_prefix: str, method: str) -> None:
finally:
fetch_listing.insert_entries_done()

def insert_entry(self, entry: Entry) -> None:
self.warehouse.insert_rows(
self.dataset_rows.get_table(),
self.warehouse.prepare_entries(self.client.uri, [entry]),
)
def insert_entry(self, entry: File) -> None:
self.insert_entries([entry])

def insert_entries(self, entries: Iterable[Entry]) -> None:
def insert_entries(self, entries: Iterable[File]) -> None:
self.warehouse.insert_rows(
self.dataset_rows.get_table(),
self.warehouse.prepare_entries(self.client.uri, entries),
self.warehouse.prepare_entries(entries),
)

def insert_entries_done(self) -> None:
Expand Down
Loading

0 comments on commit 4f741e6

Please sign in to comment.