From 33427249260dc1d64eb1a3dfefc53818ce99c251 Mon Sep 17 00:00:00 2001 From: Vladimir Rudnyh Date: Fri, 27 Sep 2024 21:44:38 +0700 Subject: [PATCH 1/5] Implement chain group_by --- src/datachain/__init__.py | 2 + src/datachain/data_storage/sqlite.py | 8 ++++ src/datachain/lib/dc.py | 67 ++++++++++++++++++++++++++++ src/datachain/lib/func.py | 28 ++++++++++++ src/datachain/query/dataset.py | 26 ++++++++++- 5 files changed, 129 insertions(+), 2 deletions(-) create mode 100644 src/datachain/lib/func.py diff --git a/src/datachain/__init__.py b/src/datachain/__init__.py index e8bbc00bf..98c7b8641 100644 --- a/src/datachain/__init__.py +++ b/src/datachain/__init__.py @@ -1,3 +1,4 @@ +from datachain.lib import func from datachain.lib.data_model import DataModel, DataType, is_chain_type from datachain.lib.dc import C, Column, DataChain, Sys from datachain.lib.file import ( @@ -34,6 +35,7 @@ "Sys", "TarVFile", "TextFile", + "func", "is_chain_type", "metrics", "param", diff --git a/src/datachain/data_storage/sqlite.py b/src/datachain/data_storage/sqlite.py index 39dffdcb2..bd0868e66 100644 --- a/src/datachain/data_storage/sqlite.py +++ b/src/datachain/data_storage/sqlite.py @@ -763,6 +763,14 @@ def copy_table( query: Select, progress_cb: Optional[Callable[[int], None]] = None, ) -> None: + if len(query._group_by_clause) > 0: + select_q = query.with_only_columns( + *[c for c in query.selected_columns if c.name != "sys__id"] + ) + q = table.insert().from_select(list(select_q.selected_columns), select_q) + self.db.execute(q) + return + if "sys__id" in query.selected_columns: col_id = query.selected_columns.sys__id else: diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index e66abcd99..78b7ba802 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -29,6 +29,7 @@ from datachain.lib.dataset_info import DatasetInfo from datachain.lib.file import ArrowRow, File, get_file_type from datachain.lib.file import ExportPlacement as FileExportPlacement +from datachain.lib.func import Func from datachain.lib.listing import ( is_listing_dataset, is_listing_expired, @@ -1011,6 +1012,72 @@ def select_except(self, *args: str) -> "Self": query=self._query.select(*columns), signal_schema=new_schema ) + def group_by( + self, + *, + partition_by: Union[str, Sequence[str]], + **kwargs: Func, + ) -> "Self": + """Groups by specified set of signals.""" + if not kwargs: + raise ValueError("At least one column should be provided for group_by") + + partition_by = [partition_by] if isinstance(partition_by, str) else partition_by + if not partition_by: + raise ValueError("At least one column should be provided for partition_by") + + all_columns = { + DEFAULT_DELIMITER.join(path): _type + for path, _type, has_subtree, _ in self.signals_schema.get_flat_tree() + if not has_subtree + } + + partition_by_columns = [] + schema_fields = {} + for col_name in partition_by: + col_type = all_columns.get(col_name) + if col_type is None: + raise DataChainColumnError( + col_name, f"Column {col_name} not found in schema" + ) + column = Column(col_name, python_to_sql(col_type)) + partition_by_columns.append(column) + schema_fields[col_name] = col_type + + select_columns = [] + for field, func in kwargs.items(): + cols = [] + result_type = func.result_type + for col_name in func.cols: + col_type = all_columns.get(col_name) + if col_type is None: + raise DataChainColumnError( + col_name, f"Column {col_name} not found in schema" + ) + cols.append(Column(col_name, python_to_sql(col_type))) + if result_type is None: + result_type = col_type + elif col_type != result_type: + raise DataChainColumnError( + col_name, + ( + f"Column {col_name} has type {col_type}" + f"but expected {result_type}" + ), + ) + if result_type is None: + raise ValueError( + f"Cannot infer type for function {func} with columns {func.cols}" + ) + + select_columns.append(func.inner(*cols).label(field)) + schema_fields[field] = result_type + + return self._evolve( + query=self._query.group_by(select_columns, partition_by_columns), + signal_schema=SignalSchema(schema_fields), + ) + def mutate(self, **kwargs) -> "Self": """Create new signals based on existing signals. diff --git a/src/datachain/lib/func.py b/src/datachain/lib/func.py new file mode 100644 index 000000000..61f163af3 --- /dev/null +++ b/src/datachain/lib/func.py @@ -0,0 +1,28 @@ +from typing import TYPE_CHECKING, Callable, Optional + +from sqlalchemy import func + +if TYPE_CHECKING: + from datachain import DataType + + +class Func: + result_type: Optional["DataType"] = None + + def __init__( + self, + inner: Callable, + cols: tuple[str, ...], + result_type: Optional["DataType"] = None, + ) -> None: + self.inner = inner + self.cols = [col.replace(".", "__") for col in cols] + self.result_type = result_type + + +def sum(*cols: str) -> Func: + return Func(inner=func.sum, cols=cols) + + +def count(*cols: str) -> Func: + return Func(inner=func.count, cols=cols, result_type=int) diff --git a/src/datachain/query/dataset.py b/src/datachain/query/dataset.py index 67559317c..3c370e4b3 100644 --- a/src/datachain/query/dataset.py +++ b/src/datachain/query/dataset.py @@ -956,6 +956,24 @@ def q(*columns): ) +@frozen +class SQLGroupBy(SQLClause): + cols: Sequence[Union[str, ColumnElement]] + group_by: Sequence[Union[str, ColumnElement]] + + def apply_sql_clause(self, query) -> Select: + subquery = query.subquery() + + cols = [ + subquery.c[str(c)] if isinstance(c, (str, C)) else c + for c in [*self.group_by, *self.cols] + ] + if not cols: + cols = subquery.c + + return sqlalchemy.select(*cols).select_from(subquery).group_by(*self.group_by) + + @frozen class GroupBy(Step): """Group rows by a specific column.""" @@ -1410,9 +1428,13 @@ def max(self, col: ColumnElement) -> int: return query.as_scalar() @detach - def group_by(self, *cols: ColumnElement) -> "Self": + def group_by( + self, + cols: Sequence[ColumnElement], + group_by: Sequence[ColumnElement], + ) -> "Self": query = self.clone() - query.steps.append(GroupBy(cols)) + query.steps.append(SQLGroupBy(cols, group_by)) return query @detach From e6d628f49ee7e008beff21d66ac9279cbeb31733 Mon Sep 17 00:00:00 2001 From: Vladimir Rudnyh Date: Tue, 1 Oct 2024 00:31:13 +0700 Subject: [PATCH 2/5] Add more group_by aggregate functions --- src/datachain/lib/dc.py | 64 +++---- src/datachain/lib/func.py | 36 +++- src/datachain/lib/signal_schema.py | 7 + src/datachain/query/dataset.py | 39 +--- tests/func/test_datachain.py | 295 +++++++++++++++++++++++++++++ tests/func/test_dataset_query.py | 34 ---- 6 files changed, 367 insertions(+), 108 deletions(-) diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index 78b7ba802..ad120ad72 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -1019,62 +1019,60 @@ def group_by( **kwargs: Func, ) -> "Self": """Groups by specified set of signals.""" - if not kwargs: - raise ValueError("At least one column should be provided for group_by") - partition_by = [partition_by] if isinstance(partition_by, str) else partition_by if not partition_by: raise ValueError("At least one column should be provided for partition_by") - all_columns = { - DEFAULT_DELIMITER.join(path): _type - for path, _type, has_subtree, _ in self.signals_schema.get_flat_tree() - if not has_subtree - } + if not kwargs: + raise ValueError("At least one column should be provided for group_by") + for col_name, func in kwargs.items(): + if not isinstance(func, Func): + raise DataChainColumnError( + col_name, + f"Column {col_name} has type {type(func)} but expected Func object", + ) - partition_by_columns = [] + schema_columns = self.signals_schema.db_columns_types() schema_fields = {} + + # validate partition_by columns and add them to the schema + partition_by_columns: list[Column] = [] for col_name in partition_by: - col_type = all_columns.get(col_name) + db_col_name = col_name.replace(".", DEFAULT_DELIMITER) + col_type = schema_columns.get(db_col_name) if col_type is None: raise DataChainColumnError( col_name, f"Column {col_name} not found in schema" ) - column = Column(col_name, python_to_sql(col_type)) - partition_by_columns.append(column) - schema_fields[col_name] = col_type + partition_by_columns.append(Column(db_col_name, python_to_sql(col_type))) + schema_fields[db_col_name] = col_type - select_columns = [] - for field, func in kwargs.items(): - cols = [] + # validate signal columns and add them to the schema + signal_columns: list[Column] = [] + for col_name, func in kwargs.items(): result_type = func.result_type - for col_name in func.cols: - col_type = all_columns.get(col_name) + if func.col is None: + signal_columns.append(func.inner().label(col_name)) + else: + col_type = schema_columns.get(func.col) if col_type is None: raise DataChainColumnError( - col_name, f"Column {col_name} not found in schema" + func.col, f"Column {func.col} not found in schema" ) - cols.append(Column(col_name, python_to_sql(col_type))) if result_type is None: result_type = col_type - elif col_type != result_type: - raise DataChainColumnError( - col_name, - ( - f"Column {col_name} has type {col_type}" - f"but expected {result_type}" - ), - ) + col = Column(func.col, python_to_sql(col_type)) + signal_columns.append(func.inner(col).label(col_name)) + if result_type is None: - raise ValueError( - f"Cannot infer type for function {func} with columns {func.cols}" + raise DataChainColumnError( + col_name, f"Cannot infer type for function {func}" ) - select_columns.append(func.inner(*cols).label(field)) - schema_fields[field] = result_type + schema_fields[col_name] = result_type return self._evolve( - query=self._query.group_by(select_columns, partition_by_columns), + query=self._query.group_by(signal_columns, partition_by_columns), signal_schema=SignalSchema(schema_fields), ) diff --git a/src/datachain/lib/func.py b/src/datachain/lib/func.py index 61f163af3..da5014388 100644 --- a/src/datachain/lib/func.py +++ b/src/datachain/lib/func.py @@ -1,6 +1,9 @@ from typing import TYPE_CHECKING, Callable, Optional -from sqlalchemy import func +from sqlalchemy import func as sa_func + +from datachain.query.schema import DEFAULT_DELIMITER +from datachain.sql import functions as dc_func if TYPE_CHECKING: from datachain import DataType @@ -12,17 +15,36 @@ class Func: def __init__( self, inner: Callable, - cols: tuple[str, ...], + col: Optional[str] = None, result_type: Optional["DataType"] = None, ) -> None: self.inner = inner - self.cols = [col.replace(".", "__") for col in cols] + self.col = col.replace(".", DEFAULT_DELIMITER) if col else None self.result_type = result_type -def sum(*cols: str) -> Func: - return Func(inner=func.sum, cols=cols) +def count(col: Optional[str] = None) -> Func: + return Func(inner=sa_func.count, col=col, result_type=int) + + +def sum(col: str) -> Func: + return Func(inner=sa_func.sum, col=col) + + +def avg(col: str) -> Func: + return Func(inner=dc_func.avg, col=col) + + +def min(col: str) -> Func: + return Func(inner=sa_func.min, col=col) + + +def max(col: str) -> Func: + return Func(inner=sa_func.max, col=col) + +def concat(col: str, separator="") -> Func: + def inner(arg): + return sa_func.aggregate_strings(arg, separator) -def count(*cols: str) -> Func: - return Func(inner=func.count, cols=cols, result_type=int) + return Func(inner=inner, col=col, result_type=str) diff --git a/src/datachain/lib/signal_schema.py b/src/datachain/lib/signal_schema.py index b6e4573d4..9e25de074 100644 --- a/src/datachain/lib/signal_schema.py +++ b/src/datachain/lib/signal_schema.py @@ -400,6 +400,13 @@ def _set_file_stream( if ModelStore.is_pydantic(finfo.annotation): SignalSchema._set_file_stream(getattr(obj, field), catalog, cache) + def db_columns_types(self) -> dict[str, type]: + return { + DEFAULT_DELIMITER.join(path): _type + for path, _type, has_subtree, _ in self.get_flat_tree() + if not has_subtree + } + def db_signals( self, name: Optional[str] = None, as_columns=False ) -> Union[list[str], list[Column]]: diff --git a/src/datachain/query/dataset.py b/src/datachain/query/dataset.py index 3c370e4b3..05a68dc33 100644 --- a/src/datachain/query/dataset.py +++ b/src/datachain/query/dataset.py @@ -962,39 +962,21 @@ class SQLGroupBy(SQLClause): group_by: Sequence[Union[str, ColumnElement]] def apply_sql_clause(self, query) -> Select: + if not self.cols: + raise ValueError("No columns to select") + if not self.group_by: + raise ValueError("No columns to group by") + subquery = query.subquery() cols = [ subquery.c[str(c)] if isinstance(c, (str, C)) else c for c in [*self.group_by, *self.cols] ] - if not cols: - cols = subquery.c return sqlalchemy.select(*cols).select_from(subquery).group_by(*self.group_by) -@frozen -class GroupBy(Step): - """Group rows by a specific column.""" - - cols: PartitionByType - - def clone(self) -> "Self": - return self.__class__(self.cols) - - def apply( - self, query_generator: QueryGenerator, temp_tables: list[str] - ) -> StepResult: - query = query_generator.select() - grouped_query = query.group_by(*self.cols) - - def q(*columns): - return grouped_query.with_only_columns(*columns) - - return step_result(q, grouped_query.selected_columns) - - def _validate_columns( left_columns: Iterable[ColumnElement], right_columns: Iterable[ColumnElement] ) -> set[str]: @@ -1148,25 +1130,14 @@ def apply_steps(self) -> QueryGenerator: query.steps = query.steps[-1:] + query.steps[:-1] result = query.starting_step.apply() - group_by = None self.dependencies.update(result.dependencies) for step in query.steps: - if isinstance(step, GroupBy): - if group_by is not None: - raise TypeError("only one group_by allowed") - group_by = step - continue - result = step.apply( result.query_generator, self.temp_table_names ) # a chain of steps linked by results self.dependencies.update(result.dependencies) - if group_by: - result = group_by.apply(result.query_generator, self.temp_table_names) - self.dependencies.update(result.dependencies) - return result.query_generator @staticmethod diff --git a/tests/func/test_datachain.py b/tests/func/test_datachain.py index 43da575cc..8fcb513da 100644 --- a/tests/func/test_datachain.py +++ b/tests/func/test_datachain.py @@ -15,10 +15,12 @@ from PIL import Image from sqlalchemy import Column +from datachain import DataModel from datachain.catalog.catalog import QUERY_SCRIPT_CANCELED_EXIT_CODE from datachain.client.local import FileClient from datachain.data_storage.sqlite import SQLiteWarehouse from datachain.dataset import DatasetDependencyType, DatasetStats +from datachain.lib import func from datachain.lib.dc import C, DataChain, DataChainColumnError from datachain.lib.file import File, ImageFile from datachain.lib.listing import ( @@ -47,6 +49,10 @@ def _get_listing_datasets(session): ) +def _sorted_records(list, *signals): + return sorted(list, key=lambda x: tuple(x[s] for s in signals)) + + @pytest.mark.parametrize("anon", [True, False]) def test_catalog_anon(tmp_dir, catalog, anon): chain = DataChain.from_storage(tmp_dir.as_uri(), anon=anon) @@ -1313,3 +1319,292 @@ def test_datachain_save_with_job(test_session, catalog, datachain_job_id): dataset = catalog.get_dataset("my-ds") result_job_id = dataset.get_version(dataset.latest_version).job_id assert result_job_id == datachain_job_id + + +def test_group_by_int(test_session): + ds = ( + DataChain.from_values( + col1=["a", "a", "b", "b", "b", "c"], + col2=[1, 2, 3, 4, 5, 6], + session=test_session, + ) + .group_by( + cnt=func.count(), + cnt_col=func.count("col2"), + sum=func.sum("col2"), + avg=func.avg("col2"), + min=func.min("col2"), + max=func.max("col2"), + partition_by="col1", + ) + .save("my-ds") + ) + + assert ds.signals_schema.serialize() == { + "col1": "str", + "cnt": "int", + "cnt_col": "int", + "sum": "int", + "avg": "int", + "min": "int", + "max": "int", + } + assert _sorted_records(ds.to_records(), "col1") == _sorted_records( + [ + { + "col1": "a", + "cnt": 2, + "cnt_col": 2, + "sum": 3, + "avg": 1.5, + "min": 1, + "max": 2, + }, + { + "col1": "b", + "cnt": 3, + "cnt_col": 3, + "sum": 12, + "avg": 4.0, + "min": 3, + "max": 5, + }, + { + "col1": "c", + "cnt": 1, + "cnt_col": 1, + "sum": 6, + "avg": 6.0, + "min": 6, + "max": 6, + }, + ], + "col1", + ) + + +def test_group_by_float(test_session): + ds = ( + DataChain.from_values( + col1=["a", "a", "b", "b", "b", "c"], + col2=[1.5, 2.5, 3.5, 4.5, 5.5, 6.5], + session=test_session, + ) + .group_by( + cnt=func.count(), + cnt_col=func.count("col2"), + sum=func.sum("col2"), + avg=func.avg("col2"), + min=func.min("col2"), + max=func.max("col2"), + partition_by="col1", + ) + .save("my-ds") + ) + + assert ds.signals_schema.serialize() == { + "col1": "str", + "cnt": "int", + "cnt_col": "int", + "sum": "float", + "avg": "float", + "min": "float", + "max": "float", + } + assert _sorted_records(ds.to_records(), "col1") == _sorted_records( + [ + { + "col1": "a", + "cnt": 2, + "cnt_col": 2, + "sum": 4.0, + "avg": 2.0, + "min": 1.5, + "max": 2.5, + }, + { + "col1": "b", + "cnt": 3, + "cnt_col": 3, + "sum": 13.5, + "avg": 4.5, + "min": 3.5, + "max": 5.5, + }, + { + "col1": "c", + "cnt": 1, + "cnt_col": 1, + "sum": 6.5, + "avg": 6.5, + "min": 6.5, + "max": 6.5, + }, + ], + "col1", + ) + + +def test_group_by_str(test_session): + ds = ( + DataChain.from_values( + col1=["a", "a", "b", "b", "b", "c"], + col2=["1", "2", "3", "4", "5", "6"], + session=test_session, + ) + .group_by( + cnt=func.count(), + cnt_col=func.count("col2"), + concat=func.concat("col2"), + concat_sep=func.concat("col2", separator=","), + partition_by="col1", + ) + .save("my-ds") + ) + + assert ds.signals_schema.serialize() == { + "col1": "str", + "cnt": "int", + "cnt_col": "int", + "concat": "str", + "concat_sep": "str", + } + assert _sorted_records(ds.to_records(), "col1") == _sorted_records( + [ + {"col1": "a", "cnt": 2, "cnt_col": 2, "concat": "12", "concat_sep": "1,2"}, + { + "col1": "b", + "cnt": 3, + "cnt_col": 3, + "concat": "345", + "concat_sep": "3,4,5", + }, + {"col1": "c", "cnt": 1, "cnt_col": 1, "concat": "6", "concat_sep": "6"}, + ], + "col1", + ) + + +def test_group_by_multiple_partition_by(test_session): + ds = ( + DataChain.from_values( + col1=["a", "a", "b", "b", "b", "c"], + col2=[1, 2, 1, 2, 1, 2], + col3=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + col4=["1", "2", "3", "4", "5", "6"], + session=test_session, + ) + .group_by( + cnt=func.count(), + cnt_col=func.count("col2"), + sum=func.sum("col3"), + concat=func.concat("col4"), + partition_by=("col1", "col2"), + ) + .save("my-ds") + ) + + assert ds.signals_schema.serialize() == { + "col1": "str", + "col2": "int", + "cnt": "int", + "cnt_col": "int", + "sum": "float", + "concat": "str", + } + assert _sorted_records(ds.to_records(), "col1", "col2") == _sorted_records( + [ + {"col1": "a", "col2": 1, "cnt": 1, "cnt_col": 1, "sum": 1.0, "concat": "1"}, + {"col1": "a", "col2": 2, "cnt": 1, "cnt_col": 1, "sum": 2.0, "concat": "2"}, + { + "col1": "b", + "col2": 1, + "cnt": 2, + "cnt_col": 2, + "sum": 8.0, + "concat": "35", + }, + {"col1": "b", "col2": 2, "cnt": 1, "cnt_col": 1, "sum": 4.0, "concat": "4"}, + {"col1": "c", "col2": 2, "cnt": 1, "cnt_col": 1, "sum": 6.0, "concat": "6"}, + ], + "col1", + "col2", + ) + + +def test_group_by_error(test_session): + dc = DataChain.from_values( + col1=["a", "a", "b", "b", "b", "c"], + col2=[1, 2, 3, 4, 5, 6], + session=test_session, + ) + + with pytest.raises(TypeError): + dc.group_by(cnt=func.count()) + + with pytest.raises( + ValueError, match="At least one column should be provided for partition_by" + ): + dc.group_by(cnt=func.count(), partition_by=()) + + with pytest.raises( + ValueError, match="At least one column should be provided for group_by" + ): + dc.group_by(partition_by="col1") + + with pytest.raises( + DataChainColumnError, + match="Column foo has type but expected Func object", + ): + dc.group_by(foo="col2", partition_by="col1") + + with pytest.raises(DataChainColumnError, match="Column col3 not found in schema"): + dc.group_by(foo=func.sum("col3"), partition_by="col1") + + with pytest.raises(DataChainColumnError, match="Column col3 not found in schema"): + dc.group_by(foo=func.sum("col2"), partition_by="col3") + + +@pytest.mark.parametrize("partition_by", ["file_info.path", "file_info__path"]) +@pytest.mark.parametrize("signal_name", ["file.size", "file__size"]) +def test_group_by_signals(cloud_test_catalog, partition_by, signal_name): + session = cloud_test_catalog.session + src_uri = cloud_test_catalog.src_uri + + class FileInfo(DataModel): + path: str = "" + name: str = "" + + def file_info(file: File) -> DataModel: + full_path = file.source.rstrip("/") + "/" + file.path + rel_path = posixpath.relpath(full_path, src_uri) + path_parts = rel_path.split("/", 1) + return FileInfo( + path=path_parts[0] if len(path_parts) > 1 else "", + name=path_parts[1] if len(path_parts) > 1 else path_parts[0], + ) + + ds = ( + DataChain.from_storage(src_uri, session=session) + .map(file_info, params=["file"], output={"file_info": FileInfo}) + .group_by( + cnt=func.count(), + sum=func.sum(signal_name), + partition_by=partition_by, + ) + .save("my-ds") + ) + + assert ds.signals_schema.serialize() == { + "file_info__path": "str", + "cnt": "int", + "sum": "int", + } + assert _sorted_records(ds.to_records(), "file_info__path") == _sorted_records( + [ + {"file_info__path": "", "cnt": 1, "sum": 13}, + {"file_info__path": "cats", "cnt": 2, "sum": 8}, + {"file_info__path": "dogs", "cnt": 4, "sum": 15}, + ], + "file_info__path", + ) diff --git a/tests/func/test_dataset_query.py b/tests/func/test_dataset_query.py index b32910260..37937824a 100644 --- a/tests/func/test_dataset_query.py +++ b/tests/func/test_dataset_query.py @@ -14,7 +14,6 @@ DatasetVersionNotFoundError, ) from datachain.query import C, DatasetQuery, Object, Stream -from datachain.sql import functions from datachain.sql.functions import path as pathfunc from datachain.sql.types import String from tests.utils import assert_row_names, dataset_dependency_asdict @@ -920,39 +919,6 @@ def test_aggregate(cloud_test_catalog, dogs_dataset): assert q.max(C("file.size")) == 4 -def test_group_by(cloud_test_catalog, cloud_type, dogs_dataset): - catalog = cloud_test_catalog.catalog - - q = ( - DatasetQuery(name=dogs_dataset.name, version=1, catalog=catalog) - .mutate(parent=pathfunc.parent(C("file.path"))) - .group_by(C.parent) - .select( - C.parent, - functions.count(), - functions.sum(C("file.size")), - functions.avg(C("file.size")), - functions.min(C("file.size")), - functions.max(C("file.size")), - ) - ) - result = q.db_results() - assert len(result) == 2 - - result_dict = {r[0]: r[1:] for r in result} - if cloud_type == "file": - assert result_dict == { - f"{cloud_test_catalog.partial_path}/dogs": (3, 11, 11 / 3, 3, 4), - f"{cloud_test_catalog.partial_path}/dogs/others": (1, 4, 4, 4, 4), - } - - else: - assert result_dict == { - "dogs": (3, 11, 11 / 3, 3, 4), - "dogs/others": (1, 4, 4, 4, 4), - } - - @pytest.mark.parametrize( "cloud_type,version_aware", [("s3", True)], From efcea64d4337bda3b5d289c69bdcb9d606e6b80d Mon Sep 17 00:00:00 2001 From: Vladimir Rudnyh Date: Tue, 1 Oct 2024 22:27:37 +0700 Subject: [PATCH 3/5] Implement custom 'group_concat' SQLAlchemy function --- src/datachain/lib/dc.py | 57 +++++++++++++++++++--------- src/datachain/lib/func.py | 2 +- src/datachain/sql/functions/array.py | 10 ++++- src/datachain/sql/sqlite/base.py | 5 +++ 4 files changed, 55 insertions(+), 19 deletions(-) diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index ad120ad72..e361bc2bd 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -983,10 +983,9 @@ def distinct(self, arg: str, *args: str) -> "Self": # type: ignore[override] row is left in the result set. Example: - ```py - dc.distinct("file.parent", "file.name") - ) - ``` + ```py + dc.distinct("file.parent", "file.name") + ``` """ return self._evolve( query=self._query.distinct( @@ -1015,11 +1014,33 @@ def select_except(self, *args: str) -> "Self": def group_by( self, *, - partition_by: Union[str, Sequence[str]], + partition_by: Union[ + Union[str, GenericFunction], Sequence[Union[str, GenericFunction]] + ], **kwargs: Func, ) -> "Self": - """Groups by specified set of signals.""" - partition_by = [partition_by] if isinstance(partition_by, str) else partition_by + """Group rows by specified set of signals and return new signals + with aggregated values. + + Example: + Using column name(s) in partition_by: + ```py + chain = chain.group_by( + cnt=func.count(), + partition_by=("file_source", "file_ext"), + ) + + Using GenericFunction in partition_by: + ```py + chain = chain.group_by( + total_size=func.sum("file.size"), + partition_by=func.file_ext(C("file.path")), + ) + ``` + """ + partition_by = ( + partition_by if isinstance(partition_by, (list, tuple)) else [partition_by] + ) if not partition_by: raise ValueError("At least one column should be provided for partition_by") @@ -1036,16 +1057,18 @@ def group_by( schema_fields = {} # validate partition_by columns and add them to the schema - partition_by_columns: list[Column] = [] - for col_name in partition_by: - db_col_name = col_name.replace(".", DEFAULT_DELIMITER) - col_type = schema_columns.get(db_col_name) - if col_type is None: - raise DataChainColumnError( - col_name, f"Column {col_name} not found in schema" - ) - partition_by_columns.append(Column(db_col_name, python_to_sql(col_type))) - schema_fields[db_col_name] = col_type + partition_by_columns: list[Union[Column, GenericFunction]] = [] + for col in partition_by: + if isinstance(col, GenericFunction): + partition_by_columns.append(col) + schema_fields[col.name] = col.type + else: + col_name = col.replace(".", DEFAULT_DELIMITER) + col_type = schema_columns.get(col_name) + if col_type is None: + raise DataChainColumnError(col, f"Column {col} not found in schema") + partition_by_columns.append(Column(col_name, python_to_sql(col_type))) + schema_fields[col_name] = col_type # validate signal columns and add them to the schema signal_columns: list[Column] = [] diff --git a/src/datachain/lib/func.py b/src/datachain/lib/func.py index da5014388..57f72e75f 100644 --- a/src/datachain/lib/func.py +++ b/src/datachain/lib/func.py @@ -45,6 +45,6 @@ def max(col: str) -> Func: def concat(col: str, separator="") -> Func: def inner(arg): - return sa_func.aggregate_strings(arg, separator) + return dc_func.array.group_concat(arg, separator) return Func(inner=inner, col=col, result_type=str) diff --git a/src/datachain/sql/functions/array.py b/src/datachain/sql/functions/array.py index 9da84d87e..15378b71b 100644 --- a/src/datachain/sql/functions/array.py +++ b/src/datachain/sql/functions/array.py @@ -1,6 +1,6 @@ from sqlalchemy.sql.functions import GenericFunction -from datachain.sql.types import Float, Int64 +from datachain.sql.types import Float, Int64, String from datachain.sql.utils import compiler_not_implemented @@ -51,8 +51,16 @@ class avg(GenericFunction): # noqa: N801 inherit_cache = True +class group_concat(GenericFunction): # noqa: N801 + type = String() + package = "array" + name = "group_concat" + inherit_cache = True + + compiler_not_implemented(cosine_distance) compiler_not_implemented(euclidean_distance) compiler_not_implemented(length) compiler_not_implemented(sip_hash_64) compiler_not_implemented(avg) +compiler_not_implemented(group_concat) diff --git a/src/datachain/sql/sqlite/base.py b/src/datachain/sql/sqlite/base.py index 35cf25724..001198b6a 100644 --- a/src/datachain/sql/sqlite/base.py +++ b/src/datachain/sql/sqlite/base.py @@ -85,6 +85,7 @@ def setup(): compiles(Values, "sqlite")(compile_values) compiles(random.rand, "sqlite")(compile_rand) compiles(array.avg, "sqlite")(compile_avg) + compiles(array.group_concat, "sqlite")(compile_group_concat) if load_usearch_extension(sqlite3.connect(":memory:")): compiles(array.cosine_distance, "sqlite")(compile_cosine_distance_ext) @@ -400,6 +401,10 @@ def compile_avg(element, compiler, **kwargs): return compiler.process(func.avg(*element.clauses.clauses), **kwargs) +def compile_group_concat(element, compiler, **kwargs): + return compiler.process(func.aggregate_strings(*element.clauses.clauses), **kwargs) + + def load_usearch_extension(conn) -> bool: try: # usearch is part of the vector optional dependencies From a00a1e81336dd0361b28a867dfa9894aebb7b675 Mon Sep 17 00:00:00 2001 From: Vladimir Rudnyh Date: Tue, 1 Oct 2024 22:59:45 +0700 Subject: [PATCH 4/5] Use 'sql_to_python' for GenericFunction type conversion --- src/datachain/lib/convert/sql_to_python.py | 20 +++++------- src/datachain/lib/dc.py | 13 +++++--- src/datachain/lib/signal_schema.py | 8 ++--- tests/unit/lib/test_sql_to_python.py | 37 ++++++++++------------ 4 files changed, 37 insertions(+), 41 deletions(-) diff --git a/src/datachain/lib/convert/sql_to_python.py b/src/datachain/lib/convert/sql_to_python.py index 8003a9101..bcce9809b 100644 --- a/src/datachain/lib/convert/sql_to_python.py +++ b/src/datachain/lib/convert/sql_to_python.py @@ -4,15 +4,11 @@ from sqlalchemy import ColumnElement -def sql_to_python(args_map: dict[str, ColumnElement]) -> dict[str, Any]: - res = {} - for name, sql_exp in args_map.items(): - try: - type_ = sql_exp.type.python_type - if type_ == Decimal: - type_ = float - except NotImplementedError: - type_ = str - res[name] = type_ - - return res +def sql_to_python(sql_exp: ColumnElement) -> Any: + try: + type_ = sql_exp.type.python_type + if type_ == Decimal: + type_ = float + except NotImplementedError: + type_ = str + return type_ diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index e361bc2bd..3d539dd71 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -24,6 +24,7 @@ from sqlalchemy.sql.sqltypes import NullType from datachain.lib.convert.python_to_sql import python_to_sql +from datachain.lib.convert.sql_to_python import sql_to_python from datachain.lib.convert.values_to_tuples import values_to_tuples from datachain.lib.data_model import DataModel, DataType, dict_to_data_model from datachain.lib.dataset_info import DatasetInfo @@ -1039,7 +1040,9 @@ def group_by( ``` """ partition_by = ( - partition_by if isinstance(partition_by, (list, tuple)) else [partition_by] + [partition_by] + if isinstance(partition_by, (str, GenericFunction)) + else partition_by ) if not partition_by: raise ValueError("At least one column should be provided for partition_by") @@ -1054,14 +1057,14 @@ def group_by( ) schema_columns = self.signals_schema.db_columns_types() - schema_fields = {} + schema_fields: dict[str, DataType] = {} # validate partition_by columns and add them to the schema partition_by_columns: list[Union[Column, GenericFunction]] = [] for col in partition_by: if isinstance(col, GenericFunction): partition_by_columns.append(col) - schema_fields[col.name] = col.type + schema_fields[col.name] = sql_to_python(col) else: col_name = col.replace(".", DEFAULT_DELIMITER) col_type = schema_columns.get(col_name) @@ -1084,8 +1087,8 @@ def group_by( ) if result_type is None: result_type = col_type - col = Column(func.col, python_to_sql(col_type)) - signal_columns.append(func.inner(col).label(col_name)) + func_col = Column(func.col, python_to_sql(col_type)) + signal_columns.append(func.inner(func_col).label(col_name)) if result_type is None: raise DataChainColumnError( diff --git a/src/datachain/lib/signal_schema.py b/src/datachain/lib/signal_schema.py index 9e25de074..3b21cd51b 100644 --- a/src/datachain/lib/signal_schema.py +++ b/src/datachain/lib/signal_schema.py @@ -400,7 +400,7 @@ def _set_file_stream( if ModelStore.is_pydantic(finfo.annotation): SignalSchema._set_file_stream(getattr(obj, field), catalog, cache) - def db_columns_types(self) -> dict[str, type]: + def db_columns_types(self) -> dict[str, DataType]: return { DEFAULT_DELIMITER.join(path): _type for path, _type, has_subtree, _ in self.get_flat_tree() @@ -497,7 +497,7 @@ def mutate(self, args_map: dict) -> "SignalSchema": new_values[name] = args_map[name] else: # adding new signal - new_values.update(sql_to_python({name: value})) + new_values[name] = sql_to_python(value) return SignalSchema(new_values) @@ -541,12 +541,12 @@ def _build_tree( for name, val in values.items() } - def get_flat_tree(self) -> Iterator[tuple[list[str], type, bool, int]]: + def get_flat_tree(self) -> Iterator[tuple[list[str], DataType, bool, int]]: yield from self._get_flat_tree(self.tree, [], 0) def _get_flat_tree( self, tree: dict, prefix: list[str], depth: int - ) -> Iterator[tuple[list[str], type, bool, int]]: + ) -> Iterator[tuple[list[str], DataType, bool, int]]: for name, (type_, substree) in tree.items(): suffix = name.split(".") new_prefix = prefix + suffix diff --git a/tests/unit/lib/test_sql_to_python.py b/tests/unit/lib/test_sql_to_python.py index ea11210c9..85c973ac9 100644 --- a/tests/unit/lib/test_sql_to_python.py +++ b/tests/unit/lib/test_sql_to_python.py @@ -1,3 +1,4 @@ +import pytest from sqlalchemy.sql.sqltypes import NullType from datachain import Column @@ -6,23 +7,19 @@ from datachain.sql.types import Float, Int64, String -def test_sql_columns_to_python_types(): - assert sql_to_python( - { - "name": Column("name", String), - "age": Column("age", Int64), - "score": Column("score", Float), - } - ) == {"name": str, "age": int, "score": float} - - -def test_sql_expression_to_python_types(): - assert sql_to_python({"age": Column("age", Int64) - 2}) == {"age": int} - - -def test_sql_function_to_python_types(): - assert sql_to_python({"age": func.avg(Column("age", Int64))}) == {"age": float} - - -def test_sql_to_python_types_default_type(): - assert sql_to_python({"null": Column("null", NullType)}) == {"null": str} +@pytest.mark.parametrize( + "sql_column, expected", + [ + (Column("name", String), str), + (Column("age", Int64), int), + (Column("score", Float), float), + # SQL expression + (Column("age", Int64) - 2, int), + # SQL function + (func.avg(Column("age", Int64)), float), + # Default type + (Column("null", NullType), str), + ], +) +def test_sql_columns_to_python_types(sql_column, expected): + assert sql_to_python(sql_column) == expected From 56999e83cd14c0253bc8d28f17991a0cdf8bae28 Mon Sep 17 00:00:00 2001 From: Vladimir Rudnyh Date: Tue, 15 Oct 2024 04:09:24 +0700 Subject: [PATCH 5/5] Add missing aggregate functions --- src/datachain/lib/dc.py | 87 ++---- src/datachain/lib/func.py | 50 ---- src/datachain/lib/func/__init__.py | 14 + src/datachain/lib/func/aggregate.py | 42 +++ src/datachain/lib/func/func.py | 64 ++++ src/datachain/lib/signal_schema.py | 11 +- src/datachain/lib/utils.py | 5 + src/datachain/sql/functions/__init__.py | 2 +- src/datachain/sql/functions/aggregate.py | 47 +++ src/datachain/sql/functions/array.py | 18 +- src/datachain/sql/sqlite/base.py | 19 +- tests/func/test_datachain.py | 280 ++---------------- tests/unit/lib/test_datachain.py | 355 ++++++++++++++++++++++- tests/utils.py | 20 ++ 14 files changed, 602 insertions(+), 412 deletions(-) delete mode 100644 src/datachain/lib/func.py create mode 100644 src/datachain/lib/func/__init__.py create mode 100644 src/datachain/lib/func/aggregate.py create mode 100644 src/datachain/lib/func/func.py create mode 100644 src/datachain/sql/functions/aggregate.py diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index 3d539dd71..8fe032873 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -24,7 +24,6 @@ from sqlalchemy.sql.sqltypes import NullType from datachain.lib.convert.python_to_sql import python_to_sql -from datachain.lib.convert.sql_to_python import sql_to_python from datachain.lib.convert.values_to_tuples import values_to_tuples from datachain.lib.data_model import DataModel, DataType, dict_to_data_model from datachain.lib.dataset_info import DatasetInfo @@ -44,21 +43,12 @@ from datachain.lib.model_store import ModelStore from datachain.lib.settings import Settings from datachain.lib.signal_schema import SignalSchema -from datachain.lib.udf import ( - Aggregator, - BatchMapper, - Generator, - Mapper, - UDFBase, -) +from datachain.lib.udf import Aggregator, BatchMapper, Generator, Mapper, UDFBase from datachain.lib.udf_signature import UdfSignature -from datachain.lib.utils import DataChainParamsError +from datachain.lib.utils import DataChainColumnError, DataChainParamsError from datachain.query import Session -from datachain.query.dataset import ( - DatasetQuery, - PartitionByType, -) -from datachain.query.schema import DEFAULT_DELIMITER, Column +from datachain.query.dataset import DatasetQuery, PartitionByType +from datachain.query.schema import DEFAULT_DELIMITER, Column, ColumnMeta from datachain.sql.functions import path as pathfunc from datachain.telemetry import telemetry from datachain.utils import batched_it, inside_notebook @@ -151,11 +141,6 @@ def _get_str(on: Sequence[Union[str, sqlalchemy.ColumnElement]]) -> str: super().__init__(f"Merge error on='{on_str}'{right_on_str}: {msg}") -class DataChainColumnError(DataChainParamsError): # noqa: D101 - def __init__(self, col_name, msg): # noqa: D107 - super().__init__(f"Error for column {col_name}: {msg}") - - OutputType = Union[None, DataType, Sequence[str], dict[str, DataType]] @@ -1015,35 +1000,22 @@ def select_except(self, *args: str) -> "Self": def group_by( self, *, - partition_by: Union[ - Union[str, GenericFunction], Sequence[Union[str, GenericFunction]] - ], + partition_by: Union[str, Sequence[str]], **kwargs: Func, ) -> "Self": """Group rows by specified set of signals and return new signals with aggregated values. Example: - Using column name(s) in partition_by: ```py chain = chain.group_by( cnt=func.count(), partition_by=("file_source", "file_ext"), ) - - Using GenericFunction in partition_by: - ```py - chain = chain.group_by( - total_size=func.sum("file.size"), - partition_by=func.file_ext(C("file.path")), - ) ``` """ - partition_by = ( - [partition_by] - if isinstance(partition_by, (str, GenericFunction)) - else partition_by - ) + if isinstance(partition_by, str): + partition_by = [partition_by] if not partition_by: raise ValueError("At least one column should be provided for partition_by") @@ -1056,46 +1028,23 @@ def group_by( f"Column {col_name} has type {type(func)} but expected Func object", ) - schema_columns = self.signals_schema.db_columns_types() + partition_by_columns: list[Column] = [] + signal_columns: list[Column] = [] schema_fields: dict[str, DataType] = {} # validate partition_by columns and add them to the schema - partition_by_columns: list[Union[Column, GenericFunction]] = [] - for col in partition_by: - if isinstance(col, GenericFunction): - partition_by_columns.append(col) - schema_fields[col.name] = sql_to_python(col) - else: - col_name = col.replace(".", DEFAULT_DELIMITER) - col_type = schema_columns.get(col_name) - if col_type is None: - raise DataChainColumnError(col, f"Column {col} not found in schema") - partition_by_columns.append(Column(col_name, python_to_sql(col_type))) - schema_fields[col_name] = col_type + for col_name in partition_by: + col_db_name = ColumnMeta.to_db_name(col_name) + col_type = self.signals_schema.get_column_type(col_db_name) + col = Column(col_db_name, python_to_sql(col_type)) + partition_by_columns.append(col) + schema_fields[col_db_name] = col_type # validate signal columns and add them to the schema - signal_columns: list[Column] = [] for col_name, func in kwargs.items(): - result_type = func.result_type - if func.col is None: - signal_columns.append(func.inner().label(col_name)) - else: - col_type = schema_columns.get(func.col) - if col_type is None: - raise DataChainColumnError( - func.col, f"Column {func.col} not found in schema" - ) - if result_type is None: - result_type = col_type - func_col = Column(func.col, python_to_sql(col_type)) - signal_columns.append(func.inner(func_col).label(col_name)) - - if result_type is None: - raise DataChainColumnError( - col_name, f"Cannot infer type for function {func}" - ) - - schema_fields[col_name] = result_type + col = func.get_column(self.signals_schema, label=col_name) + signal_columns.append(col) + schema_fields[col_name] = func.get_result_type(self.signals_schema) return self._evolve( query=self._query.group_by(signal_columns, partition_by_columns), diff --git a/src/datachain/lib/func.py b/src/datachain/lib/func.py deleted file mode 100644 index 57f72e75f..000000000 --- a/src/datachain/lib/func.py +++ /dev/null @@ -1,50 +0,0 @@ -from typing import TYPE_CHECKING, Callable, Optional - -from sqlalchemy import func as sa_func - -from datachain.query.schema import DEFAULT_DELIMITER -from datachain.sql import functions as dc_func - -if TYPE_CHECKING: - from datachain import DataType - - -class Func: - result_type: Optional["DataType"] = None - - def __init__( - self, - inner: Callable, - col: Optional[str] = None, - result_type: Optional["DataType"] = None, - ) -> None: - self.inner = inner - self.col = col.replace(".", DEFAULT_DELIMITER) if col else None - self.result_type = result_type - - -def count(col: Optional[str] = None) -> Func: - return Func(inner=sa_func.count, col=col, result_type=int) - - -def sum(col: str) -> Func: - return Func(inner=sa_func.sum, col=col) - - -def avg(col: str) -> Func: - return Func(inner=dc_func.avg, col=col) - - -def min(col: str) -> Func: - return Func(inner=sa_func.min, col=col) - - -def max(col: str) -> Func: - return Func(inner=sa_func.max, col=col) - - -def concat(col: str, separator="") -> Func: - def inner(arg): - return dc_func.array.group_concat(arg, separator) - - return Func(inner=inner, col=col, result_type=str) diff --git a/src/datachain/lib/func/__init__.py b/src/datachain/lib/func/__init__.py new file mode 100644 index 000000000..5b4c5524a --- /dev/null +++ b/src/datachain/lib/func/__init__.py @@ -0,0 +1,14 @@ +from .aggregate import any_value, avg, collect, concat, count, max, min, sum +from .func import Func + +__all__ = [ + "Func", + "any_value", + "avg", + "collect", + "concat", + "count", + "max", + "min", + "sum", +] diff --git a/src/datachain/lib/func/aggregate.py b/src/datachain/lib/func/aggregate.py new file mode 100644 index 000000000..cfe04beb6 --- /dev/null +++ b/src/datachain/lib/func/aggregate.py @@ -0,0 +1,42 @@ +from typing import Optional + +from sqlalchemy import func as sa_func + +from datachain.sql import functions as dc_func + +from .func import Func + + +def count(col: Optional[str] = None) -> Func: + return Func(inner=sa_func.count, col=col, result_type=int) + + +def sum(col: str) -> Func: + return Func(inner=sa_func.sum, col=col) + + +def avg(col: str) -> Func: + return Func(inner=dc_func.aggregate.avg, col=col) + + +def min(col: str) -> Func: + return Func(inner=sa_func.min, col=col) + + +def max(col: str) -> Func: + return Func(inner=sa_func.max, col=col) + + +def any_value(col: str) -> Func: + return Func(inner=dc_func.aggregate.any_value, col=col) + + +def collect(col: str) -> Func: + return Func(inner=dc_func.aggregate.collect, col=col, is_array=True) + + +def concat(col: str, separator="") -> Func: + def inner(arg): + return dc_func.aggregate.group_concat(arg, separator) + + return Func(inner=inner, col=col, result_type=str) diff --git a/src/datachain/lib/func/func.py b/src/datachain/lib/func/func.py new file mode 100644 index 000000000..ef4f3781e --- /dev/null +++ b/src/datachain/lib/func/func.py @@ -0,0 +1,64 @@ +from typing import TYPE_CHECKING, Callable, Optional + +from datachain.lib.convert.python_to_sql import python_to_sql +from datachain.lib.utils import DataChainColumnError +from datachain.query.schema import Column, ColumnMeta + +if TYPE_CHECKING: + from datachain import DataType + from datachain.lib.signal_schema import SignalSchema + + +class Func: + def __init__( + self, + inner: Callable, + col: Optional[str] = None, + result_type: Optional["DataType"] = None, + is_array: bool = False, + ) -> None: + self.inner = inner + self.col = col + self.result_type = result_type + self.is_array = is_array + + @property + def db_col(self) -> Optional[str]: + return ColumnMeta.to_db_name(self.col) if self.col else None + + def db_col_type(self, signals_schema: "SignalSchema") -> Optional["DataType"]: + if not self.db_col: + return None + col_type: type = signals_schema.get_column_type(self.db_col) + return list[col_type] if self.is_array else col_type # type: ignore[valid-type] + + def get_result_type(self, signals_schema: "SignalSchema") -> "DataType": + col_type = self.db_col_type(signals_schema) + + if self.result_type: + return self.result_type + + if col_type: + return col_type + + raise DataChainColumnError( + str(self.inner), + "Column name is required to infer result type", + ) + + def get_column( + self, signals_schema: "SignalSchema", label: Optional[str] = None + ) -> Column: + if self.col: + if label == "collect": + print(label) + col_type = self.get_result_type(signals_schema) + col = Column(self.db_col, python_to_sql(col_type)) + func_col = self.inner(col) + else: + func_col = self.inner() + + if label: + func_col = func_col.label(label) + + return func_col diff --git a/src/datachain/lib/signal_schema.py b/src/datachain/lib/signal_schema.py index 3b21cd51b..5cdb1fe5c 100644 --- a/src/datachain/lib/signal_schema.py +++ b/src/datachain/lib/signal_schema.py @@ -400,12 +400,11 @@ def _set_file_stream( if ModelStore.is_pydantic(finfo.annotation): SignalSchema._set_file_stream(getattr(obj, field), catalog, cache) - def db_columns_types(self) -> dict[str, DataType]: - return { - DEFAULT_DELIMITER.join(path): _type - for path, _type, has_subtree, _ in self.get_flat_tree() - if not has_subtree - } + def get_column_type(self, col_name: str) -> DataType: + for path, _type, has_subtree, _ in self.get_flat_tree(): + if not has_subtree and DEFAULT_DELIMITER.join(path) == col_name: + return _type + raise SignalResolvingError([col_name], "is not found") def db_signals( self, name: Optional[str] = None, as_columns=False diff --git a/src/datachain/lib/utils.py b/src/datachain/lib/utils.py index 5b653265d..cd11da9cd 100644 --- a/src/datachain/lib/utils.py +++ b/src/datachain/lib/utils.py @@ -23,3 +23,8 @@ def __init__(self, message): class DataChainParamsError(DataChainError): def __init__(self, message): super().__init__(message) + + +class DataChainColumnError(DataChainParamsError): + def __init__(self, col_name, msg): + super().__init__(f"Error for column {col_name}: {msg}") diff --git a/src/datachain/sql/functions/__init__.py b/src/datachain/sql/functions/__init__.py index 3634e51b8..c8d4ef0de 100644 --- a/src/datachain/sql/functions/__init__.py +++ b/src/datachain/sql/functions/__init__.py @@ -1,7 +1,7 @@ from sqlalchemy.sql.expression import func from . import array, path, string -from .array import avg +from .aggregate import avg from .conditional import greatest, least from .random import rand diff --git a/src/datachain/sql/functions/aggregate.py b/src/datachain/sql/functions/aggregate.py new file mode 100644 index 000000000..dab916a42 --- /dev/null +++ b/src/datachain/sql/functions/aggregate.py @@ -0,0 +1,47 @@ +from sqlalchemy.sql.functions import GenericFunction, ReturnTypeFromArgs + +from datachain.sql.types import Float, String +from datachain.sql.utils import compiler_not_implemented + + +class avg(GenericFunction): # noqa: N801 + """ + Returns the average of the column. + """ + + type = Float() + package = "array" + name = "avg" + inherit_cache = True + + +class group_concat(GenericFunction): # noqa: N801 + """ + Returns the concatenated string of the column. + """ + + type = String() + package = "array" + name = "group_concat" + inherit_cache = True + + +class any_value(ReturnTypeFromArgs): # noqa: N801 + """ + Returns first value of the column. + """ + + inherit_cache = True + + +class collect(ReturnTypeFromArgs): # noqa: N801 + """ + Returns an array of the column. + """ + + inherit_cache = True + + +compiler_not_implemented(avg) +compiler_not_implemented(group_concat) +compiler_not_implemented(any_value) diff --git a/src/datachain/sql/functions/array.py b/src/datachain/sql/functions/array.py index 15378b71b..567162fe6 100644 --- a/src/datachain/sql/functions/array.py +++ b/src/datachain/sql/functions/array.py @@ -1,6 +1,6 @@ from sqlalchemy.sql.functions import GenericFunction -from datachain.sql.types import Float, Int64, String +from datachain.sql.types import Float, Int64 from datachain.sql.utils import compiler_not_implemented @@ -44,23 +44,7 @@ class sip_hash_64(GenericFunction): # noqa: N801 inherit_cache = True -class avg(GenericFunction): # noqa: N801 - type = Float() - package = "array" - name = "avg" - inherit_cache = True - - -class group_concat(GenericFunction): # noqa: N801 - type = String() - package = "array" - name = "group_concat" - inherit_cache = True - - compiler_not_implemented(cosine_distance) compiler_not_implemented(euclidean_distance) compiler_not_implemented(length) compiler_not_implemented(sip_hash_64) -compiler_not_implemented(avg) -compiler_not_implemented(group_concat) diff --git a/src/datachain/sql/sqlite/base.py b/src/datachain/sql/sqlite/base.py index 001198b6a..d99e49de1 100644 --- a/src/datachain/sql/sqlite/base.py +++ b/src/datachain/sql/sqlite/base.py @@ -14,7 +14,7 @@ from sqlalchemy.sql.expression import case from sqlalchemy.sql.functions import func -from datachain.sql.functions import array, conditional, random, string +from datachain.sql.functions import aggregate, array, conditional, random, string from datachain.sql.functions import path as sql_path from datachain.sql.selectable import Values, base_values_compiler from datachain.sql.sqlite.types import ( @@ -84,8 +84,10 @@ def setup(): compiles(conditional.least, "sqlite")(compile_least) compiles(Values, "sqlite")(compile_values) compiles(random.rand, "sqlite")(compile_rand) - compiles(array.avg, "sqlite")(compile_avg) - compiles(array.group_concat, "sqlite")(compile_group_concat) + compiles(aggregate.avg, "sqlite")(compile_avg) + compiles(aggregate.group_concat, "sqlite")(compile_group_concat) + compiles(aggregate.any_value, "sqlite")(compile_any_value) + compiles(aggregate.collect, "sqlite")(compile_collect) if load_usearch_extension(sqlite3.connect(":memory:")): compiles(array.cosine_distance, "sqlite")(compile_cosine_distance_ext) @@ -405,6 +407,17 @@ def compile_group_concat(element, compiler, **kwargs): return compiler.process(func.aggregate_strings(*element.clauses.clauses), **kwargs) +def compile_any_value(element, compiler, **kwargs): + # use bare column to return any value from the group, + # this is documented behavior for sqlite, + # see https://www.sqlite.org/lang_select.html#bare_columns_in_an_aggregate_query + return compiler.process(*element.clauses.clauses, **kwargs) + + +def compile_collect(element, compiler, **kwargs): + return compiler.process(func.json_group_array(*element.clauses.clauses), **kwargs) + + def load_usearch_extension(conn) -> bool: try: # usearch is part of the vector optional dependencies diff --git a/tests/func/test_datachain.py b/tests/func/test_datachain.py index 8fcb513da..b256ba296 100644 --- a/tests/func/test_datachain.py +++ b/tests/func/test_datachain.py @@ -20,21 +20,23 @@ from datachain.client.local import FileClient from datachain.data_storage.sqlite import SQLiteWarehouse from datachain.dataset import DatasetDependencyType, DatasetStats -from datachain.lib import func -from datachain.lib.dc import C, DataChain, DataChainColumnError +from datachain.lib.dc import C, DataChain from datachain.lib.file import File, ImageFile -from datachain.lib.listing import ( - LISTING_TTL, - is_listing_dataset, - parse_listing_uri, -) +from datachain.lib.listing import LISTING_TTL, is_listing_dataset, parse_listing_uri from datachain.lib.tar import process_tar from datachain.lib.udf import Mapper -from datachain.lib.utils import DataChainError +from datachain.lib.utils import DataChainColumnError, DataChainError from datachain.query.dataset import QueryStep from datachain.sql.functions import path as pathfunc from datachain.sql.functions.array import cosine_distance, euclidean_distance -from tests.utils import NUM_TREE, TARRED_TREE, images_equal, text_embedding +from tests.utils import ( + ANY_VALUE, + NUM_TREE, + TARRED_TREE, + images_equal, + sorted_dicts, + text_embedding, +) def _get_listing_datasets(session): @@ -49,10 +51,6 @@ def _get_listing_datasets(session): ) -def _sorted_records(list, *signals): - return sorted(list, key=lambda x: tuple(x[s] for s in signals)) - - @pytest.mark.parametrize("anon", [True, False]) def test_catalog_anon(tmp_dir, catalog, anon): chain = DataChain.from_storage(tmp_dir.as_uri(), anon=anon) @@ -1321,253 +1319,11 @@ def test_datachain_save_with_job(test_session, catalog, datachain_job_id): assert result_job_id == datachain_job_id -def test_group_by_int(test_session): - ds = ( - DataChain.from_values( - col1=["a", "a", "b", "b", "b", "c"], - col2=[1, 2, 3, 4, 5, 6], - session=test_session, - ) - .group_by( - cnt=func.count(), - cnt_col=func.count("col2"), - sum=func.sum("col2"), - avg=func.avg("col2"), - min=func.min("col2"), - max=func.max("col2"), - partition_by="col1", - ) - .save("my-ds") - ) - - assert ds.signals_schema.serialize() == { - "col1": "str", - "cnt": "int", - "cnt_col": "int", - "sum": "int", - "avg": "int", - "min": "int", - "max": "int", - } - assert _sorted_records(ds.to_records(), "col1") == _sorted_records( - [ - { - "col1": "a", - "cnt": 2, - "cnt_col": 2, - "sum": 3, - "avg": 1.5, - "min": 1, - "max": 2, - }, - { - "col1": "b", - "cnt": 3, - "cnt_col": 3, - "sum": 12, - "avg": 4.0, - "min": 3, - "max": 5, - }, - { - "col1": "c", - "cnt": 1, - "cnt_col": 1, - "sum": 6, - "avg": 6.0, - "min": 6, - "max": 6, - }, - ], - "col1", - ) - - -def test_group_by_float(test_session): - ds = ( - DataChain.from_values( - col1=["a", "a", "b", "b", "b", "c"], - col2=[1.5, 2.5, 3.5, 4.5, 5.5, 6.5], - session=test_session, - ) - .group_by( - cnt=func.count(), - cnt_col=func.count("col2"), - sum=func.sum("col2"), - avg=func.avg("col2"), - min=func.min("col2"), - max=func.max("col2"), - partition_by="col1", - ) - .save("my-ds") - ) - - assert ds.signals_schema.serialize() == { - "col1": "str", - "cnt": "int", - "cnt_col": "int", - "sum": "float", - "avg": "float", - "min": "float", - "max": "float", - } - assert _sorted_records(ds.to_records(), "col1") == _sorted_records( - [ - { - "col1": "a", - "cnt": 2, - "cnt_col": 2, - "sum": 4.0, - "avg": 2.0, - "min": 1.5, - "max": 2.5, - }, - { - "col1": "b", - "cnt": 3, - "cnt_col": 3, - "sum": 13.5, - "avg": 4.5, - "min": 3.5, - "max": 5.5, - }, - { - "col1": "c", - "cnt": 1, - "cnt_col": 1, - "sum": 6.5, - "avg": 6.5, - "min": 6.5, - "max": 6.5, - }, - ], - "col1", - ) - - -def test_group_by_str(test_session): - ds = ( - DataChain.from_values( - col1=["a", "a", "b", "b", "b", "c"], - col2=["1", "2", "3", "4", "5", "6"], - session=test_session, - ) - .group_by( - cnt=func.count(), - cnt_col=func.count("col2"), - concat=func.concat("col2"), - concat_sep=func.concat("col2", separator=","), - partition_by="col1", - ) - .save("my-ds") - ) - - assert ds.signals_schema.serialize() == { - "col1": "str", - "cnt": "int", - "cnt_col": "int", - "concat": "str", - "concat_sep": "str", - } - assert _sorted_records(ds.to_records(), "col1") == _sorted_records( - [ - {"col1": "a", "cnt": 2, "cnt_col": 2, "concat": "12", "concat_sep": "1,2"}, - { - "col1": "b", - "cnt": 3, - "cnt_col": 3, - "concat": "345", - "concat_sep": "3,4,5", - }, - {"col1": "c", "cnt": 1, "cnt_col": 1, "concat": "6", "concat_sep": "6"}, - ], - "col1", - ) - - -def test_group_by_multiple_partition_by(test_session): - ds = ( - DataChain.from_values( - col1=["a", "a", "b", "b", "b", "c"], - col2=[1, 2, 1, 2, 1, 2], - col3=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], - col4=["1", "2", "3", "4", "5", "6"], - session=test_session, - ) - .group_by( - cnt=func.count(), - cnt_col=func.count("col2"), - sum=func.sum("col3"), - concat=func.concat("col4"), - partition_by=("col1", "col2"), - ) - .save("my-ds") - ) - - assert ds.signals_schema.serialize() == { - "col1": "str", - "col2": "int", - "cnt": "int", - "cnt_col": "int", - "sum": "float", - "concat": "str", - } - assert _sorted_records(ds.to_records(), "col1", "col2") == _sorted_records( - [ - {"col1": "a", "col2": 1, "cnt": 1, "cnt_col": 1, "sum": 1.0, "concat": "1"}, - {"col1": "a", "col2": 2, "cnt": 1, "cnt_col": 1, "sum": 2.0, "concat": "2"}, - { - "col1": "b", - "col2": 1, - "cnt": 2, - "cnt_col": 2, - "sum": 8.0, - "concat": "35", - }, - {"col1": "b", "col2": 2, "cnt": 1, "cnt_col": 1, "sum": 4.0, "concat": "4"}, - {"col1": "c", "col2": 2, "cnt": 1, "cnt_col": 1, "sum": 6.0, "concat": "6"}, - ], - "col1", - "col2", - ) - - -def test_group_by_error(test_session): - dc = DataChain.from_values( - col1=["a", "a", "b", "b", "b", "c"], - col2=[1, 2, 3, 4, 5, 6], - session=test_session, - ) - - with pytest.raises(TypeError): - dc.group_by(cnt=func.count()) - - with pytest.raises( - ValueError, match="At least one column should be provided for partition_by" - ): - dc.group_by(cnt=func.count(), partition_by=()) - - with pytest.raises( - ValueError, match="At least one column should be provided for group_by" - ): - dc.group_by(partition_by="col1") - - with pytest.raises( - DataChainColumnError, - match="Column foo has type but expected Func object", - ): - dc.group_by(foo="col2", partition_by="col1") - - with pytest.raises(DataChainColumnError, match="Column col3 not found in schema"): - dc.group_by(foo=func.sum("col3"), partition_by="col1") - - with pytest.raises(DataChainColumnError, match="Column col3 not found in schema"): - dc.group_by(foo=func.sum("col2"), partition_by="col3") - - @pytest.mark.parametrize("partition_by", ["file_info.path", "file_info__path"]) @pytest.mark.parametrize("signal_name", ["file.size", "file__size"]) def test_group_by_signals(cloud_test_catalog, partition_by, signal_name): + from datachain import func + session = cloud_test_catalog.session src_uri = cloud_test_catalog.src_uri @@ -1590,6 +1346,7 @@ def file_info(file: File) -> DataModel: .group_by( cnt=func.count(), sum=func.sum(signal_name), + value=func.any_value(signal_name), partition_by=partition_by, ) .save("my-ds") @@ -1599,12 +1356,13 @@ def file_info(file: File) -> DataModel: "file_info__path": "str", "cnt": "int", "sum": "int", + "value": "int", } - assert _sorted_records(ds.to_records(), "file_info__path") == _sorted_records( + assert sorted_dicts(ds.to_records(), "file_info__path") == sorted_dicts( [ - {"file_info__path": "", "cnt": 1, "sum": 13}, - {"file_info__path": "cats", "cnt": 2, "sum": 8}, - {"file_info__path": "dogs", "cnt": 4, "sum": 15}, + {"file_info__path": "", "cnt": 1, "sum": 13, "value": ANY_VALUE(13)}, + {"file_info__path": "cats", "cnt": 2, "sum": 8, "value": ANY_VALUE(4)}, + {"file_info__path": "dogs", "cnt": 4, "sum": 15, "value": ANY_VALUE(3, 4)}, ], "file_info__path", ) diff --git a/tests/unit/lib/test_datachain.py b/tests/unit/lib/test_datachain.py index 19c971c7a..8d2ca157d 100644 --- a/tests/unit/lib/test_datachain.py +++ b/tests/unit/lib/test_datachain.py @@ -14,7 +14,7 @@ from datachain import Column from datachain.client import Client from datachain.lib.data_model import DataModel -from datachain.lib.dc import C, DataChain, DataChainColumnError, Sys +from datachain.lib.dc import C, DataChain, Sys from datachain.lib.file import File from datachain.lib.listing import LISTING_PREFIX from datachain.lib.listing_info import ListingInfo @@ -24,10 +24,9 @@ SignalSchema, ) from datachain.lib.udf_signature import UdfSignatureError -from datachain.lib.utils import DataChainParamsError -from datachain.sql import functions as func +from datachain.lib.utils import DataChainColumnError, DataChainParamsError from datachain.sql.types import Float, Int64, String -from tests.utils import skip_if_not_sqlite +from tests.utils import ANY_VALUE, skip_if_not_sqlite, sorted_dicts DF_DATA = { "first_name": ["Alice", "Bob", "Charlie", "David", "Eva"], @@ -2000,7 +1999,9 @@ def test_mutate_with_multiplication(test_session): assert ds.mutate(new=ds.column("id") * 10).signals_schema.values["new"] is int -def test_mutate_with_func(test_session): +def test_mutate_with_sql_func(test_session): + from datachain.sql import functions as func + ds = DataChain.from_values(id=[1, 2], session=test_session) assert ( ds.mutate(new=func.avg(ds.column("id"))).signals_schema.values["new"] is float @@ -2008,6 +2009,8 @@ def test_mutate_with_func(test_session): def test_mutate_with_complex_expression(test_session): + from datachain.sql import functions as func + ds = DataChain.from_values(id=[1, 2], name=["Jim", "Jon"], session=test_session) assert ( ds.mutate( @@ -2098,3 +2101,345 @@ def test_from_hf_object_name(test_session): def test_from_hf_invalid(test_session): with pytest.raises(FileNotFoundError): DataChain.from_hf("invalid_dataset", session=test_session) + + +def test_group_by_int(test_session): + from datachain import func + + ds = ( + DataChain.from_values( + col1=["a", "a", "b", "b", "b", "c"], + col2=[1, 2, 3, 4, 5, 6], + session=test_session, + ) + .group_by( + cnt=func.count(), + cnt_col=func.count("col2"), + sum=func.sum("col2"), + avg=func.avg("col2"), + min=func.min("col2"), + max=func.max("col2"), + value=func.any_value("col2"), + collect=func.collect("col2"), + partition_by="col1", + ) + .save("my-ds") + ) + + assert ds.signals_schema.serialize() == { + "col1": "str", + "cnt": "int", + "cnt_col": "int", + "sum": "int", + "avg": "int", + "min": "int", + "max": "int", + "value": "int", + "collect": "list[int]", + } + assert sorted_dicts(ds.to_records(), "col1") == sorted_dicts( + [ + { + "col1": "a", + "cnt": 2, + "cnt_col": 2, + "sum": 3, + "avg": 1.5, + "min": 1, + "max": 2, + "value": ANY_VALUE(1, 2), + "collect": [1, 2], + }, + { + "col1": "b", + "cnt": 3, + "cnt_col": 3, + "sum": 12, + "avg": 4.0, + "min": 3, + "max": 5, + "value": ANY_VALUE(3, 4, 5), + "collect": [3, 4, 5], + }, + { + "col1": "c", + "cnt": 1, + "cnt_col": 1, + "sum": 6, + "avg": 6.0, + "min": 6, + "max": 6, + "value": ANY_VALUE(6), + "collect": [6], + }, + ], + "col1", + ) + + +def test_group_by_float(test_session): + from datachain import func + + ds = ( + DataChain.from_values( + col1=["a", "a", "b", "b", "b", "c"], + col2=[1.5, 2.5, 3.5, 4.5, 5.5, 6.5], + session=test_session, + ) + .group_by( + cnt=func.count(), + cnt_col=func.count("col2"), + sum=func.sum("col2"), + avg=func.avg("col2"), + min=func.min("col2"), + max=func.max("col2"), + value=func.any_value("col2"), + collect=func.collect("col2"), + partition_by="col1", + ) + .save("my-ds") + ) + + assert ds.signals_schema.serialize() == { + "col1": "str", + "cnt": "int", + "cnt_col": "int", + "sum": "float", + "avg": "float", + "min": "float", + "max": "float", + "value": "float", + "collect": "list[float]", + } + assert sorted_dicts(ds.to_records(), "col1") == sorted_dicts( + [ + { + "col1": "a", + "cnt": 2, + "cnt_col": 2, + "sum": 4.0, + "avg": 2.0, + "min": 1.5, + "max": 2.5, + "value": ANY_VALUE(1.5, 2.5), + "collect": [1.5, 2.5], + }, + { + "col1": "b", + "cnt": 3, + "cnt_col": 3, + "sum": 13.5, + "avg": 4.5, + "min": 3.5, + "max": 5.5, + "value": ANY_VALUE(3.5, 4.5, 5.5), + "collect": [3.5, 4.5, 5.5], + }, + { + "col1": "c", + "cnt": 1, + "cnt_col": 1, + "sum": 6.5, + "avg": 6.5, + "min": 6.5, + "max": 6.5, + "value": ANY_VALUE(6.5), + "collect": [6.5], + }, + ], + "col1", + ) + + +def test_group_by_str(test_session): + from datachain import func + + ds = ( + DataChain.from_values( + col1=["a", "a", "b", "b", "b", "c"], + col2=["1", "2", "3", "4", "5", "6"], + session=test_session, + ) + .group_by( + cnt=func.count(), + cnt_col=func.count("col2"), + concat=func.concat("col2"), + concat_sep=func.concat("col2", separator=","), + value=func.any_value("col2"), + collect=func.collect("col2"), + partition_by="col1", + ) + .save("my-ds") + ) + + assert ds.signals_schema.serialize() == { + "col1": "str", + "cnt": "int", + "cnt_col": "int", + "concat": "str", + "concat_sep": "str", + "value": "str", + "collect": "list[str]", + } + assert sorted_dicts(ds.to_records(), "col1") == sorted_dicts( + [ + { + "col1": "a", + "cnt": 2, + "cnt_col": 2, + "concat": "12", + "concat_sep": "1,2", + "value": ANY_VALUE("1", "2"), + "collect": ["1", "2"], + }, + { + "col1": "b", + "cnt": 3, + "cnt_col": 3, + "concat": "345", + "concat_sep": "3,4,5", + "value": ANY_VALUE("3", "4", "5"), + "collect": ["3", "4", "5"], + }, + { + "col1": "c", + "cnt": 1, + "cnt_col": 1, + "concat": "6", + "concat_sep": "6", + "value": ANY_VALUE("6"), + "collect": ["6"], + }, + ], + "col1", + ) + + +def test_group_by_multiple_partition_by(test_session): + from datachain import func + + ds = ( + DataChain.from_values( + col1=["a", "a", "b", "b", "b", "c"], + col2=[1, 2, 1, 2, 1, 2], + col3=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + col4=["1", "2", "3", "4", "5", "6"], + session=test_session, + ) + .group_by( + cnt=func.count(), + cnt_col=func.count("col2"), + sum=func.sum("col3"), + concat=func.concat("col4"), + value=func.any_value("col3"), + collect=func.collect("col3"), + partition_by=("col1", "col2"), + ) + .save("my-ds") + ) + + assert ds.signals_schema.serialize() == { + "col1": "str", + "col2": "int", + "cnt": "int", + "cnt_col": "int", + "sum": "float", + "concat": "str", + "value": "float", + "collect": "list[float]", + } + assert sorted_dicts(ds.to_records(), "col1", "col2") == sorted_dicts( + [ + { + "col1": "a", + "col2": 1, + "cnt": 1, + "cnt_col": 1, + "sum": 1.0, + "concat": "1", + "value": ANY_VALUE(1.0), + "collect": [1.0], + }, + { + "col1": "a", + "col2": 2, + "cnt": 1, + "cnt_col": 1, + "sum": 2.0, + "concat": "2", + "value": ANY_VALUE(2.0), + "collect": [2.0], + }, + { + "col1": "b", + "col2": 1, + "cnt": 2, + "cnt_col": 2, + "sum": 8.0, + "concat": "35", + "value": ANY_VALUE(3.0, 5.0), + "collect": [3.0, 5.0], + }, + { + "col1": "b", + "col2": 2, + "cnt": 1, + "cnt_col": 1, + "sum": 4.0, + "concat": "4", + "value": ANY_VALUE(4.0), + "collect": [4.0], + }, + { + "col1": "c", + "col2": 2, + "cnt": 1, + "cnt_col": 1, + "sum": 6.0, + "concat": "6", + "value": ANY_VALUE(6.0), + "collect": [6.0], + }, + ], + "col1", + "col2", + ) + + +def test_group_by_error(test_session): + from datachain import func + + dc = DataChain.from_values( + col1=["a", "a", "b", "b", "b", "c"], + col2=[1, 2, 3, 4, 5, 6], + session=test_session, + ) + + with pytest.raises(TypeError): + dc.group_by(cnt=func.count()) + + with pytest.raises( + ValueError, match="At least one column should be provided for partition_by" + ): + dc.group_by(cnt=func.count(), partition_by=()) + + with pytest.raises( + ValueError, match="At least one column should be provided for group_by" + ): + dc.group_by(partition_by="col1") + + with pytest.raises( + DataChainColumnError, + match="Column foo has type but expected Func object", + ): + dc.group_by(foo="col2", partition_by="col1") + + with pytest.raises( + SignalResolvingError, match="cannot resolve signal name 'col3': is not found" + ): + dc.group_by(foo=func.sum("col3"), partition_by="col1") + + with pytest.raises( + SignalResolvingError, match="cannot resolve signal name 'col3': is not found" + ): + dc.group_by(foo=func.sum("col2"), partition_by="col3") diff --git a/tests/utils.py b/tests/utils.py index 7232c7c47..a1dd0ddc3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -212,3 +212,23 @@ def assert_row_names( def images_equal(img1: Image.Image, img2: Image.Image): """Checks if two image objects have exactly the same data""" return list(img1.getdata()) == list(img2.getdata()) + + +def sorted_dicts(list_of_dicts, *keys): + return sorted(list_of_dicts, key=lambda x: tuple(x[k] for k in keys)) + + +class ANY_VALUE: # noqa: N801 + """A helper object that compares equal to any value from the list.""" + + def __init__(self, *args): + self.values = args + + def __eq__(self, other) -> bool: + return other in self.values + + def __ne__(self, other) -> bool: + return other not in self.values + + def __repr__(self) -> str: + return f""