From 4b4195629e4818644aa35d9078fc294232371d0a Mon Sep 17 00:00:00 2001 From: dss-vipps Date: Fri, 10 Nov 2023 18:40:25 +0100 Subject: [PATCH] Parse arguments of stored procedures to make them available in DOM --- sqlparser/dom.go | 36 ++++++++- sqlparser/parser.go | 161 ++++++++++++++++++++++++++++++--------- sqlparser/parser_test.go | 69 ++++++++++++++++- 3 files changed, 225 insertions(+), 41 deletions(-) diff --git a/sqlparser/dom.go b/sqlparser/dom.go index cc661f4..3d2c1d3 100644 --- a/sqlparser/dom.go +++ b/sqlparser/dom.go @@ -2,9 +2,10 @@ package sqlparser import ( "fmt" - "gopkg.in/yaml.v3" "io" "strings" + + "gopkg.in/yaml.v3" ) type Unparsed struct { @@ -59,12 +60,35 @@ func (p PosString) String() string { return p.Value } +type Parameter struct { + Start Pos + Stop Pos + + VariableName string + Datatype Type + + // Attributes only relevant for procedures: + DefaultValue Unparsed + IsReadonly bool + IsOutput bool +} + +func (p Parameter) WithoutPos() (result Parameter) { + result = p + result.Start = Pos{} + result.Stop = Pos{} + result.DefaultValue.Start = Pos{} + result.DefaultValue.Stop = Pos{} + return +} + type Create struct { CreateType string // "procedure", "function" or "type" QuotedName PosString // proc/func/type name, including [] Body []Unparsed DependsOn []PosString Docstring []PosString // comment lines before the create statement. Note: this is also part of Body + Parameters []Parameter } func (c Create) DocstringAsString() string { @@ -100,9 +124,14 @@ func (c Create) ParseYamlInDocstring(out any) error { return yaml.Unmarshal([]byte(yamldoc), out) } +// Type indicates the type of a parameter. It can either be a basic type in which case BaseType and Args are set; +// or a table type in which case TableTypeSchema and TableTypeName is set. type Type struct { BaseType string Args []string + + TableTypeSchema string + TableTypeName string } func (t Type) String() (result string) { @@ -165,11 +194,16 @@ func (c Create) WithoutPos() Create { for _, x := range c.Body { body = append(body, x.WithoutPos()) } + var parameters []Parameter + for _, x := range c.Parameters { + parameters = append(parameters, x.WithoutPos()) + } return Create{ CreateType: c.CreateType, QuotedName: c.QuotedName, DependsOn: c.DependsOn, Body: body, + Parameters: parameters, } } diff --git a/sqlparser/parser.go b/sqlparser/parser.go index 304ded4..d716af3 100644 --- a/sqlparser/parser.go +++ b/sqlparser/parser.go @@ -39,27 +39,6 @@ func NextTokenCopyingWhitespace(s *Scanner, target *[]Unparsed) { } -// AdvanceAndCopy is like NextToken; advance to next token that is not whitespace and return -// Note: The 'go' and EOF tokens are *not* copied -func AdvanceAndCopy(s *Scanner, target *[]Unparsed) { - for { - tt := s.NextToken() - switch tt { - case EOFToken, BatchSeparatorToken: - // do not copy - return - case WhitespaceToken, MultilineCommentToken, SinglelineCommentToken: - // copy, and loop around - CopyToken(s, target) - continue - default: - // copy, and return - CopyToken(s, target) - return - } - } -} - func CreateUnparsed(s *Scanner) Unparsed { return Unparsed{ Type: s.TokenType(), @@ -80,27 +59,29 @@ func (d *Document) unexpectedTokenError(s *Scanner) { d.addError(s, "Unexpected: "+s.Token()) } -func (doc *Document) parseTypeExpression(s *Scanner) (t Type) { +func (doc *Document) parseTypeExpression(s *Scanner, allowTableTypes bool, target *[]Unparsed) (result Type) { parseArgs := func() { // parses *after* the initial (; consumes trailing ) for { + CopyToken(s, target) switch { case s.TokenType() == NumberToken: - t.Args = append(t.Args, s.Token()) + result.Args = append(result.Args, s.Token()) case s.TokenType() == UnquotedIdentifierToken && s.TokenLower() == "max": - t.Args = append(t.Args, "max") + result.Args = append(result.Args, "max") default: doc.unexpectedTokenError(s) doc.recoverToNextStatement(s) return } - s.NextNonWhitespaceCommentToken() + NextTokenCopyingWhitespace(s, target) + CopyToken(s, target) switch { case s.TokenType() == CommaToken: - s.NextNonWhitespaceCommentToken() + NextTokenCopyingWhitespace(s, target) continue case s.TokenType() == RightParenToken: - s.NextNonWhitespaceCommentToken() + NextTokenCopyingWhitespace(s, target) return default: doc.unexpectedTokenError(s) @@ -110,14 +91,37 @@ func (doc *Document) parseTypeExpression(s *Scanner) (t Type) { } } - if s.TokenType() != UnquotedIdentifierToken { - panic("assertion failed, bug in caller") + if s.TokenType() != UnquotedIdentifierToken && s.TokenType() != QuotedIdentifierToken { + doc.addError(s, "expected type, got: "+s.Token()) + return } - t.BaseType = s.Token() - s.NextNonWhitespaceCommentToken() - if s.TokenType() == LeftParenToken { - s.NextNonWhitespaceCommentToken() - parseArgs() + // We will assume that a table type will have a schema name; types in 'default schema' we just don'result support. + // So an identifier followed by a `.` indicates table type. + firstToken := s.Token() + CopyToken(s, target) + NextTokenCopyingWhitespace(s, target) + + if s.TokenType() == DotToken { + if !allowTableTypes { + doc.addError(s, "expected basic type (no table types), got: .") + return + } + CopyToken(s, target) + NextTokenCopyingWhitespace(s, target) + + // parse a table type + result.TableTypeSchema = firstToken + result.TableTypeName = s.Token() + CopyToken(s, target) + NextTokenCopyingWhitespace(s, target) + } else { + // parse a basic type + result.BaseType = firstToken + if s.TokenType() == LeftParenToken { + CopyToken(s, target) + NextTokenCopyingWhitespace(s, target) + parseArgs() + } } return } @@ -146,7 +150,10 @@ loop: doc.addError(s, "sqlcode constants needs a type declared explicitly") s.NextNonWhitespaceCommentToken() case UnquotedIdentifierToken: - variableType = doc.parseTypeExpression(s) + // parseTypeExpression is also used in a context where we are copying Unparsed nodes into stored procedure body; + // to use it here too just use a dummy output + var dummy []Unparsed + variableType = doc.parseTypeExpression(s, false, &dummy) } if s.TokenType() != EqualToken { @@ -366,6 +373,81 @@ func (d *Document) parseCodeschemaName(s *Scanner, target *[]Unparsed) PosString } } +func (d *Document) parseArgumentList(s *Scanner, target *[]Unparsed) (result []Parameter) { + if s.TokenType() != LeftParenToken { + panic("assertion failed: should only be called on the ( position") + } + // Copy the `(` + CopyToken(s, target) + NextTokenCopyingWhitespace(s, target) + + for s.TokenType() != RightParenToken { + var parameter Parameter + + // `@parameter` + if s.TokenType() != VariableIdentifierToken { + d.addError(s, "expected a parameter name starting with @, got: "+s.Token()) + return + } + + parameter.Start = s.Start() + parameter.VariableName = s.Token() + CopyToken(s, target) + NextTokenCopyingWhitespace(s, target) + + // datatype. This can either be a table type or a basic type... + parameter.Datatype = d.parseTypeExpression(s, true, target) + + // Do we have a default value? + if s.TokenType() == EqualToken { + // Default value. AFAICT this can only be a single literal, not a full expression + CopyToken(s, target) + NextTokenCopyingWhitespace(s, target) + switch s.TokenType() { + case NVarcharLiteralToken, VarcharLiteralToken, NumberToken: + parameter.DefaultValue = CreateUnparsed(s) + default: + d.addError(s, "expecting default value literal, got: "+s.Token()) + return + } + CopyToken(s, target) + NextTokenCopyingWhitespace(s, target) + } + + // Do we have an option? This can be *either* readonly or output, both would not be relevant on the same + if s.TokenType() == UnquotedIdentifierToken { + // readonly or output + switch s.TokenLower() { + case "readonly": + parameter.IsReadonly = true + case "output": + parameter.IsOutput = true + default: + d.addError(s, "parsing argument list, unexpected: "+s.Token()) + return + } + CopyToken(s, target) + NextTokenCopyingWhitespace(s, target) + } + + // At this point we should have a comma or a right paren... + switch s.TokenType() { + case CommaToken: + CopyToken(s, target) + NextTokenCopyingWhitespace(s, target) + // Trailing comma won't be an error in this parser; but SQL will complain later.. + case RightParenToken: + // fall through to break out of loop + default: + d.addError(s, "parsing argument list, unexpected: "+s.Token()) + return + } + + result = append(result, parameter) + } + return +} + // parseCreate parses anything that starts with "create". Position is // *on* the create token. // At this stage in sqlcode parser development we're only interested @@ -411,8 +493,13 @@ func (d *Document) parseCreate(s *Scanner, createCountInBatch int) (result Creat return } - // We have matched "create [code]."; at this - // point we copy the rest until the batch ends; *but* track dependencies + // We have matched "create [code].". Try to parse + // parameters. Procedures do not need an argument list, so only do this if we see () + if createType == "procedure" && s.tokenType == LeftParenToken { + result.Parameters = d.parseArgumentList(s, &result.Body) + } + + // At this point we copy the rest until the batch ends; *but* track dependencies // + some other details mentioned below tailloop: diff --git a/sqlparser/parser_test.go b/sqlparser/parser_test.go index af5c29f..9aaa4de 100644 --- a/sqlparser/parser_test.go +++ b/sqlparser/parser_test.go @@ -1,10 +1,11 @@ package sqlparser import ( - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "strings" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestParserSmokeTest(t *testing.T) { @@ -271,7 +272,6 @@ create procedure [code].FirstProc as table (x int) assert.Equal(t, emsg, doc.Errors[0].Message) } - func TestGoWithoutNewline(t *testing.T) { doc := ParseString("test.sql", ` create procedure [code].Foo() as begin @@ -352,3 +352,66 @@ create procedure [code].Foo as begin end err.Error()) } + +func TestProcedureArgs(t *testing.T) { + doc := parseAndVerifyCreate(t, "test.sql", ` +create procedure [code].Foo + ( + @a bigint, + + @b varchar(max) = N'asdfas + +lkjlkjlkjasdf' + + output, + @c [code].[something:asf asdf -- as +df/ MyTableType] readonly, + @d numeric(1,2),@e tinyint +) as begin end +`) + create := doc.Creates[0].WithoutPos() + + assert.Equal(t, []Parameter{ + { + VariableName: "@a", + Datatype: Type{BaseType: "bigint"}, + }, + { + VariableName: "@b", + Datatype: Type{BaseType: "varchar", Args: []string{"max"}}, + DefaultValue: Unparsed{Type: NVarcharLiteralToken, RawValue: "N'asdfas\n\nlkjlkjlkjasdf'"}, + IsOutput: true, + }, + { + VariableName: "@c", + Datatype: Type{ + TableTypeSchema: "[code]", + TableTypeName: "[something:asf asdf -- as\ndf/ MyTableType]", + }, + IsReadonly: true, + }, + { + VariableName: "@d", + Datatype: Type{ + BaseType: "numeric", + Args: []string{"1", "2"}, + }, + }, + { + VariableName: "@e", + Datatype: Type{ + BaseType: "tinyint", + }, + }, + }, create.Parameters) + +} + +// parseAndVerifyCreate expects to parse a single `create` statement, and verifies that serializing +// it back produces the same string. +func parseAndVerifyCreate(t *testing.T, filename FileRef, createStatement string) Document { + doc := ParseString(filename, createStatement) + require.Equal(t, 1, len(doc.Creates)) + require.Equal(t, strings.TrimLeft(createStatement, "\n "), doc.Creates[0].String()) + return doc +}