diff --git a/queries/query.go b/queries/query.go index 5e99d4ebd..dc811dc39 100644 --- a/queries/query.go +++ b/queries/query.go @@ -30,22 +30,23 @@ type Query struct { load []string loadMods map[string]Applicator - delete bool - update map[string]interface{} - withs []argClause - selectCols []string - count bool - from []string - joins []join - where []where - groupBy []string - orderBy []argClause - having []argClause - limit *int - offset int - forlock string - distinct string - comment string + delete bool + update map[string]interface{} + withs []argClause + selectCols []string + count bool + from []string + joins []join + where []where + groupBy []string + orderBy []argClause + having []argClause + limit *int + offset int + forlock string + distinct string + comment string + appendComment string // This field is a hack to allow a query to strip out the reference // to deleted at is null. @@ -281,6 +282,11 @@ func SetComment(q *Query, comment string) { q.comment = comment } +// SetAppendComment on the query. +func SetAppendComment(q *Query, appendComment string) { + q.appendComment = appendComment +} + // SetUpdate on the query. func SetUpdate(q *Query, cols map[string]interface{}) { q.update = cols diff --git a/queries/query_builders.go b/queries/query_builders.go index 7f5cc1a72..de6e7693d 100644 --- a/queries/query_builders.go +++ b/queries/query_builders.go @@ -134,6 +134,9 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) { writeModifiers(q, buf, &args) buf.WriteByte(';') + + writeAppendComment(q, buf) + return buf, args } @@ -157,6 +160,8 @@ func buildDeleteQuery(q *Query) (*bytes.Buffer, []interface{}) { buf.WriteByte(';') + writeAppendComment(q, buf) + return buf, args } @@ -201,6 +206,8 @@ func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) { buf.WriteByte(';') + writeAppendComment(q, buf) + return buf, args } @@ -604,6 +611,16 @@ func writeComment(q *Query, buf *bytes.Buffer) { } } +func writeAppendComment(q *Query, buf *bytes.Buffer) { + if len(q.appendComment) == 0 { + return + } + + buf.WriteString(" /* ") + buf.WriteString(q.appendComment) + buf.WriteString(" */ ") +} + func writeCTEs(q *Query, buf *bytes.Buffer, args *[]interface{}) { if len(q.withs) == 0 { return diff --git a/queries/query_builders_test.go b/queries/query_builders_test.go index 892d8dde0..f94dc2441 100644 --- a/queries/query_builders_test.go +++ b/queries/query_builders_test.go @@ -699,3 +699,33 @@ func TestWriteComment(t *testing.T) { t.Errorf(`bad two lines comment, got: %s`, got) } } + +func TestWriteAppendComment(t *testing.T) { + t.Parallel() + + tests := []struct { + appendComment string + expectPredicate func(sql string) bool + }{ + {"", func(sql string) bool { + return !strings.Contains(sql, "/*") + }}, + {"comment", func(sql string) bool { + return strings.Contains(sql, "; /* comment */") + }}, + {"first\nsecond", func(sql string) bool { + return strings.Contains(sql, "; /* first\nsecond */") + }}, + } + + for i, test := range tests { + q := &Query{ + appendComment: test.appendComment, + dialect: &drivers.Dialect{LQ: '"', RQ: '"', UseIndexPlaceholders: true, UseTopClause: false}, + } + sql, _ := BuildQuery(q) + if !test.expectPredicate(sql) { + t.Errorf("%d) Unexpected built SQL query: %s", i, sql) + } + } +} diff --git a/queries/query_test.go b/queries/query_test.go index f813ea61e..72557e1cc 100644 --- a/queries/query_test.go +++ b/queries/query_test.go @@ -654,6 +654,17 @@ func TestSetComment(t *testing.T) { } } +func TestSetAppendComment(t *testing.T) { + t.Parallel() + + q := &Query{} + SetAppendComment(q, "my comment") + + if q.appendComment != "my comment" { + t.Errorf("Got invalid comment: %s", q.appendComment) + } +} + func TestRemoveSoftDeleteWhere(t *testing.T) { t.Parallel()