diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index 9a9963ed3259..7274ea7a9d47 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -49,6 +49,7 @@ impl SQLContext { pub fn new() -> Self { Self::default() } + /// Get the names of all registered tables, in sorted order. pub fn get_tables(&self) -> Vec { let mut tables = Vec::from_iter(self.table_map.keys().cloned()); @@ -164,6 +165,7 @@ impl SQLContext { .. } => self.execute_drop_table(stmt)?, stmt @ Statement::Explain { .. } => self.execute_explain(stmt)?, + stmt @ Statement::Truncate { .. } => self.execute_truncate_table(stmt)?, _ => polars_bail!( ComputeError: "SQL statement type {:?} is not supported", ast, ), @@ -262,19 +264,43 @@ impl SQLContext { fn execute_drop_table(&mut self, stmt: &Statement) -> PolarsResult { match stmt { Statement::Drop { names, .. } => { - for name in names { + names.iter().for_each(|name| { self.table_map.remove(&name.to_string()); - } + }); Ok(DataFrame::empty().lazy()) }, _ => unreachable!(), } } + fn execute_truncate_table(&mut self, stmt: &Statement) -> PolarsResult { + if let Statement::Truncate { + table_name, + partitions, + .. + } = stmt + { + match partitions { + None => { + let tbl = table_name.to_string(); + if let Some(lf) = self.table_map.get_mut(&tbl) { + *lf = DataFrame::from(lf.schema().unwrap().as_ref()).lazy(); + Ok(lf.clone()) + } else { + polars_bail!(ComputeError: "table '{}' does not exist", tbl); + } + }, + _ => polars_bail!(ComputeError: "TRUNCATE does not support use of 'partitions'"), + } + } else { + unreachable!() + } + } + fn register_ctes(&mut self, query: &Query) -> PolarsResult<()> { if let Some(with) = &query.with { if with.recursive { - polars_bail!(ComputeError: "Recursive CTEs are not supported") + polars_bail!(ComputeError: "recursive CTEs are not supported") } for cte in &with.cte_tables { let cte_name = cte.alias.name.to_string(); diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index 805b96c2e415..cbbff5bcf75b 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -2761,7 +2761,7 @@ def clear(self, n: int = 0) -> LazyFrame: ... "c": [True, True, False, None], ... } ... ) - >>> lf.clear().fetch() + >>> lf.clear().collect() shape: (0, 3) ┌─────┬─────┬──────┐ │ a ┆ b ┆ c │ @@ -2770,7 +2770,7 @@ def clear(self, n: int = 0) -> LazyFrame: ╞═════╪═════╪══════╡ └─────┴─────┴──────┘ - >>> lf.clear(2).fetch() + >>> lf.clear(2).collect() shape: (2, 3) ┌──────┬──────┬──────┐ │ a ┆ b ┆ c │ diff --git a/py-polars/tests/unit/sql/test_table_operations.py b/py-polars/tests/unit/sql/test_table_operations.py new file mode 100644 index 000000000000..9189449ebb96 --- /dev/null +++ b/py-polars/tests/unit/sql/test_table_operations.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +import re +from datetime import date + +import pytest + +import polars as pl +from polars.exceptions import ComputeError +from polars.testing import assert_frame_equal + + +@pytest.fixture() +def test_frame() -> pl.LazyFrame: + return pl.LazyFrame( + { + "x": [1, 2, 3], + "y": ["aaa", "bbb", "ccc"], + "z": [date(2000, 12, 31), date(1978, 11, 15), date(2077, 10, 20)], + }, + schema_overrides={"x": pl.UInt8}, + ) + + +def test_drop_table(test_frame: pl.LazyFrame) -> None: + # 'drop' completely removes the table from sql context + expected = pl.DataFrame() + + with pl.SQLContext(frame=test_frame, eager_execution=True) as ctx: + res = ctx.execute("DROP TABLE frame") + assert_frame_equal(res, expected) + + with pytest.raises(ComputeError, match="'frame' was not found"): + ctx.execute("SELECT * FROM frame") + + +def test_explain_query(test_frame: pl.LazyFrame) -> None: + # 'explain' returns the query plan for the given sql + with pl.SQLContext(frame=test_frame) as ctx: + plan = ( + ctx.execute("EXPLAIN SELECT * FROM frame") + .select(pl.col("Logical Plan").str.concat("")) + .collect() + .item() + ) + assert ( + re.search( + pattern=r'SELECT.+?"x".+?"y".+?"z".+?FROM.+?PROJECT.+?COLUMNS', + string=plan, + flags=re.IGNORECASE, + ) + is not None + ) + + +def test_show_tables(test_frame: pl.LazyFrame) -> None: + # 'show tables' lists all tables registered with the sql context in sorted order + with pl.SQLContext( + tbl3=test_frame, + tbl2=test_frame, + tbl1=test_frame, + ) as ctx: + res = ctx.execute("SHOW TABLES").collect() + assert_frame_equal(res, pl.DataFrame({"name": ["tbl1", "tbl2", "tbl3"]})) + + +@pytest.mark.parametrize( + "truncate_sql", + [ + "TRUNCATE TABLE frame", + "TRUNCATE frame", + ], +) +def test_truncate_table(truncate_sql: str, test_frame: pl.LazyFrame) -> None: + # 'truncate' preserves the table, but optimally drops all rows within it + expected = pl.DataFrame(schema=test_frame.schema) + + with pl.SQLContext(frame=test_frame, eager_execution=True) as ctx: + res = ctx.execute(truncate_sql) + assert_frame_equal(res, expected) + + res = ctx.execute("SELECT * FROM frame") + assert_frame_equal(res, expected) diff --git a/py-polars/tests/unit/test_lazy.py b/py-polars/tests/unit/test_lazy.py index 83784ca17cb0..0a98b99f95f3 100644 --- a/py-polars/tests/unit/test_lazy.py +++ b/py-polars/tests/unit/test_lazy.py @@ -1058,7 +1058,7 @@ def test_self_join() -> None: pl.col("employee_name_right").alias("manager_name"), ] ) - .fetch() + .collect() ) assert set(out.rows()) == { (100, "James", None),