From 93f2513ce442e13603b47406b6103ede2074abca Mon Sep 17 00:00:00 2001 From: Martin Mauch Date: Sun, 15 Dec 2024 08:06:59 +0100 Subject: [PATCH 1/2] Implement CREATE SCHEMA --- exec_test.go | 57 +++++++++++++++++++++++++++++---- internal/analyzer.go | 16 ++++++++++ internal/catalog.go | 58 +++++++++++++++++++++++++++++++--- internal/schema_spec.go | 22 +++++++++++++ internal/spec.go | 13 +++++++- internal/stmt_action_schema.go | 47 +++++++++++++++++++++++++++ 6 files changed, 201 insertions(+), 12 deletions(-) create mode 100644 internal/schema_spec.go create mode 100644 internal/stmt_action_schema.go diff --git a/exec_test.go b/exec_test.go index cab263f..8b4cdc7 100644 --- a/exec_test.go +++ b/exec_test.go @@ -22,11 +22,12 @@ func TestExec(t *testing.T) { t.Fatal(err) } defer db.Close() - for _, test := range []struct { - name string - query string - args []interface{} - expectedErr bool + tests := []struct { + name string + query string + args []interface{} + wantErr bool + validate func(t *testing.T, db *sql.DB) }{ { name: "create table with all types", @@ -64,6 +65,12 @@ CREATE OR REPLACE TABLE recreate_table ( a string ); DROP TABLE recreate_table; CREATE TABLE recreate_table ( b string ); INSERT recreate_table (b) VALUES ('hello'); +`, + }, + { + name: "create schema", + query: ` +CREATE SCHEMA new_schema; `, }, { @@ -134,11 +141,47 @@ TRUNCATE TABLE tmp; COMMIT TRANSACTION; `, }, - } { + { + name: "create schema", + query: "CREATE SCHEMA test_schema", + wantErr: false, + validate: func(t *testing.T, db *sql.DB) { + // Create a table in the schema + _, err := db.Exec("CREATE TABLE test_schema.test_table (id INT64)") + if err != nil { + t.Errorf("failed to create table in schema: %v", err) + } + // Query the table with schema qualification + rows, err := db.Query("SELECT * FROM test_schema.test_table") + if err != nil { + t.Errorf("failed to query table in schema: %v", err) + } + defer rows.Close() + }, + }, + { + name: "create schema if not exists", + query: "CREATE SCHEMA IF NOT EXISTS test_schema", + wantErr: false, + }, + { + name: "create schema - already exists", + query: "CREATE SCHEMA test_schema", + wantErr: true, + }, + } + for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { if _, err := db.ExecContext(ctx, test.query); err != nil { - t.Fatal(err) + if !test.wantErr { + t.Fatal(err) + } + } else if test.wantErr { + t.Fatal("expected error") + } + if test.validate != nil { + test.validate(t, db) } }) } diff --git a/internal/analyzer.go b/internal/analyzer.go index 7274d29..655b4df 100644 --- a/internal/analyzer.go +++ b/internal/analyzer.go @@ -93,6 +93,7 @@ func newAnalyzerOptions() (*zetasql.AnalyzerOptions, error) { ast.CreateFunctionStmt, ast.CreateTableFunctionStmt, ast.CreateViewStmt, + ast.CreateSchemaStmt, ast.DropFunctionStmt, }) // Enable QUALIFY without WHERE @@ -263,6 +264,8 @@ func (a *Analyzer) analyzeTemplatedFunctionWithRuntimeArgument(ctx context.Conte func (a *Analyzer) newStmtAction(ctx context.Context, query string, args []driver.NamedValue, node ast.StatementNode) (StmtAction, error) { switch node.Kind() { + case ast.CreateSchemaStmt: + return a.newCreateSchemaStmtAction(ctx, query, args, node.(*ast.CreateSchemaStmtNode)) case ast.CreateTableStmt: return a.newCreateTableStmtAction(ctx, query, args, node.(*ast.CreateTableStmtNode)) case ast.CreateTableAsSelectStmt: @@ -745,3 +748,16 @@ func getArgsFromParams(values []driver.NamedValue, params []*ast.ParameterNode) } return args, nil } + +func (a *Analyzer) newCreateSchemaStmtAction(_ context.Context, query string, _ []driver.NamedValue, node *ast.CreateSchemaStmtNode) (*CreateSchemaStmtAction, error) { + schemaName := node.NamePath()[0] + if _, exists := a.catalog.schemaMap[schemaName]; exists && node.CreateMode() != ast.CreateOrReplaceMode { + return nil, fmt.Errorf("schema %s already exists", schemaName) + } + spec := NewSchemaSpec(schemaName) + return &CreateSchemaStmtAction{ + query: query, + spec: spec, + catalog: a.catalog, + }, nil +} diff --git a/internal/catalog.go b/internal/catalog.go index c626dd1..0613d12 100644 --- a/internal/catalog.go +++ b/internal/catalog.go @@ -51,6 +51,7 @@ const ( TableSpecKind CatalogSpecKind = "table" ViewSpecKind CatalogSpecKind = "view" FunctionSpecKind CatalogSpecKind = "function" + SchemaSpecKind CatalogSpecKind = "schema" catalogName = "zetasqlite" ) @@ -61,8 +62,10 @@ type Catalog struct { tables []*TableSpec functions []*FunctionSpec catalog *types.SimpleCatalog + schemas []*SchemaSpec tableMap map[string]*TableSpec funcMap map[string]*FunctionSpec + schemaMap map[string]*SchemaSpec } func newSimpleCatalog(name string) *types.SimpleCatalog { @@ -73,10 +76,11 @@ func newSimpleCatalog(name string) *types.SimpleCatalog { func NewCatalog(db *sql.DB) *Catalog { return &Catalog{ - db: db, - catalog: newSimpleCatalog(catalogName), - tableMap: map[string]*TableSpec{}, - funcMap: map[string]*FunctionSpec{}, + db: db, + catalog: newSimpleCatalog(catalogName), + tableMap: map[string]*TableSpec{}, + funcMap: map[string]*FunctionSpec{}, + schemaMap: map[string]*SchemaSpec{}, } } @@ -209,6 +213,10 @@ func (c *Catalog) Sync(ctx context.Context, conn *Conn) error { if err := c.loadFunctionSpec(spec); err != nil { return fmt.Errorf("failed to load function spec: %w", err) } + case SchemaSpecKind: + if err := c.loadSchemaSpec(spec); err != nil { + return fmt.Errorf("failed to load schema spec: %w", err) + } default: return fmt.Errorf("unknown catalog spec kind %s", kind) } @@ -247,6 +255,18 @@ func (c *Catalog) AddNewFunctionSpec(ctx context.Context, conn *Conn, spec *Func return nil } +func (c *Catalog) AddNewSchemaSpec(ctx context.Context, conn *Conn, spec *SchemaSpec) error { + c.mu.Lock() + defer c.mu.Unlock() + + if err := c.saveSchemaSpec(ctx, conn, spec); err != nil { + return fmt.Errorf("failed to save schema spec: %w", err) + } + c.schemas = append(c.schemas, spec) + c.schemaMap[spec.Name] = spec + return nil +} + func (c *Catalog) DeleteTableSpec(ctx context.Context, conn *Conn, name string) error { c.mu.Lock() defer c.mu.Unlock() @@ -374,6 +394,26 @@ func (c *Catalog) saveFunctionSpec(ctx context.Context, conn *Conn, spec *Functi return nil } +func (c *Catalog) saveSchemaSpec(ctx context.Context, conn *Conn, spec *SchemaSpec) error { + encoded, err := json.Marshal(spec) + if err != nil { + return fmt.Errorf("failed to encode schema spec: %w", err) + } + now := time.Now() + if _, err := conn.ExecContext( + ctx, + upsertCatalogQuery, + sql.Named("name", spec.Name), + sql.Named("kind", string(SchemaSpecKind)), + sql.Named("spec", string(encoded)), + sql.Named("updatedAt", now), + sql.Named("createdAt", now), + ); err != nil { + return fmt.Errorf("failed to save schema spec: %w", err) + } + return nil +} + func (c *Catalog) createCatalogTablesIfNotExists(ctx context.Context, conn *Conn) error { if _, err := conn.ExecContext(ctx, createCatalogTableQuery); err != nil { return fmt.Errorf("failed to create catalog table: %w", err) @@ -403,6 +443,16 @@ func (c *Catalog) loadFunctionSpec(spec string) error { return nil } +func (c *Catalog) loadSchemaSpec(spec string) error { + var schemaSpec SchemaSpec + if err := json.Unmarshal([]byte(spec), &schemaSpec); err != nil { + return fmt.Errorf("failed to decode schema spec: %w", err) + } + c.schemas = append(c.schemas, &schemaSpec) + c.schemaMap[schemaSpec.Name] = &schemaSpec + return nil +} + func (c *Catalog) trimmedLastPath(path []string) []string { if len(path) == 0 { return path diff --git a/internal/schema_spec.go b/internal/schema_spec.go new file mode 100644 index 0000000..9ecbddc --- /dev/null +++ b/internal/schema_spec.go @@ -0,0 +1,22 @@ +package internal + +import ( + "time" +) + +// SchemaSpec represents a schema in the database +type SchemaSpec struct { + Name string `json:"name"` + UpdatedAt time.Time `json:"updatedAt"` + CreatedAt time.Time `json:"createdAt"` +} + +// NewSchemaSpec creates a new schema specification +func NewSchemaSpec(name string) *SchemaSpec { + now := time.Now() + return &SchemaSpec{ + Name: name, + UpdatedAt: now, + CreatedAt: now, + } +} diff --git a/internal/spec.go b/internal/spec.go index fb1c831..595269b 100644 --- a/internal/spec.go +++ b/internal/spec.go @@ -124,7 +124,18 @@ func (s *TableSpec) Column(name string) *ColumnSpec { } func (s *TableSpec) TableName() string { - return formatPath(s.NamePath) + if len(s.NamePath) > 1 { + // First element is schema, rest is table name + return fmt.Sprintf("%s.%s", s.NamePath[0], strings.Join(s.NamePath[1:], "_")) + } + return strings.Join(s.NamePath, "_") +} + +func (s *TableSpec) GetSchema() string { + if len(s.NamePath) > 1 { + return s.NamePath[0] + } + return "" } func (s *TableSpec) SQLiteSchema() string { diff --git a/internal/stmt_action_schema.go b/internal/stmt_action_schema.go new file mode 100644 index 0000000..8c74f8b --- /dev/null +++ b/internal/stmt_action_schema.go @@ -0,0 +1,47 @@ +package internal + +import ( + "context" + "database/sql/driver" + "fmt" + //ast "github.com/goccy/go-zetasql/resolved_ast" +) + +type CreateSchemaStmtAction struct { + query string + spec *SchemaSpec + catalog *Catalog +} + +func (a *CreateSchemaStmtAction) exec(ctx context.Context, conn *Conn) error { + if err := a.catalog.AddNewSchemaSpec(ctx, conn, a.spec); err != nil { + return fmt.Errorf("failed to add new schema spec: %w", err) + } + return nil +} + +func (a *CreateSchemaStmtAction) ExecContext(ctx context.Context, conn *Conn) (driver.Result, error) { + if err := a.exec(ctx, conn); err != nil { + return nil, err + } + return &Result{conn: conn}, nil +} + +func (a *CreateSchemaStmtAction) QueryContext(ctx context.Context, conn *Conn) (*Rows, error) { + if err := a.exec(ctx, conn); err != nil { + return nil, err + } + return &Rows{conn: conn}, nil +} + +func (a *CreateSchemaStmtAction) Prepare(ctx context.Context, conn *Conn) (driver.Stmt, error) { + return nil, fmt.Errorf("prepare not supported for CREATE SCHEMA") +} + +func (a *CreateSchemaStmtAction) Args() []interface{} { + return nil +} + +func (a *CreateSchemaStmtAction) Cleanup(ctx context.Context, conn *Conn) error { + return nil +} From 16feb633608f27c76e006a14d7a3b2fd9c72a8a8 Mon Sep 17 00:00:00 2001 From: Martin Mauch Date: Sun, 15 Dec 2024 08:54:00 +0100 Subject: [PATCH 2/2] Allow writing test DB to disk --- .gitignore | 3 +++ exec_test.go | 43 +++++++++++++++++++------------------------ 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/.gitignore b/.gitignore index 29b269e..ee9ad19 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,9 @@ # Test binary, built with `go test -c` *.test +# Test DBs +*.sqlite3 + # Output of the go coverage tool, specifically when used with LiteIDE *.out diff --git a/exec_test.go b/exec_test.go index 8b4cdc7..4d9ba95 100644 --- a/exec_test.go +++ b/exec_test.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "os" "reflect" "testing" "time" @@ -13,14 +14,23 @@ import ( zetasqlite "github.com/goccy/go-zetasqlite" ) +func openTestDB(t *testing.T) *sql.DB { + dbPath := ":memory:" + if path := os.Getenv("ZETASQLITE_TEST_DB"); path != "" { + dbPath = path + } + db, err := sql.Open("zetasqlite", dbPath) + if err != nil { + t.Fatal(err) + } + return db +} + func TestExec(t *testing.T) { now := time.Now() ctx := context.Background() ctx = zetasqlite.WithCurrentTime(ctx, now) - db, err := sql.Open("zetasqlite", ":memory:") - if err != nil { - t.Fatal(err) - } + db := openTestDB(t) defer db.Close() tests := []struct { name string @@ -191,10 +201,7 @@ func TestNestedStructFieldAccess(t *testing.T) { now := time.Now() ctx := context.Background() ctx = zetasqlite.WithCurrentTime(ctx, now) - db, err := sql.Open("zetasqlite", ":memory:") - if err != nil { - t.Fatal(err) - } + db := openTestDB(t) defer db.Close() if _, err := db.ExecContext(ctx, ` CREATE TABLE table ( @@ -256,10 +263,7 @@ func TestCreateTempTable(t *testing.T) { now := time.Now() ctx := context.Background() ctx = zetasqlite.WithCurrentTime(ctx, now) - db, err := sql.Open("zetasqlite", ":memory:") - if err != nil { - t.Fatal(err) - } + db := openTestDB(t) defer db.Close() if _, err := db.ExecContext(ctx, "CREATE TEMP TABLE tmp_table (id INT64)"); err != nil { t.Fatal(err) @@ -277,10 +281,7 @@ func TestCreateTempTable(t *testing.T) { func TestWildcardTable(t *testing.T) { ctx := context.Background() - db, err := sql.Open("zetasqlite", ":memory:") - if err != nil { - t.Fatal(err) - } + db := openTestDB(t) defer db.Close() if _, err := db.ExecContext( ctx, @@ -384,10 +385,7 @@ func TestWildcardTable(t *testing.T) { func TestTemplatedArgFunc(t *testing.T) { ctx := context.Background() - db, err := sql.Open("zetasqlite", ":memory:") - if err != nil { - t.Fatal(err) - } + db := openTestDB(t) defer db.Close() t.Run("simple any arguments", func(t *testing.T) { if _, err := db.ExecContext( @@ -481,10 +479,7 @@ func TestTemplatedArgFunc(t *testing.T) { func TestJavaScriptUDF(t *testing.T) { ctx := context.Background() - db, err := sql.Open("zetasqlite", ":memory:") - if err != nil { - t.Fatal(err) - } + db := openTestDB(t) defer db.Close() t.Run("operation", func(t *testing.T) { if _, err := db.ExecContext(