diff --git a/catalog.go b/catalog.go index 9b16aed..648d550 100644 --- a/catalog.go +++ b/catalog.go @@ -17,8 +17,11 @@ type ( NameWithType = internal.NameWithType ColumnSpec = internal.ColumnSpec Type = internal.Type + NamePath = internal.NamePath ) +var NewNamePath = internal.NewNamePath + // ChangedCatalogFromRows retrieve modified catalog information from sql.Rows. // NOTE: This API relies on the internal structure of sql.Rows, so not will work for all Go versions. func ChangedCatalogFromRows(rows *sql.Rows) (*ChangedCatalog, error) { diff --git a/cmd/zetasqlite-cli/main.go b/cmd/zetasqlite-cli/main.go index 15e6094..f57b60a 100644 --- a/cmd/zetasqlite-cli/main.go +++ b/cmd/zetasqlite-cli/main.go @@ -203,7 +203,7 @@ func (cli *CLI) showTablesCommand(ctx context.Context) error { if err := json.Unmarshal([]byte(spec), &table); err != nil { return err } - fmt.Fprintf(cli.out, "%s\n", strings.Join(table.NamePath, ".")) + fmt.Fprintf(cli.out, "%s\n", table.NamePath.FormatNamePath()) } return nil } @@ -231,7 +231,7 @@ func (cli *CLI) showFunctionsCommand(ctx context.Context) error { if err := json.Unmarshal([]byte(spec), &fn); err != nil { return err } - fmt.Fprintf(cli.out, "%s\n", strings.Join(fn.NamePath, ".")) + fmt.Fprintf(cli.out, "%s\n", fn.NamePath.FormatNamePath()) } return nil } diff --git a/driver_test.go b/driver_test.go index 30d8030..07265b0 100644 --- a/driver_test.go +++ b/driver_test.go @@ -124,7 +124,7 @@ CREATE TABLE IF NOT EXISTS Singers ( if len(resultCatalog.Table.Added) != 1 { t.Fatal("failed to get created table spec") } - if diff := cmp.Diff(resultCatalog.Table.Added[0].NamePath, []string{"Singers"}); diff != "" { + if diff := cmp.Diff(resultCatalog.Table.Added[0].NamePath.Path(), []string{"Singers"}); diff != "" { t.Errorf("(-want +got):\n%s", diff) } rowsCatalog, err := zetasqlite.ChangedCatalogFromRows(rows) @@ -137,7 +137,7 @@ CREATE TABLE IF NOT EXISTS Singers ( if len(rowsCatalog.Table.Deleted) != 1 { t.Fatal("failed to get deleted table spec") } - if diff := cmp.Diff(rowsCatalog.Table.Deleted[0].NamePath, []string{"Singers"}); diff != "" { + if diff := cmp.Diff(rowsCatalog.Table.Deleted[0].NamePath.Path(), []string{"Singers"}); diff != "" { t.Errorf("(-want +got):\n%s", diff) } }) @@ -167,7 +167,7 @@ CREATE TABLE IF NOT EXISTS Singers ( if len(resultCatalog.Function.Added) != 1 { t.Fatal("failed to get created function spec") } - if diff := cmp.Diff(resultCatalog.Function.Added[0].NamePath, []string{"ANY_ADD"}); diff != "" { + if diff := cmp.Diff(resultCatalog.Function.Added[0].NamePath.Path(), []string{"ANY_ADD"}); diff != "" { t.Errorf("(-want +got):\n%s", diff) } rowsCatalog, err := zetasqlite.ChangedCatalogFromRows(rows) @@ -180,7 +180,7 @@ CREATE TABLE IF NOT EXISTS Singers ( if len(rowsCatalog.Function.Deleted) != 1 { t.Fatal("failed to get deleted function spec") } - if diff := cmp.Diff(rowsCatalog.Function.Deleted[0].NamePath, []string{"ANY_ADD"}); diff != "" { + if diff := cmp.Diff(rowsCatalog.Function.Deleted[0].NamePath.Path(), []string{"ANY_ADD"}); diff != "" { t.Errorf("(-want +got):\n%s", diff) } }) diff --git a/exec_test.go b/exec_test.go index cab263f..c265130 100644 --- a/exec_test.go +++ b/exec_test.go @@ -64,6 +64,34 @@ CREATE OR REPLACE TABLE recreate_table ( a string ); DROP TABLE recreate_table; CREATE TABLE recreate_table ( b string ); INSERT recreate_table (b) VALUES ('hello'); +`, + }, + { + name: "create schema", + query: ` +CREATE SCHEMA new_schema; +CREATE TABLE new_schema.new_table (a STRING); +`, + }, + { + name: "create schema with qualifiers", + query: ` +CREATE SCHEMA projectId.new_schema; +CREATE TABLE projectId.new_schema.new_table (a STRING); +`, + }, + { + name: "create schema if not exists", + query: ` +CREATE SCHEMA IF NOT EXISTS new_schema_2; +CREATE SCHEMA IF NOT EXISTS new_schema_2; +`, + }, + { + name: "drop schema", + query: ` +CREATE SCHEMA new_schema_3; +DROP SCHEMA IF EXISTS new_schema_3; `, }, { diff --git a/internal/analyzer.go b/internal/analyzer.go index 7274d29..d2392ca 100644 --- a/internal/analyzer.go +++ b/internal/analyzer.go @@ -94,6 +94,7 @@ func newAnalyzerOptions() (*zetasql.AnalyzerOptions, error) { ast.CreateTableFunctionStmt, ast.CreateViewStmt, ast.DropFunctionStmt, + ast.CreateSchemaStmt, }) // Enable QUALIFY without WHERE // https://github.com/google/zetasql/issues/124 @@ -116,7 +117,7 @@ func (a *Analyzer) SetExplainMode(enabled bool) { } func (a *Analyzer) NamePath() []string { - return a.namePath.path + return a.namePath.Path() } func (a *Analyzer) SetNamePath(path []string) error { @@ -273,6 +274,8 @@ func (a *Analyzer) newStmtAction(ctx context.Context, query string, args []drive case ast.CreateViewStmt: ctx = withUseColumnID(ctx) return a.newCreateViewStmtAction(ctx, query, args, node.(*ast.CreateViewStmtNode)) + case ast.CreateSchemaStmt: + return a.newCreateSchemaStmtAction(ctx, query, args, node.(*ast.CreateSchemaStmtNode)) case ast.DropStmt: return a.newDropStmtAction(ctx, query, args, node.(*ast.DropStmtNode)) case ast.DropFunctionStmt: @@ -370,6 +373,16 @@ func (a *Analyzer) newCreateViewStmtAction(ctx context.Context, _ string, _ []dr }, nil } +//nolint:unparam +func (a *Analyzer) newCreateSchemaStmtAction(_ context.Context, query string, _ []driver.NamedValue, node *ast.CreateSchemaStmtNode) (*CreateSchemaStmtAction, error) { + spec := newSchemaSpec(a.namePath, node) + return &CreateSchemaStmtAction{ + query: query, + spec: spec, + catalog: a.catalog, + }, nil +} + func (a *Analyzer) resultTypeIsTemplatedType(sig *types.FunctionSignature) bool { if !sig.IsTemplated() { return false diff --git a/internal/catalog.go b/internal/catalog.go index c626dd1..a5eb09f 100644 --- a/internal/catalog.go +++ b/internal/catalog.go @@ -51,6 +51,7 @@ const ( TableSpecKind CatalogSpecKind = "table" ViewSpecKind CatalogSpecKind = "view" FunctionSpecKind CatalogSpecKind = "function" + SchemaSpecKind CatalogSpecKind = "schema" catalogName = "zetasqlite" ) @@ -60,9 +61,11 @@ type Catalog struct { mu sync.Mutex tables []*TableSpec functions []*FunctionSpec + schemas []*SchemaSpec catalog *types.SimpleCatalog tableMap map[string]*TableSpec funcMap map[string]*FunctionSpec + schemaMap map[string]*SchemaSpec } func newSimpleCatalog(name string) *types.SimpleCatalog { @@ -73,10 +76,11 @@ func newSimpleCatalog(name string) *types.SimpleCatalog { func NewCatalog(db *sql.DB) *Catalog { return &Catalog{ - db: db, - catalog: newSimpleCatalog(catalogName), - tableMap: map[string]*TableSpec{}, - funcMap: map[string]*FunctionSpec{}, + db: db, + catalog: newSimpleCatalog(catalogName), + tableMap: map[string]*TableSpec{}, + funcMap: map[string]*FunctionSpec{}, + schemaMap: map[string]*SchemaSpec{}, } } @@ -147,23 +151,22 @@ func (c *Catalog) SuggestConstant(mistypedPath []string) string { return c.catalog.SuggestConstant(mistypedPath) } -func (c *Catalog) formatNamePath(path []string) string { - return strings.Join(path, "_") +func (c *Catalog) formatNamePath(path *NamePath) string { + return path.CatalogPath() } func (c *Catalog) getFunctions(namePath *NamePath) []*FunctionSpec { - if namePath.empty() { + if namePath.Empty() { return c.functions } - key := c.formatNamePath(namePath.path) + key := c.formatNamePath(namePath) specs := make([]*FunctionSpec, 0, len(c.functions)) for _, fn := range c.functions { - if len(fn.NamePath) == 1 { - // function name only + if fn.NamePath.HasSimpleName() { specs = append(specs, fn) continue } - pathPrefixKey := c.formatNamePath(c.trimmedLastPath(fn.NamePath)) + pathPrefixKey := c.formatNamePath(c.trimmedLastPath(&fn.NamePath)) if strings.Contains(pathPrefixKey, key) { specs = append(specs, fn) } @@ -209,6 +212,10 @@ func (c *Catalog) Sync(ctx context.Context, conn *Conn) error { if err := c.loadFunctionSpec(spec); err != nil { return fmt.Errorf("failed to load function spec: %w", err) } + case SchemaSpecKind: + if err := c.loadSchemaSpec(spec); err != nil { + return fmt.Errorf("failed to load schema spec: %w", err) + } default: return fmt.Errorf("unknown catalog spec kind %s", kind) } @@ -247,6 +254,32 @@ func (c *Catalog) AddNewFunctionSpec(ctx context.Context, conn *Conn, spec *Func return nil } +func (c *Catalog) AddNewSchemaSpec(ctx context.Context, conn *Conn, spec *SchemaSpec) error { + c.mu.Lock() + defer c.mu.Unlock() + + if err := c.saveSchemaSpec(ctx, conn, spec); err != nil { + return fmt.Errorf("failed to save schema spec: %w", err) + } + if err := c.addSchemaSpec(spec); err != nil { + return fmt.Errorf("failed to add schema spec: %w", err) + } + return nil +} + +func (c *Catalog) DeleteSchemaSpec(ctx context.Context, conn *Conn, name string) error { + c.mu.Lock() + defer c.mu.Unlock() + + if err := c.deleteSchemaSpecByName(name); err != nil { + return err + } + if _, err := conn.ExecContext(ctx, deleteCatalogQuery, sql.Named("name", name)); err != nil { + return err + } + return nil +} + func (c *Catalog) DeleteTableSpec(ctx context.Context, conn *Conn, name string) error { c.mu.Lock() defer c.mu.Unlock() @@ -279,9 +312,9 @@ func (c *Catalog) deleteTableSpecByName(name string) error { return fmt.Errorf("failed to find table spec from map by %s", name) } tables := make([]*TableSpec, 0, len(c.tables)) - specName := c.formatNamePath(spec.NamePath) + specName := c.formatNamePath(&spec.NamePath) for _, table := range c.tables { - if specName == c.formatNamePath(table.NamePath) { + if specName == c.formatNamePath(&table.NamePath) { continue } tables = append(tables, table) @@ -298,9 +331,9 @@ func (c *Catalog) deleteFunctionSpecByName(name string) error { return fmt.Errorf("failed to find function spec from map by %s", name) } functions := make([]*FunctionSpec, 0, len(c.functions)) - specName := c.formatNamePath(spec.NamePath) + specName := c.formatNamePath(&spec.NamePath) for _, function := range c.functions { - if specName == c.formatNamePath(function.NamePath) { + if specName == c.formatNamePath(&function.NamePath) { continue } functions = append(functions, function) @@ -311,12 +344,29 @@ func (c *Catalog) deleteFunctionSpecByName(name string) error { return nil } +func (c *Catalog) deleteSchemaSpecByName(name string) error { + spec, exists := c.schemaMap[name] + if !exists { + return nil + } + delete(c.schemaMap, name) + for i, s := range c.schemas { + if s == spec { + c.schemas = append(c.schemas[:i], c.schemas[i+1:]...) + break + } + } + return nil +} + func (c *Catalog) resetCatalog(tables []*TableSpec, functions []*FunctionSpec) error { c.catalog = newSimpleCatalog(catalogName) c.tables = []*TableSpec{} c.functions = []*FunctionSpec{} + c.schemas = []*SchemaSpec{} c.tableMap = map[string]*TableSpec{} c.funcMap = map[string]*FunctionSpec{} + c.schemaMap = map[string]*SchemaSpec{} for _, spec := range tables { if err := c.addTableSpec(spec); err != nil { return err @@ -374,6 +424,26 @@ func (c *Catalog) saveFunctionSpec(ctx context.Context, conn *Conn, spec *Functi return nil } +func (c *Catalog) saveSchemaSpec(ctx context.Context, conn *Conn, spec *SchemaSpec) error { + if err := c.createCatalogTablesIfNotExists(ctx, conn); err != nil { + return err + } + data, err := json.Marshal(spec) + if err != nil { + return fmt.Errorf("failed to encode schema spec: %w", err) + } + if _, err := conn.ExecContext(ctx, upsertCatalogQuery, + spec.SchemaName(), + SchemaSpecKind, + string(data), + spec.UpdatedAt, + spec.CreatedAt, + ); err != nil { + return fmt.Errorf("failed to save schema spec: %w", err) + } + return nil +} + func (c *Catalog) createCatalogTablesIfNotExists(ctx context.Context, conn *Conn) error { if _, err := conn.ExecContext(ctx, createCatalogTableQuery); err != nil { return fmt.Errorf("failed to create catalog table: %w", err) @@ -403,11 +473,19 @@ func (c *Catalog) loadFunctionSpec(spec string) error { return nil } -func (c *Catalog) trimmedLastPath(path []string) []string { - if len(path) == 0 { - return path +func (c *Catalog) loadSchemaSpec(spec string) error { + var s SchemaSpec + if err := json.Unmarshal([]byte(spec), &s); err != nil { + return fmt.Errorf("failed to decode schema spec: %w", err) } - return path[:len(path)-1] + if err := c.addSchemaSpec(&s); err != nil { + return fmt.Errorf("failed to add schema spec: %w", err) + } + return nil +} + +func (c *Catalog) trimmedLastPath(path *NamePath) *NamePath { + return path.dropLast() } func (c *Catalog) addFunctionSpec(spec *FunctionSpec) error { @@ -438,15 +516,27 @@ func (c *Catalog) addTableSpec(spec *TableSpec) error { return nil } +func (c *Catalog) addSchemaSpec(spec *SchemaSpec) error { + name := spec.SchemaName() + if _, exists := c.schemaMap[name]; exists { + if err := c.deleteSchemaSpecByName(name); err != nil { + return err + } + } + c.schemas = append(c.schemas, spec) + c.schemaMap[name] = spec + return c.addSchemaSpecRecursive(c.catalog, spec) +} + func (c *Catalog) addTableSpecRecursive(cat *types.SimpleCatalog, spec *TableSpec) error { - if len(spec.NamePath) > 1 { - subCatalogName := spec.NamePath[0] + if catalogId := spec.NamePath.GetCatalogId(); catalogId != "" { + subCatalogName := catalogId subCatalog, _ := cat.Catalog(subCatalogName) if subCatalog == nil { subCatalog = newSimpleCatalog(subCatalogName) cat.AddCatalog(subCatalog) } - fullTableName := strings.Join(spec.NamePath, ".") + fullTableName := spec.NamePath.CatalogPath() if !c.existsTable(cat, fullTableName) { table, err := c.createSimpleTable(fullTableName, spec) if err != nil { @@ -454,7 +544,7 @@ func (c *Catalog) addTableSpecRecursive(cat *types.SimpleCatalog, spec *TableSpe } cat.AddTable(table) } - newNamePath := spec.NamePath[1:] + newNamePath := spec.NamePath.dropFirst() // add sub catalog to root catalog if err := c.addTableSpecRecursive(cat, c.copyTableSpec(spec, newNamePath)); err != nil { return fmt.Errorf("failed to add table spec to root catalog: %w", err) @@ -465,11 +555,11 @@ func (c *Catalog) addTableSpecRecursive(cat *types.SimpleCatalog, spec *TableSpe } return nil } - if len(spec.NamePath) == 0 { + if spec.NamePath.Empty() { return fmt.Errorf("table name is not found") } - tableName := spec.NamePath[0] + tableName := spec.NamePath.GetObjectId() if c.existsTable(cat, tableName) { return nil } @@ -496,14 +586,14 @@ func (c *Catalog) createSimpleTable(tableName string, spec *TableSpec) (*types.S } func (c *Catalog) addFunctionSpecRecursive(cat *types.SimpleCatalog, spec *FunctionSpec) error { - if len(spec.NamePath) > 1 { - subCatalogName := spec.NamePath[0] + if catalogId := spec.NamePath.GetCatalogId(); catalogId != "" { + subCatalogName := catalogId subCatalog, _ := cat.Catalog(subCatalogName) if subCatalog == nil { subCatalog = newSimpleCatalog(subCatalogName) cat.AddCatalog(subCatalog) } - newNamePath := spec.NamePath[1:] + newNamePath := spec.NamePath.dropFirst() // add sub catalog to root catalog if err := c.addFunctionSpecRecursive(cat, c.copyFunctionSpec(spec, newNamePath)); err != nil { return fmt.Errorf("failed to add function spec to root catalog: %w", err) @@ -514,11 +604,11 @@ func (c *Catalog) addFunctionSpecRecursive(cat *types.SimpleCatalog, spec *Funct } return nil } - if len(spec.NamePath) == 0 { + if spec.NamePath.Empty() { return fmt.Errorf("function name is not found") } - funcName := spec.NamePath[0] + funcName := spec.NamePath.GetObjectId() if c.existsFunction(cat, funcName) { return nil } @@ -540,6 +630,29 @@ func (c *Catalog) addFunctionSpecRecursive(cat *types.SimpleCatalog, spec *Funct return nil } +func (c *Catalog) addSchemaSpecRecursive(cat *types.SimpleCatalog, spec *SchemaSpec) error { + if spec.NamePath.Empty() { + return nil + } + name := spec.NamePath.path[0] + if c.existsSchema(cat, name) { + return nil + } + var subCat *types.SimpleCatalog + if !c.existsSchema(cat, name) { + subCat = types.NewSimpleCatalog(name) + cat.AddCatalog(subCat) + } else { + schema, _ := cat.Catalog(name) + subCat = schema + } + return c.addSchemaSpecRecursive(subCat, &SchemaSpec{ + NamePath: NamePath{path: spec.NamePath.path[1:]}, + UpdatedAt: spec.UpdatedAt, + CreatedAt: spec.CreatedAt, + }) +} + func (c *Catalog) existsTable(cat *types.SimpleCatalog, name string) bool { foundTable, _ := cat.FindTable([]string{name}) return !c.isNilTable(foundTable) @@ -550,6 +663,11 @@ func (c *Catalog) existsFunction(cat *types.SimpleCatalog, name string) bool { return foundFunc != nil } +func (c *Catalog) existsSchema(cat *types.SimpleCatalog, name string) bool { + schema, _ := cat.Catalog(name) + return schema != nil +} + func (c *Catalog) isNilTable(t types.Table) bool { v := reflect.ValueOf(t) if !v.IsValid() { @@ -558,17 +676,17 @@ func (c *Catalog) isNilTable(t types.Table) bool { return v.IsNil() } -func (c *Catalog) copyTableSpec(spec *TableSpec, newNamePath []string) *TableSpec { +func (c *Catalog) copyTableSpec(spec *TableSpec, newNamePath *NamePath) *TableSpec { return &TableSpec{ - NamePath: newNamePath, + NamePath: *newNamePath, Columns: spec.Columns, CreateMode: spec.CreateMode, } } -func (c *Catalog) copyFunctionSpec(spec *FunctionSpec, newNamePath []string) *FunctionSpec { +func (c *Catalog) copyFunctionSpec(spec *FunctionSpec, newPath *NamePath) *FunctionSpec { return &FunctionSpec{ - NamePath: newNamePath, + NamePath: *newPath, Language: spec.Language, Args: spec.Args, Return: spec.Return, diff --git a/internal/conn.go b/internal/conn.go index 2f7c2bb..f603e27 100644 --- a/internal/conn.go +++ b/internal/conn.go @@ -8,17 +8,28 @@ import ( type ChangedCatalog struct { Table *ChangedTable Function *ChangedFunction + Schema *ChangedSchema } func newChangedCatalog() *ChangedCatalog { return &ChangedCatalog{ Table: &ChangedTable{}, Function: &ChangedFunction{}, + Schema: &ChangedSchema{}, } } func (c *ChangedCatalog) Changed() bool { - return c.Table.Changed() || c.Function.Changed() + return c.Table.Changed() || c.Function.Changed() || c.Schema.Changed() +} + +type ChangedSchema struct { + Added []*SchemaSpec + Deleted []*SchemaSpec +} + +func (s *ChangedSchema) Changed() bool { + return len(s.Added) != 0 || len(s.Deleted) != 0 } type ChangedTable struct { @@ -75,6 +86,16 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args ...interface return c.conn.QueryContext(ctx, query, args...) } +func (c *Conn) addSchema(spec *SchemaSpec) { + c.removeFromDeletedSchemasIfExists(spec) + c.cc.Schema.Added = append(c.cc.Schema.Added, spec) +} + +func (c *Conn) deleteSchema(spec *SchemaSpec) { + c.removeFromAddedSchemasIfExists(spec) + c.cc.Schema.Deleted = append(c.cc.Schema.Deleted, spec) +} + func (c *Conn) addTable(spec *TableSpec) { c.removeFromDeletedTablesIfExists(spec) c.cc.Table.Added = append(c.cc.Table.Added, spec) @@ -100,6 +121,28 @@ func (c *Conn) deleteFunction(spec *FunctionSpec) { c.cc.Function.Deleted = append(c.cc.Function.Deleted, spec) } +func (c *Conn) removeFromDeletedSchemasIfExists(spec *SchemaSpec) { + schemas := make([]*SchemaSpec, 0, len(c.cc.Schema.Deleted)) + for _, schema := range c.cc.Schema.Deleted { + if schema.SchemaName() == spec.SchemaName() { + continue + } + schemas = append(schemas, schema) + } + c.cc.Schema.Deleted = schemas +} + +func (c *Conn) removeFromAddedSchemasIfExists(spec *SchemaSpec) { + schemas := make([]*SchemaSpec, 0, len(c.cc.Schema.Added)) + for _, schema := range c.cc.Schema.Added { + if schema.SchemaName() == spec.SchemaName() { + continue + } + schemas = append(schemas, schema) + } + c.cc.Schema.Added = schemas +} + func (c *Conn) removeFromDeletedTablesIfExists(spec *TableSpec) { tables := make([]*TableSpec, 0, len(c.cc.Table.Deleted)) for _, table := range c.cc.Table.Deleted { diff --git a/internal/name_path.go b/internal/name_path.go index 1813ef4..ff8a9ee 100644 --- a/internal/name_path.go +++ b/internal/name_path.go @@ -3,6 +3,8 @@ package internal import ( "fmt" "strings" + + "github.com/goccy/go-json" ) type NamePath struct { @@ -10,6 +12,18 @@ type NamePath struct { maxNum int } +func (p *NamePath) Clone() *NamePath { + return &NamePath{path: append([]string{}, p.path...)} +} + +func (p *NamePath) MarshalJSON() ([]byte, error) { + return json.Marshal(p.path) +} + +func (p *NamePath) UnmarshalJSON(data []byte) error { + return json.Unmarshal(data, &p.path) +} + func (p *NamePath) isInformationSchema(path []string) bool { if len(path) == 0 { return false @@ -50,14 +64,14 @@ func (p *NamePath) normalizePath(path []string) []string { return ret } -func (p *NamePath) mergePath(path []string) []string { +func (p *NamePath) mergePath(path []string) *NamePath { path = p.normalizePath(path) maxNum := p.getMaxNum(path) - if maxNum > 0 && len(path) == maxNum { - return path + if maxNum > 0 && p.hasMaxComponents(path) { + return &NamePath{path: path} } if len(path) == 0 { - return p.path + return &NamePath{path: p.path} } merged := []string{} for _, basePath := range p.path { @@ -69,15 +83,41 @@ func (p *NamePath) mergePath(path []string) []string { } merged = append(merged, basePath) } - return append(merged, path...) + return &NamePath{path: append(merged, path...)} } func (p *NamePath) format(path []string) string { - return formatPath(p.mergePath(path)) + mergedPath := p.mergePath(path) + return mergedPath.CatalogPath() +} + +func (p *NamePath) dropFirst() *NamePath { + if p.Empty() { + return p + } + return &NamePath{path: p.path[1:]} } -func formatPath(path []string) string { - return strings.Join(path, "_") +func (p *NamePath) dropLast() *NamePath { + if p.Empty() { + return p + } + return &NamePath{path: p.path[:len(p.path)-1]} +} + +func (p *NamePath) CatalogPath() string { + return strings.Join(p.path, "_") +} + +func (p *NamePath) FormatNamePath() string { + if p.HasFullyQualifiedName() { + return fmt.Sprintf("%s.%s", p.GetProjectId(), strings.Join(p.path[1:], "_")) + } + return strings.Join(p.path, "_") +} + +func (p *NamePath) Path() []string { + return p.path } func (p *NamePath) setPath(path []string) error { @@ -106,6 +146,66 @@ func (p *NamePath) addPath(path string) error { return nil } -func (p *NamePath) empty() bool { - return len(p.path) == 0 +func (p *NamePath) replace(index int, value string) { + p.path[index] = value +} + +func (p *NamePath) Length() int { + return len(p.path) +} + +func (p *NamePath) Empty() bool { + return p.Length() == 0 +} + +func (p *NamePath) GetCatalogId() string { + if p.Length() < 2 { + return "" + } + return p.path[0] +} + +func (p *NamePath) GetProjectId() string { + if p.Length() < 3 { + return "" + } + return p.path[0] +} + +func (p *NamePath) GetDatasetId() string { + if p.Length() < 2 { + return "" + } else if p.Length() == 2 { + return p.path[0] + } else { + return p.path[1] + } +} + +func (p *NamePath) GetObjectId() string { + if p.Empty() { + return "" + } + return p.path[p.Length()-1] +} + +func (p *NamePath) HasSimpleName() bool { + return p.Length() == 1 +} + +func (p *NamePath) HasQualifiers() bool { + return p.Length() > 1 +} + +func (p *NamePath) HasFullyQualifiedName() bool { + return p.Length() > 2 +} + +func (p *NamePath) hasMaxComponents(path []string) bool { + maxNum := p.getMaxNum(path) + return maxNum > 0 && len(path) == maxNum +} + +func NewNamePath(path []string) *NamePath { + return &NamePath{path: path} } diff --git a/internal/name_path_test.go b/internal/name_path_test.go index 0de21a6..b1396a9 100644 --- a/internal/name_path_test.go +++ b/internal/name_path_test.go @@ -12,28 +12,28 @@ func TestNamePath(t *testing.T) { t.Fatal(err) } namePath.setMaxNum(3) - if diff := cmp.Diff(namePath.mergePath([]string{"project1", "dataset1", "table1"}), []string{"project1", "dataset1", "table1"}); diff != "" { + if diff := cmp.Diff(namePath.mergePath([]string{"project1", "dataset1", "table1"}).Path(), []string{"project1", "dataset1", "table1"}); diff != "" { t.Errorf("(-want +got):\n%s", diff) } - if diff := cmp.Diff(namePath.mergePath([]string{"dataset1", "table1"}), []string{"project1", "dataset1", "table1"}); diff != "" { + if diff := cmp.Diff(namePath.mergePath([]string{"dataset1", "table1"}).Path(), []string{"project1", "dataset1", "table1"}); diff != "" { t.Errorf("(-want +got):\n%s", diff) } - if diff := cmp.Diff(namePath.mergePath([]string{"project2", "dataset2", "table1"}), []string{"project2", "dataset2", "table1"}); diff != "" { + if diff := cmp.Diff(namePath.mergePath([]string{"project2", "dataset2", "table1"}).Path(), []string{"project2", "dataset2", "table1"}); diff != "" { t.Errorf("(-want +got):\n%s", diff) } - if diff := cmp.Diff(namePath.mergePath([]string{"dataset2", "table1"}), []string{"project1", "dataset2", "table1"}); diff != "" { + if diff := cmp.Diff(namePath.mergePath([]string{"dataset2", "table1"}).Path(), []string{"project1", "dataset2", "table1"}); diff != "" { t.Errorf("(-want +got):\n%s", diff) } - if diff := cmp.Diff(namePath.mergePath([]string{"table1"}), []string{"project1", "dataset1", "table1"}); diff != "" { + if diff := cmp.Diff(namePath.mergePath([]string{"table1"}).Path(), []string{"project1", "dataset1", "table1"}); diff != "" { t.Errorf("(-want +got):\n%s", diff) } - if diff := cmp.Diff(namePath.mergePath([]string{"project2", "dataset2", "INFORMATION_SCHEMA", "TABLES"}), []string{"project2", "dataset2", "INFORMATION_SCHEMA", "TABLES"}); diff != "" { + if diff := cmp.Diff(namePath.mergePath([]string{"project2", "dataset2", "INFORMATION_SCHEMA", "TABLES"}).Path(), []string{"project2", "dataset2", "INFORMATION_SCHEMA", "TABLES"}); diff != "" { t.Errorf("(-want +got):\n%s", diff) } - if diff := cmp.Diff(namePath.mergePath([]string{"dataset2", "INFORMATION_SCHEMA", "TABLES"}), []string{"project1", "dataset2", "INFORMATION_SCHEMA", "TABLES"}); diff != "" { + if diff := cmp.Diff(namePath.mergePath([]string{"dataset2", "INFORMATION_SCHEMA", "TABLES"}).Path(), []string{"project1", "dataset2", "INFORMATION_SCHEMA", "TABLES"}); diff != "" { t.Errorf("(-want +got):\n%s", diff) } - if diff := cmp.Diff(namePath.mergePath([]string{"INFORMATION_SCHEMA", "TABLES"}), []string{"project1", "dataset1", "INFORMATION_SCHEMA", "TABLES"}); diff != "" { + if diff := cmp.Diff(namePath.mergePath([]string{"INFORMATION_SCHEMA", "TABLES"}).Path(), []string{"project1", "dataset1", "INFORMATION_SCHEMA", "TABLES"}); diff != "" { t.Errorf("(-want +got):\n%s", diff) } } diff --git a/internal/spec.go b/internal/spec.go index fb1c831..2d1a443 100644 --- a/internal/spec.go +++ b/internal/spec.go @@ -35,7 +35,7 @@ func (t *NameWithType) FunctionArgumentType() (*types.FunctionArgumentType, erro type FunctionSpec struct { IsTemp bool `json:"isTemp"` - NamePath []string `json:"name"` + NamePath NamePath `json:"name"` Language string `json:"language"` Args []*NameWithType `json:"args"` Return *Type `json:"return"` @@ -46,7 +46,7 @@ type FunctionSpec struct { } func (s *FunctionSpec) FuncName() string { - return formatPath(s.NamePath) + return s.NamePath.CatalogPath() } func (s *FunctionSpec) SQL() string { @@ -78,7 +78,7 @@ func (s *FunctionSpec) CallSQL(ctx context.Context, callNode *ast.BaseFunctionCa fmt.Sprintf("%s %s", s.Args[idx].Name, typeName), ) } - funcName := strings.Join(s.NamePath, ".") + funcName := s.NamePath.FormatNamePath() runtimeDefinedFunc := fmt.Sprintf( "CREATE FUNCTION `%s`(%s) as (%s)", funcName, @@ -105,7 +105,7 @@ func (s *FunctionSpec) CallSQL(ctx context.Context, callNode *ast.BaseFunctionCa type TableSpec struct { IsTemp bool `json:"isTemp"` IsView bool `json:"isView"` - NamePath []string `json:"namePath"` + NamePath NamePath `json:"namePath"` Columns []*ColumnSpec `json:"columns"` PrimaryKey []string `json:"primaryKey"` CreateMode ast.CreateMode `json:"createMode"` @@ -124,7 +124,7 @@ func (s *TableSpec) Column(name string) *ColumnSpec { } func (s *TableSpec) TableName() string { - return formatPath(s.NamePath) + return s.NamePath.CatalogPath() } func (s *TableSpec) SQLiteSchema() string { @@ -327,6 +327,29 @@ func (s *ColumnSpec) SQLiteSchema() string { return schema } +type SchemaSpec struct { + NamePath NamePath `json:"namePath"` + CreateMode ast.CreateMode `json:"createMode"` + UpdatedAt time.Time `json:"updatedAt"` + CreatedAt time.Time `json:"createdAt"` +} + +func (s *SchemaSpec) SchemaName() string { + return s.NamePath.GetObjectId() + +} + +func newSchemaSpec(namePath *NamePath, stmt *ast.CreateSchemaStmtNode) *SchemaSpec { + schemaNamePath := stmt.NamePath() + now := time.Now() + return &SchemaSpec{ + NamePath: *namePath.mergePath(schemaNamePath), + CreateMode: stmt.CreateMode(), + UpdatedAt: now, + CreatedAt: now, + } +} + func newTypeFromFunctionArgumentType(t *types.FunctionArgumentType) *Type { if t.IsTemplated() { return &Type{SignatureKind: t.Kind()} @@ -392,7 +415,7 @@ func newFunctionSpec(ctx context.Context, namePath *NamePath, stmt *ast.CreateFu now := time.Now() return &FunctionSpec{ IsTemp: stmt.CreateScope() == ast.CreateScopeTemp, - NamePath: namePath.mergePath(stmt.NamePath()), + NamePath: *namePath.mergePath(stmt.NamePath()), Args: args, Return: newType(stmt.ReturnType()), Code: stmt.Code(), @@ -460,7 +483,7 @@ func newTemplatedFunctionSpec(ctx context.Context, namePath *NamePath, stmt *ast now := time.Now() return &FunctionSpec{ IsTemp: stmt.CreateScope() == ast.CreateScopeTemp, - NamePath: namePath.mergePath(stmt.NamePath()), + NamePath: *namePath.mergePath(stmt.NamePath()), Args: args, Return: retType, Code: stmt.Code(), @@ -517,7 +540,7 @@ func newTableSpec(namePath *NamePath, stmt *ast.CreateTableStmtNode) *TableSpec now := time.Now() return &TableSpec{ IsTemp: stmt.CreateScope() == ast.CreateScopeTemp, - NamePath: namePath.mergePath(stmt.NamePath()), + NamePath: *namePath.mergePath(stmt.NamePath()), Columns: newColumnsFromDef(stmt.ColumnDefinitionList()), PrimaryKey: newPrimaryKey(stmt.PrimaryKey()), CreateMode: stmt.CreateMode(), @@ -541,7 +564,7 @@ func newTableAsViewSpec(namePath *NamePath, query string, stmt *ast.CreateViewSt return &TableSpec{ IsTemp: stmt.CreateScope() == ast.CreateScopeTemp, IsView: true, - NamePath: namePath.mergePath(stmt.NamePath()), + NamePath: *namePath.mergePath(stmt.NamePath()), Columns: newColumnsFromOutputColumns(stmt.OutputColumnList()), CreateMode: stmt.CreateMode(), Query: fmt.Sprintf("SELECT %s FROM (%s)", strings.Join(outputColumns, ","), query), @@ -564,7 +587,7 @@ func newTableAsSelectSpec(namePath *NamePath, query string, stmt *ast.CreateTabl now := time.Now() return &TableSpec{ IsTemp: stmt.CreateScope() == ast.CreateScopeTemp, - NamePath: namePath.mergePath(stmt.NamePath()), + NamePath: *namePath.mergePath(stmt.NamePath()), Columns: newColumnsFromDef(stmt.ColumnDefinitionList()), PrimaryKey: newPrimaryKey(stmt.PrimaryKey()), CreateMode: stmt.CreateMode(), diff --git a/internal/stmt.go b/internal/stmt.go index 1aaba39..d0020f6 100644 --- a/internal/stmt.go +++ b/internal/stmt.go @@ -14,6 +14,7 @@ var ( _ driver.Stmt = &CreateFunctionStmt{} _ driver.Stmt = &DMLStmt{} _ driver.Stmt = &QueryStmt{} + _ driver.Stmt = &CreateSchemaStmt{} ) type CreateTableStmt struct { @@ -125,6 +126,51 @@ func newCreateFunctionStmt(conn *Conn, catalog *Catalog, spec *FunctionSpec) *Cr } } +type CreateSchemaStmt struct { + conn *Conn + catalog *Catalog + spec *SchemaSpec +} + +func newCreateSchemaStmt(conn *Conn, catalog *Catalog, spec *SchemaSpec) *CreateSchemaStmt { + return &CreateSchemaStmt{ + conn: conn, + catalog: catalog, + spec: spec, + } +} + +func (s *CreateSchemaStmt) Close() error { + return nil +} + +func (s *CreateSchemaStmt) NumInput() int { + return 0 +} + +func (s *CreateSchemaStmt) Exec(args []driver.Value) (driver.Result, error) { + if err := s.catalog.AddNewSchemaSpec(context.Background(), s.conn, s.spec); err != nil { + return nil, fmt.Errorf("failed to add new schema spec: %w", err) + } + return nil, nil +} + +func (s *CreateSchemaStmt) Query(args []driver.Value) (driver.Rows, error) { + return nil, fmt.Errorf("failed to query for CreateSchemaStmt") +} + +func (s *CreateSchemaStmt) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + return nil, fmt.Errorf("unimplemented ExecContext for CreateSchemaStmt") +} + +func (s *CreateSchemaStmt) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + return nil, fmt.Errorf("unsupported query for CreateSchemaStmt") +} + +func (s *CreateSchemaStmt) CheckNamedValue(value *driver.NamedValue) error { + return nil +} + type DMLStmt struct { stmt *sql.Stmt args []*ast.ParameterNode diff --git a/internal/stmt_action.go b/internal/stmt_action.go index 8dee05e..b19e6d6 100644 --- a/internal/stmt_action.go +++ b/internal/stmt_action.go @@ -4,7 +4,6 @@ import ( "context" "database/sql/driver" "fmt" - "strings" ast "github.com/goccy/go-zetasql/resolved_ast" ) @@ -46,7 +45,7 @@ func (a *CreateTableStmtAction) createIndexAutomatically(ctx context.Context, co if !col.Type.AvailableAutoIndex() { continue } - indexName := fmt.Sprintf("zetasqlite_autoindex_%s_%s", col.Name, strings.Join(a.spec.NamePath, "_")) + indexName := fmt.Sprintf("zetasqlite_autoindex_%s_%s", col.Name, a.spec.NamePath.CatalogPath()) createIndexQuery := fmt.Sprintf( "CREATE INDEX IF NOT EXISTS %s ON `%s`(`%s`)", indexName, @@ -246,6 +245,45 @@ func (a *CreateFunctionStmtAction) Cleanup(ctx context.Context, conn *Conn) erro return nil } +type CreateSchemaStmtAction struct { + query string + spec *SchemaSpec + catalog *Catalog +} + +func (a *CreateSchemaStmtAction) Args() []interface{} { + return nil +} + +func (a *CreateSchemaStmtAction) Prepare(ctx context.Context, conn *Conn) (driver.Stmt, error) { + return newCreateSchemaStmt(conn, a.catalog, a.spec), nil +} + +func (a *CreateSchemaStmtAction) ExecContext(ctx context.Context, conn *Conn) (driver.Result, error) { + if err := a.exec(ctx, conn); err != nil { + return nil, err + } + return &Result{conn: conn}, nil +} + +func (a *CreateSchemaStmtAction) QueryContext(ctx context.Context, conn *Conn) (*Rows, error) { + if err := a.exec(ctx, conn); err != nil { + return nil, err + } + return &Rows{conn: conn}, nil +} + +func (a *CreateSchemaStmtAction) exec(ctx context.Context, conn *Conn) error { + if err := a.catalog.AddNewSchemaSpec(ctx, conn, a.spec); err != nil { + return fmt.Errorf("failed to add new schema spec: %w", err) + } + return nil +} + +func (a *CreateSchemaStmtAction) Cleanup(ctx context.Context, conn *Conn) error { + return nil +} + type DropStmtAction struct { name string objectType string @@ -273,6 +311,12 @@ func (a *DropStmtAction) exec(ctx context.Context, conn *Conn) error { } conn.deleteFunction(a.funcMap[a.name]) delete(a.funcMap, a.name) + case "SCHEMA": + spec := a.catalog.schemaMap[a.name] + if err := a.catalog.DeleteSchemaSpec(ctx, conn, a.name); err != nil { + return fmt.Errorf("failed to delete schema spec: %w", err) + } + conn.deleteSchema(spec) default: return fmt.Errorf("currently unsupported DROP %s statement", a.objectType) } diff --git a/internal/wildcard_table.go b/internal/wildcard_table.go index a926c27..a51e0b1 100644 --- a/internal/wildcard_table.go +++ b/internal/wildcard_table.go @@ -54,7 +54,7 @@ func (t *WildcardTable) FormatSQL(ctx context.Context) (string, error) { columns = append(columns, fmt.Sprintf("NULL as %s", column.Name)) } } - fullName := strings.Join(table.NamePath, ".") + fullName := table.NamePath.FormatNamePath() if len(fullName) <= len(t.prefix) { return "", fmt.Errorf("failed to find table suffix from %s", fullName) } @@ -77,7 +77,7 @@ func (t *WildcardTable) FormatSQL(ctx context.Context) (string, error) { } func (t *WildcardTable) Name() string { - return strings.Join(t.spec.NamePath, ".") + return t.spec.NamePath.FormatNamePath() } func (t *WildcardTable) FullName() string { @@ -95,7 +95,7 @@ func (t *WildcardTable) Column(idx int) types.Column { return nil } return types.NewSimpleColumn( - strings.Join(t.spec.NamePath, "."), column.Name, typ, + t.spec.NamePath.FormatNamePath(), column.Name, typ, ) } @@ -165,20 +165,19 @@ func (c *Catalog) createWildcardTable(path []string) (types.Table, error) { spec := matchedSpecs[0] wildcardTable := new(TableSpec) *wildcardTable = *spec - wildcardTable.NamePath = append([]string{}, spec.NamePath...) + wildcardTable.NamePath = *spec.NamePath.Clone() wildcardTable.Columns = append(wildcardTable.Columns, &ColumnSpec{ Name: tableSuffixColumnName, Type: &Type{Kind: types.STRING}, }) - lastNamePath := spec.NamePath[len(spec.NamePath)-1] - lastNamePath = lastNamePath[:len(path)-1] - wildcardTable.NamePath[len(spec.NamePath)-1] = fmt.Sprintf( + lastNamePath := spec.NamePath.GetObjectId() + wildcardTable.NamePath.replace(spec.NamePath.Length()-1, fmt.Sprintf( "%s_wildcard_%d", lastNamePath, time.Now().Unix(), - ) + )) // firstIdentifier may be omitted, so we need to check it. prefix := name - firstIdentifier := spec.NamePath[0] + firstIdentifier := spec.NamePath.GetProjectId() if !strings.HasPrefix(prefix, firstIdentifier+".") { prefix = firstIdentifier + "." + prefix } diff --git a/test.sqlite3 b/test.sqlite3 new file mode 100644 index 0000000..cd86b54 Binary files /dev/null and b/test.sqlite3 differ diff --git a/test2.sqlite3 b/test2.sqlite3 new file mode 100644 index 0000000..8f68bfa Binary files /dev/null and b/test2.sqlite3 differ