From fd4716204aa9a313322c668b8ebf5832e6b867ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Thu, 24 Oct 2024 16:37:15 +0545 Subject: [PATCH] datachain: support mutating existing column --- src/datachain/lib/dc.py | 8 -------- src/datachain/lib/signal_schema.py | 8 ++++---- src/datachain/query/dataset.py | 8 +++++++- tests/func/test_datachain.py | 10 ++-------- 4 files changed, 13 insertions(+), 21 deletions(-) diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index 4037ba08..d9f90e32 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -1111,14 +1111,6 @@ def mutate(self, **kwargs) -> "Self": ) ``` """ - existing_columns = set(self.signals_schema.values.keys()) - for col_name in kwargs: - if col_name in existing_columns: - raise DataChainColumnError( - col_name, - "Cannot modify existing column with mutate(). " - "Use a different name for the new column.", - ) for col_name, expr in kwargs.items(): if not isinstance(expr, (Column, Func)) and isinstance(expr.type, NullType): raise DataChainColumnError( diff --git a/src/datachain/lib/signal_schema.py b/src/datachain/lib/signal_schema.py index 9de9d603..3a6cdb51 100644 --- a/src/datachain/lib/signal_schema.py +++ b/src/datachain/lib/signal_schema.py @@ -491,10 +491,10 @@ def mutate(self, args_map: dict) -> "SignalSchema": # renaming existing signal del new_values[value.name] new_values[name] = self.values[value.name] - elif name in self.values: - # changing the type of existing signal, e.g File -> ImageFile - del new_values[name] - new_values[name] = args_map[name] + # elif name in self.values: + # # changing the type of existing signal, e.g File -> ImageFile + # del new_values[name] + # new_values[name] = args_map[name] elif isinstance(value, Func): # adding new signal with function new_values[name] = value.get_result_type(self) diff --git a/src/datachain/query/dataset.py b/src/datachain/query/dataset.py index a94d773f..7105bb5e 100644 --- a/src/datachain/query/dataset.py +++ b/src/datachain/query/dataset.py @@ -720,10 +720,16 @@ class SQLMutate(SQLClause): def apply_sql_clause(self, query: Select) -> Select: original_subquery = query.subquery() + to_mutate = {c.name for c in self.args} + + cols = [ + c if c.name not in to_mutate else c.label(f"{c.name}_2") + for c in original_subquery.c + ] # this is needed for new column to be used in clauses # like ORDER BY, otherwise new column is not recognized subquery = ( - sqlalchemy.select(*original_subquery.c, *self.args) + sqlalchemy.select(*cols, *self.args) .select_from(original_subquery) .subquery() ) diff --git a/tests/func/test_datachain.py b/tests/func/test_datachain.py index cb286b90..9040b778 100644 --- a/tests/func/test_datachain.py +++ b/tests/func/test_datachain.py @@ -491,15 +491,9 @@ def test_from_storage_check_rows(tmp_dir, test_session): def test_mutate_existing_column(test_session): ds = DataChain.from_values(ids=[1, 2, 3], session=test_session) + ds = ds.mutate(ids=Column("ids") + 1) - with pytest.raises(DataChainColumnError) as excinfo: - ds.mutate(ids=Column("ids") + 1) - - assert ( - str(excinfo.value) - == "Error for column ids: Cannot modify existing column with mutate()." - " Use a different name for the new column." - ) + assert list(ds.collect()) == [(2,), (3,), (4,)] @pytest.mark.parametrize(