Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for ON CONFLICT in InsertStmt #267

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,7 @@ type Dialect interface {
EncodeBytes(b []byte) string

Placeholder(n int) string

OnConflict(constraint string) string
Proposed(column string) string
}
8 changes: 8 additions & 0 deletions dialect/mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,11 @@ func (d mssql) EncodeBytes(b []byte) string {
func (d mssql) Placeholder(n int) string {
return fmt.Sprintf("@p%d", n+1)
}

func (d mssql) OnConflict(_ string) string {
return ""
}

func (d mssql) Proposed(_ string) string {
return ""
}
8 changes: 8 additions & 0 deletions dialect/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,11 @@ func (d mysql) EncodeBytes(b []byte) string {
func (d mysql) Placeholder(_ int) string {
return "?"
}

func (d mysql) OnConflict(_ string) string {
return "ON DUPLICATE KEY UPDATE"
}

func (d mysql) Proposed(column string) string {
return fmt.Sprintf("VALUES(%s)", d.QuoteIdent(column))
}
8 changes: 8 additions & 0 deletions dialect/postgresql.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,11 @@ func (d postgreSQL) EncodeBytes(b []byte) string {
func (d postgreSQL) Placeholder(n int) string {
return fmt.Sprintf("$%d", n+1)
}

func (d postgreSQL) OnConflict(constraint string) string {
return fmt.Sprintf("ON CONFLICT ON CONSTRAINT %s DO UPDATE SET", d.QuoteIdent(constraint))
}

func (d postgreSQL) Proposed(column string) string {
return fmt.Sprintf("EXCLUDED.%s", d.QuoteIdent(column))
}
8 changes: 8 additions & 0 deletions dialect/sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,11 @@ func (d sqlite3) EncodeBytes(b []byte) string {
func (d sqlite3) Placeholder(_ int) string {
return "?"
}

func (d sqlite3) OnConflict(_ string) string {
return ""
}

func (d sqlite3) Proposed(_ string) string {
return ""
}
57 changes: 57 additions & 0 deletions insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,19 @@ package dbr
import (
"context"
"database/sql"
"fmt"
"reflect"
"strings"

"github.com/gocraft/dbr/v2/dialect"
)

// ConflictStmt is ` ON CONFLICT ...` part of InsertStmt
type ConflictStmt struct {
constraint string
actions map[string]interface{}
}

// InsertStmt builds `INSERT INTO ...`.
type InsertStmt struct {
Runner
Expand All @@ -24,6 +31,16 @@ type InsertStmt struct {
ReturnColumn []string
RecordID *int64
comments Comments

Conflict *ConflictStmt
}

// Proposed is reference to proposed value in on conflict clause
func Proposed(column string) Builder {
return BuildFunc(func(d Dialect, b Buffer) error {
_, err := b.WriteString(d.Proposed(column))
return err
})
}

type InsertBuilder = InsertStmt
Expand Down Expand Up @@ -90,6 +107,29 @@ func (b *InsertStmt) Build(d Dialect, buf Buffer) error {
buf.WriteValue(tuple...)
}

if b.Conflict != nil && len(b.Conflict.actions) > 0 {
keyword := d.OnConflict(b.Conflict.constraint)
if len(keyword) == 0 {
return fmt.Errorf("Dialect %s does not support OnConflict", d)
}
buf.WriteString(" ")
buf.WriteString(keyword)
buf.WriteString(" ")
needComma := false
for _, column := range b.Column {
if v, ok := b.Conflict.actions[column]; ok {
if needComma {
buf.WriteString(",")
}
buf.WriteString(d.QuoteIdent(column))
buf.WriteString("=")
buf.WriteString(placeholder)
buf.WriteValue(v)
needComma = true
}
}
}

if d != dialect.MSSQL && len(b.ReturnColumn) > 0 {
buf.WriteString(" RETURNING ")
for i, col := range b.ReturnColumn {
Expand Down Expand Up @@ -262,3 +302,20 @@ func (b *InsertStmt) LoadContext(ctx context.Context, value interface{}) error {
func (b *InsertStmt) Load(value interface{}) error {
return b.LoadContext(context.Background(), value)
}

// OnConflictMap allows to add actions for constraint violation, e.g UPSERT
func (b *InsertStmt) OnConflictMap(constraint string, actions map[string]interface{}) *InsertStmt {
b.Conflict = &ConflictStmt{constraint: constraint, actions: actions}
return b
}

// OnConflict creates an empty OnConflict section fo insert statement , e.g UPSERT
func (b *InsertStmt) OnConflict(constraint string) *ConflictStmt {
return b.OnConflictMap(constraint, make(map[string]interface{})).Conflict
}

// Action adds action for column which will do if conflict happens
func (b *ConflictStmt) Action(column string, action interface{}) *ConflictStmt {
b.actions[column] = action
return b
}
22 changes: 22 additions & 0 deletions insert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,28 @@ func TestPostgresReturning(t *testing.T) {
require.Equal(t, 0, sess.EventReceiver.(*testTraceReceiver).errored)
}

func TestOnConflict(t *testing.T) {
for _, sess := range testSession {
if sess.Dialect.OnConflict("") == "" {
// dialect does not support OnConflict
continue
}
t.Run(testSessionName(sess), func(t *testing.T) {
reset(t, sess)
for i := 0; i < 2; i++ {
b := sess.InsertInto("dbr_people").Columns("id", "name", "email").Values(1, "test", "[email protected]")
b.OnConflict("dbr_people_pkey").Action("email", Expr("CONCAT(?, 2)", Proposed("email")))
_, err := b.Exec()
require.NoError(t, err)
}
var value string
_, err := sess.SelectBySql("SELECT email FROM dbr_people WHERE id=?", "1").Load(&value)
require.NoError(t, err)
require.Equal(t, "[email protected]", value)
})
}
}

func BenchmarkInsertValuesSQL(b *testing.B) {
buf := NewBuffer()
for i := 0; i < b.N; i++ {
Expand Down