Skip to content

Commit

Permalink
Implement CREATE SCHEMA
Browse files Browse the repository at this point in the history
  • Loading branch information
nightscape committed Jan 8, 2025
1 parent 0c73346 commit 4b4bc95
Show file tree
Hide file tree
Showing 7 changed files with 325 additions and 6 deletions.
28 changes: 28 additions & 0 deletions exec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,34 @@ CREATE OR REPLACE TABLE recreate_table ( a string );
DROP TABLE recreate_table;
CREATE TABLE recreate_table ( b string );
INSERT recreate_table (b) VALUES ('hello');
`,
},
{
name: "create schema",
query: `
CREATE SCHEMA new_schema;
CREATE TABLE new_schema.new_table (a STRING);
`,
},
{
name: "create schema with qualifiers",
query: `
CREATE SCHEMA projectId.new_schema;
CREATE TABLE projectId.new_schema.new_table (a STRING);
`,
},
{
name: "create schema if not exists",
query: `
CREATE SCHEMA IF NOT EXISTS new_schema_2;
CREATE SCHEMA IF NOT EXISTS new_schema_2;
`,
},
{
name: "drop schema",
query: `
CREATE SCHEMA new_schema_3;
DROP SCHEMA IF EXISTS new_schema_3;
`,
},
{
Expand Down
13 changes: 13 additions & 0 deletions internal/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ func newAnalyzerOptions() (*zetasql.AnalyzerOptions, error) {
ast.CreateTableFunctionStmt,
ast.CreateViewStmt,
ast.DropFunctionStmt,
ast.CreateSchemaStmt,
})
// Enable QUALIFY without WHERE
// https://github.com/google/zetasql/issues/124
Expand Down Expand Up @@ -273,6 +274,8 @@ func (a *Analyzer) newStmtAction(ctx context.Context, query string, args []drive
case ast.CreateViewStmt:
ctx = withUseColumnID(ctx)
return a.newCreateViewStmtAction(ctx, query, args, node.(*ast.CreateViewStmtNode))
case ast.CreateSchemaStmt:
return a.newCreateSchemaStmtAction(ctx, query, args, node.(*ast.CreateSchemaStmtNode))
case ast.DropStmt:
return a.newDropStmtAction(ctx, query, args, node.(*ast.DropStmtNode))
case ast.DropFunctionStmt:
Expand Down Expand Up @@ -370,6 +373,16 @@ func (a *Analyzer) newCreateViewStmtAction(ctx context.Context, _ string, _ []dr
}, nil
}

//nolint:unparam
func (a *Analyzer) newCreateSchemaStmtAction(_ context.Context, query string, _ []driver.NamedValue, node *ast.CreateSchemaStmtNode) (*CreateSchemaStmtAction, error) {
spec := newSchemaSpec(a.namePath, node)
return &CreateSchemaStmtAction{
query: query,
spec: spec,
catalog: a.catalog,
}, nil
}

func (a *Analyzer) resultTypeIsTemplatedType(sig *types.FunctionSignature) bool {
if !sig.IsTemplated() {
return false
Expand Down
132 changes: 127 additions & 5 deletions internal/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ const (
TableSpecKind CatalogSpecKind = "table"
ViewSpecKind CatalogSpecKind = "view"
FunctionSpecKind CatalogSpecKind = "function"
SchemaSpecKind CatalogSpecKind = "schema"
catalogName = "zetasqlite"
)

Expand All @@ -60,9 +61,11 @@ type Catalog struct {
mu sync.Mutex
tables []*TableSpec
functions []*FunctionSpec
schemas []*SchemaSpec
catalog *types.SimpleCatalog
tableMap map[string]*TableSpec
funcMap map[string]*FunctionSpec
schemaMap map[string]*SchemaSpec
}

func newSimpleCatalog(name string) *types.SimpleCatalog {
Expand All @@ -73,10 +76,11 @@ func newSimpleCatalog(name string) *types.SimpleCatalog {

func NewCatalog(db *sql.DB) *Catalog {
return &Catalog{
db: db,
catalog: newSimpleCatalog(catalogName),
tableMap: map[string]*TableSpec{},
funcMap: map[string]*FunctionSpec{},
db: db,
catalog: newSimpleCatalog(catalogName),
tableMap: map[string]*TableSpec{},
funcMap: map[string]*FunctionSpec{},
schemaMap: map[string]*SchemaSpec{},
}
}

Expand Down Expand Up @@ -208,6 +212,10 @@ func (c *Catalog) Sync(ctx context.Context, conn *Conn) error {
if err := c.loadFunctionSpec(spec); err != nil {
return fmt.Errorf("failed to load function spec: %w", err)
}
case SchemaSpecKind:
if err := c.loadSchemaSpec(spec); err != nil {
return fmt.Errorf("failed to load schema spec: %w", err)
}
default:
return fmt.Errorf("unknown catalog spec kind %s", kind)
}
Expand Down Expand Up @@ -246,6 +254,32 @@ func (c *Catalog) AddNewFunctionSpec(ctx context.Context, conn *Conn, spec *Func
return nil
}

func (c *Catalog) AddNewSchemaSpec(ctx context.Context, conn *Conn, spec *SchemaSpec) error {
c.mu.Lock()
defer c.mu.Unlock()

if err := c.saveSchemaSpec(ctx, conn, spec); err != nil {
return fmt.Errorf("failed to save schema spec: %w", err)
}
if err := c.addSchemaSpec(spec); err != nil {
return fmt.Errorf("failed to add schema spec: %w", err)
}
return nil
}

func (c *Catalog) DeleteSchemaSpec(ctx context.Context, conn *Conn, name string) error {
c.mu.Lock()
defer c.mu.Unlock()

if err := c.deleteSchemaSpecByName(name); err != nil {
return err
}
if _, err := conn.ExecContext(ctx, deleteCatalogQuery, sql.Named("name", name)); err != nil {
return err
}
return nil
}

func (c *Catalog) DeleteTableSpec(ctx context.Context, conn *Conn, name string) error {
c.mu.Lock()
defer c.mu.Unlock()
Expand Down Expand Up @@ -310,12 +344,29 @@ func (c *Catalog) deleteFunctionSpecByName(name string) error {
return nil
}

func (c *Catalog) deleteSchemaSpecByName(name string) error {
spec, exists := c.schemaMap[name]
if !exists {
return nil
}
delete(c.schemaMap, name)
for i, s := range c.schemas {
if s == spec {
c.schemas = append(c.schemas[:i], c.schemas[i+1:]...)
break
}
}
return nil
}

func (c *Catalog) resetCatalog(tables []*TableSpec, functions []*FunctionSpec) error {
c.catalog = newSimpleCatalog(catalogName)
c.tables = []*TableSpec{}
c.functions = []*FunctionSpec{}
c.schemas = []*SchemaSpec{}
c.tableMap = map[string]*TableSpec{}
c.funcMap = map[string]*FunctionSpec{}
c.schemaMap = map[string]*SchemaSpec{}
for _, spec := range tables {
if err := c.addTableSpec(spec); err != nil {
return err
Expand Down Expand Up @@ -373,6 +424,26 @@ func (c *Catalog) saveFunctionSpec(ctx context.Context, conn *Conn, spec *Functi
return nil
}

func (c *Catalog) saveSchemaSpec(ctx context.Context, conn *Conn, spec *SchemaSpec) error {
if err := c.createCatalogTablesIfNotExists(ctx, conn); err != nil {
return err
}
data, err := json.Marshal(spec)
if err != nil {
return fmt.Errorf("failed to encode schema spec: %w", err)
}
if _, err := conn.ExecContext(ctx, upsertCatalogQuery,
spec.SchemaName(),
SchemaSpecKind,
string(data),
spec.UpdatedAt,
spec.CreatedAt,
); err != nil {
return fmt.Errorf("failed to save schema spec: %w", err)
}
return nil
}

func (c *Catalog) createCatalogTablesIfNotExists(ctx context.Context, conn *Conn) error {
if _, err := conn.ExecContext(ctx, createCatalogTableQuery); err != nil {
return fmt.Errorf("failed to create catalog table: %w", err)
Expand Down Expand Up @@ -402,6 +473,17 @@ func (c *Catalog) loadFunctionSpec(spec string) error {
return nil
}

func (c *Catalog) loadSchemaSpec(spec string) error {
var s SchemaSpec
if err := json.Unmarshal([]byte(spec), &s); err != nil {
return fmt.Errorf("failed to decode schema spec: %w", err)
}
if err := c.addSchemaSpec(&s); err != nil {
return fmt.Errorf("failed to add schema spec: %w", err)
}
return nil
}

func (c *Catalog) trimmedLastPath(path *NamePath) *NamePath {
return path.dropLast()
}
Expand Down Expand Up @@ -434,6 +516,18 @@ func (c *Catalog) addTableSpec(spec *TableSpec) error {
return nil
}

func (c *Catalog) addSchemaSpec(spec *SchemaSpec) error {
name := spec.SchemaName()
if _, exists := c.schemaMap[name]; exists {
if err := c.deleteSchemaSpecByName(name); err != nil {
return err
}
}
c.schemas = append(c.schemas, spec)
c.schemaMap[name] = spec
return c.addSchemaSpecRecursive(c.catalog, spec)
}

func (c *Catalog) addTableSpecRecursive(cat *types.SimpleCatalog, spec *TableSpec) error {
if catalogId := spec.NamePath.GetCatalogId(); catalogId != "" {
subCatalogName := catalogId
Expand Down Expand Up @@ -461,11 +555,11 @@ func (c *Catalog) addTableSpecRecursive(cat *types.SimpleCatalog, spec *TableSpe
}
return nil
}
tableName := spec.NamePath.GetObjectId()
if spec.NamePath.Empty() {
return fmt.Errorf("table name is not found")
}

tableName := spec.NamePath.GetObjectId()
if c.existsTable(cat, tableName) {
return nil
}
Expand Down Expand Up @@ -536,6 +630,29 @@ func (c *Catalog) addFunctionSpecRecursive(cat *types.SimpleCatalog, spec *Funct
return nil
}

func (c *Catalog) addSchemaSpecRecursive(cat *types.SimpleCatalog, spec *SchemaSpec) error {
if spec.NamePath.Empty() {
return nil
}
name := spec.NamePath.path[0]
if c.existsSchema(cat, name) {
return nil
}
var subCat *types.SimpleCatalog
if !c.existsSchema(cat, name) {
subCat = types.NewSimpleCatalog(name)
cat.AddCatalog(subCat)
} else {
schema, _ := cat.Catalog(name)
subCat = schema
}
return c.addSchemaSpecRecursive(subCat, &SchemaSpec{
NamePath: NamePath{path: spec.NamePath.path[1:]},
UpdatedAt: spec.UpdatedAt,
CreatedAt: spec.CreatedAt,
})
}

func (c *Catalog) existsTable(cat *types.SimpleCatalog, name string) bool {
foundTable, _ := cat.FindTable([]string{name})
return !c.isNilTable(foundTable)
Expand All @@ -546,6 +663,11 @@ func (c *Catalog) existsFunction(cat *types.SimpleCatalog, name string) bool {
return foundFunc != nil
}

func (c *Catalog) existsSchema(cat *types.SimpleCatalog, name string) bool {
schema, _ := cat.Catalog(name)
return schema != nil
}

func (c *Catalog) isNilTable(t types.Table) bool {
v := reflect.ValueOf(t)
if !v.IsValid() {
Expand Down
45 changes: 44 additions & 1 deletion internal/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,28 @@ import (
type ChangedCatalog struct {
Table *ChangedTable
Function *ChangedFunction
Schema *ChangedSchema
}

func newChangedCatalog() *ChangedCatalog {
return &ChangedCatalog{
Table: &ChangedTable{},
Function: &ChangedFunction{},
Schema: &ChangedSchema{},
}
}

func (c *ChangedCatalog) Changed() bool {
return c.Table.Changed() || c.Function.Changed()
return c.Table.Changed() || c.Function.Changed() || c.Schema.Changed()
}

type ChangedSchema struct {
Added []*SchemaSpec
Deleted []*SchemaSpec
}

func (s *ChangedSchema) Changed() bool {
return len(s.Added) != 0 || len(s.Deleted) != 0
}

type ChangedTable struct {
Expand Down Expand Up @@ -75,6 +86,16 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args ...interface
return c.conn.QueryContext(ctx, query, args...)
}

func (c *Conn) addSchema(spec *SchemaSpec) {

Check failure on line 89 in internal/conn.go

View workflow job for this annotation

GitHub Actions / lint

func `(*Conn).addSchema` is unused (unused)
c.removeFromDeletedSchemasIfExists(spec)
c.cc.Schema.Added = append(c.cc.Schema.Added, spec)
}

func (c *Conn) deleteSchema(spec *SchemaSpec) {
c.removeFromAddedSchemasIfExists(spec)
c.cc.Schema.Deleted = append(c.cc.Schema.Deleted, spec)
}

func (c *Conn) addTable(spec *TableSpec) {
c.removeFromDeletedTablesIfExists(spec)
c.cc.Table.Added = append(c.cc.Table.Added, spec)
Expand All @@ -100,6 +121,28 @@ func (c *Conn) deleteFunction(spec *FunctionSpec) {
c.cc.Function.Deleted = append(c.cc.Function.Deleted, spec)
}

func (c *Conn) removeFromDeletedSchemasIfExists(spec *SchemaSpec) {

Check failure on line 124 in internal/conn.go

View workflow job for this annotation

GitHub Actions / lint

func `(*Conn).removeFromDeletedSchemasIfExists` is unused (unused)
schemas := make([]*SchemaSpec, 0, len(c.cc.Schema.Deleted))
for _, schema := range c.cc.Schema.Deleted {
if schema.SchemaName() == spec.SchemaName() {
continue
}
schemas = append(schemas, schema)
}
c.cc.Schema.Deleted = schemas
}

func (c *Conn) removeFromAddedSchemasIfExists(spec *SchemaSpec) {
schemas := make([]*SchemaSpec, 0, len(c.cc.Schema.Added))
for _, schema := range c.cc.Schema.Added {
if schema.SchemaName() == spec.SchemaName() {
continue
}
schemas = append(schemas, schema)
}
c.cc.Schema.Added = schemas
}

func (c *Conn) removeFromDeletedTablesIfExists(spec *TableSpec) {
tables := make([]*TableSpec, 0, len(c.cc.Table.Deleted))
for _, table := range c.cc.Table.Deleted {
Expand Down
Loading

0 comments on commit 4b4bc95

Please sign in to comment.