Skip to content

Commit

Permalink
Fix MERGE handling named paths
Browse files Browse the repository at this point in the history
  • Loading branch information
bony2023 committed Sep 16, 2024
1 parent 9d65c7d commit a319e2e
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 4 deletions.
43 changes: 39 additions & 4 deletions internal/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql/driver"
"fmt"
"regexp"
"strings"

"github.com/goccy/go-zetasql"
Expand Down Expand Up @@ -625,6 +626,40 @@ func (a *Analyzer) newMergeStmtAction(ctx context.Context, _ string, args []driv
mergedTableSourceColumnName,
mergedTableTargetColumnName,
)

sourceTableName := sourceColumn.TableName()
targetTableName := targetColumn.TableName()
namePath := a.namePath.path

if len(namePath) > 0 {
// prepend projectID and datasetID to source and target table names when applicable
projectID := namePath[0]
datasetID := ""

if len(namePath) == 2 {
// namePath already contains datasetID
datasetID = namePath[1]
} else if len(namePath) == 1 {
// namePath doesn't have a datasetID. Try to extract it from the query
datasetExtractPattern := fmt.Sprintf(`\b%s_([^_]+)_%s\b`, regexp.QuoteMeta(projectID), regexp.QuoteMeta(sourceTableName))
re, err := regexp.Compile(datasetExtractPattern)
if err != nil {
return nil, err
}
matches := re.FindStringSubmatch(sourceTable)

if matches == nil || len(matches) < 2 {
return nil, fmt.Errorf("cannot deduce dataset in the merge query")
}

datasetID = matches[1]
}

// format source and target names as `<projectID>_<datasetID>_<tableName>` to comply with sqlite syntax
sourceTableName = fmt.Sprintf("%s_%s_%s", projectID, datasetID, strings.TrimPrefix(sourceTableName, fmt.Sprintf("%s.%s.", projectID, datasetID)))
targetTableName = fmt.Sprintf("%s_%s_%s", projectID, datasetID, strings.TrimPrefix(targetTableName, fmt.Sprintf("%s.%s.", projectID, datasetID)))
}

for _, when := range node.WhenClauseList() {
var fromStmt string
switch when.MatchType() {
Expand Down Expand Up @@ -652,10 +687,10 @@ func (a *Analyzer) newMergeStmtAction(ctx context.Context, _ string, args []driv
}
stmts = append(stmts, fmt.Sprintf(
"INSERT INTO `%[1]s`(%[2]s) SELECT %[3]s FROM (SELECT * FROM `%[4]s` %[5]s)",
targetColumn.TableName(),
targetTableName,
strings.Join(columns, ","),
row,
sourceColumn.TableName(),
sourceTableName,
whereStmt,
))
case ast.ActionTypeUpdate:
Expand All @@ -669,14 +704,14 @@ func (a *Analyzer) newMergeStmtAction(ctx context.Context, _ string, args []driv
}
stmts = append(stmts, fmt.Sprintf(
"UPDATE `%s` SET %s %s",
targetColumn.TableName(),
targetTableName,
strings.Join(items, ","),
fromStmt,
))
case ast.ActionTypeDelete:
stmts = append(stmts, fmt.Sprintf(
"DELETE FROM `%s` %s",
targetColumn.TableName(),
targetTableName,
whereStmt,
))
}
Expand Down
42 changes: 42 additions & 0 deletions query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5960,6 +5960,48 @@ SELECT * FROM table2;
os.Unsetenv("TZ")
}

func TestMergeStatementWithNamedPath(t *testing.T) {
for i, namePath := range [][]string{{"project", "dataset"}, {"project"}} {
sql.Register(fmt.Sprintf("zetasqlite-merge-%v", i), &zetasqlite.ZetaSQLiteDriver{
ConnectHook: func(conn *zetasqlite.ZetaSQLiteConn) error {
return conn.SetNamePath(namePath)
},
})

db, err := sql.Open(fmt.Sprintf("zetasqlite-merge-%v", i), ":memory:")
if err != nil {
t.Fatal(err)
}

for _, stmt := range []string{
`CREATE TABLE project.dataset.target(id INT64, name STRING);`,
`CREATE TABLE project.dataset.source(id INT64, name STRING);`,
`INSERT INTO project.dataset.source(id, name) VALUES (1, "test");`,
`MERGE project.dataset.target T USING project.dataset.source S ON T.id = S.id
WHEN MATCHED THEN UPDATE SET id = S.id, name = S.name
WHEN NOT MATCHED THEN INSERT (id, name) VALUES (id, name);`,
} {
if _, err := db.Exec(stmt); err != nil {
t.Fatal(err)
}
}

row := db.QueryRow("SELECT * FROM project.dataset.target;")
if row.Err() != nil {
t.Fatal(row.Err())
}
var id int64
var name string
if err := row.Scan(&id, &name); err != nil {
t.Fatal(err)
}

if id != 1 || name != "test" {
t.Fatalf("failed to merge row %v, %s", id, name)
}
}
}

func createTimestampFormatFromTime(t time.Time) string {
unixmicro := t.UnixMicro()
sec := unixmicro / int64(time.Millisecond)
Expand Down

0 comments on commit a319e2e

Please sign in to comment.