Skip to content

Commit

Permalink
Add more group_by aggregate functions
Browse files Browse the repository at this point in the history
  • Loading branch information
dreadatour committed Sep 30, 2024
1 parent 291e1c3 commit bf6d2e9
Show file tree
Hide file tree
Showing 5 changed files with 255 additions and 93 deletions.
37 changes: 19 additions & 18 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,12 @@ def group_by(
"""Groups by specified set of signals."""
if not kwargs:
raise ValueError("At least one column should be provided for group_by")
for col_name, func in kwargs.items():
if not isinstance(func, Func):
raise DataChainColumnError(
col_name,
f"Column {col_name} has type {type(func)} but expected Func object",
)

partition_by = [partition_by] if isinstance(partition_by, str) else partition_by
if not partition_by:
Expand All @@ -1040,33 +1046,28 @@ def group_by(
schema_fields[col_name] = col_type

select_columns = []
for field, func in kwargs.items():
cols = []
for col_name, func in kwargs.items():
result_type = func.result_type
for col_name in func.cols:
col_type = all_columns.get(col_name)

if func.col is None:
select_columns.append(func.inner().label(col_name))
else:
col_type = all_columns.get(func.col)
if col_type is None:
raise DataChainColumnError(
col_name, f"Column {col_name} not found in schema"
func.col, f"Column {func.col} not found in schema"
)
cols.append(Column(col_name, python_to_sql(col_type)))
if result_type is None:
result_type = col_type
elif col_type != result_type:
raise DataChainColumnError(
col_name,
(
f"Column {col_name} has type {col_type}"
f"but expected {result_type}"
),
)
col = Column(func.col, python_to_sql(col_type))
select_columns.append(func.inner(col).label(col_name))

if result_type is None:
raise ValueError(
f"Cannot infer type for function {func} with columns {func.cols}"
raise DataChainColumnError(
col_name, f"Cannot infer type for function {func}"
)

select_columns.append(func.inner(*cols).label(field))
schema_fields[field] = result_type
schema_fields[col_name] = result_type

return self._evolve(
query=self._query.group_by(select_columns, partition_by_columns),
Expand Down
35 changes: 28 additions & 7 deletions src/datachain/lib/func.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import TYPE_CHECKING, Callable, Optional

from sqlalchemy import func
from sqlalchemy import func as sa_func

from datachain.sql import functions as dc_func

if TYPE_CHECKING:
from datachain import DataType
Expand All @@ -12,17 +14,36 @@ class Func:
def __init__(
self,
inner: Callable,
cols: tuple[str, ...],
col: Optional[str] = None,
result_type: Optional["DataType"] = None,
) -> None:
self.inner = inner
self.cols = [col.replace(".", "__") for col in cols]
self.col = col.replace(".", "__") if col else None
self.result_type = result_type


def sum(*cols: str) -> Func:
return Func(inner=func.sum, cols=cols)
def count(col: Optional[str] = None) -> Func:
return Func(inner=sa_func.count, col=col, result_type=int)


def sum(col: str) -> Func:
return Func(inner=sa_func.sum, col=col)


def avg(col: str) -> Func:
return Func(inner=dc_func.avg, col=col)


def min(col: str) -> Func:
return Func(inner=sa_func.min, col=col)


def max(col: str) -> Func:
return Func(inner=sa_func.max, col=col)


def concat(col: str, separator="") -> Func:
def inner(arg):
return sa_func.aggregate_strings(arg, separator)

def count(*cols: str) -> Func:
return Func(inner=func.count, cols=cols, result_type=int)
return Func(inner=inner, col=col, result_type=str)
39 changes: 5 additions & 34 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,39 +962,21 @@ class SQLGroupBy(SQLClause):
group_by: Sequence[Union[str, ColumnElement]]

def apply_sql_clause(self, query) -> Select:
if not self.cols:
raise ValueError("No columns to select")
if not self.group_by:
raise ValueError("No columns to group by")

subquery = query.subquery()

cols = [
subquery.c[str(c)] if isinstance(c, (str, C)) else c
for c in [*self.group_by, *self.cols]
]
if not cols:
cols = subquery.c

return sqlalchemy.select(*cols).select_from(subquery).group_by(*self.group_by)


@frozen
class GroupBy(Step):
"""Group rows by a specific column."""

cols: PartitionByType

def clone(self) -> "Self":
return self.__class__(self.cols)

def apply(
self, query_generator: QueryGenerator, temp_tables: list[str]
) -> StepResult:
query = query_generator.select()
grouped_query = query.group_by(*self.cols)

def q(*columns):
return grouped_query.with_only_columns(*columns)

return step_result(q, grouped_query.selected_columns)


def _validate_columns(
left_columns: Iterable[ColumnElement], right_columns: Iterable[ColumnElement]
) -> set[str]:
Expand Down Expand Up @@ -1148,25 +1130,14 @@ def apply_steps(self) -> QueryGenerator:
query.steps = query.steps[-1:] + query.steps[:-1]

result = query.starting_step.apply()
group_by = None
self.dependencies.update(result.dependencies)

for step in query.steps:
if isinstance(step, GroupBy):
if group_by is not None:
raise TypeError("only one group_by allowed")
group_by = step
continue

result = step.apply(
result.query_generator, self.temp_table_names
) # a chain of steps linked by results
self.dependencies.update(result.dependencies)

if group_by:
result = group_by.apply(result.query_generator, self.temp_table_names)
self.dependencies.update(result.dependencies)

return result.query_generator

@staticmethod
Expand Down
203 changes: 203 additions & 0 deletions tests/func/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from datachain.client.local import FileClient
from datachain.data_storage.sqlite import SQLiteWarehouse
from datachain.dataset import DatasetDependencyType, DatasetStats
from datachain.lib import func
from datachain.lib.dc import C, DataChain, DataChainColumnError
from datachain.lib.file import File, ImageFile
from datachain.lib.listing import (
Expand Down Expand Up @@ -1313,3 +1314,205 @@ def test_datachain_save_with_job(test_session, catalog, datachain_job_id):
dataset = catalog.get_dataset("my-ds")
result_job_id = dataset.get_version(dataset.latest_version).job_id
assert result_job_id == datachain_job_id


def test_group_by_int(test_session):
ds = (
DataChain.from_values(
col1=["a", "a", "b", "b", "b", "c"],
col2=[1, 2, 3, 4, 5, 6],
session=test_session,
)
.group_by(
cnt=func.count(),
cnt_col=func.count("col2"),
sum=func.sum("col2"),
avg=func.avg("col2"),
min=func.min("col2"),
max=func.max("col2"),
partition_by="col1",
)
.save("my-ds")
)

assert ds.signals_schema.serialize() == {
"col1": "str",
"cnt": "int",
"cnt_col": "int",
"sum": "int",
"avg": "int",
"min": "int",
"max": "int",
}
assert ds.to_records() == [
{"col1": "a", "cnt": 2, "cnt_col": 2, "sum": 3, "avg": 1.5, "min": 1, "max": 2},
{
"col1": "b",
"cnt": 3,
"cnt_col": 3,
"sum": 12,
"avg": 4.0,
"min": 3,
"max": 5,
},
{"col1": "c", "cnt": 1, "cnt_col": 1, "sum": 6, "avg": 6.0, "min": 6, "max": 6},
]


def test_group_by_float(test_session):
ds = (
DataChain.from_values(
col1=["a", "a", "b", "b", "b", "c"],
col2=[1.5, 2.5, 3.5, 4.5, 5.5, 6.5],
session=test_session,
)
.group_by(
cnt=func.count(),
cnt_col=func.count("col2"),
sum=func.sum("col2"),
avg=func.avg("col2"),
min=func.min("col2"),
max=func.max("col2"),
partition_by="col1",
)
.save("my-ds")
)

assert ds.signals_schema.serialize() == {
"col1": "str",
"cnt": "int",
"cnt_col": "int",
"sum": "float",
"avg": "float",
"min": "float",
"max": "float",
}
assert ds.to_records() == [
{
"col1": "a",
"cnt": 2,
"cnt_col": 2,
"sum": 4.0,
"avg": 2.0,
"min": 1.5,
"max": 2.5,
},
{
"col1": "b",
"cnt": 3,
"cnt_col": 3,
"sum": 13.5,
"avg": 4.5,
"min": 3.5,
"max": 5.5,
},
{
"col1": "c",
"cnt": 1,
"cnt_col": 1,
"sum": 6.5,
"avg": 6.5,
"min": 6.5,
"max": 6.5,
},
]


def test_group_by_str(test_session):
ds = (
DataChain.from_values(
col1=["a", "a", "b", "b", "b", "c"],
col2=["1", "2", "3", "4", "5", "6"],
session=test_session,
)
.group_by(
cnt=func.count(),
cnt_col=func.count("col2"),
concat=func.concat("col2"),
concat_sep=func.concat("col2", separator=","),
partition_by="col1",
)
.save("my-ds")
)

assert ds.signals_schema.serialize() == {
"col1": "str",
"cnt": "int",
"cnt_col": "int",
"concat": "str",
"concat_sep": "str",
}
assert ds.to_records() == [
{"col1": "a", "cnt": 2, "cnt_col": 2, "concat": "12", "concat_sep": "1,2"},
{"col1": "b", "cnt": 3, "cnt_col": 3, "concat": "345", "concat_sep": "3,4,5"},
{"col1": "c", "cnt": 1, "cnt_col": 1, "concat": "6", "concat_sep": "6"},
]


def test_group_by_multiple_partition_by(test_session):
ds = (
DataChain.from_values(
col1=["a", "a", "b", "b", "b", "c"],
col2=[1, 2, 1, 2, 1, 2],
col3=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
col4=["1", "2", "3", "4", "5", "6"],
session=test_session,
)
.group_by(
cnt=func.count(),
cnt_col=func.count("col2"),
sum=func.sum("col3"),
concat=func.concat("col4"),
partition_by=("col1", "col2"),
)
.save("my-ds")
)

assert ds.signals_schema.serialize() == {
"col1": "str",
"col2": "int",
"cnt": "int",
"cnt_col": "int",
"sum": "float",
"concat": "str",
}
assert ds.to_records() == [
{"col1": "a", "col2": 1, "cnt": 1, "cnt_col": 1, "sum": 1.0, "concat": "1"},
{"col1": "a", "col2": 2, "cnt": 1, "cnt_col": 1, "sum": 2.0, "concat": "2"},
{"col1": "b", "col2": 1, "cnt": 2, "cnt_col": 2, "sum": 8.0, "concat": "35"},
{"col1": "b", "col2": 2, "cnt": 1, "cnt_col": 1, "sum": 4.0, "concat": "4"},
{"col1": "c", "col2": 2, "cnt": 1, "cnt_col": 1, "sum": 6.0, "concat": "6"},
]


def test_group_by_error(test_session):
chain = DataChain.from_values(
col1=["a", "a", "b", "b", "b", "c"],
col2=[1, 2, 3, 4, 5, 6],
session=test_session,
)

with pytest.raises(TypeError):
chain.group_by(cnt=func.count())

with pytest.raises(
ValueError, match="At least one column should be provided for partition_by"
):
chain.group_by(cnt=func.count(), partition_by=tuple())

with pytest.raises(
ValueError, match="At least one column should be provided for group_by"
):
chain.group_by(partition_by="col1")

with pytest.raises(
DataChainColumnError,
match="Column foo has type <class 'str'> but expected Func object",
):
chain.group_by(foo="col2", partition_by="col1")

with pytest.raises(DataChainColumnError, match="Column col3 not found in schema"):
chain.group_by(foo=func.sum("col3"), partition_by="col1")

with pytest.raises(DataChainColumnError, match="Column col3 not found in schema"):
chain.group_by(foo=func.sum("col2"), partition_by="col3")
Loading

0 comments on commit bf6d2e9

Please sign in to comment.