Skip to content

Commit

Permalink
Merge pull request #596 from dolthub/zachmu/discard
Browse files Browse the repository at this point in the history
Implemented DISCARD
  • Loading branch information
zachmu authored Aug 12, 2024
2 parents fb68084 + ff0c0a8 commit 8d9ec55
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 10 deletions.
13 changes: 10 additions & 3 deletions server/ast/discard.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,19 @@ import (
vitess "github.com/dolthub/vitess/go/vt/sqlparser"

"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
"github.com/dolthub/doltgresql/server/node"
)

// nodeDiscard handles *tree.Discard nodes.
func nodeDiscard(node *tree.Discard) (vitess.Statement, error) {
if node == nil {
func nodeDiscard(discard *tree.Discard) (vitess.Statement, error) {
if discard == nil {
return nil, nil
}
return nil, fmt.Errorf("DISCARD is not yet supported")
if discard.Mode != tree.DiscardModeAll {
return nil, fmt.Errorf("unhandled DISCARD mode: %v", discard.Mode)
}

return vitess.InjectedStatement{
Statement: node.DiscardStatement{},
}, nil
}
48 changes: 42 additions & 6 deletions server/connection_handler.go
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (
"github.com/dolthub/doltgresql/postgres/parser/parser"
"github.com/dolthub/doltgresql/server/ast"
pgexprs "github.com/dolthub/doltgresql/server/expression"
"github.com/dolthub/doltgresql/server/node"
pgtypes "github.com/dolthub/doltgresql/server/types"
)

Expand Down Expand Up @@ -347,15 +348,29 @@ func (h *ConnectionHandler) handleQuery(message messages.Query) error {
delete(h.preparedStatements, "")
delete(h.portals, "")

// The Deallocate message does not get passed to the engine, since we handle allocation / deallocation of
// prepared statements at this layer
// Certain statement types get handled directly by the handler instead of being passed to the engine
err, handled = h.handleQueryOutsideEngine(query)
if handled {
return err
}

return h.query(query)
}

// handleQueryOutsideEngine handles any queries that should be handled by the handler directly, rather than being
// passed to the engine. Returns true if the query was handled and any error that occurred while doing so.
func (h *ConnectionHandler) handleQueryOutsideEngine(query ConvertedQuery) (error, bool) {
switch stmt := query.AST.(type) {
case *sqlparser.Deallocate:
// TODO: handle ALL keyword
return h.deallocatePreparedStatement(stmt.Name, h.preparedStatements, query, h.Conn())
return h.deallocatePreparedStatement(stmt.Name, h.preparedStatements, query, h.Conn()), true
case sqlparser.InjectedStatement:
switch stmt.Statement.(type) {
case node.DiscardStatement:
return h.discardAll(query, h.Conn()), true
}
}

return h.query(query)
return nil, false
}

// handleParse handles a parse message, returning any error that occurs
Expand Down Expand Up @@ -497,7 +512,13 @@ func (h *ConnectionHandler) handleExecute(message messages.Execute) error {
return connection.Send(h.Conn(), messages.EmptyQueryResponse{})
}

err := h.handler.(mysql.ExtendedHandler).ComExecuteBound(context.Background(), h.mysqlConn, query.String, portalData.BoundPlan, spoolRowsCallback(h.Conn(), &complete, true))
// Certain statement types get handled directly by the handler instead of being passed to the engine
err, handled := h.handleQueryOutsideEngine(query)
if handled {
return err
}

err = h.handler.(mysql.ExtendedHandler).ComExecuteBound(context.Background(), h.mysqlConn, query.String, portalData.BoundPlan, spoolRowsCallback(h.Conn(), &complete, true))
if err != nil {
return err
}
Expand Down Expand Up @@ -979,3 +1000,18 @@ func (h *ConnectionHandler) bindParams(

return plan, fields, err
}

// discardAll handles the DISCARD ALL command
func (h *ConnectionHandler) discardAll(query ConvertedQuery, conn net.Conn) error {
err := h.handler.ComResetConnection(h.mysqlConn)
if err != nil {
return err
}

commandComplete := messages.CommandComplete{
Query: query.String,
Tag: query.StatementTag,
}

return connection.Send(conn, commandComplete)
}
50 changes: 50 additions & 0 deletions server/node/discard.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package node

import (
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/vitess/go/vt/sqlparser"
)

// DiscardStatement is just a marker type, since all functionality is handled by the connection handler,
// rather than the engine. It has to conform to the sql.ExecSourceRel interface to be used in the handler, but this
// functionality is all unused.
type DiscardStatement struct{}

var _ sqlparser.Injectable = DiscardStatement{}
var _ sql.ExecSourceRel = DiscardStatement{}

func (d DiscardStatement) Resolved() bool {
return true
}

func (d DiscardStatement) String() string {
return "DISCARD ALL"
}

func (d DiscardStatement) Schema() sql.Schema {
return nil
}

func (d DiscardStatement) Children() []sql.Node {
return nil
}

func (d DiscardStatement) WithChildren(children ...sql.Node) (sql.Node, error) {
return d, nil
}

func (d DiscardStatement) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
return true
}

func (d DiscardStatement) IsReadOnly() bool {
return true
}

func (d DiscardStatement) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, error) {
panic("DISCARD ALL should be handled by the connection handler")
}

func (d DiscardStatement) WithResolvedChildren(children []any) (any, error) {
return d, nil
}
2 changes: 1 addition & 1 deletion testing/generation/command_docs/output/discard_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import "testing"

func TestDiscard(t *testing.T) {
tests := []QueryParses{
Parses("DISCARD ALL"),
Converts("DISCARD ALL"),
Unimplemented("DISCARD PLANS"),
Unimplemented("DISCARD SEQUENCES"),
Unimplemented("DISCARD TEMPORARY"),
Expand Down
71 changes: 71 additions & 0 deletions testing/go/session_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package _go

import (
"testing"

"github.com/dolthub/go-mysql-server/sql"
)

func TestDiscard(t *testing.T) {
RunScripts(t, []ScriptTest{
{
Name: "Test discard",
SetUpScript: []string{
`CREATE temporary TABLE test (a INT)`,
`insert into test values (1)`,
},
Assertions: []ScriptTestAssertion{
{
Query: "select * from test",
Expected: []sql.Row{
{1},
},
},
{
Query: "DISCARD ALL",
Expected: []sql.Row{},
},
{
Query: "select * from test",
ExpectedErr: "table not found",
},
},
},
{
Name: "Test discard errors",
SetUpScript: []string{
`CREATE temporary TABLE test (a INT)`,
`insert into test values (1)`,
},
Assertions: []ScriptTestAssertion{
{
Query: "DISCARD SEQUENCES",
ExpectedErr: "unimplemented",
},
{
Query: "select * from test",
Expected: []sql.Row{
{1},
},
},
},
},
{
Name: "Test discard in transaction",
SetUpScript: []string{
`CREATE temporary TABLE test (a INT)`,
`insert into test values (1)`,
},
Assertions: []ScriptTestAssertion{
{
Query: "BEGIN",
},
{
Query: "DISCARD ALL",
ExpectedErr: "DISCARD ALL cannot run inside a transaction block",
Skip: true, // not yet implemented
},
},
},
})
}

0 comments on commit 8d9ec55

Please sign in to comment.