diff --git a/internal/analyzer.go b/internal/analyzer.go index fa0aac6..a845c67 100644 --- a/internal/analyzer.go +++ b/internal/analyzer.go @@ -75,6 +75,7 @@ func newAnalyzerOptions() (*zetasql.AnalyzerOptions, error) { zetasql.FeatureV11WithOnSubquery, zetasql.FeatureV13Pivot, zetasql.FeatureV13Unpivot, + zetasql.FeatureV13ColumnDefaultValue, }) langOpt.SetSupportedStatementKinds([]ast.Kind{ ast.BeginStmt, @@ -294,8 +295,11 @@ func (a *Analyzer) newStmtAction(ctx context.Context, query string, args []drive return nil, fmt.Errorf("unsupported stmt %s", node.DebugString()) } -func (a *Analyzer) newCreateTableStmtAction(_ context.Context, query string, args []driver.NamedValue, node *ast.CreateTableStmtNode) (*CreateTableStmtAction, error) { - spec := newTableSpec(a.namePath, node) +func (a *Analyzer) newCreateTableStmtAction(ctx context.Context, query string, args []driver.NamedValue, node *ast.CreateTableStmtNode) (*CreateTableStmtAction, error) { + spec, err := newTableSpec(ctx, a.namePath, node) + if err != nil { + return nil, err + } params := getParamsFromNode(node) queryArgs, err := getArgsFromParams(args, params) if err != nil { @@ -315,7 +319,10 @@ func (a *Analyzer) newCreateTableAsSelectStmtAction(ctx context.Context, _ strin if err != nil { return nil, err } - spec := newTableAsSelectSpec(a.namePath, query, node) + spec, err := newTableAsSelectSpec(ctx, a.namePath, query, node) + if err != nil { + return nil, err + } params := getParamsFromNode(node) queryArgs, err := getArgsFromParams(args, params) if err != nil { diff --git a/internal/formatter.go b/internal/formatter.go index 4b0dba4..3bc1980 100644 --- a/internal/formatter.go +++ b/internal/formatter.go @@ -1215,7 +1215,7 @@ func (n *GeneratedColumnInfoNode) FormatSQL(ctx context.Context) (string, error) } func (n *ColumnDefaultValueNode) FormatSQL(ctx context.Context) (string, error) { - return "", nil + return newNode(n.node.Expression()).FormatSQL(ctx) } func (n *ColumnDefinitionNode) FormatSQL(ctx context.Context) (string, error) { diff --git a/internal/spec.go b/internal/spec.go index fb1c831..544cda4 100644 --- a/internal/spec.go +++ b/internal/spec.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "reflect" + "regexp" "strings" "time" @@ -170,9 +171,10 @@ func viewSQLiteSchema(s *TableSpec) string { } type ColumnSpec struct { - Name string `json:"name"` - Type *Type `json:"type"` - IsNotNull bool `json:"isNotNull"` + Name string `json:"name"` + DefaultValue string `json:"defaultValue"` + Type *Type `json:"type"` + IsNotNull bool `json:"isNotNull"` } type Type struct { @@ -274,6 +276,14 @@ func (t *Type) FormatType() string { return types.TypeKind(t.Kind).String() } +var matchDoubleQuotesRegex = regexp.MustCompile(`"(.*?)"`) + +func replaceDoubleQuotedStringsWithSingleQuotes(s string) string { + return matchDoubleQuotesRegex.ReplaceAllStringFunc(s, func(value string) string { + return `'` + strings.ReplaceAll(value[1:len(value)-1], `'`, `\'`) + `'` + }) +} + func (s *ColumnSpec) SQLiteSchema() string { var typ string switch types.TypeKind(s.Type.Kind) { @@ -324,6 +334,13 @@ func (s *ColumnSpec) SQLiteSchema() string { if s.IsNotNull { schema += " NOT NULL" } + if s.DefaultValue != "" { + // Default value expressions must be considered constant, so we must replace double quotes with single quotes. + // For the purposes of the DEFAULT clause, an expression is considered constant if it contains no sub-queries, + // column or table references, bound parameters, or string literals enclosed in double-quotes instead of single-quotes. + // https://www.sqlite.org/lang_createtable.html#the_default_clause + schema += fmt.Sprintf(" DEFAULT (%s)", replaceDoubleQuotedStringsWithSingleQuotes(s.DefaultValue)) + } return schema } @@ -471,7 +488,7 @@ func newTemplatedFunctionSpec(ctx context.Context, namePath *NamePath, stmt *ast }, nil } -func newColumnsFromDef(def []*ast.ColumnDefinitionNode) []*ColumnSpec { +func newColumnsFromDef(ctx context.Context, def []*ast.ColumnDefinitionNode) ([]*ColumnSpec, error) { columns := []*ColumnSpec{} for _, columnNode := range def { annotation := columnNode.Annotations() @@ -484,13 +501,23 @@ func newColumnsFromDef(def []*ast.ColumnDefinitionNode) []*ColumnSpec { } isNotNull = annotation.NotNull() } + defaultValue := columnNode.DefaultValue() + var defaultValueSQL string + if defaultValue != nil { + var err error + defaultValueSQL, err = newNode(defaultValue).FormatSQL(ctx) + if err != nil { + return nil, err + } + } columns = append(columns, &ColumnSpec{ - Name: columnNode.Name(), - Type: newType(columnNode.Type()), - IsNotNull: isNotNull, + Name: columnNode.Name(), + Type: newType(columnNode.Type()), + IsNotNull: isNotNull, + DefaultValue: defaultValueSQL, }) } - return columns + return columns, nil } func newColumnsFromOutputColumns(def []*ast.OutputColumnNode) []*ColumnSpec { @@ -513,17 +540,21 @@ func newPrimaryKey(key *ast.PrimaryKeyNode) []string { return key.ColumnNameList() } -func newTableSpec(namePath *NamePath, stmt *ast.CreateTableStmtNode) *TableSpec { +func newTableSpec(ctx context.Context, namePath *NamePath, stmt *ast.CreateTableStmtNode) (*TableSpec, error) { now := time.Now() + columns, err := newColumnsFromDef(ctx, stmt.ColumnDefinitionList()) + if err != nil { + return nil, err + } return &TableSpec{ IsTemp: stmt.CreateScope() == ast.CreateScopeTemp, NamePath: namePath.mergePath(stmt.NamePath()), - Columns: newColumnsFromDef(stmt.ColumnDefinitionList()), + Columns: columns, PrimaryKey: newPrimaryKey(stmt.PrimaryKey()), CreateMode: stmt.CreateMode(), UpdatedAt: now, CreatedAt: now, - } + }, nil } func newTableAsViewSpec(namePath *NamePath, query string, stmt *ast.CreateViewStmtNode) *TableSpec { @@ -550,7 +581,7 @@ func newTableAsViewSpec(namePath *NamePath, query string, stmt *ast.CreateViewSt } } -func newTableAsSelectSpec(namePath *NamePath, query string, stmt *ast.CreateTableAsSelectStmtNode) *TableSpec { +func newTableAsSelectSpec(ctx context.Context, namePath *NamePath, query string, stmt *ast.CreateTableAsSelectStmtNode) (*TableSpec, error) { var outputColumns []string for _, column := range stmt.OutputColumnList() { colName := column.Name() @@ -562,16 +593,20 @@ func newTableAsSelectSpec(namePath *NamePath, query string, stmt *ast.CreateTabl ) } now := time.Now() + columns, err := newColumnsFromDef(ctx, stmt.ColumnDefinitionList()) + if err != nil { + return nil, err + } return &TableSpec{ IsTemp: stmt.CreateScope() == ast.CreateScopeTemp, NamePath: namePath.mergePath(stmt.NamePath()), - Columns: newColumnsFromDef(stmt.ColumnDefinitionList()), + Columns: columns, PrimaryKey: newPrimaryKey(stmt.PrimaryKey()), CreateMode: stmt.CreateMode(), Query: fmt.Sprintf("SELECT %s FROM (%s)", strings.Join(outputColumns, ","), query), UpdatedAt: now, CreatedAt: now, - } + }, nil } func newType(t types.Type) *Type { diff --git a/query_test.go b/query_test.go index de6c697..c6076ff 100644 --- a/query_test.go +++ b/query_test.go @@ -5872,6 +5872,23 @@ SELECT c1 * ? * ? FROM t1; args: []interface{}{int64(1), int64(2), int64(3)}, expectedRows: [][]interface{}{{int64(6)}}, }, + { + name: "table default value", + query: ` +CREATE TEMP TABLE t1 ( + id INT64, + name STRING DEFAULT LOWER("DEFAULT EXPRESSION"), + ts DATE DEFAULT DATE "2024-04-14", + test_ts_complex INT64 DEFAULT EXTRACT(YEAR FROM CURRENT_TIMESTAMP()) - MOD(EXTRACT(YEAR FROM CURRENT_TIMESTAMP()), 2000), + state STRING DEFAULT "' escape test" +); + INSERT INTO t1 (id) VALUES (1); +SELECT * FROM t1; +`, + expectedRows: [][]interface{}{ + {int64(1), "default expression", "2024-04-14", int64(2000), "' escape test"}, + }, + }, } { test := test t.Run(test.name, func(t *testing.T) {