Skip to content

Commit

Permalink
feat: Add SQL support for TRUNCATE TABLE command (#15513)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie authored Apr 7, 2024
1 parent eda3ccd commit 4b94d2f
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 6 deletions.
32 changes: 29 additions & 3 deletions crates/polars-sql/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ impl SQLContext {
pub fn new() -> Self {
Self::default()
}

/// Get the names of all registered tables, in sorted order.
pub fn get_tables(&self) -> Vec<String> {
let mut tables = Vec::from_iter(self.table_map.keys().cloned());
Expand Down Expand Up @@ -164,6 +165,7 @@ impl SQLContext {
..
} => self.execute_drop_table(stmt)?,
stmt @ Statement::Explain { .. } => self.execute_explain(stmt)?,
stmt @ Statement::Truncate { .. } => self.execute_truncate_table(stmt)?,
_ => polars_bail!(
ComputeError: "SQL statement type {:?} is not supported", ast,
),
Expand Down Expand Up @@ -262,19 +264,43 @@ impl SQLContext {
fn execute_drop_table(&mut self, stmt: &Statement) -> PolarsResult<LazyFrame> {
match stmt {
Statement::Drop { names, .. } => {
for name in names {
names.iter().for_each(|name| {
self.table_map.remove(&name.to_string());
}
});
Ok(DataFrame::empty().lazy())
},
_ => unreachable!(),
}
}

fn execute_truncate_table(&mut self, stmt: &Statement) -> PolarsResult<LazyFrame> {
if let Statement::Truncate {
table_name,
partitions,
..
} = stmt
{
match partitions {
None => {
let tbl = table_name.to_string();
if let Some(lf) = self.table_map.get_mut(&tbl) {
*lf = DataFrame::from(lf.schema().unwrap().as_ref()).lazy();
Ok(lf.clone())
} else {
polars_bail!(ComputeError: "table '{}' does not exist", tbl);
}
},
_ => polars_bail!(ComputeError: "TRUNCATE does not support use of 'partitions'"),
}
} else {
unreachable!()
}
}

fn register_ctes(&mut self, query: &Query) -> PolarsResult<()> {
if let Some(with) = &query.with {
if with.recursive {
polars_bail!(ComputeError: "Recursive CTEs are not supported")
polars_bail!(ComputeError: "recursive CTEs are not supported")
}
for cte in &with.cte_tables {
let cte_name = cte.alias.name.to_string();
Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2761,7 +2761,7 @@ def clear(self, n: int = 0) -> LazyFrame:
... "c": [True, True, False, None],
... }
... )
>>> lf.clear().fetch()
>>> lf.clear().collect()
shape: (0, 3)
┌─────┬─────┬──────┐
│ a ┆ b ┆ c │
Expand All @@ -2770,7 +2770,7 @@ def clear(self, n: int = 0) -> LazyFrame:
╞═════╪═════╪══════╡
└─────┴─────┴──────┘
>>> lf.clear(2).fetch()
>>> lf.clear(2).collect()
shape: (2, 3)
┌──────┬──────┬──────┐
│ a ┆ b ┆ c │
Expand Down
83 changes: 83 additions & 0 deletions py-polars/tests/unit/sql/test_table_operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from __future__ import annotations

import re
from datetime import date

import pytest

import polars as pl
from polars.exceptions import ComputeError
from polars.testing import assert_frame_equal


@pytest.fixture()
def test_frame() -> pl.LazyFrame:
return pl.LazyFrame(
{
"x": [1, 2, 3],
"y": ["aaa", "bbb", "ccc"],
"z": [date(2000, 12, 31), date(1978, 11, 15), date(2077, 10, 20)],
},
schema_overrides={"x": pl.UInt8},
)


def test_drop_table(test_frame: pl.LazyFrame) -> None:
# 'drop' completely removes the table from sql context
expected = pl.DataFrame()

with pl.SQLContext(frame=test_frame, eager_execution=True) as ctx:
res = ctx.execute("DROP TABLE frame")
assert_frame_equal(res, expected)

with pytest.raises(ComputeError, match="'frame' was not found"):
ctx.execute("SELECT * FROM frame")


def test_explain_query(test_frame: pl.LazyFrame) -> None:
# 'explain' returns the query plan for the given sql
with pl.SQLContext(frame=test_frame) as ctx:
plan = (
ctx.execute("EXPLAIN SELECT * FROM frame")
.select(pl.col("Logical Plan").str.concat(""))
.collect()
.item()
)
assert (
re.search(
pattern=r'SELECT.+?"x".+?"y".+?"z".+?FROM.+?PROJECT.+?COLUMNS',
string=plan,
flags=re.IGNORECASE,
)
is not None
)


def test_show_tables(test_frame: pl.LazyFrame) -> None:
# 'show tables' lists all tables registered with the sql context in sorted order
with pl.SQLContext(
tbl3=test_frame,
tbl2=test_frame,
tbl1=test_frame,
) as ctx:
res = ctx.execute("SHOW TABLES").collect()
assert_frame_equal(res, pl.DataFrame({"name": ["tbl1", "tbl2", "tbl3"]}))


@pytest.mark.parametrize(
"truncate_sql",
[
"TRUNCATE TABLE frame",
"TRUNCATE frame",
],
)
def test_truncate_table(truncate_sql: str, test_frame: pl.LazyFrame) -> None:
# 'truncate' preserves the table, but optimally drops all rows within it
expected = pl.DataFrame(schema=test_frame.schema)

with pl.SQLContext(frame=test_frame, eager_execution=True) as ctx:
res = ctx.execute(truncate_sql)
assert_frame_equal(res, expected)

res = ctx.execute("SELECT * FROM frame")
assert_frame_equal(res, expected)
2 changes: 1 addition & 1 deletion py-polars/tests/unit/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,7 +1058,7 @@ def test_self_join() -> None:
pl.col("employee_name_right").alias("manager_name"),
]
)
.fetch()
.collect()
)
assert set(out.rows()) == {
(100, "James", None),
Expand Down

0 comments on commit 4b94d2f

Please sign in to comment.