From 5d0434461de094e6fbf8311245b2c6b0efc2cb6e Mon Sep 17 00:00:00 2001 From: Bony Roopchandani Date: Sun, 15 Sep 2024 19:07:53 -0400 Subject: [PATCH] Fix MERGE handling named paths --- internal/analyzer.go | 39 +++++++++++++++++++++++++++++++++++---- query_test.go | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 4 deletions(-) diff --git a/internal/analyzer.go b/internal/analyzer.go index 7274d29..ea50f8b 100644 --- a/internal/analyzer.go +++ b/internal/analyzer.go @@ -4,6 +4,7 @@ import ( "context" "database/sql/driver" "fmt" + "regexp" "strings" "github.com/goccy/go-zetasql" @@ -625,6 +626,36 @@ func (a *Analyzer) newMergeStmtAction(ctx context.Context, _ string, args []driv mergedTableSourceColumnName, mergedTableTargetColumnName, ) + + sourceTableName := sourceColumn.TableName() + targetTableName := targetColumn.TableName() + namePath := a.namePath.path + // prepend projectID and datasetID to source and target table names when applicable + if len(namePath) > 0 { + var projectID = "" + var datasetID = "" + + if len(namePath) == 2 { + projectID, datasetID = namePath[0], namePath[1] + } else if len(namePath) == 1 { + projectID = namePath[0] + datasetExtractPattern := fmt.Sprintf(`\b%s_([^_]+)_%s\b`, regexp.QuoteMeta(projectID), regexp.QuoteMeta(sourceTableName)) + re, err := regexp.Compile(datasetExtractPattern) + if err != nil { + return nil, err + } + matches := re.FindStringSubmatch(sourceTable) + + if matches == nil || len(matches) < 2 { + return nil, fmt.Errorf("cannot deduce dataset in the merge query") + } + + datasetID = matches[1] + } + sourceTableName = fmt.Sprintf("%s_%s_%s", projectID, datasetID, strings.TrimPrefix(sourceTableName, fmt.Sprintf("%s.%s.", projectID, datasetID))) + targetTableName = fmt.Sprintf("%s_%s_%s", projectID, datasetID, strings.TrimPrefix(targetTableName, fmt.Sprintf("%s.%s.", projectID, datasetID))) + } + for _, when := range node.WhenClauseList() { var fromStmt string switch when.MatchType() { @@ -652,10 +683,10 @@ func (a *Analyzer) newMergeStmtAction(ctx context.Context, _ string, args []driv } stmts = append(stmts, fmt.Sprintf( "INSERT INTO `%[1]s`(%[2]s) SELECT %[3]s FROM (SELECT * FROM `%[4]s` %[5]s)", - targetColumn.TableName(), + targetTableName, strings.Join(columns, ","), row, - sourceColumn.TableName(), + sourceTableName, whereStmt, )) case ast.ActionTypeUpdate: @@ -669,14 +700,14 @@ func (a *Analyzer) newMergeStmtAction(ctx context.Context, _ string, args []driv } stmts = append(stmts, fmt.Sprintf( "UPDATE `%s` SET %s %s", - targetColumn.TableName(), + targetTableName, strings.Join(items, ","), fromStmt, )) case ast.ActionTypeDelete: stmts = append(stmts, fmt.Sprintf( "DELETE FROM `%s` %s", - targetColumn.TableName(), + targetTableName, whereStmt, )) } diff --git a/query_test.go b/query_test.go index 39a1d0d..39941b7 100644 --- a/query_test.go +++ b/query_test.go @@ -5960,6 +5960,48 @@ SELECT * FROM table2; os.Unsetenv("TZ") } +func TestMergeStatementWithNamedPath(t *testing.T) { + for i, namePath := range [][]string{{"project", "dataset"}, {"project"}} { + sql.Register(fmt.Sprintf("zetasqlite-merge-%v", i), &zetasqlite.ZetaSQLiteDriver{ + ConnectHook: func(conn *zetasqlite.ZetaSQLiteConn) error { + return conn.SetNamePath(namePath) + }, + }) + + db, err := sql.Open(fmt.Sprintf("zetasqlite-merge-%v", i), ":memory:") + if err != nil { + t.Fatal(err) + } + + for _, stmt := range []string{ + `CREATE TABLE project.dataset.target(id INT64, name STRING);`, + `CREATE TABLE project.dataset.source(id INT64, name STRING);`, + `INSERT INTO project.dataset.source(id, name) VALUES (1, "test");`, + `MERGE project.dataset.target T USING project.dataset.source S ON T.id = S.id + WHEN MATCHED THEN UPDATE SET id = S.id, name = S.name + WHEN NOT MATCHED THEN INSERT (id, name) VALUES (id, name);`, + } { + if _, err := db.Exec(stmt); err != nil { + t.Fatal(err) + } + } + + row := db.QueryRow("SELECT * FROM project.dataset.target;") + if row.Err() != nil { + t.Fatal(row.Err()) + } + var id int64 + var name string + if err := row.Scan(&id, &name); err != nil { + t.Fatal(err) + } + + if id != 1 || name != "test" { + t.Fatalf("failed to merge row %v, %s", id, name) + } + } +} + func createTimestampFormatFromTime(t time.Time) string { unixmicro := t.UnixMicro() sec := unixmicro / int64(time.Millisecond)