diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index 8b78b0a341c5..eb136698f5de 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -209,8 +209,12 @@ impl SQLContext { match quantifier { // UNION ALL SetQuantifier::All => concatenated, - // UNION DISTINCT | UNION - _ => concatenated.map(|lf| lf.unique(None, UniqueKeepStrategy::Any)), + // UNION [DISTINCT] + SetQuantifier::Distinct | SetQuantifier::None => { + concatenated.map(|lf| lf.unique(None, UniqueKeepStrategy::Any)) + }, + // TODO: support "UNION [ALL] BY NAME" + _ => polars_bail!(InvalidOperation: "UNION {} is not yet supported", quantifier), } } diff --git a/py-polars/tests/unit/sql/test_sql.py b/py-polars/tests/unit/sql/test_sql.py index f0bef2312644..cf94aff147d4 100644 --- a/py-polars/tests/unit/sql/test_sql.py +++ b/py-polars/tests/unit/sql/test_sql.py @@ -908,6 +908,67 @@ def test_sql_trim(foods_ipc_path: Path) -> None: } +@pytest.mark.parametrize( + ("cols1", "cols2", "union_subtype", "expected"), + [ + ( + ["*"], + ["*"], + "", + [(1, "zz"), (2, "yy"), (3, "xx")], + ), + ( + ["*"], + ["frame2.*"], + "ALL", + [(1, "zz"), (2, "yy"), (2, "yy"), (3, "xx")], + ), + ( + ["frame1.*"], + ["c1", "c2"], + "DISTINCT", + [(1, "zz"), (2, "yy"), (3, "xx")], + ), + ( + ["*"], + ["c2", "c1"], + "ALL BY NAME", + None, # [(1, 'zz'), (2, 'yy'), (2, 'yy'), (3, 'xx')], + ), + ( + ["c1", "c2"], + ["c2", "c1"], + "BY NAME", + None, # [(1, 'zz'), (2, 'yy'), (3, 'xx')], + ), + ], +) +def test_sql_union( + cols1: list[str], + cols2: list[str], + union_subtype: str, + expected: dict[str, list[int] | list[str]] | None, +) -> None: + with pl.SQLContext( + frame1=pl.DataFrame({"c1": [1, 2], "c2": ["zz", "yy"]}), + frame2=pl.DataFrame({"c1": [2, 3], "c2": ["yy", "xx"]}), + eager_execution=True, + ) as ctx: + query = f""" + SELECT {', '.join(cols1)} FROM frame1 + UNION {union_subtype} + SELECT {', '.join(cols2)} FROM frame2 + """ + if expected is not None: + assert sorted(ctx.execute(query).rows()) == expected + else: + with pytest.raises( + pl.InvalidOperationError, + match=f"UNION {union_subtype} is not yet supported", + ): + ctx.execute(query) + + def test_sql_nullif_coalesce(foods_ipc_path: Path) -> None: nums = pl.LazyFrame( {