Skip to content

Commit

Permalink
Merge pull request #5 from raisadz/duckdb-relational-replace-all
Browse files Browse the repository at this point in the history
Add `replace_all` for DuckDB and xfail tests for `replace`
  • Loading branch information
MarcoGorelli authored Dec 30, 2024
2 parents 2abf875 + 364ae0d commit ca1c643
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
25 changes: 25 additions & 0 deletions narwhals/_duckdb/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any
from typing import Callable
from typing import Literal
from typing import NoReturn
from typing import Sequence

from narwhals._duckdb.utils import get_column_name
Expand Down Expand Up @@ -609,6 +610,30 @@ def strip_chars(self, characters: str | None) -> DuckDBExpr:
returns_scalar=False,
)

def replace_all(
self, pattern: str, value: str, *, literal: bool = False
) -> DuckDBExpr:
from duckdb import ConstantExpression
from duckdb import FunctionExpression

if literal is False:
msg = "`replace_all` for DuckDB currently only supports `literal=True`."
raise NotImplementedError(msg)
return self._compliant_expr._from_call(
lambda _input: FunctionExpression(
"replace",
_input,
ConstantExpression(pattern),
ConstantExpression(value),
),
"replace_all",
returns_scalar=False,
)

def replace(self, pattern: str, value: str, *, literal: bool, n: int) -> NoReturn:
msg = "`replace` is currently not supported for DuckDB"
raise NotImplementedError(msg)


class DuckDBExprDateTimeNamespace:
def __init__(self, expr: DuckDBExpr) -> None:
Expand Down
8 changes: 6 additions & 2 deletions tests/expr_and_series/str/replace_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,17 @@ def test_str_replace_all_series(
)
def test_str_replace_expr(
constructor: Constructor,
request: pytest.FixtureRequest,
data: dict[str, list[str]],
pattern: str,
value: str,
n: int,
literal: bool, # noqa: FBT001
expected: dict[str, list[str]],
) -> None:
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))

result_df = df.select(
nw.col("a").str.replace(pattern=pattern, value=value, n=n, literal=literal)
)
Expand All @@ -114,14 +116,16 @@ def test_str_replace_expr(
)
def test_str_replace_all_expr(
constructor: Constructor,
request: pytest.FixtureRequest,
data: dict[str, list[str]],
pattern: str,
value: str,
literal: bool, # noqa: FBT001
expected: dict[str, list[str]],
) -> None:
if "duckdb" in str(constructor) and literal is False:
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))

result = df.select(
nw.col("a").str.replace_all(pattern=pattern, value=value, literal=literal)
)
Expand Down

0 comments on commit ca1c643

Please sign in to comment.