Skip to content

Commit 6081e00

Browse files
committed
Added SetSingleQuoteEscaper for #1
1 parent 1f3c062 commit 6081e00

File tree

2 files changed

+30
-6
lines changed

2 files changed

+30
-6
lines changed

sql.go

+12-6
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@ import (
1111
)
1212

1313
var (
14-
tmFmtZero = "0000-00-00 00:00:00"
15-
tmFmtWithMS = "2006-01-02 15:04:05.999"
16-
escaper = "'"
17-
nullStr = "NULL"
14+
tmFmtZero = "0000-00-00 00:00:00"
15+
tmFmtWithMS = "2006-01-02 15:04:05.999"
16+
escaper = "'"
17+
nullStr = "NULL"
18+
singleQuoteEscaper = "\\"
1819
)
1920

2021
//Escape escape the val for sql
@@ -69,7 +70,7 @@ func EscapeInLocation(val interface{}, loc *time.Location) string {
6970
return fmt.Sprintf("%.6f", v)
7071

7172
case string:
72-
return escaper + strings.Replace(v, escaper, "\\"+escaper, -1) + escaper
73+
return escaper + strings.Replace(v, escaper, singleQuoteEscaper+escaper, -1) + escaper
7374
default:
7475
refValue := reflect.ValueOf(v)
7576
if v == nil || !refValue.IsValid() {
@@ -93,7 +94,7 @@ func EscapeInLocation(val interface{}, loc *time.Location) string {
9394
if err != nil {
9495
return nullStr
9596
}
96-
return escaper + strings.Replace(string(stringifyData), escaper, "\\"+escaper, -1) + escaper
97+
return escaper + strings.Replace(string(stringifyData), escaper, singleQuoteEscaper+escaper, -1) + escaper
9798

9899
}
99100
}
@@ -141,3 +142,8 @@ func FormatInLocation(query string, loc *time.Location, args ...interface{}) str
141142
}
142143
return sql.String()
143144
}
145+
146+
//SetSingleQuoteEscaper set the singleQuoteEscaper
147+
func SetSingleQuoteEscaper(escaper string) {
148+
singleQuoteEscaper = escaper
149+
}

sql_test.go

+18
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,24 @@ func TestStringEscape(t *testing.T) {
8787
}
8888
}
8989

90+
func TestStringCustomEscape(t *testing.T) {
91+
s := "hello world"
92+
SetSingleQuoteEscaper("'")
93+
result := Escape(s)
94+
if result != "'hello world'" {
95+
t.Fatalf("escape string error")
96+
97+
}
98+
99+
s = "hello ' world"
100+
result = Escape(s)
101+
t.Logf("TestStringCustomEscape result: %s", result)
102+
if result != "'hello '' world'" {
103+
t.Fatalf("escape string error")
104+
105+
}
106+
}
107+
90108
func TestBytesEscape(t *testing.T) {
91109
s := []byte{0, 1, 254, 255}
92110
result := Escape(s)

0 commit comments

Comments
 (0)