Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support CREATE TABLE column DEFAULT value expressions #211

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions internal/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ func newAnalyzerOptions() (*zetasql.AnalyzerOptions, error) {
zetasql.FeatureV11WithOnSubquery,
zetasql.FeatureV13Pivot,
zetasql.FeatureV13Unpivot,
zetasql.FeatureV13ColumnDefaultValue,
})
langOpt.SetSupportedStatementKinds([]ast.Kind{
ast.BeginStmt,
Expand Down Expand Up @@ -294,8 +295,11 @@ func (a *Analyzer) newStmtAction(ctx context.Context, query string, args []drive
return nil, fmt.Errorf("unsupported stmt %s", node.DebugString())
}

func (a *Analyzer) newCreateTableStmtAction(_ context.Context, query string, args []driver.NamedValue, node *ast.CreateTableStmtNode) (*CreateTableStmtAction, error) {
spec := newTableSpec(a.namePath, node)
func (a *Analyzer) newCreateTableStmtAction(ctx context.Context, query string, args []driver.NamedValue, node *ast.CreateTableStmtNode) (*CreateTableStmtAction, error) {
spec, err := newTableSpec(ctx, a.namePath, node)
if err != nil {
return nil, err
}
params := getParamsFromNode(node)
queryArgs, err := getArgsFromParams(args, params)
if err != nil {
Expand All @@ -315,7 +319,10 @@ func (a *Analyzer) newCreateTableAsSelectStmtAction(ctx context.Context, _ strin
if err != nil {
return nil, err
}
spec := newTableAsSelectSpec(a.namePath, query, node)
spec, err := newTableAsSelectSpec(ctx, a.namePath, query, node)
if err != nil {
return nil, err
}
params := getParamsFromNode(node)
queryArgs, err := getArgsFromParams(args, params)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/formatter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1215,7 +1215,7 @@ func (n *GeneratedColumnInfoNode) FormatSQL(ctx context.Context) (string, error)
}

func (n *ColumnDefaultValueNode) FormatSQL(ctx context.Context) (string, error) {
return "", nil
return newNode(n.node.Expression()).FormatSQL(ctx)
}

func (n *ColumnDefinitionNode) FormatSQL(ctx context.Context) (string, error) {
Expand Down
63 changes: 49 additions & 14 deletions internal/spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"reflect"
"regexp"
"strings"
"time"

Expand Down Expand Up @@ -170,9 +171,10 @@ func viewSQLiteSchema(s *TableSpec) string {
}

type ColumnSpec struct {
Name string `json:"name"`
Type *Type `json:"type"`
IsNotNull bool `json:"isNotNull"`
Name string `json:"name"`
DefaultValue string `json:"defaultValue"`
Type *Type `json:"type"`
IsNotNull bool `json:"isNotNull"`
}

type Type struct {
Expand Down Expand Up @@ -274,6 +276,14 @@ func (t *Type) FormatType() string {
return types.TypeKind(t.Kind).String()
}

var matchDoubleQuotesRegex = regexp.MustCompile(`"(.*?)"`)

func replaceDoubleQuotedStringsWithSingleQuotes(s string) string {
return matchDoubleQuotesRegex.ReplaceAllStringFunc(s, func(value string) string {
return `'` + strings.ReplaceAll(value[1:len(value)-1], `'`, `\'`) + `'`
})
}

func (s *ColumnSpec) SQLiteSchema() string {
var typ string
switch types.TypeKind(s.Type.Kind) {
Expand Down Expand Up @@ -324,6 +334,13 @@ func (s *ColumnSpec) SQLiteSchema() string {
if s.IsNotNull {
schema += " NOT NULL"
}
if s.DefaultValue != "" {
// Default value expressions must be considered constant, so we must replace double quotes with single quotes.
// For the purposes of the DEFAULT clause, an expression is considered constant if it contains no sub-queries,
// column or table references, bound parameters, or string literals enclosed in double-quotes instead of single-quotes.
// https://www.sqlite.org/lang_createtable.html#the_default_clause
schema += fmt.Sprintf(" DEFAULT (%s)", replaceDoubleQuotedStringsWithSingleQuotes(s.DefaultValue))
}
return schema
}

Expand Down Expand Up @@ -471,7 +488,7 @@ func newTemplatedFunctionSpec(ctx context.Context, namePath *NamePath, stmt *ast
}, nil
}

func newColumnsFromDef(def []*ast.ColumnDefinitionNode) []*ColumnSpec {
func newColumnsFromDef(ctx context.Context, def []*ast.ColumnDefinitionNode) ([]*ColumnSpec, error) {
columns := []*ColumnSpec{}
for _, columnNode := range def {
annotation := columnNode.Annotations()
Expand All @@ -484,13 +501,23 @@ func newColumnsFromDef(def []*ast.ColumnDefinitionNode) []*ColumnSpec {
}
isNotNull = annotation.NotNull()
}
defaultValue := columnNode.DefaultValue()
var defaultValueSQL string
if defaultValue != nil {
var err error
defaultValueSQL, err = newNode(defaultValue).FormatSQL(ctx)
if err != nil {
return nil, err
}
}
columns = append(columns, &ColumnSpec{
Name: columnNode.Name(),
Type: newType(columnNode.Type()),
IsNotNull: isNotNull,
Name: columnNode.Name(),
Type: newType(columnNode.Type()),
IsNotNull: isNotNull,
DefaultValue: defaultValueSQL,
})
}
return columns
return columns, nil
}

func newColumnsFromOutputColumns(def []*ast.OutputColumnNode) []*ColumnSpec {
Expand All @@ -513,17 +540,21 @@ func newPrimaryKey(key *ast.PrimaryKeyNode) []string {
return key.ColumnNameList()
}

func newTableSpec(namePath *NamePath, stmt *ast.CreateTableStmtNode) *TableSpec {
func newTableSpec(ctx context.Context, namePath *NamePath, stmt *ast.CreateTableStmtNode) (*TableSpec, error) {
now := time.Now()
columns, err := newColumnsFromDef(ctx, stmt.ColumnDefinitionList())
if err != nil {
return nil, err
}
return &TableSpec{
IsTemp: stmt.CreateScope() == ast.CreateScopeTemp,
NamePath: namePath.mergePath(stmt.NamePath()),
Columns: newColumnsFromDef(stmt.ColumnDefinitionList()),
Columns: columns,
PrimaryKey: newPrimaryKey(stmt.PrimaryKey()),
CreateMode: stmt.CreateMode(),
UpdatedAt: now,
CreatedAt: now,
}
}, nil
}

func newTableAsViewSpec(namePath *NamePath, query string, stmt *ast.CreateViewStmtNode) *TableSpec {
Expand All @@ -550,7 +581,7 @@ func newTableAsViewSpec(namePath *NamePath, query string, stmt *ast.CreateViewSt
}
}

func newTableAsSelectSpec(namePath *NamePath, query string, stmt *ast.CreateTableAsSelectStmtNode) *TableSpec {
func newTableAsSelectSpec(ctx context.Context, namePath *NamePath, query string, stmt *ast.CreateTableAsSelectStmtNode) (*TableSpec, error) {
var outputColumns []string
for _, column := range stmt.OutputColumnList() {
colName := column.Name()
Expand All @@ -562,16 +593,20 @@ func newTableAsSelectSpec(namePath *NamePath, query string, stmt *ast.CreateTabl
)
}
now := time.Now()
columns, err := newColumnsFromDef(ctx, stmt.ColumnDefinitionList())
if err != nil {
return nil, err
}
return &TableSpec{
IsTemp: stmt.CreateScope() == ast.CreateScopeTemp,
NamePath: namePath.mergePath(stmt.NamePath()),
Columns: newColumnsFromDef(stmt.ColumnDefinitionList()),
Columns: columns,
PrimaryKey: newPrimaryKey(stmt.PrimaryKey()),
CreateMode: stmt.CreateMode(),
Query: fmt.Sprintf("SELECT %s FROM (%s)", strings.Join(outputColumns, ","), query),
UpdatedAt: now,
CreatedAt: now,
}
}, nil
}

func newType(t types.Type) *Type {
Expand Down
17 changes: 17 additions & 0 deletions query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5872,6 +5872,23 @@ SELECT c1 * ? * ? FROM t1;
args: []interface{}{int64(1), int64(2), int64(3)},
expectedRows: [][]interface{}{{int64(6)}},
},
{
name: "table default value",
query: `
CREATE TEMP TABLE t1 (
id INT64,
name STRING DEFAULT LOWER("DEFAULT EXPRESSION"),
ts DATE DEFAULT DATE "2024-04-14",
test_ts_complex INT64 DEFAULT EXTRACT(YEAR FROM CURRENT_TIMESTAMP()) - MOD(EXTRACT(YEAR FROM CURRENT_TIMESTAMP()), 2000),
state STRING DEFAULT "' escape test"
);
INSERT INTO t1 (id) VALUES (1);
SELECT * FROM t1;
`,
expectedRows: [][]interface{}{
{int64(1), "default expression", "2024-04-14", int64(2000), "' escape test"},
},
},
} {
test := test
t.Run(test.name, func(t *testing.T) {
Expand Down
Loading