diff --git a/driver_test.go b/driver_test.go index 30d8030..ec83b6c 100644 --- a/driver_test.go +++ b/driver_test.go @@ -10,24 +10,196 @@ import ( zetasqlite "github.com/goccy/go-zetasqlite" ) +func TestDriverAlter(t *testing.T) { + db, err := sql.Open("zetasqlite", ":memory:") + if err != nil { + t.Fatal(err) + } + if _, err := db.Exec(` + CREATE TABLE IF NOT EXISTS Artists ( + SingerId INT64 NOT NULL, + FirstName STRING(1024), + LastName STRING(1024), + SingerInfo BYTES(MAX) + ) + `); err != nil { + t.Fatal(err) + } + if _, err := db.Exec(`INSERT Artists (SingerId, FirstName, LastName) VALUES (1, 'John', 'Titor')`); err != nil { + t.Fatal(err) + } + row := db.QueryRow(`SELECT SingerId, FirstName, LastName FROM Artists WHERE SingerId = @id`, 1) + if row.Err() != nil { + t.Fatal(row.Err()) + } + var ( + singerID int64 + firstName string + lastName string + ) + if err := row.Scan(&singerID, &firstName, &lastName); err != nil { + t.Fatal(err) + } + if singerID != 1 || firstName != "John" || lastName != "Titor" { + t.Fatalf("failed to find row %v %v %v", singerID, firstName, lastName) + } + + if _, err := db.Exec(` + CREATE VIEW IF NOT EXISTS + SingerNames AS SELECT FirstName || ' ' || LastName AS Name + FROM Artists + `); err != nil { + t.Fatal(err) + } + + viewRow := db.QueryRow(`SELECT Name FROM SingerNames LIMIT 1`) + if viewRow.Err() != nil { + t.Fatal(viewRow.Err()) + } + + var name string + + if err := viewRow.Scan(&name); err != nil { + t.Fatal(err) + } + if name != "John Titor" { + t.Fatalf("failed to find view row") + } + + // Test ALTER TABLE SET OPTIONS + if _, err := db.Exec(`ALTER TABLE Artists SET OPTIONS (description="Famous Artists")`); err != nil { + t.Fatal(err) + } + + // Test ALTER TABLE ADD COLUMN + if _, err := db.Exec(`ALTER TABLE Artists ADD COLUMN Age INT64, ADD COLUMN IsSingle BOOL`); err != nil { + t.Fatal(err) + } + + // Verify the changes + row = db.QueryRow(` + SELECT SingerId, FirstName, LastName, Age, IsSingle + FROM Artists + WHERE SingerId = @id`, + 1, + ) + if row.Err() != nil { + t.Fatal(row.Err()) + } + + var age sql.NullInt64 + var isSingle sql.NullBool + if err := row.Scan(&singerID, &firstName, &lastName, &age, &isSingle); err != nil { + t.Fatal(err) + } + if singerID != 1 || firstName != "John" || lastName != "Titor" || age.Valid || isSingle.Valid { + t.Fatalf("failed to find row after ALTER TABLE statements") + } + + if _, err := db.Exec(` + INSERT Artists (SingerId, FirstName, LastName, Age, IsSingle) + VALUES (2, 'Mike', 'Bit', 11, TRUE) + `); err != nil { + t.Fatal(err) + } + row = db.QueryRow(` + SELECT SingerId, FirstName, LastName, Age, isSingle + FROM Artists + WHERE SingerId = @id AND isSingle IS NOT NULL`, + 2, + ) + if row.Err() != nil { + t.Fatal(row.Err()) + } + if err := row.Scan(&singerID, &firstName, &lastName, &age, &isSingle); err != nil { + t.Fatal(err) + } + if singerID != 2 || firstName != "Mike" || lastName != "Bit" || age.Int64 != 11 || isSingle.Bool != true { + t.Fatalf("Failed to find row %v %v %v %v %v", singerID, firstName, lastName, age, isSingle) + } + + if _, err := db.Exec(` + ALTER TABLE Artists + ADD COLUMN Nationality STRING + `); err != nil { + t.Fatal(err) + } + + if _, err := db.Exec(` + ALTER TABLE Artists + ALTER COLUMN Nationality SET DEFAULT 'Unknown' + `); err != nil { + t.Fatal(err) + } + + // Verify the changes + row = db.QueryRow(` + SELECT SingerID, FirstName, LastName, Age, IsSingle, Nationality + FROM Artists + WHERE SingerId = @id`, + 2, + ) + if row.Err() != nil { + t.Fatal(row.Err()) + } + + var nationality sql.NullString + if err := row.Scan(&singerID, &firstName, &lastName, &age, &isSingle, &nationality); err != nil { + t.Fatal(err) + } + + if singerID != 2 || firstName != "Mike" || lastName != "Bit" || age.Int64 != 11 || isSingle.Bool != true || nationality.Valid { + t.Fatalf("failed to find row after multi-action ALTER TABLE statement") + } + + if _, err := db.Exec(` + INSERT Artists (SingerId, FirstName, LastName, Age, IsSingle) + VALUES (3, 'Mark', 'Byte', 12, FALSE) + `); err != nil { + t.Fatal(err) + } + + // Verify the changes + row = db.QueryRow(` + SELECT SingerID, FirstName, LastName, Age, IsSingle, Nationality + FROM Artists + WHERE SingerId = @id`, + 3, + ) + if row.Err() != nil { + t.Fatal(row.Err()) + } + + if err := row.Scan(&singerID, &firstName, &lastName, &age, &isSingle, &nationality); err != nil { + t.Fatal(err) + } + if singerID != 3 || firstName != "Mark" || lastName != "Byte" || age.Int64 != 12 || isSingle.Bool != false || nationality.String != "Unknown" { + t.Fatalf("failed to find row after multi-action ALTER TABLE statement") + } +} + func TestDriver(t *testing.T) { db, err := sql.Open("zetasqlite", ":memory:") if err != nil { t.Fatal(err) } if _, err := db.Exec(` -CREATE TABLE IF NOT EXISTS Singers ( - SingerId INT64 NOT NULL, - FirstName STRING(1024), - LastName STRING(1024), - SingerInfo BYTES(MAX) -)`); err != nil { + CREATE TABLE IF NOT EXISTS Singers ( + SingerId INT64 NOT NULL, + FirstName STRING(1024), + LastName STRING(1024), + SingerInfo BYTES(MAX) + ) + `); err != nil { t.Fatal(err) } - if _, err := db.Exec(`INSERT Singers (SingerId, FirstName, LastName) VALUES (1, 'John', 'Titor')`); err != nil { + if _, err := db.Exec(` + INSERT Singers (SingerId, FirstName, LastName) + VALUES (1, 'John', 'Titor') + `); err != nil { t.Fatal(err) } - row := db.QueryRow("SELECT SingerID, FirstName, LastName FROM Singers WHERE SingerId = @id", 1) + row := db.QueryRow(`SELECT SingerID, FirstName, LastName FROM Singers WHERE SingerId = @id`, 1) if row.Err() != nil { t.Fatal(row.Err()) } @@ -43,7 +215,10 @@ CREATE TABLE IF NOT EXISTS Singers ( t.Fatalf("failed to find row %v %v %v", singerID, firstName, lastName) } if _, err := db.Exec(` -CREATE VIEW IF NOT EXISTS SingerNames AS SELECT FirstName || ' ' || LastName AS Name FROM Singers`); err != nil { + CREATE VIEW IF NOT EXISTS + SingerNames AS SELECT FirstName || ' ' || LastName AS Name + FROM Singers + `); err != nil { t.Fatal(err) } diff --git a/internal/analyzer.go b/internal/analyzer.go index fa0aac6..e21d891 100644 --- a/internal/analyzer.go +++ b/internal/analyzer.go @@ -75,6 +75,7 @@ func newAnalyzerOptions() (*zetasql.AnalyzerOptions, error) { zetasql.FeatureV11WithOnSubquery, zetasql.FeatureV13Pivot, zetasql.FeatureV13Unpivot, + zetasql.FeatureV13ColumnDefaultValue, }) langOpt.SetSupportedStatementKinds([]ast.Kind{ ast.BeginStmt, @@ -87,6 +88,7 @@ func newAnalyzerOptions() (*zetasql.AnalyzerOptions, error) { ast.DropStmt, ast.TruncateStmt, ast.CreateTableStmt, + ast.AlterTableStmt, ast.CreateTableAsSelectStmt, ast.CreateProcedureStmt, ast.CreateFunctionStmt, @@ -290,10 +292,32 @@ func (a *Analyzer) newStmtAction(ctx context.Context, query string, args []drive return a.newBeginStmtAction(ctx, query, args, node) case ast.CommitStmt: return a.newCommitStmtAction(ctx, query, args, node) + case ast.AlterTableStmt: + return a.alterTableStmtAction(ctx, query, args, node.(*ast.AlterTableStmtNode)) } return nil, fmt.Errorf("unsupported stmt %s", node.DebugString()) } +func (a *Analyzer) alterTableStmtAction(ctx context.Context, query string, args []driver.NamedValue, node *ast.AlterTableStmtNode) (*AlterTableStmtAction, error) { + spec, err := newAlterSpec(ctx, a.namePath, node) + if err != nil { + return nil, err + } + params := getParamsFromNode(node) + queryArgs, err := getArgsFromParams(args, params) + if err != nil { + return nil, err + } + return &AlterTableStmtAction{ + query: query, + spec: spec, + node: node, + args: queryArgs, + rawArgs: args, + catalog: a.catalog, + }, nil +} + func (a *Analyzer) newCreateTableStmtAction(_ context.Context, query string, args []driver.NamedValue, node *ast.CreateTableStmtNode) (*CreateTableStmtAction, error) { spec := newTableSpec(a.namePath, node) params := getParamsFromNode(node) diff --git a/internal/catalog.go b/internal/catalog.go index c626dd1..8bec441 100644 --- a/internal/catalog.go +++ b/internal/catalog.go @@ -438,6 +438,42 @@ func (c *Catalog) addTableSpec(spec *TableSpec) error { return nil } +func (c *Catalog) modifyTableSpec(spec *AlterTableSpec) error { + tableName := spec.TableName() + foundSpecToUpdate, exists := c.tableMap[tableName] + + if !exists { + return fmt.Errorf("table %s does not exist", tableName) + } + + formattedPath := formatPath(spec.NamePath) + + err := c.deleteTableSpecByName(formattedPath) + if err != nil { + return err + } + + for _, column := range spec.ColumnsWithNewDefaultValue { + if foundSpecToUpdate.Column(column.ColumnName) == nil { + return fmt.Errorf("cannot update column %s to have a default value, table %s does not have this column", tableName, column.ColumnName) + } + } + + addedColumns := make([]*ColumnSpec, len(foundSpecToUpdate.Columns)) + copy(addedColumns, foundSpecToUpdate.Columns) + addedColumns = append(addedColumns, spec.AddedColumns...) + + foundSpecToUpdate.Columns = addedColumns + foundSpecToUpdate.UpdatedAt = spec.UpdatedAt + + err = c.addTableSpec(foundSpecToUpdate) + if err != nil { + return err + } + + return nil +} + func (c *Catalog) addTableSpecRecursive(cat *types.SimpleCatalog, spec *TableSpec) error { if len(spec.NamePath) > 1 { subCatalogName := spec.NamePath[0] diff --git a/internal/spec.go b/internal/spec.go index fb1c831..b2e631d 100644 --- a/internal/spec.go +++ b/internal/spec.go @@ -114,6 +114,18 @@ type TableSpec struct { CreatedAt time.Time `json:"createdAt"` } +type ColumnWithDefaultSpec struct { + ColumnName string + DefaultValue string +} + +type AlterTableSpec struct { + NamePath []string `json:"namePath"` + AddedColumns []*ColumnSpec `json:"addedColumns"` + ColumnsWithNewDefaultValue []*ColumnWithDefaultSpec `json:"columnsWithNewDefaultValue"` + UpdatedAt time.Time `json:"updatedAt"` +} + func (s *TableSpec) Column(name string) *ColumnSpec { for _, col := range s.Columns { if col.Name == name { @@ -123,6 +135,10 @@ func (s *TableSpec) Column(name string) *ColumnSpec { return nil } +func (s *AlterTableSpec) TableName() string { + return formatPath(s.NamePath) +} + func (s *TableSpec) TableName() string { return formatPath(s.NamePath) } @@ -513,6 +529,54 @@ func newPrimaryKey(key *ast.PrimaryKeyNode) []string { return key.ColumnNameList() } +func newAlterSpec(ctx context.Context, namePath *NamePath, stmt *ast.AlterTableStmtNode) (*AlterTableSpec, error) { + list := stmt.AlterActionList() + var columns []*ast.ColumnDefinitionNode + var columnsAddDefault []*ColumnWithDefaultSpec + + var err error + + for i := range list { + action := list[i] + if err != nil { + return nil, err + } + switch action.Kind() { + case ast.AddColumnAction | ast.AlterColumnSetDefaultAction: + err = fmt.Errorf("adding field with default value to an existing table schema is not supported") + case ast.AddColumnAction: + addColumnAction := action.(*ast.AddColumnActionNode) + columns = append(columns, addColumnAction.ColumnDefinition()) + case ast.AlterColumnSetDefaultAction: + setDefaultAction := action.(*ast.AlterColumnSetDefaultActionNode) + columnName := setDefaultAction.Column() + defaultValueExpr := setDefaultAction.DefaultValue().Expression() + var defaultValue string + if defaultValueExpr != nil { + // TODO: figure out the timestamp thing here? + defaultValue, err = newNode(defaultValueExpr).FormatSQL(ctx) // assuming newNode has a method to format SQL + if err != nil { + return nil, fmt.Errorf("failed to format default value: %w", err) + } + } + columnsAddDefault = append(columnsAddDefault, &ColumnWithDefaultSpec{ + ColumnName: columnName, + DefaultValue: defaultValue, + }) + default: + err = fmt.Errorf("unknown alter action kind: %v", action.Kind()) + } + } + + now := time.Now() + return &AlterTableSpec{ + NamePath: namePath.mergePath(stmt.NamePath()), + AddedColumns: newColumnsFromDef(columns), + ColumnsWithNewDefaultValue: columnsAddDefault, + UpdatedAt: now, + }, nil +} + func newTableSpec(namePath *NamePath, stmt *ast.CreateTableStmtNode) *TableSpec { now := time.Now() return &TableSpec{ diff --git a/internal/stmt_action.go b/internal/stmt_action.go index 8dee05e..90e7951 100644 --- a/internal/stmt_action.go +++ b/internal/stmt_action.go @@ -17,6 +17,81 @@ type StmtAction interface { Args() []interface{} } +type AlterTableStmtAction struct { + query string + node *ast.AlterTableStmtNode + spec *AlterTableSpec + args []interface{} + rawArgs []driver.NamedValue + catalog *Catalog +} + +func (a *AlterTableStmtAction) Prepare(ctx context.Context, conn *Conn) (driver.Stmt, error) { + stmt, err := conn.PrepareContext(ctx, a.query) + if err != nil { + return nil, fmt.Errorf("failed to prepare %s: %w", a.query, err) + } + return newDMLStmt(stmt, []*ast.ParameterNode{}, ""), nil +} + +func (a *AlterTableStmtAction) exec(ctx context.Context, conn *Conn) error { + var statementsToExecute []string + + for _, column := range a.spec.AddedColumns { + statementsToExecute = append( + statementsToExecute, + fmt.Sprintf("ALTER TABLE `%s` ADD COLUMN `%s` %s;", formatPath(a.spec.NamePath), column.Name, column.SQLiteSchema()), + ) + } + + for _, column := range a.spec.ColumnsWithNewDefaultValue { + foundColumn := a.catalog.tableMap[a.spec.TableName()].Column(column.ColumnName) + renameOldColumn := fmt.Sprintf("ALTER TABLE `%s` RENAME COLUMN `%s` TO `%s_old`;", formatPath(a.spec.NamePath), column.ColumnName, column.ColumnName) + addNewColumn := fmt.Sprintf("ALTER TABLE `%s` ADD COLUMN `%s` %s DEFAULT %s;", formatPath(a.spec.NamePath), column.ColumnName, foundColumn.SQLiteSchema(), column.DefaultValue) + copyData := fmt.Sprintf("UPDATE `%s` SET `%s` = `%s_old`;", formatPath(a.spec.NamePath), column.ColumnName, column.ColumnName) + dropOldColumn := fmt.Sprintf("ALTER TABLE `%s` DROP COLUMN `%s_old`;", formatPath(a.spec.NamePath), column.ColumnName) + + statementsToExecute = append(statementsToExecute, renameOldColumn, addNewColumn, copyData, dropOldColumn) + } + + fullQuery := strings.Join(statementsToExecute, "\n") + + // TODO: improve + fullQuery = strings.ReplaceAll(fullQuery, "zetasqlite_current_timestamp()", "CURRENT_TIMESTAMP") + + if _, err := conn.ExecContext(ctx, fullQuery); err != nil { + return fmt.Errorf("failed to exec %s: %w", a.query, err) + } + + if err := a.catalog.modifyTableSpec(a.spec); err != nil { + return fmt.Errorf("failed to add new table spec: %w", err) + } + + return nil +} + +func (a *AlterTableStmtAction) ExecContext(ctx context.Context, conn *Conn) (driver.Result, error) { + if err := a.exec(ctx, conn); err != nil { + return nil, err + } + return &Result{conn: conn}, nil +} + +func (a *AlterTableStmtAction) QueryContext(ctx context.Context, conn *Conn) (*Rows, error) { + if err := a.exec(ctx, conn); err != nil { + return nil, err + } + return &Rows{conn: conn}, nil +} + +func (a *AlterTableStmtAction) Args() []interface{} { + return a.args +} + +func (a *AlterTableStmtAction) Cleanup(ctx context.Context, conn *Conn) error { + return nil +} + type CreateTableStmtAction struct { query string args []interface{}