From 1d33b35325c43e0d2d702a1ac372ecdfea07d344 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Thu, 24 Oct 2024 16:37:15 +0545 Subject: [PATCH] datachain: support mutating existing column --- src/datachain/lib/dc.py | 8 -------- src/datachain/lib/signal_schema.py | 9 ++++----- src/datachain/query/dataset.py | 10 +++++++++- tests/func/test_datachain.py | 12 +++--------- 4 files changed, 16 insertions(+), 23 deletions(-) diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index 4037ba08..d9f90e32 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -1111,14 +1111,6 @@ def mutate(self, **kwargs) -> "Self": ) ``` """ - existing_columns = set(self.signals_schema.values.keys()) - for col_name in kwargs: - if col_name in existing_columns: - raise DataChainColumnError( - col_name, - "Cannot modify existing column with mutate(). " - "Use a different name for the new column.", - ) for col_name, expr in kwargs.items(): if not isinstance(expr, (Column, Func)) and isinstance(expr.type, NullType): raise DataChainColumnError( diff --git a/src/datachain/lib/signal_schema.py b/src/datachain/lib/signal_schema.py index 9de9d603..10bd3699 100644 --- a/src/datachain/lib/signal_schema.py +++ b/src/datachain/lib/signal_schema.py @@ -20,6 +20,7 @@ ) from pydantic import BaseModel, create_model +from sqlalchemy import ColumnElement from typing_extensions import Literal as LiteralEx from datachain.lib.convert.python_to_sql import python_to_sql @@ -491,16 +492,14 @@ def mutate(self, args_map: dict) -> "SignalSchema": # renaming existing signal del new_values[value.name] new_values[name] = self.values[value.name] - elif name in self.values: - # changing the type of existing signal, e.g File -> ImageFile - del new_values[name] - new_values[name] = args_map[name] elif isinstance(value, Func): # adding new signal with function new_values[name] = value.get_result_type(self) - else: + elif isinstance(value, ColumnElement): # adding new signal new_values[name] = sql_to_python(value) + else: + new_values[name] = value return SignalSchema(new_values) diff --git a/src/datachain/query/dataset.py b/src/datachain/query/dataset.py index a94d773f..88351a6b 100644 --- a/src/datachain/query/dataset.py +++ b/src/datachain/query/dataset.py @@ -10,6 +10,7 @@ from collections.abc import Generator, Iterable, Iterator, Sequence from copy import copy from functools import wraps +from secrets import token_hex from typing import ( TYPE_CHECKING, Any, @@ -720,10 +721,17 @@ class SQLMutate(SQLClause): def apply_sql_clause(self, query: Select) -> Select: original_subquery = query.subquery() + to_mutate = {c.name for c in self.args} + + prefix = f"mutate{token_hex(8)}_" + cols = [ + c.label(prefix + c.name) if c.name in to_mutate else c + for c in original_subquery.c + ] # this is needed for new column to be used in clauses # like ORDER BY, otherwise new column is not recognized subquery = ( - sqlalchemy.select(*original_subquery.c, *self.args) + sqlalchemy.select(*cols, *self.args) .select_from(original_subquery) .subquery() ) diff --git a/tests/func/test_datachain.py b/tests/func/test_datachain.py index cb286b90..ce7b84d9 100644 --- a/tests/func/test_datachain.py +++ b/tests/func/test_datachain.py @@ -24,7 +24,7 @@ from datachain.lib.listing import LISTING_TTL, is_listing_dataset, parse_listing_uri from datachain.lib.tar import process_tar from datachain.lib.udf import Mapper -from datachain.lib.utils import DataChainColumnError, DataChainError +from datachain.lib.utils import DataChainError from datachain.query.dataset import QueryStep from datachain.sql.functions import path as pathfunc from datachain.sql.functions.array import cosine_distance, euclidean_distance @@ -491,15 +491,9 @@ def test_from_storage_check_rows(tmp_dir, test_session): def test_mutate_existing_column(test_session): ds = DataChain.from_values(ids=[1, 2, 3], session=test_session) + ds = ds.mutate(ids=Column("ids") + 1) - with pytest.raises(DataChainColumnError) as excinfo: - ds.mutate(ids=Column("ids") + 1) - - assert ( - str(excinfo.value) - == "Error for column ids: Cannot modify existing column with mutate()." - " Use a different name for the new column." - ) + assert list(ds.collect()) == [(2,), (3,), (4,)] @pytest.mark.parametrize(