Skip to content

Commit

Permalink
datachain: support mutating existing column
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry committed Oct 24, 2024
1 parent 34e7c2b commit fd47162
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 21 deletions.
8 changes: 0 additions & 8 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,14 +1111,6 @@ def mutate(self, **kwargs) -> "Self":
)
```
"""
existing_columns = set(self.signals_schema.values.keys())
for col_name in kwargs:
if col_name in existing_columns:
raise DataChainColumnError(
col_name,
"Cannot modify existing column with mutate(). "
"Use a different name for the new column.",
)
for col_name, expr in kwargs.items():
if not isinstance(expr, (Column, Func)) and isinstance(expr.type, NullType):
raise DataChainColumnError(
Expand Down
8 changes: 4 additions & 4 deletions src/datachain/lib/signal_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,10 +491,10 @@ def mutate(self, args_map: dict) -> "SignalSchema":
# renaming existing signal
del new_values[value.name]
new_values[name] = self.values[value.name]
elif name in self.values:
# changing the type of existing signal, e.g File -> ImageFile
del new_values[name]
new_values[name] = args_map[name]
# elif name in self.values:
# # changing the type of existing signal, e.g File -> ImageFile
# del new_values[name]
# new_values[name] = args_map[name]
elif isinstance(value, Func):
# adding new signal with function
new_values[name] = value.get_result_type(self)
Expand Down
8 changes: 7 additions & 1 deletion src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,10 +720,16 @@ class SQLMutate(SQLClause):

def apply_sql_clause(self, query: Select) -> Select:
original_subquery = query.subquery()
to_mutate = {c.name for c in self.args}

cols = [
c if c.name not in to_mutate else c.label(f"{c.name}_2")
for c in original_subquery.c
]
# this is needed for new column to be used in clauses
# like ORDER BY, otherwise new column is not recognized
subquery = (
sqlalchemy.select(*original_subquery.c, *self.args)
sqlalchemy.select(*cols, *self.args)
.select_from(original_subquery)
.subquery()
)
Expand Down
10 changes: 2 additions & 8 deletions tests/func/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,15 +491,9 @@ def test_from_storage_check_rows(tmp_dir, test_session):

def test_mutate_existing_column(test_session):
ds = DataChain.from_values(ids=[1, 2, 3], session=test_session)
ds = ds.mutate(ids=Column("ids") + 1)

with pytest.raises(DataChainColumnError) as excinfo:
ds.mutate(ids=Column("ids") + 1)

assert (
str(excinfo.value)
== "Error for column ids: Cannot modify existing column with mutate()."
" Use a different name for the new column."
)
assert list(ds.collect()) == [(2,), (3,), (4,)]


@pytest.mark.parametrize(
Expand Down

0 comments on commit fd47162

Please sign in to comment.