diff --git a/exec_test.go b/exec_test.go index cab263f..bba0726 100644 --- a/exec_test.go +++ b/exec_test.go @@ -64,6 +64,19 @@ CREATE OR REPLACE TABLE recreate_table ( a string ); DROP TABLE recreate_table; CREATE TABLE recreate_table ( b string ); INSERT recreate_table (b) VALUES ('hello'); +`, + }, + { + name: "create schema", + query: ` +CREATE SCHEMA new_schema; +`, + }, + { + name: "create schema if not exists", + query: ` +CREATE SCHEMA IF NOT EXISTS new_schema_2; +CREATE SCHEMA IF NOT EXISTS new_schema_2; `, }, { diff --git a/internal/analyzer.go b/internal/analyzer.go index 4e7319c..cec355c 100644 --- a/internal/analyzer.go +++ b/internal/analyzer.go @@ -94,6 +94,7 @@ func newAnalyzerOptions() (*zetasql.AnalyzerOptions, error) { ast.CreateTableFunctionStmt, ast.CreateViewStmt, ast.DropFunctionStmt, + ast.CreateSchemaStmt, }) // Enable QUALIFY without WHERE // https://github.com/google/zetasql/issues/124 @@ -273,6 +274,8 @@ func (a *Analyzer) newStmtAction(ctx context.Context, query string, args []drive case ast.CreateViewStmt: ctx = withUseColumnID(ctx) return a.newCreateViewStmtAction(ctx, query, args, node.(*ast.CreateViewStmtNode)) + case ast.CreateSchemaStmt: + return a.newCreateSchemaStmtAction(ctx, query, args, node.(*ast.CreateSchemaStmtNode)) case ast.DropStmt: return a.newDropStmtAction(ctx, query, args, node.(*ast.DropStmtNode)) case ast.DropFunctionStmt: @@ -370,6 +373,15 @@ func (a *Analyzer) newCreateViewStmtAction(ctx context.Context, _ string, _ []dr }, nil } +func (a *Analyzer) newCreateSchemaStmtAction(_ context.Context, query string, _ []driver.NamedValue, node *ast.CreateSchemaStmtNode) (*CreateSchemaStmtAction, error) { + spec := newSchemaSpec(a.namePath, node) + return &CreateSchemaStmtAction{ + query: query, + spec: spec, + catalog: a.catalog, + }, nil +} + func (a *Analyzer) resultTypeIsTemplatedType(sig *types.FunctionSignature) bool { if !sig.IsTemplated() { return false diff --git a/internal/catalog.go b/internal/catalog.go index 15a5b5f..d22c341 100644 --- a/internal/catalog.go +++ b/internal/catalog.go @@ -51,6 +51,7 @@ const ( TableSpecKind CatalogSpecKind = "table" ViewSpecKind CatalogSpecKind = "view" FunctionSpecKind CatalogSpecKind = "function" + SchemaSpecKind CatalogSpecKind = "schema" catalogName = "zetasqlite" ) @@ -60,9 +61,11 @@ type Catalog struct { mu sync.Mutex tables []*TableSpec functions []*FunctionSpec + schemas []*SchemaSpec catalog *types.SimpleCatalog tableMap map[string]*TableSpec funcMap map[string]*FunctionSpec + schemaMap map[string]*SchemaSpec } func newSimpleCatalog(name string) *types.SimpleCatalog { @@ -73,10 +76,11 @@ func newSimpleCatalog(name string) *types.SimpleCatalog { func NewCatalog(db *sql.DB) *Catalog { return &Catalog{ - db: db, - catalog: newSimpleCatalog(catalogName), - tableMap: map[string]*TableSpec{}, - funcMap: map[string]*FunctionSpec{}, + db: db, + catalog: newSimpleCatalog(catalogName), + tableMap: map[string]*TableSpec{}, + funcMap: map[string]*FunctionSpec{}, + schemaMap: map[string]*SchemaSpec{}, } } @@ -208,6 +212,10 @@ func (c *Catalog) Sync(ctx context.Context, conn *Conn) error { if err := c.loadFunctionSpec(spec); err != nil { return fmt.Errorf("failed to load function spec: %w", err) } + case SchemaSpecKind: + if err := c.loadSchemaSpec(spec); err != nil { + return fmt.Errorf("failed to load schema spec: %w", err) + } default: return fmt.Errorf("unknown catalog spec kind %s", kind) } @@ -246,6 +254,19 @@ func (c *Catalog) AddNewFunctionSpec(ctx context.Context, conn *Conn, spec *Func return nil } +func (c *Catalog) AddNewSchemaSpec(ctx context.Context, conn *Conn, spec *SchemaSpec) error { + c.mu.Lock() + defer c.mu.Unlock() + + if err := c.saveSchemaSpec(ctx, conn, spec); err != nil { + return fmt.Errorf("failed to save schema spec: %w", err) + } + if err := c.addSchemaSpec(spec); err != nil { + return fmt.Errorf("failed to add schema spec: %w", err) + } + return nil +} + func (c *Catalog) DeleteTableSpec(ctx context.Context, conn *Conn, name string) error { c.mu.Lock() defer c.mu.Unlock() @@ -310,12 +331,29 @@ func (c *Catalog) deleteFunctionSpecByName(name string) error { return nil } +func (c *Catalog) deleteSchemaSpecByName(name string) error { + spec, exists := c.schemaMap[name] + if !exists { + return nil + } + delete(c.schemaMap, name) + for i, s := range c.schemas { + if s == spec { + c.schemas = append(c.schemas[:i], c.schemas[i+1:]...) + break + } + } + return nil +} + func (c *Catalog) resetCatalog(tables []*TableSpec, functions []*FunctionSpec) error { c.catalog = newSimpleCatalog(catalogName) c.tables = []*TableSpec{} c.functions = []*FunctionSpec{} + c.schemas = []*SchemaSpec{} c.tableMap = map[string]*TableSpec{} c.funcMap = map[string]*FunctionSpec{} + c.schemaMap = map[string]*SchemaSpec{} for _, spec := range tables { if err := c.addTableSpec(spec); err != nil { return err @@ -373,6 +411,26 @@ func (c *Catalog) saveFunctionSpec(ctx context.Context, conn *Conn, spec *Functi return nil } +func (c *Catalog) saveSchemaSpec(ctx context.Context, conn *Conn, spec *SchemaSpec) error { + if err := c.createCatalogTablesIfNotExists(ctx, conn); err != nil { + return err + } + data, err := json.Marshal(spec) + if err != nil { + return fmt.Errorf("failed to encode schema spec: %w", err) + } + if _, err := conn.ExecContext(ctx, upsertCatalogQuery, + spec.SchemaName(), + SchemaSpecKind, + string(data), + spec.UpdatedAt, + spec.CreatedAt, + ); err != nil { + return fmt.Errorf("failed to save schema spec: %w", err) + } + return nil +} + func (c *Catalog) createCatalogTablesIfNotExists(ctx context.Context, conn *Conn) error { if _, err := conn.ExecContext(ctx, createCatalogTableQuery); err != nil { return fmt.Errorf("failed to create catalog table: %w", err) @@ -402,6 +460,17 @@ func (c *Catalog) loadFunctionSpec(spec string) error { return nil } +func (c *Catalog) loadSchemaSpec(spec string) error { + var s SchemaSpec + if err := json.Unmarshal([]byte(spec), &s); err != nil { + return fmt.Errorf("failed to decode schema spec: %w", err) + } + if err := c.addSchemaSpec(&s); err != nil { + return fmt.Errorf("failed to add schema spec: %w", err) + } + return nil +} + func (c *Catalog) trimmedLastPath(path *NamePath) *NamePath { return path.dropLast() } @@ -434,6 +503,18 @@ func (c *Catalog) addTableSpec(spec *TableSpec) error { return nil } +func (c *Catalog) addSchemaSpec(spec *SchemaSpec) error { + name := spec.SchemaName() + if _, exists := c.schemaMap[name]; exists { + if err := c.deleteSchemaSpecByName(name); err != nil { + return err + } + } + c.schemas = append(c.schemas, spec) + c.schemaMap[name] = spec + return c.addSchemaSpecRecursive(c.catalog, spec) +} + func (c *Catalog) addTableSpecRecursive(cat *types.SimpleCatalog, spec *TableSpec) error { if spec.NamePath.hasQualifiers() { subCatalogName := spec.NamePath.getProjectName() @@ -536,6 +617,29 @@ func (c *Catalog) addFunctionSpecRecursive(cat *types.SimpleCatalog, spec *Funct return nil } +func (c *Catalog) addSchemaSpecRecursive(cat *types.SimpleCatalog, spec *SchemaSpec) error { + if spec.NamePath.empty() { + return nil + } + name := spec.NamePath.path[0] + if c.existsSchema(cat, name) { + return nil + } + var subCat *types.SimpleCatalog + if !c.existsSchema(cat, name) { + subCat = types.NewSimpleCatalog(name) + cat.AddCatalog(subCat) + } else { + schema, _ := cat.Catalog(name) + subCat = schema + } + return c.addSchemaSpecRecursive(subCat, &SchemaSpec{ + NamePath: NamePath{path: spec.NamePath.path[1:]}, + UpdatedAt: spec.UpdatedAt, + CreatedAt: spec.CreatedAt, + }) +} + func (c *Catalog) existsTable(cat *types.SimpleCatalog, name string) bool { foundTable, _ := cat.FindTable([]string{name}) return !c.isNilTable(foundTable) @@ -546,6 +650,11 @@ func (c *Catalog) existsFunction(cat *types.SimpleCatalog, name string) bool { return foundFunc != nil } +func (c *Catalog) existsSchema(cat *types.SimpleCatalog, name string) bool { + schema, _ := cat.Catalog(name) + return schema != nil +} + func (c *Catalog) isNilTable(t types.Table) bool { v := reflect.ValueOf(t) if !v.IsValid() { diff --git a/internal/spec.go b/internal/spec.go index e4d551d..4bdbdb1 100644 --- a/internal/spec.go +++ b/internal/spec.go @@ -327,6 +327,28 @@ func (s *ColumnSpec) SQLiteSchema() string { return schema } +type SchemaSpec struct { + NamePath NamePath `json:"namePath"` + CreateMode ast.CreateMode `json:"createMode"` + UpdatedAt time.Time `json:"updatedAt"` + CreatedAt time.Time `json:"createdAt"` +} + +func (s *SchemaSpec) SchemaName() string { + return s.NamePath.getObjectName() +} + +func newSchemaSpec(namePath *NamePath, stmt *ast.CreateSchemaStmtNode) *SchemaSpec { + schemaNamePath := stmt.NamePath() + now := time.Now() + return &SchemaSpec{ + NamePath: *namePath.mergePath(schemaNamePath), + CreateMode: stmt.CreateMode(), + UpdatedAt: now, + CreatedAt: now, + } +} + func newTypeFromFunctionArgumentType(t *types.FunctionArgumentType) *Type { if t.IsTemplated() { return &Type{SignatureKind: t.Kind()} diff --git a/internal/stmt.go b/internal/stmt.go index 1aaba39..d0020f6 100644 --- a/internal/stmt.go +++ b/internal/stmt.go @@ -14,6 +14,7 @@ var ( _ driver.Stmt = &CreateFunctionStmt{} _ driver.Stmt = &DMLStmt{} _ driver.Stmt = &QueryStmt{} + _ driver.Stmt = &CreateSchemaStmt{} ) type CreateTableStmt struct { @@ -125,6 +126,51 @@ func newCreateFunctionStmt(conn *Conn, catalog *Catalog, spec *FunctionSpec) *Cr } } +type CreateSchemaStmt struct { + conn *Conn + catalog *Catalog + spec *SchemaSpec +} + +func newCreateSchemaStmt(conn *Conn, catalog *Catalog, spec *SchemaSpec) *CreateSchemaStmt { + return &CreateSchemaStmt{ + conn: conn, + catalog: catalog, + spec: spec, + } +} + +func (s *CreateSchemaStmt) Close() error { + return nil +} + +func (s *CreateSchemaStmt) NumInput() int { + return 0 +} + +func (s *CreateSchemaStmt) Exec(args []driver.Value) (driver.Result, error) { + if err := s.catalog.AddNewSchemaSpec(context.Background(), s.conn, s.spec); err != nil { + return nil, fmt.Errorf("failed to add new schema spec: %w", err) + } + return nil, nil +} + +func (s *CreateSchemaStmt) Query(args []driver.Value) (driver.Rows, error) { + return nil, fmt.Errorf("failed to query for CreateSchemaStmt") +} + +func (s *CreateSchemaStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + return nil, fmt.Errorf("unimplemented ExecContext for CreateSchemaStmt") +} + +func (s *CreateSchemaStmt) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + return nil, fmt.Errorf("unsupported query for CreateSchemaStmt") +} + +func (s *CreateSchemaStmt) CheckNamedValue(value *driver.NamedValue) error { + return nil +} + type DMLStmt struct { stmt *sql.Stmt args []*ast.ParameterNode diff --git a/internal/stmt_action.go b/internal/stmt_action.go index 8d494a7..89634eb 100644 --- a/internal/stmt_action.go +++ b/internal/stmt_action.go @@ -542,3 +542,42 @@ func (a *MergeStmtAction) Args() []interface{} { func (a *MergeStmtAction) Cleanup(ctx context.Context, conn *Conn) error { return nil } + +type CreateSchemaStmtAction struct { + query string + spec *SchemaSpec + catalog *Catalog +} + +func (a *CreateSchemaStmtAction) Args() []interface{} { + return nil +} + +func (a *CreateSchemaStmtAction) Prepare(ctx context.Context, conn *Conn) (driver.Stmt, error) { + return newCreateSchemaStmt(conn, a.catalog, a.spec), nil +} + +func (a *CreateSchemaStmtAction) ExecContext(ctx context.Context, conn *Conn) (driver.Result, error) { + if err := a.exec(ctx, conn); err != nil { + return nil, err + } + return &Result{conn: conn}, nil +} + +func (a *CreateSchemaStmtAction) QueryContext(ctx context.Context, conn *Conn) (*Rows, error) { + if err := a.exec(ctx, conn); err != nil { + return nil, err + } + return &Rows{conn: conn}, nil +} + +func (a *CreateSchemaStmtAction) exec(ctx context.Context, conn *Conn) error { + if err := a.catalog.AddNewSchemaSpec(ctx, conn, a.spec); err != nil { + return fmt.Errorf("failed to add new schema spec: %w", err) + } + return nil +} + +func (a *CreateSchemaStmtAction) Cleanup(ctx context.Context, conn *Conn) error { + return nil +}