Skip to content

Commit

Permalink
feat(batch): support extended declare query cursor (#19043)
Browse files Browse the repository at this point in the history
  • Loading branch information
KeXiangWang authored Nov 18, 2024
1 parent f5537d9 commit 7ba6650
Show file tree
Hide file tree
Showing 9 changed files with 166 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package com.risingwave;

import org.junit.jupiter.api.Test;

import java.sql.*;

import org.junit.jupiter.api.Assertions;

public class TestCursor {

public static void createTable() throws SQLException {
try (Connection connection = TestUtils.establishConnection()) {
String createTableSQL = "CREATE TABLE test_table (" +
"id INT PRIMARY KEY, " +
"trading_date DATE, " +
"volume INT)";
Statement statement = connection.createStatement();
statement.execute(createTableSQL);

String insertSQL = "INSERT INTO test_table (id, trading_date, volume) VALUES (1, '2024-07-10', 23)";
statement.execute(insertSQL);
System.out.println("Table test_table created successfully.");
}
}

public static void dropTable() throws SQLException {
String dropSourceQuery = "DROP TABLE test_table;";
try (Connection connection = TestUtils.establishConnection()) {
Statement statement = connection.createStatement();
statement.executeUpdate(dropSourceQuery);
System.out.println("Table test_table dropped successfully.");
}
}


public static void readWithExtendedCursor() throws SQLException {
try (Connection connection = TestUtils.establishConnection()) {
connection.setAutoCommit(false);
Statement statement = connection.createStatement();
statement.execute("START TRANSACTION ISOLATION LEVEL REPEATABLE READ");

String declareCursorSql = "DECLARE c1 CURSOR FOR SELECT id, trading_date, volume FROM public.test_table WHERE ((id = CAST(? AS INT)))";
PreparedStatement pstmt = connection.prepareStatement(declareCursorSql);
pstmt.setInt(1, 1);
pstmt.execute();

statement.execute("FETCH 100 FROM c1");
ResultSet resultSet = statement.getResultSet();

while (resultSet != null && resultSet.next()) {
Assertions.assertEquals(resultSet.getInt("id"), 1);
Assertions.assertEquals(resultSet.getString("trading_date"), "2024-07-10");
Assertions.assertEquals(resultSet.getInt("volume"), 23);
}

statement.execute("CLOSE c1");
statement.execute("COMMIT");

System.out.println("Data in table read with cursor successfully.");
}
}

@Test
public void testCursor() throws SQLException {
createTable();
readWithExtendedCursor();
dropTable();
}
}
32 changes: 32 additions & 0 deletions src/frontend/src/binder/declare_cursor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright 2024 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use risingwave_sqlparser::ast::ObjectName;

use super::statement::RewriteExprsRecursive;
use crate::binder::BoundQuery;
use crate::expr::ExprRewriter;

#[derive(Debug, Clone)]
pub struct BoundDeclareCursor {
pub cursor_name: ObjectName,
// Currently we only support cursor with query
pub query: Box<BoundQuery>, // reuse the BoundQuery struct
}

impl RewriteExprsRecursive for BoundDeclareCursor {
fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl ExprRewriter) {
self.query.rewrite_exprs_recursive(rewriter);
}
}
1 change: 1 addition & 0 deletions src/frontend/src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ mod bind_context;
mod bind_param;
mod create;
mod create_view;
mod declare_cursor;
mod delete;
mod expr;
pub mod fetch_cursor;
Expand Down
21 changes: 20 additions & 1 deletion src/frontend/src/binder/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@

use risingwave_common::bail_not_implemented;
use risingwave_common::catalog::Field;
use risingwave_sqlparser::ast::Statement;
use risingwave_sqlparser::ast::{DeclareCursor, Statement};

use super::declare_cursor::BoundDeclareCursor;
use super::delete::BoundDelete;
use super::fetch_cursor::BoundFetchCursor;
use super::update::BoundUpdate;
Expand All @@ -30,6 +31,7 @@ pub enum BoundStatement {
Delete(Box<BoundDelete>),
Update(Box<BoundUpdate>),
Query(Box<BoundQuery>),
DeclareCursor(Box<BoundDeclareCursor>),
FetchCursor(Box<BoundFetchCursor>),
CreateView(Box<BoundCreateView>),
}
Expand All @@ -50,6 +52,7 @@ impl BoundStatement {
.as_ref()
.map_or(vec![], |s| s.fields().into()),
BoundStatement::Query(q) => q.schema().fields().into(),
BoundStatement::DeclareCursor(_) => vec![],
BoundStatement::FetchCursor(f) => f
.returning_schema
.as_ref()
Expand Down Expand Up @@ -92,6 +95,21 @@ impl Binder {

Statement::Query(q) => Ok(BoundStatement::Query(self.bind_query(*q)?.into())),

Statement::DeclareCursor { stmt } => {
if let DeclareCursor::Query(body) = stmt.declare_cursor {
let query = self.bind_query(*body)?;
Ok(BoundStatement::DeclareCursor(
BoundDeclareCursor {
cursor_name: stmt.cursor_name,
query: query.into(),
}
.into(),
))
} else {
bail_not_implemented!("unsupported statement {:?}", stmt)
}
}

// Note(eric): Can I just bind CreateView to Query??
Statement::CreateView {
or_replace,
Expand Down Expand Up @@ -133,6 +151,7 @@ impl RewriteExprsRecursive for BoundStatement {
BoundStatement::Delete(inner) => inner.rewrite_exprs_recursive(rewriter),
BoundStatement::Update(inner) => inner.rewrite_exprs_recursive(rewriter),
BoundStatement::Query(inner) => inner.rewrite_exprs_recursive(rewriter),
BoundStatement::DeclareCursor(inner) => inner.rewrite_exprs_recursive(rewriter),
BoundStatement::FetchCursor(_) => {}
BoundStatement::CreateView(inner) => inner.rewrite_exprs_recursive(rewriter),
}
Expand Down
17 changes: 17 additions & 0 deletions src/frontend/src/handler/declare_cursor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,23 @@ async fn handle_declare_query_cursor(
Ok(PgResponse::empty_result(StatementType::DECLARE_CURSOR))
}

pub async fn handle_bound_declare_query_cursor(
handle_args: HandlerArgs,
cursor_name: ObjectName,
plan_fragmenter_result: BatchPlanFragmenterResult,
) -> Result<RwPgResponse> {
let session = handle_args.session.clone();
let (chunk_stream, fields) =
create_chunk_stream_for_cursor(session, plan_fragmenter_result).await?;

handle_args
.session
.get_cursor_manager()
.add_query_cursor(cursor_name, chunk_stream, fields)
.await?;
Ok(PgResponse::empty_result(StatementType::DECLARE_CURSOR))
}

pub async fn create_stream_for_cursor_stmt(
handle_args: HandlerArgs,
stmt: Statement,
Expand Down
9 changes: 8 additions & 1 deletion src/frontend/src/handler/extended_handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use bytes::Bytes;
use pgwire::types::Format;
use risingwave_common::bail_not_implemented;
use risingwave_common::types::DataType;
use risingwave_sqlparser::ast::{CreateSink, Query, Statement};
use risingwave_sqlparser::ast::{CreateSink, DeclareCursor, Query, Statement};

use super::query::BoundResult;
use super::{fetch_cursor, handle, query, HandlerArgs, RwPgResponse};
Expand Down Expand Up @@ -112,6 +112,13 @@ pub async fn handle_parse(
Statement::FetchCursor { .. } => {
fetch_cursor::handle_parse(handler_args, statement, specific_param_types).await
}
Statement::DeclareCursor { stmt } => {
if let DeclareCursor::Query(_) = stmt.declare_cursor {
query::handle_parse(handler_args, statement, specific_param_types)
} else {
bail_not_implemented!("DECLARE SUBSCRIPTION CURSOR with parameters");
}
}
Statement::CreateView {
query,
materialized,
Expand Down
3 changes: 3 additions & 0 deletions src/frontend/src/handler/privilege.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ pub(crate) fn resolve_privileges(stmt: &BoundStatement) -> Vec<ObjectCheckItem>
objects.push(object);
}
BoundStatement::Query(ref query) => objects.extend(resolve_query_privileges(query)),
BoundStatement::DeclareCursor(ref declare_cursor) => {
objects.extend(resolve_query_privileges(&declare_cursor.query))
}
BoundStatement::FetchCursor(_) => unimplemented!(),
BoundStatement::CreateView(ref create_view) => {
objects.extend(resolve_query_privileges(&create_view.query))
Expand Down
16 changes: 15 additions & 1 deletion src/frontend/src/handler/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use risingwave_common::types::{DataType, Datum};
use risingwave_sqlparser::ast::{SetExpr, Statement};

use super::extended_handle::{PortalResult, PrepareStatement, PreparedResult};
use super::{create_mv, PgResponseStream, RwPgResponse};
use super::{create_mv, declare_cursor, PgResponseStream, RwPgResponse};
use crate::binder::{Binder, BoundCreateView, BoundStatement};
use crate::catalog::TableId;
use crate::error::{ErrorCode, Result, RwError};
Expand Down Expand Up @@ -149,6 +149,20 @@ pub async fn handle_execute(
)
.await
}
Statement::DeclareCursor { stmt } => {
let session = handler_args.session.clone();
let plan_fragmenter_result = {
let context = OptimizerContext::from_handler_args(handler_args.clone());
let plan_result = gen_batch_query_plan(&session, context.into(), bound_result)?;
gen_batch_plan_fragmenter(&session, plan_result)?
};
declare_cursor::handle_bound_declare_query_cursor(
handler_args,
stmt.cursor_name,
plan_fragmenter_result,
)
.await
}
_ => unreachable!(),
}
}
Expand Down
1 change: 1 addition & 0 deletions src/frontend/src/planner/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ impl Planner {
BoundStatement::Delete(d) => self.plan_delete(*d),
BoundStatement::Update(u) => self.plan_update(*u),
BoundStatement::Query(q) => self.plan_query(*q),
BoundStatement::DeclareCursor(d) => self.plan_query(*d.query),
BoundStatement::FetchCursor(_) => unimplemented!(),
BoundStatement::CreateView(c) => self.plan_query(*c.query),
}
Expand Down

0 comments on commit 7ba6650

Please sign in to comment.