Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically ref names of CTETables in DELETE and UPDATE statements #179

Merged
merged 1 commit into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions cte.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,12 @@ func (cteb *CTEBuilder) TableNames() []string {
return tableNames
}

// tableNamesForSelect returns a list of table names which should be automatically added to FROM clause.
// It's not public, as this feature is designed only for SelectBuilder right now.
func (cteb *CTEBuilder) tableNamesForSelect() []string {
// tableNamesForFrom returns a list of table names which should be automatically added to FROM clause.
// It's not public, as this feature is designed only for SelectBuilder/UpdateBuilder/DeleteBuilder right now.
func (cteb *CTEBuilder) tableNamesForFrom() []string {
cnt := 0

// It's rare that the ShouldAddToTableList() returns true.
// ShouldAddToTableList() unlikely returns true.
// Count it before allocating any memory for better performance.
for _, query := range cteb.queries {
if query.ShouldAddToTableList() {
Expand Down
37 changes: 37 additions & 0 deletions cte_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,43 @@ func ExampleCTEBuilder() {
// [users valid_users]
}

func ExampleCTEBuilder_update() {
builder := With(
CTETable("users", "user_id").As(
Select("user_id").From("vip_users"),
),
).Update("orders").Set(
"orders.transport_fee = 0",
).Where(
"users.user_id = orders.user_id",
)

sqlForMySQL, _ := builder.BuildWithFlavor(MySQL)
sqlForPostgreSQL, _ := builder.BuildWithFlavor(PostgreSQL)

fmt.Println(sqlForMySQL)
fmt.Println(sqlForPostgreSQL)

// Output:
// WITH users (user_id) AS (SELECT user_id FROM vip_users) UPDATE orders, users SET orders.transport_fee = 0 WHERE users.user_id = orders.user_id
// WITH users (user_id) AS (SELECT user_id FROM vip_users) UPDATE orders FROM users SET orders.transport_fee = 0 WHERE users.user_id = orders.user_id
}

func ExampleCTEBuilder_delete() {
sql := With(
CTETable("users", "user_id").As(
Select("user_id").From("cheaters"),
),
).DeleteFrom("awards").Where(
"users.user_id = awards.user_id",
).String()

fmt.Println(sql)

// Output:
// WITH users (user_id) AS (SELECT user_id FROM cheaters) DELETE FROM awards, users WHERE users.user_id = awards.user_id
}

func TestCTEBuilder(t *testing.T) {
a := assert.New(t)
cteb := newCTEBuilder()
Expand Down
51 changes: 40 additions & 11 deletions delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ type DeleteBuilder struct {
whereClauseProxy *whereClauseProxy
whereClauseExpr string

cteBuilder string
table string
cteBuilderVar string
cteBuilder *CTEBuilder

tables []string
orderByCols []string
order string
limit int
Expand All @@ -60,24 +62,48 @@ type DeleteBuilder struct {
var _ Builder = new(DeleteBuilder)

// DeleteFrom sets table name in DELETE.
func DeleteFrom(table string) *DeleteBuilder {
return DefaultFlavor.NewDeleteBuilder().DeleteFrom(table)
func DeleteFrom(table ...string) *DeleteBuilder {
return DefaultFlavor.NewDeleteBuilder().DeleteFrom(table...)
}

// With sets WITH clause (the Common Table Expression) before DELETE.
func (db *DeleteBuilder) With(builder *CTEBuilder) *DeleteBuilder {
db.marker = deleteMarkerAfterWith
db.cteBuilder = db.Var(builder)
db.cteBuilderVar = db.Var(builder)
db.cteBuilder = builder
return db
}

// DeleteFrom sets table name in DELETE.
func (db *DeleteBuilder) DeleteFrom(table string) *DeleteBuilder {
db.table = Escape(table)
func (db *DeleteBuilder) DeleteFrom(table ...string) *DeleteBuilder {
db.tables = table
db.marker = deleteMarkerAfterDeleteFrom
return db
}

// TableNames returns all table names in this DELETE statement.
func (db *DeleteBuilder) TableNames() []string {
var additionalTableNames []string

if db.cteBuilder != nil {
additionalTableNames = db.cteBuilder.tableNamesForFrom()
}

var tableNames []string

if len(db.tables) > 0 && len(additionalTableNames) > 0 {
tableNames = make([]string, len(db.tables)+len(additionalTableNames))
copy(tableNames, db.tables)
copy(tableNames[len(db.tables):], additionalTableNames)
} else if len(db.tables) > 0 {
tableNames = db.tables
} else if len(additionalTableNames) > 0 {
tableNames = additionalTableNames
}

return tableNames
}

// Where sets expressions of WHERE in DELETE.
func (db *DeleteBuilder) Where(andExpr ...string) *DeleteBuilder {
if len(andExpr) == 0 || estimateStringsBytes(andExpr) == 0 {
Expand Down Expand Up @@ -146,17 +172,20 @@ func (db *DeleteBuilder) Build() (sql string, args []interface{}) {
// BuildWithFlavor returns compiled DELETE string and args with flavor and initial args.
// They can be used in `DB#Query` of package `database/sql` directly.
func (db *DeleteBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{}) {

buf := newStringBuilder()
db.injection.WriteTo(buf, deleteMarkerInit)

if db.cteBuilder != "" {
buf.WriteLeadingString(db.cteBuilder)
if db.cteBuilder != nil {
buf.WriteLeadingString(db.cteBuilderVar)
db.injection.WriteTo(buf, deleteMarkerAfterWith)
}

if len(db.table) > 0 {
tableNames := db.TableNames()

if len(tableNames) > 0 {
buf.WriteLeadingString("DELETE FROM ")
buf.WriteString(db.table)
buf.WriteStrings(tableNames, ", ")
}

db.injection.WriteTo(buf, deleteMarkerAfterDeleteFrom)
Expand Down
4 changes: 2 additions & 2 deletions select.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,12 @@ func Select(col ...string) *SelectBuilder {
return DefaultFlavor.NewSelectBuilder().Select(col...)
}

// TableNames returns all table names in a SELECT.
// TableNames returns all table names in this SELECT statement.
func (sb *SelectBuilder) TableNames() []string {
var additionalTableNames []string

if sb.cteBuilder != nil {
additionalTableNames = sb.cteBuilder.tableNamesForSelect()
additionalTableNames = sb.cteBuilder.tableNamesForFrom()
}

var tableNames []string
Expand Down
70 changes: 58 additions & 12 deletions update.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@ type UpdateBuilder struct {
whereClauseProxy *whereClauseProxy
whereClauseExpr string

cteBuilder string
table string
cteBuilderVar string
cteBuilder *CTEBuilder

tables []string
assignments []string
orderByCols []string
order string
Expand All @@ -63,24 +65,46 @@ type UpdateBuilder struct {
var _ Builder = new(UpdateBuilder)

// Update sets table name in UPDATE.
func Update(table string) *UpdateBuilder {
return DefaultFlavor.NewUpdateBuilder().Update(table)
func Update(table ...string) *UpdateBuilder {
return DefaultFlavor.NewUpdateBuilder().Update(table...)
}

// With sets WITH clause (the Common Table Expression) before UPDATE.
func (ub *UpdateBuilder) With(builder *CTEBuilder) *UpdateBuilder {
ub.marker = updateMarkerAfterWith
ub.cteBuilder = ub.Var(builder)
ub.cteBuilderVar = ub.Var(builder)
ub.cteBuilder = builder
return ub
}

// Update sets table name in UPDATE.
func (ub *UpdateBuilder) Update(table string) *UpdateBuilder {
ub.table = Escape(table)
func (ub *UpdateBuilder) Update(table ...string) *UpdateBuilder {
ub.tables = table
ub.marker = updateMarkerAfterUpdate
return ub
}

// TableNames returns all table names in this UPDATE statement.
func (ub *UpdateBuilder) TableNames() (tableNames []string) {
var additionalTableNames []string

if ub.cteBuilder != nil {
additionalTableNames = ub.cteBuilder.tableNamesForFrom()
}

if len(ub.tables) > 0 && len(additionalTableNames) > 0 {
tableNames = make([]string, len(ub.tables)+len(additionalTableNames))
copy(tableNames, ub.tables)
copy(tableNames[len(ub.tables):], additionalTableNames)
} else if len(ub.tables) > 0 {
tableNames = ub.tables
} else if len(additionalTableNames) > 0 {
tableNames = additionalTableNames
}

return tableNames
}

// Set sets the assignments in SET.
func (ub *UpdateBuilder) Set(assignment ...string) *UpdateBuilder {
ub.assignments = assignment
Expand Down Expand Up @@ -212,14 +236,36 @@ func (ub *UpdateBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{
buf := newStringBuilder()
ub.injection.WriteTo(buf, updateMarkerInit)

if ub.cteBuilder != "" {
buf.WriteLeadingString(ub.cteBuilder)
if ub.cteBuilder != nil {
buf.WriteLeadingString(ub.cteBuilderVar)
ub.injection.WriteTo(buf, updateMarkerAfterWith)
}

if len(ub.table) > 0 {
buf.WriteLeadingString("UPDATE ")
buf.WriteString(ub.table)
switch flavor {
case MySQL:
// CTE table names should be written after UPDATE keyword in MySQL.
tableNames := ub.TableNames()

if len(tableNames) > 0 {
buf.WriteLeadingString("UPDATE ")
buf.WriteStrings(tableNames, ", ")
}

default:
if len(ub.tables) > 0 {
buf.WriteLeadingString("UPDATE ")
buf.WriteStrings(ub.tables, ", ")

// For ISO SQL, CTE table names should be written after FROM keyword.
if ub.cteBuilder != nil {
cteTableNames := ub.cteBuilder.tableNamesForFrom()

if len(cteTableNames) > 0 {
buf.WriteLeadingString("FROM ")
buf.WriteStrings(cteTableNames, ", ")
}
}
}
}

ub.injection.WriteTo(buf, updateMarkerAfterUpdate)
Expand Down
Loading