Skip to content

Commit 15b96c4

Browse files
feat: add missing PyLogicalPlan to_variant (#1085)
* add expr * format * clippy * add license * update * ruff * Update expr.py * add test * ruff * Minor ruff whitespace change * Minor format change --------- Co-authored-by: Tim Saucer <[email protected]>
1 parent 6fbecef commit 15b96c4

21 files changed

+2372
-16
lines changed

python/datafusion/common.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,12 @@
3333
SqlTable = common_internal.SqlTable
3434
SqlType = common_internal.SqlType
3535
SqlView = common_internal.SqlView
36+
TableType = common_internal.TableType
37+
TableSource = common_internal.TableSource
38+
Constraints = common_internal.Constraints
3639

3740
__all__ = [
41+
"Constraints",
3842
"DFSchema",
3943
"DataType",
4044
"DataTypeMap",
@@ -47,6 +51,8 @@
4751
"SqlTable",
4852
"SqlType",
4953
"SqlView",
54+
"TableSource",
55+
"TableType",
5056
]
5157

5258

python/datafusion/expr.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,29 @@
5454
Case = expr_internal.Case
5555
Cast = expr_internal.Cast
5656
Column = expr_internal.Column
57+
CopyTo = expr_internal.CopyTo
58+
CreateCatalog = expr_internal.CreateCatalog
59+
CreateCatalogSchema = expr_internal.CreateCatalogSchema
60+
CreateExternalTable = expr_internal.CreateExternalTable
61+
CreateFunction = expr_internal.CreateFunction
62+
CreateFunctionBody = expr_internal.CreateFunctionBody
63+
CreateIndex = expr_internal.CreateIndex
5764
CreateMemoryTable = expr_internal.CreateMemoryTable
5865
CreateView = expr_internal.CreateView
66+
Deallocate = expr_internal.Deallocate
67+
DescribeTable = expr_internal.DescribeTable
5968
Distinct = expr_internal.Distinct
69+
DmlStatement = expr_internal.DmlStatement
70+
DropCatalogSchema = expr_internal.DropCatalogSchema
71+
DropFunction = expr_internal.DropFunction
6072
DropTable = expr_internal.DropTable
73+
DropView = expr_internal.DropView
6174
EmptyRelation = expr_internal.EmptyRelation
75+
Execute = expr_internal.Execute
6276
Exists = expr_internal.Exists
6377
Explain = expr_internal.Explain
6478
Extension = expr_internal.Extension
79+
FileType = expr_internal.FileType
6580
Filter = expr_internal.Filter
6681
GroupingSet = expr_internal.GroupingSet
6782
Join = expr_internal.Join
@@ -83,21 +98,31 @@
8398
Literal = expr_internal.Literal
8499
Negative = expr_internal.Negative
85100
Not = expr_internal.Not
101+
OperateFunctionArg = expr_internal.OperateFunctionArg
86102
Partitioning = expr_internal.Partitioning
87103
Placeholder = expr_internal.Placeholder
104+
Prepare = expr_internal.Prepare
88105
Projection = expr_internal.Projection
106+
RecursiveQuery = expr_internal.RecursiveQuery
89107
Repartition = expr_internal.Repartition
90108
ScalarSubquery = expr_internal.ScalarSubquery
91109
ScalarVariable = expr_internal.ScalarVariable
110+
SetVariable = expr_internal.SetVariable
92111
SimilarTo = expr_internal.SimilarTo
93112
Sort = expr_internal.Sort
94113
Subquery = expr_internal.Subquery
95114
SubqueryAlias = expr_internal.SubqueryAlias
96115
TableScan = expr_internal.TableScan
116+
TransactionAccessMode = expr_internal.TransactionAccessMode
117+
TransactionConclusion = expr_internal.TransactionConclusion
118+
TransactionEnd = expr_internal.TransactionEnd
119+
TransactionIsolationLevel = expr_internal.TransactionIsolationLevel
120+
TransactionStart = expr_internal.TransactionStart
97121
TryCast = expr_internal.TryCast
98122
Union = expr_internal.Union
99123
Unnest = expr_internal.Unnest
100124
UnnestExpr = expr_internal.UnnestExpr
125+
Values = expr_internal.Values
101126
WindowExpr = expr_internal.WindowExpr
102127

103128
__all__ = [
@@ -111,15 +136,30 @@
111136
"CaseBuilder",
112137
"Cast",
113138
"Column",
139+
"CopyTo",
140+
"CreateCatalog",
141+
"CreateCatalogSchema",
142+
"CreateExternalTable",
143+
"CreateFunction",
144+
"CreateFunctionBody",
145+
"CreateIndex",
114146
"CreateMemoryTable",
115147
"CreateView",
148+
"Deallocate",
149+
"DescribeTable",
116150
"Distinct",
151+
"DmlStatement",
152+
"DropCatalogSchema",
153+
"DropFunction",
117154
"DropTable",
155+
"DropView",
118156
"EmptyRelation",
157+
"Execute",
119158
"Exists",
120159
"Explain",
121160
"Expr",
122161
"Extension",
162+
"FileType",
123163
"Filter",
124164
"GroupingSet",
125165
"ILike",
@@ -142,22 +182,32 @@
142182
"Literal",
143183
"Negative",
144184
"Not",
185+
"OperateFunctionArg",
145186
"Partitioning",
146187
"Placeholder",
188+
"Prepare",
147189
"Projection",
190+
"RecursiveQuery",
148191
"Repartition",
149192
"ScalarSubquery",
150193
"ScalarVariable",
194+
"SetVariable",
151195
"SimilarTo",
152196
"Sort",
153197
"SortExpr",
154198
"Subquery",
155199
"SubqueryAlias",
156200
"TableScan",
201+
"TransactionAccessMode",
202+
"TransactionConclusion",
203+
"TransactionEnd",
204+
"TransactionIsolationLevel",
205+
"TransactionStart",
157206
"TryCast",
158207
"Union",
159208
"Unnest",
160209
"UnnestExpr",
210+
"Values",
161211
"Window",
162212
"WindowExpr",
163213
"WindowFrame",
@@ -686,8 +736,8 @@ def log10(self) -> Expr:
686736
def initcap(self) -> Expr:
687737
"""Set the initial letter of each word to capital.
688738
689-
Converts the first letter of each word in ``string``
690-
to uppercase and the remaining characters to lowercase.
739+
Converts the first letter of each word in ``string`` to uppercase and the
740+
remaining characters to lowercase.
691741
"""
692742
from . import functions as F
693743

python/tests/test_expr.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,21 @@
2323
AggregateFunction,
2424
BinaryExpr,
2525
Column,
26+
CopyTo,
27+
CreateIndex,
28+
DescribeTable,
29+
DmlStatement,
30+
DropCatalogSchema,
2631
Filter,
2732
Limit,
2833
Literal,
2934
Projection,
35+
RecursiveQuery,
3036
Sort,
3137
TableScan,
38+
TransactionEnd,
39+
TransactionStart,
40+
Values,
3241
)
3342

3443

@@ -249,6 +258,83 @@ def test_fill_null(df):
249258
assert result.column(2) == pa.array([1234, 1234, 8])
250259

251260

261+
def test_copy_to():
262+
ctx = SessionContext()
263+
ctx.sql("CREATE TABLE foo (a int, b int)").collect()
264+
df = ctx.sql("COPY foo TO bar STORED AS CSV")
265+
plan = df.logical_plan()
266+
plan = plan.to_variant()
267+
assert isinstance(plan, CopyTo)
268+
269+
270+
def test_create_index():
271+
ctx = SessionContext()
272+
ctx.sql("CREATE TABLE foo (a int, b int)").collect()
273+
plan = ctx.sql("create index idx on foo (a)").logical_plan()
274+
plan = plan.to_variant()
275+
assert isinstance(plan, CreateIndex)
276+
277+
278+
def test_describe_table():
279+
ctx = SessionContext()
280+
ctx.sql("CREATE TABLE foo (a int, b int)").collect()
281+
plan = ctx.sql("describe foo").logical_plan()
282+
plan = plan.to_variant()
283+
assert isinstance(plan, DescribeTable)
284+
285+
286+
def test_dml_statement():
287+
ctx = SessionContext()
288+
ctx.sql("CREATE TABLE foo (a int, b int)").collect()
289+
plan = ctx.sql("insert into foo values (1, 2)").logical_plan()
290+
plan = plan.to_variant()
291+
assert isinstance(plan, DmlStatement)
292+
293+
294+
def drop_catalog_schema():
295+
ctx = SessionContext()
296+
plan = ctx.sql("drop schema cat").logical_plan()
297+
plan = plan.to_variant()
298+
assert isinstance(plan, DropCatalogSchema)
299+
300+
301+
def test_recursive_query():
302+
ctx = SessionContext()
303+
plan = ctx.sql(
304+
"""
305+
WITH RECURSIVE cte AS (
306+
SELECT 1 as n
307+
UNION ALL
308+
SELECT n + 1 FROM cte WHERE n < 5
309+
)
310+
SELECT * FROM cte;
311+
"""
312+
).logical_plan()
313+
plan = plan.inputs()[0].inputs()[0].to_variant()
314+
assert isinstance(plan, RecursiveQuery)
315+
316+
317+
def test_values():
318+
ctx = SessionContext()
319+
plan = ctx.sql("values (1, 'foo'), (2, 'bar')").logical_plan()
320+
plan = plan.to_variant()
321+
assert isinstance(plan, Values)
322+
323+
324+
def test_transaction_start():
325+
ctx = SessionContext()
326+
plan = ctx.sql("START TRANSACTION").logical_plan()
327+
plan = plan.to_variant()
328+
assert isinstance(plan, TransactionStart)
329+
330+
331+
def test_transaction_end():
332+
ctx = SessionContext()
333+
plan = ctx.sql("COMMIT").logical_plan()
334+
plan = plan.to_variant()
335+
assert isinstance(plan, TransactionEnd)
336+
337+
252338
def test_col_getattr():
253339
ctx = SessionContext()
254340
data = {

src/common.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,8 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
3636
m.add_class::<schema::SqlView>()?;
3737
m.add_class::<schema::SqlStatistics>()?;
3838
m.add_class::<function::SqlFunction>()?;
39+
m.add_class::<schema::PyTableType>()?;
40+
m.add_class::<schema::PyTableSource>()?;
41+
m.add_class::<schema::PyConstraints>()?;
3942
Ok(())
4043
}

src/common/schema.rs

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,22 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use std::fmt::{self, Display, Formatter};
19+
use std::sync::Arc;
1820
use std::{any::Any, borrow::Cow};
1921

22+
use arrow::datatypes::Schema;
23+
use arrow::pyarrow::PyArrowType;
2024
use datafusion::arrow::datatypes::SchemaRef;
25+
use datafusion::common::Constraints;
26+
use datafusion::datasource::TableType;
2127
use datafusion::logical_expr::{Expr, TableProviderFilterPushDown, TableSource};
2228
use pyo3::prelude::*;
2329

2430
use datafusion::logical_expr::utils::split_conjunction;
2531

32+
use crate::sql::logical::PyLogicalPlan;
33+
2634
use super::{data_type::DataTypeMap, function::SqlFunction};
2735

2836
#[pyclass(name = "SqlSchema", module = "datafusion.common", subclass)]
@@ -218,3 +226,84 @@ impl SqlStatistics {
218226
self.row_count
219227
}
220228
}
229+
230+
#[pyclass(name = "Constraints", module = "datafusion.expr", subclass)]
231+
#[derive(Clone)]
232+
pub struct PyConstraints {
233+
pub constraints: Constraints,
234+
}
235+
236+
impl From<PyConstraints> for Constraints {
237+
fn from(constraints: PyConstraints) -> Self {
238+
constraints.constraints
239+
}
240+
}
241+
242+
impl From<Constraints> for PyConstraints {
243+
fn from(constraints: Constraints) -> Self {
244+
PyConstraints { constraints }
245+
}
246+
}
247+
248+
impl Display for PyConstraints {
249+
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
250+
write!(f, "Constraints: {:?}", self.constraints)
251+
}
252+
}
253+
254+
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
255+
#[pyclass(eq, eq_int, name = "TableType", module = "datafusion.common")]
256+
pub enum PyTableType {
257+
Base,
258+
View,
259+
Temporary,
260+
}
261+
262+
impl From<PyTableType> for datafusion::logical_expr::TableType {
263+
fn from(table_type: PyTableType) -> Self {
264+
match table_type {
265+
PyTableType::Base => datafusion::logical_expr::TableType::Base,
266+
PyTableType::View => datafusion::logical_expr::TableType::View,
267+
PyTableType::Temporary => datafusion::logical_expr::TableType::Temporary,
268+
}
269+
}
270+
}
271+
272+
impl From<TableType> for PyTableType {
273+
fn from(table_type: TableType) -> Self {
274+
match table_type {
275+
datafusion::logical_expr::TableType::Base => PyTableType::Base,
276+
datafusion::logical_expr::TableType::View => PyTableType::View,
277+
datafusion::logical_expr::TableType::Temporary => PyTableType::Temporary,
278+
}
279+
}
280+
}
281+
282+
#[pyclass(name = "TableSource", module = "datafusion.common", subclass)]
283+
#[derive(Clone)]
284+
pub struct PyTableSource {
285+
pub table_source: Arc<dyn TableSource>,
286+
}
287+
288+
#[pymethods]
289+
impl PyTableSource {
290+
pub fn schema(&self) -> PyArrowType<Schema> {
291+
(*self.table_source.schema()).clone().into()
292+
}
293+
294+
pub fn constraints(&self) -> Option<PyConstraints> {
295+
self.table_source.constraints().map(|c| PyConstraints {
296+
constraints: c.clone(),
297+
})
298+
}
299+
300+
pub fn table_type(&self) -> PyTableType {
301+
self.table_source.table_type().into()
302+
}
303+
304+
pub fn get_logical_plan(&self) -> Option<PyLogicalPlan> {
305+
self.table_source
306+
.get_logical_plan()
307+
.map(|plan| PyLogicalPlan::new(plan.into_owned()))
308+
}
309+
}

0 commit comments

Comments
 (0)