Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CLI to work with DataChain new listing #517

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 57 additions & 186 deletions src/datachain/catalog/catalog.py

Large diffs are not rendered by default.

12 changes: 0 additions & 12 deletions src/datachain/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,12 +249,6 @@ def get_parser() -> ArgumentParser: # noqa: PLR0915
action="store_true",
help="AWS anon (aka awscli's --no-sign-request)",
)
parent_parser.add_argument(
"--ttl",
type=human_time_type,
default=TTL_HUMAN,
help="Time-to-live of data source cache. Negative equals forever.",
)
parent_parser.add_argument(
"-u", "--update", action="count", default=0, help="Update cache"
)
Expand Down Expand Up @@ -1011,7 +1005,6 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
edatachain_only=False,
no_edatachain_file=True,
no_glob=args.no_glob,
ttl=args.ttl,
)
elif args.command == "clone":
catalog.clone(
Expand All @@ -1021,7 +1014,6 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
update=bool(args.update),
recursive=bool(args.recursive),
no_glob=args.no_glob,
ttl=args.ttl,
no_cp=args.no_cp,
edatachain=args.edatachain,
edatachain_file=args.edatachain_file,
Expand All @@ -1047,7 +1039,6 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
args.sources,
long=bool(args.long),
remote=args.remote,
ttl=args.ttl,
update=bool(args.update),
client_config=client_config,
)
Expand Down Expand Up @@ -1081,15 +1072,13 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
show_bytes=args.bytes,
depth=args.depth,
si=args.si,
ttl=args.ttl,
update=bool(args.update),
client_config=client_config,
)
elif args.command == "find":
results_found = False
for result in catalog.find(
args.sources,
ttl=args.ttl,
update=bool(args.update),
names=args.name,
inames=args.iname,
Expand All @@ -1107,7 +1096,6 @@ def main(argv: Optional[list[str]] = None) -> int: # noqa: C901, PLR0912, PLR09
index(
catalog,
args.sources,
ttl=args.ttl,
update=bool(args.update),
)
elif args.command == "completion":
Expand Down
99 changes: 66 additions & 33 deletions src/datachain/data_storage/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@
from sqlalchemy.sql.elements import ColumnElement


DEFAULT_DELIMITER = "__"


def col_name(name: str, object_name: str = "file") -> str:
return f"{object_name}{DEFAULT_DELIMITER}{name}"


def dedup_columns(columns: Iterable[sa.Column]) -> list[sa.Column]:
"""
Removes duplicate columns from a list of columns.
Expand Down Expand Up @@ -76,64 +83,81 @@ def convert_rows_custom_column_types(


class DirExpansion:
@staticmethod
def base_select(q):
def __init__(self, object_name: str):
self.object_name = object_name

def col_name(self, name: str, object_name: Optional[str] = None) -> str:
object_name = object_name or self.object_name
return col_name(name, object_name)

def c(self, query, name: str, object_name: Optional[str] = None) -> str:
return getattr(query.c, self.col_name(name, object_name=object_name))

def base_select(self, q):
return sa.select(
q.c.sys__id,
false().label("is_dir"),
q.c.source,
q.c.path,
q.c.version,
q.c.location,
self.c(q, "id", object_name="sys"),
false().label(self.col_name("is_dir")),
self.c(q, "source"),
self.c(q, "path"),
self.c(q, "version"),
self.c(q, "location"),
)

@staticmethod
def apply_group_by(q):
def apply_group_by(self, q):
return (
sa.select(
f.min(q.c.sys__id).label("sys__id"),
q.c.is_dir,
q.c.source,
q.c.path,
q.c.version,
f.max(q.c.location).label("location"),
self.c(q, "is_dir"),
self.c(q, "source"),
self.c(q, "path"),
self.c(q, "version"),
f.max(self.c(q, "location")).label(self.col_name("location")),
)
.select_from(q)
.group_by(q.c.source, q.c.path, q.c.is_dir, q.c.version)
.order_by(q.c.source, q.c.path, q.c.is_dir, q.c.version)
.group_by(
self.c(q, "source"),
self.c(q, "path"),
self.c(q, "is_dir"),
self.c(q, "version"),
)
.order_by(
self.c(q, "source"),
self.c(q, "path"),
self.c(q, "is_dir"),
self.c(q, "version"),
)
)

@classmethod
def query(cls, q):
q = cls.base_select(q).cte(recursive=True)
parent = path.parent(q.c.path)
def query(self, q):
q = self.base_select(q).cte(recursive=True)
parent = path.parent(self.c(q, "path"))
q = q.union_all(
sa.select(
sa.literal(-1).label("sys__id"),
true().label("is_dir"),
q.c.source,
parent.label("path"),
sa.literal("").label("version"),
null().label("location"),
true().label(self.col_name("is_dir")),
self.c(q, "source"),
parent.label(self.col_name("path")),
sa.literal("").label(self.col_name("version")),
null().label(self.col_name("location")),
).where(parent != "")
)
return cls.apply_group_by(q)
return self.apply_group_by(q)


class DataTable:
dataset_dir_expansion = staticmethod(DirExpansion.query)

def __init__(
self,
name: str,
engine: "Engine",
metadata: Optional["sa.MetaData"] = None,
column_types: Optional[dict[str, SQLType]] = None,
object_name: str = "file",
):
self.name: str = name
self.engine = engine
self.metadata: sa.MetaData = metadata if metadata is not None else sa.MetaData()
self.column_types: dict[str, SQLType] = column_types or {}
self.object_name = object_name

@staticmethod
def copy_column(
Expand Down Expand Up @@ -204,9 +228,18 @@ def get_table(self) -> "sa.Table":
def columns(self) -> "ReadOnlyColumnCollection[str, sa.Column[Any]]":
return self.table.columns

@property
def c(self):
return self.columns
def col_name(self, name: str, object_name: Optional[str] = None) -> str:
object_name = object_name or self.object_name
return col_name(name, object_name)

def without_object(
self, column_name: str, object_name: Optional[str] = None
) -> str:
object_name = object_name or self.object_name
return column_name.removeprefix(f"{object_name}{DEFAULT_DELIMITER}")

def c(self, name: str, object_name: Optional[str] = None):
return getattr(self.columns, self.col_name(name, object_name=object_name))

@property
def table(self) -> "sa.Table":
Expand Down Expand Up @@ -246,7 +279,7 @@ def sys_columns():
]

def dir_expansion(self):
return self.dataset_dir_expansion(self)
return DirExpansion(self.object_name)


PARTITION_COLUMN_ID = "partition_id"
Expand Down
8 changes: 4 additions & 4 deletions src/datachain/data_storage/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@
self, dataset: DatasetRecord, version: int
) -> list[StorageURI]:
dr = self.dataset_rows(dataset, version)
query = dr.select(dr.c.file__source).distinct()
query = dr.select(dr.c("source", object_name="file")).distinct()
cur = self.db.cursor()
cur.row_factory = sqlite3.Row # type: ignore[assignment]

Expand Down Expand Up @@ -671,13 +671,13 @@
# destination table doesn't exist, create it
self.create_dataset_rows_table(
self.dataset_table_name(dst.name, dst_version),
columns=src_dr.c,
columns=src_dr.columns,
)
dst_empty = True

dst_dr = self.dataset_rows(dst, dst_version).table
merge_fields = [c.name for c in src_dr.c if c.name != "sys__id"]
select_src = select(*(getattr(src_dr.c, f) for f in merge_fields))
merge_fields = [c.name for c in src_dr.columns if c.name != "sys__id"]
select_src = select(*(getattr(src_dr.columns, f) for f in merge_fields))

Check warning on line 680 in src/datachain/data_storage/sqlite.py

View check run for this annotation

Codecov / codecov/patch

src/datachain/data_storage/sqlite.py#L679-L680

Added lines #L679 - L680 were not covered by tests

if dst_empty:
# we don't need union, but just select from source to destination
Expand Down
Loading
Loading