Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sql: support query strings containing multiple statements #2707

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions internal/impl/sql/conn_fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,14 +170,14 @@ func (c *connSettings) apply(ctx context.Context, db *sql.DB, log *service.Logge

c.initOnce.Do(func() {
for _, fileStmt := range c.initFileStatements {
if _, err := db.ExecContext(ctx, fileStmt[1]); err != nil {
if err := execMultiWithContext(db, ctx, fileStmt[1]); err != nil {
log.Warnf("Failed to execute init_file '%v': %v", fileStmt[0], err)
} else {
log.Debugf("Successfully ran init_file '%v'", fileStmt[0])
}
}
if c.initStatement != "" {
if _, err := db.ExecContext(ctx, c.initStatement); err != nil {
if err := execMultiWithContext(db, ctx, c.initStatement); err != nil {
log.Warnf("Failed to execute init_statement: %v", err)
} else {
log.Debug("Successfully ran init_statement")
Expand Down
2 changes: 1 addition & 1 deletion internal/impl/sql/input_sql_raw.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func (s *sqlRawInput) Connect(ctx context.Context) (err error) {
}

var rows *sql.Rows
if rows, err = db.Query(s.queryStatic, args...); err != nil {
if rows, err = queryMultiWithContext(db, ctx, s.queryStatic, args...); err != nil {
return
} else if err = rows.Err(); err != nil {
s.logger.With("err", err).Warnf("unexpected error while execute raw query %q", s.queryStatic)
Expand Down
136 changes: 136 additions & 0 deletions internal/impl/sql/multi_statement.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
// Copyright 2024 Redpanda Data, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package sql

import (
"context"
"database/sql"
"strings"
)

func splitSQLStatements(statement string) []string {
var result []string
startp := 0
p := 0
sawNonCommentOrSpace := false
for {
if p == len(statement) || statement[p] == ';' {
if p != len(statement) && statement[p] == ';' {
// include trailing semicolon
p++
}
statementPart := statement[startp:p]
if sawNonCommentOrSpace {
result = append(result, strings.TrimSpace(statementPart))
} else {
// coalesce any functionally "empty" statements into the previous statement
// so any configurations that have something like "statement; -- final comment"
// will still work
result[len(result)-1] += statementPart
}
if p == len(statement) {
break
}
startp = p
sawNonCommentOrSpace = false
} else if statement[p] == '\'' || statement[p] == '"' || statement[p] == '`' {
// single-quoted strings, double-quoted identifiers, and backtick-quoted identifiers
sentinel := statement[p]
p++
for p < len(statement) && statement[p] != sentinel {
p++
}
sawNonCommentOrSpace = true
} else if statement[p] == '#' ||
(p+1 < len(statement) && statement[p:p+2] == "--") ||
(p+1 < len(statement) && statement[p:p+2] == "//") {
// single-line comments starting with hash, double-dash, or double-slash
for p < len(statement) && statement[p] != '\n' {
p++
}
} else if p+1 < len(statement) && statement[p:p+2] == "/*" {
// multi-line comments starting with slash-asterisk
for p+1 < len(statement) && statement[p:p+2] != "*/" {
p++
}
} else if !(statement[p] == ' ' || statement[p] == '\t' || statement[p] == '\r' || statement[p] == '\n') {
sawNonCommentOrSpace = true
}
if p != len(statement) {
p++
}
}

return result
}

func execMultiWithContext(db *sql.DB, ctx context.Context, query string, args ...any) error {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer func() {
_ = tx.Rollback()
}()

statements := splitSQLStatements(query)
for _, part := range statements {
if _, err = tx.ExecContext(ctx, part, args...); err != nil {
return err
}
args = []any{}
}

if err = tx.Commit(); err != nil {
return err
}

// TODO: should this return anything for a result?
return nil
}

func queryMultiWithContext(db *sql.DB, ctx context.Context, query string, args ...any) (*sql.Rows, error) {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer func() {
_ = tx.Rollback()
}()

statements := splitSQLStatements(query)
var rows *sql.Rows
for i, part := range statements {
// this may not be useful to only give the args to the first query. but, principle of least surprise,
// make it act the same way that execMultiWithContext and the various drivers do.
if i < len(statements)-1 {
if _, err = tx.ExecContext(ctx, part, args...); err != nil {
return nil, err
}
} else {
rows, err = tx.QueryContext(ctx, part, args...)
if err != nil {
return nil, err
}
}
args = []any{}
}

if err = tx.Commit(); err != nil {
return nil, err
}

return rows, nil
}
84 changes: 84 additions & 0 deletions internal/impl/sql/multi_statement_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright 2024 Redpanda Data, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package sql

import (
"testing"

"github.com/stretchr/testify/assert"
)

func assertSplitEquals(t *testing.T, message string, statement string, wanted []string) {
result := splitSQLStatements(statement)
assert.Equal(t, wanted, result, message)
}

func TestSplitStatements(t *testing.T) {
assertSplitEquals(t, "no semicolon", "select null", []string{"select null"})

assertSplitEquals(t, "basic semicolon", "select 1; select 2", []string{"select 1;", "select 2"})

assertSplitEquals(t, "semicolon in single-quoted string",
"select 'singlequoted;string'; select null",
[]string{"select 'singlequoted;string';", "select null"})

assertSplitEquals(t, "semicolon in double-quoted identifier",
"select \"doublequoted;ident\"; select null",
[]string{"select \"doublequoted;ident\";", "select null"})

assertSplitEquals(t, "semicolon in backtick-quoted identifier",
"select `backtick;ident`; select null",
[]string{"select `backtick;ident`;", "select null"})

assertSplitEquals(t, "semicolon in hash-comment", `
select #hash;comment
1; select 2
`, []string{"select #hash;comment\n\t\t1;", "select 2"})

assertSplitEquals(t, "semicolon in double-dash comment", `
select --double-dash;comment
1; select 2
`, []string{"select --double-dash;comment\n\t\t1;", "select 2"})

assertSplitEquals(t, "semicolon in double-slash comment", `
select //double-slash;comment
1; select 2
`, []string{"select //double-slash;comment\n\t\t1;", "select 2"})

assertSplitEquals(t, "semicolon in multi-line comment", `
select /*multi;
line;comment*/
1; select 2
`, []string{"select /*multi;\n\t\tline;comment*/\n\t\t1;", "select 2"})

assertSplitEquals(t, "semicolon at end should be single statement",
"select null;",
[]string{"select null;"})

assertSplitEquals(t, "comment with no newline should not fail",
"select null // comment with no newline",
[]string{"select null // comment with no newline"})

assertSplitEquals(t, "semicolon followed by comment at end should be single statement",
"select null; // trailing comment",
[]string{"select null; // trailing comment"})

assertSplitEquals(t, "coalesce empty statements into previous but not nonempty statements",
`select 1; // comment
;
select 2;`,
[]string{"select 1; // comment\n\t\t;", "select 2;"})

}
2 changes: 1 addition & 1 deletion internal/impl/sql/output_sql_raw.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ func (s *sqlRawOutput) WriteBatch(ctx context.Context, batch service.MessageBatc
}
}

if _, err := s.db.ExecContext(ctx, queryStr, args...); err != nil {
if err := execMultiWithContext(s.db, ctx, queryStr, args...); err != nil {
return err
}
}
Expand Down
4 changes: 2 additions & 2 deletions internal/impl/sql/processor_sql_raw.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,13 @@ func (s *sqlRawProcessor) ProcessBatch(ctx context.Context, batch service.Messag
}

if s.onlyExec {
if _, err := s.db.ExecContext(ctx, queryStr, args...); err != nil {
if err := execMultiWithContext(s.db, ctx, queryStr, args...); err != nil {
s.logger.Debugf("Failed to run query: %v", err)
msg.SetError(err)
continue
}
} else {
rows, err := s.db.QueryContext(ctx, queryStr, args...)
rows, err := queryMultiWithContext(s.db, ctx, queryStr, args...)
if err != nil {
s.logger.Debugf("Failed to run query: %v", err)
msg.SetError(err)
Expand Down
Loading