Skip to content

Commit

Permalink
Implement CREATE SCHEMA
Browse files Browse the repository at this point in the history
  • Loading branch information
nightscape committed Jan 8, 2025
1 parent d905d5a commit f8da236
Show file tree
Hide file tree
Showing 6 changed files with 247 additions and 5 deletions.
13 changes: 13 additions & 0 deletions exec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,19 @@ CREATE OR REPLACE TABLE recreate_table ( a string );
DROP TABLE recreate_table;
CREATE TABLE recreate_table ( b string );
INSERT recreate_table (b) VALUES ('hello');
`,
},
{
name: "create schema",
query: `
CREATE SCHEMA new_schema;
`,
},
{
name: "create schema if not exists",
query: `
CREATE SCHEMA IF NOT EXISTS new_schema_2;
CREATE SCHEMA IF NOT EXISTS new_schema_2;
`,
},
{
Expand Down
13 changes: 13 additions & 0 deletions internal/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ func newAnalyzerOptions() (*zetasql.AnalyzerOptions, error) {
ast.CreateTableFunctionStmt,
ast.CreateViewStmt,
ast.DropFunctionStmt,
ast.CreateSchemaStmt,
})
// Enable QUALIFY without WHERE
// https://github.com/google/zetasql/issues/124
Expand Down Expand Up @@ -273,6 +274,8 @@ func (a *Analyzer) newStmtAction(ctx context.Context, query string, args []drive
case ast.CreateViewStmt:
ctx = withUseColumnID(ctx)
return a.newCreateViewStmtAction(ctx, query, args, node.(*ast.CreateViewStmtNode))
case ast.CreateSchemaStmt:
return a.newCreateSchemaStmtAction(ctx, query, args, node.(*ast.CreateSchemaStmtNode))
case ast.DropStmt:
return a.newDropStmtAction(ctx, query, args, node.(*ast.DropStmtNode))
case ast.DropFunctionStmt:
Expand Down Expand Up @@ -370,6 +373,16 @@ func (a *Analyzer) newCreateViewStmtAction(ctx context.Context, _ string, _ []dr
}, nil
}

//nolint:unparam
func (a *Analyzer) newCreateSchemaStmtAction(_ context.Context, query string, _ []driver.NamedValue, node *ast.CreateSchemaStmtNode) (*CreateSchemaStmtAction, error) {
spec := newSchemaSpec(a.namePath, node)
return &CreateSchemaStmtAction{
query: query,
spec: spec,
catalog: a.catalog,
}, nil
}

func (a *Analyzer) resultTypeIsTemplatedType(sig *types.FunctionSignature) bool {
if !sig.IsTemplated() {
return false
Expand Down
119 changes: 114 additions & 5 deletions internal/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ const (
TableSpecKind CatalogSpecKind = "table"
ViewSpecKind CatalogSpecKind = "view"
FunctionSpecKind CatalogSpecKind = "function"
SchemaSpecKind CatalogSpecKind = "schema"
catalogName = "zetasqlite"
)

Expand All @@ -60,9 +61,11 @@ type Catalog struct {
mu sync.Mutex
tables []*TableSpec
functions []*FunctionSpec
schemas []*SchemaSpec
catalog *types.SimpleCatalog
tableMap map[string]*TableSpec
funcMap map[string]*FunctionSpec
schemaMap map[string]*SchemaSpec
}

func newSimpleCatalog(name string) *types.SimpleCatalog {
Expand All @@ -73,10 +76,11 @@ func newSimpleCatalog(name string) *types.SimpleCatalog {

func NewCatalog(db *sql.DB) *Catalog {
return &Catalog{
db: db,
catalog: newSimpleCatalog(catalogName),
tableMap: map[string]*TableSpec{},
funcMap: map[string]*FunctionSpec{},
db: db,
catalog: newSimpleCatalog(catalogName),
tableMap: map[string]*TableSpec{},
funcMap: map[string]*FunctionSpec{},
schemaMap: map[string]*SchemaSpec{},
}
}

Expand Down Expand Up @@ -208,6 +212,10 @@ func (c *Catalog) Sync(ctx context.Context, conn *Conn) error {
if err := c.loadFunctionSpec(spec); err != nil {
return fmt.Errorf("failed to load function spec: %w", err)
}
case SchemaSpecKind:
if err := c.loadSchemaSpec(spec); err != nil {
return fmt.Errorf("failed to load schema spec: %w", err)
}
default:
return fmt.Errorf("unknown catalog spec kind %s", kind)
}
Expand Down Expand Up @@ -246,6 +254,19 @@ func (c *Catalog) AddNewFunctionSpec(ctx context.Context, conn *Conn, spec *Func
return nil
}

func (c *Catalog) AddNewSchemaSpec(ctx context.Context, conn *Conn, spec *SchemaSpec) error {
c.mu.Lock()
defer c.mu.Unlock()

if err := c.saveSchemaSpec(ctx, conn, spec); err != nil {
return fmt.Errorf("failed to save schema spec: %w", err)
}
if err := c.addSchemaSpec(spec); err != nil {
return fmt.Errorf("failed to add schema spec: %w", err)
}
return nil
}

func (c *Catalog) DeleteTableSpec(ctx context.Context, conn *Conn, name string) error {
c.mu.Lock()
defer c.mu.Unlock()
Expand Down Expand Up @@ -310,12 +331,29 @@ func (c *Catalog) deleteFunctionSpecByName(name string) error {
return nil
}

func (c *Catalog) deleteSchemaSpecByName(name string) error {
spec, exists := c.schemaMap[name]
if !exists {
return nil
}
delete(c.schemaMap, name)
for i, s := range c.schemas {
if s == spec {
c.schemas = append(c.schemas[:i], c.schemas[i+1:]...)
break
}
}
return nil
}

func (c *Catalog) resetCatalog(tables []*TableSpec, functions []*FunctionSpec) error {
c.catalog = newSimpleCatalog(catalogName)
c.tables = []*TableSpec{}
c.functions = []*FunctionSpec{}
c.schemas = []*SchemaSpec{}
c.tableMap = map[string]*TableSpec{}
c.funcMap = map[string]*FunctionSpec{}
c.schemaMap = map[string]*SchemaSpec{}
for _, spec := range tables {
if err := c.addTableSpec(spec); err != nil {
return err
Expand Down Expand Up @@ -373,6 +411,26 @@ func (c *Catalog) saveFunctionSpec(ctx context.Context, conn *Conn, spec *Functi
return nil
}

func (c *Catalog) saveSchemaSpec(ctx context.Context, conn *Conn, spec *SchemaSpec) error {
if err := c.createCatalogTablesIfNotExists(ctx, conn); err != nil {
return err
}
data, err := json.Marshal(spec)
if err != nil {
return fmt.Errorf("failed to encode schema spec: %w", err)
}
if _, err := conn.ExecContext(ctx, upsertCatalogQuery,
spec.SchemaName(),
SchemaSpecKind,
string(data),
spec.UpdatedAt,
spec.CreatedAt,
); err != nil {
return fmt.Errorf("failed to save schema spec: %w", err)
}
return nil
}

func (c *Catalog) createCatalogTablesIfNotExists(ctx context.Context, conn *Conn) error {
if _, err := conn.ExecContext(ctx, createCatalogTableQuery); err != nil {
return fmt.Errorf("failed to create catalog table: %w", err)
Expand Down Expand Up @@ -402,6 +460,17 @@ func (c *Catalog) loadFunctionSpec(spec string) error {
return nil
}

func (c *Catalog) loadSchemaSpec(spec string) error {
var s SchemaSpec
if err := json.Unmarshal([]byte(spec), &s); err != nil {
return fmt.Errorf("failed to decode schema spec: %w", err)
}
if err := c.addSchemaSpec(&s); err != nil {
return fmt.Errorf("failed to add schema spec: %w", err)
}
return nil
}

func (c *Catalog) trimmedLastPath(path *NamePath) *NamePath {
return path.dropLast()
}
Expand Down Expand Up @@ -434,6 +503,18 @@ func (c *Catalog) addTableSpec(spec *TableSpec) error {
return nil
}

func (c *Catalog) addSchemaSpec(spec *SchemaSpec) error {
name := spec.SchemaName()
if _, exists := c.schemaMap[name]; exists {
if err := c.deleteSchemaSpecByName(name); err != nil {
return err
}
}
c.schemas = append(c.schemas, spec)
c.schemaMap[name] = spec
return c.addSchemaSpecRecursive(c.catalog, spec)
}

func (c *Catalog) addTableSpecRecursive(cat *types.SimpleCatalog, spec *TableSpec) error {
if catalogId := spec.NamePath.GetCatalogId(); catalogId != "" {
subCatalogName := catalogId
Expand Down Expand Up @@ -461,11 +542,11 @@ func (c *Catalog) addTableSpecRecursive(cat *types.SimpleCatalog, spec *TableSpe
}
return nil
}
tableName := spec.NamePath.GetObjectId()
if spec.NamePath.Empty() {
return fmt.Errorf("table name is not found")
}

tableName := spec.NamePath.GetObjectId()
if c.existsTable(cat, tableName) {
return nil
}
Expand Down Expand Up @@ -536,6 +617,29 @@ func (c *Catalog) addFunctionSpecRecursive(cat *types.SimpleCatalog, spec *Funct
return nil
}

func (c *Catalog) addSchemaSpecRecursive(cat *types.SimpleCatalog, spec *SchemaSpec) error {
if spec.NamePath.Empty() {
return nil
}
name := spec.NamePath.path[0]
if c.existsSchema(cat, name) {
return nil
}
var subCat *types.SimpleCatalog
if !c.existsSchema(cat, name) {
subCat = types.NewSimpleCatalog(name)
cat.AddCatalog(subCat)
} else {
schema, _ := cat.Catalog(name)
subCat = schema
}
return c.addSchemaSpecRecursive(subCat, &SchemaSpec{
NamePath: NamePath{path: spec.NamePath.path[1:]},
UpdatedAt: spec.UpdatedAt,
CreatedAt: spec.CreatedAt,
})
}

func (c *Catalog) existsTable(cat *types.SimpleCatalog, name string) bool {
foundTable, _ := cat.FindTable([]string{name})
return !c.isNilTable(foundTable)
Expand All @@ -546,6 +650,11 @@ func (c *Catalog) existsFunction(cat *types.SimpleCatalog, name string) bool {
return foundFunc != nil
}

func (c *Catalog) existsSchema(cat *types.SimpleCatalog, name string) bool {
schema, _ := cat.Catalog(name)
return schema != nil
}

func (c *Catalog) isNilTable(t types.Table) bool {
v := reflect.ValueOf(t)
if !v.IsValid() {
Expand Down
22 changes: 22 additions & 0 deletions internal/spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,28 @@ func (s *ColumnSpec) SQLiteSchema() string {
return schema
}

type SchemaSpec struct {
NamePath NamePath `json:"namePath"`
CreateMode ast.CreateMode `json:"createMode"`
UpdatedAt time.Time `json:"updatedAt"`
CreatedAt time.Time `json:"createdAt"`
}

func (s *SchemaSpec) SchemaName() string {
return s.NamePath.GetObjectId()
}

func newSchemaSpec(namePath *NamePath, stmt *ast.CreateSchemaStmtNode) *SchemaSpec {
schemaNamePath := stmt.NamePath()
now := time.Now()
return &SchemaSpec{
NamePath: *namePath.mergePath(schemaNamePath),
CreateMode: stmt.CreateMode(),
UpdatedAt: now,
CreatedAt: now,
}
}

func newTypeFromFunctionArgumentType(t *types.FunctionArgumentType) *Type {
if t.IsTemplated() {
return &Type{SignatureKind: t.Kind()}
Expand Down
46 changes: 46 additions & 0 deletions internal/stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ var (
_ driver.Stmt = &CreateFunctionStmt{}
_ driver.Stmt = &DMLStmt{}
_ driver.Stmt = &QueryStmt{}
_ driver.Stmt = &CreateSchemaStmt{}
)

type CreateTableStmt struct {
Expand Down Expand Up @@ -125,6 +126,51 @@ func newCreateFunctionStmt(conn *Conn, catalog *Catalog, spec *FunctionSpec) *Cr
}
}

type CreateSchemaStmt struct {
conn *Conn
catalog *Catalog
spec *SchemaSpec
}

func newCreateSchemaStmt(conn *Conn, catalog *Catalog, spec *SchemaSpec) *CreateSchemaStmt {
return &CreateSchemaStmt{
conn: conn,
catalog: catalog,
spec: spec,
}
}

func (s *CreateSchemaStmt) Close() error {
return nil
}

func (s *CreateSchemaStmt) NumInput() int {
return 0
}

func (s *CreateSchemaStmt) Exec(args []driver.Value) (driver.Result, error) {
if err := s.catalog.AddNewSchemaSpec(context.Background(), s.conn, s.spec); err != nil {
return nil, fmt.Errorf("failed to add new schema spec: %w", err)
}
return nil, nil
}

func (s *CreateSchemaStmt) Query(args []driver.Value) (driver.Rows, error) {
return nil, fmt.Errorf("failed to query for CreateSchemaStmt")
}

func (s *CreateSchemaStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
return nil, fmt.Errorf("unimplemented ExecContext for CreateSchemaStmt")
}

func (s *CreateSchemaStmt) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
return nil, fmt.Errorf("unsupported query for CreateSchemaStmt")
}

func (s *CreateSchemaStmt) CheckNamedValue(value *driver.NamedValue) error {
return nil
}

type DMLStmt struct {
stmt *sql.Stmt
args []*ast.ParameterNode
Expand Down
Loading

0 comments on commit f8da236

Please sign in to comment.