Skip to content

Commit

Permalink
Bind SQLOptions and relative ctx method #567
Browse files Browse the repository at this point in the history
  • Loading branch information
giacomo.rebecchi committed Feb 25, 2024
1 parent 697ca2c commit aaa4bdd
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 1 deletion.
2 changes: 2 additions & 0 deletions datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
SessionConfig,
RuntimeConfig,
ScalarUDF,
SQLOptions,
)

from .common import (
Expand Down Expand Up @@ -96,6 +97,7 @@
"DataFrame",
"SessionContext",
"SessionConfig",
"SQLOptions",
"RuntimeConfig",
"Expr",
"AggregateUDF",
Expand Down
5 changes: 5 additions & 0 deletions datafusion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List

from datafusion import SQLOptions
from datafusion.common import SqlSchema, SqlTable


Expand Down Expand Up @@ -140,3 +141,7 @@ def explain(self, sql):
@abstractmethod
def sql(self, sql):
pass

@abstractmethod
def sql_with_options(self, sql, options: SQLOptions):
pass
26 changes: 26 additions & 0 deletions datafusion/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
SessionConfig,
RuntimeConfig,
DataFrame,
SQLOptions,
)
import pytest

Expand Down Expand Up @@ -395,3 +396,28 @@ def test_read_parquet(ctx):
def test_read_avro(ctx):
csv_df = ctx.read_avro(path="testing/data/avro/alltypes_plain.avro")
csv_df.show()


def test_create_sql_options():
SQLOptions()


def test_sql_with_options_no_ddl(ctx):
options = SQLOptions().with_allow_ddl(False)
sql = "CREATE TABLE IF NOT EXISTS valuetable AS VALUES(1,'HELLO'),(12,'DATAFUSION')"
with pytest.raises(Exception, match="DDL"):
ctx.sql_with_options(sql, options=options)


def test_sql_with_options_no_dml(ctx):
table_name = "t"
options = SQLOptions().with_allow_dml(False)
batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
names=["a", "b"],
)
dataset = ds.dataset([batch])
ctx.register_dataset(table_name, dataset)
sql = f'INSERT INTO "{table_name}" VALUES (1, 2), (2, 3);'
with pytest.raises(Exception, match="DML"):
ctx.sql_with_options(sql, options=options)
57 changes: 56 additions & 1 deletion src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ use datafusion::arrow::record_batch::RecordBatch;
use datafusion::datasource::file_format::file_compression_type::FileCompressionType;
use datafusion::datasource::MemTable;
use datafusion::datasource::TableProvider;
use datafusion::execution::context::{SessionConfig, SessionContext, SessionState, TaskContext};
use datafusion::execution::context::{
SQLOptions, SessionConfig, SessionContext, SessionState, TaskContext,
};
use datafusion::execution::disk_manager::DiskManagerConfig;
use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool};
use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
Expand Down Expand Up @@ -210,6 +212,43 @@ impl PyRuntimeConfig {
}
}

/// `PySQLOptions` allows you to specify options to the sql execution.
#[pyclass(name = "SQLOptions", module = "datafusion", subclass)]
#[derive(Clone)]
pub struct PySQLOptions {
pub options: SQLOptions,
}

impl From<SQLOptions> for PySQLOptions {
fn from(options: SQLOptions) -> Self {
Self { options }
}
}

#[pymethods]
impl PySQLOptions {
#[new]
fn new() -> Self {
let options = SQLOptions::new();
Self { options }
}

/// Should DDL data modification commands (e.g. `CREATE TABLE`) be run? Defaults to `true`.
fn with_allow_ddl(&self, allow: bool) -> Self {
Self::from(self.options.with_allow_ddl(allow))
}

/// Should DML data modification commands (e.g. `INSERT and COPY`) be run? Defaults to `true`
pub fn with_allow_dml(&self, allow: bool) -> Self {
Self::from(self.options.with_allow_dml(allow))
}

/// Should Statements such as (e.g. `SET VARIABLE and `BEGIN TRANSACTION` ...`) be run?. Defaults to `true`
pub fn with_allow_statements(&self, allow: bool) -> Self {
Self::from(self.options.with_allow_statements(allow))
}
}

/// `PySessionContext` is able to plan and execute DataFusion plans.
/// It has a powerful optimizer, a physical planner for local execution, and a
/// multi-threaded execution engine to perform the execution.
Expand Down Expand Up @@ -285,6 +324,22 @@ impl PySessionContext {
Ok(PyDataFrame::new(df))
}

pub fn sql_with_options(
&mut self,
query: &str,
options: Option<PySQLOptions>,
py: Python,
) -> PyResult<PyDataFrame> {
let options = if let Some(options) = options {
options.options
} else {
SQLOptions::new()
};
let result = self.ctx.sql_with_options(query, options);
let df = wait_for_future(py, result).map_err(DataFusionError::from)?;
Ok(PyDataFrame::new(df))
}

pub fn create_dataframe(
&mut self,
partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ fn _internal(py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<context::PyRuntimeConfig>()?;
m.add_class::<context::PySessionConfig>()?;
m.add_class::<context::PySessionContext>()?;
m.add_class::<context::PySQLOptions>()?;
m.add_class::<dataframe::PyDataFrame>()?;
m.add_class::<udf::PyScalarUDF>()?;
m.add_class::<udaf::PyAggregateUDF>()?;
Expand Down

0 comments on commit aaa4bdd

Please sign in to comment.