diff --git a/db/mysql.go b/db/mysql.go index 66c08b9..7cacd79 100644 --- a/db/mysql.go +++ b/db/mysql.go @@ -2,6 +2,8 @@ package isudb import ( "database/sql" + "regexp" + "sync" "github.com/go-sql-driver/mysql" ) @@ -23,7 +25,51 @@ func (msb mysqlSegmentBuilder) parseDSN(dsn string) *measureSegment { } return &measureSegment{ - driver: msb.driver(), - addr: cfg.Addr, + driver: msb.driver(), + addr: cfg.Addr, + normalizer: msb.normalizer, } } + +var ( + mysqlReList = []struct { + re *regexp.Regexp + to string + }{{ + re: regexp.MustCompile(`(\?\s*,\s*)+`), + to: "..., ", + }, { + re: regexp.MustCompile(`(\(..., \?\)\s*,\s*)+`), + to: "..., ", + }} + mysqlNormalizeCacheLocker = &sync.RWMutex{} + mysqlNormalizeCache = make(map[string]string, 50) +) + +func (mysqlSegmentBuilder) normalizer(query string) string { + var ( + normalizedQuery string + ok bool + ) + func() { + mysqlNormalizeCacheLocker.RLock() + defer mysqlNormalizeCacheLocker.RUnlock() + normalizedQuery, ok = mysqlNormalizeCache[query] + }() + if ok { + return normalizedQuery + } + + normalizedQuery = query + for _, re := range mysqlReList { + normalizedQuery = re.re.ReplaceAllString(normalizedQuery, re.to) + } + + func() { + mysqlNormalizeCacheLocker.Lock() + defer mysqlNormalizeCacheLocker.Unlock() + mysqlNormalizeCache[query] = normalizedQuery + }() + + return normalizedQuery +} diff --git a/db/mysql_test.go b/db/mysql_test.go new file mode 100644 index 0000000..634729b --- /dev/null +++ b/db/mysql_test.go @@ -0,0 +1,31 @@ +package isudb + +import ( + "fmt" + "strings" + "testing" +) + +func BenchmarkMySQLNormalizer(b *testing.B) { + msb := mysqlSegmentBuilder{} + + queryPart := fmt.Sprintf("(%s?)", strings.Repeat("?, ", 5)) + query := fmt.Sprintf("INSERT INTO users (name, email, password, salt, created_at, updated_at) VALUES %s", strings.Repeat(queryPart+", ", 999)+queryPart) + + for i := 0; i < b.N; i++ { + msb.normalizer(query) + } +} + +func TestMySQLNormalizer(t *testing.T) { + msb := mysqlSegmentBuilder{} + + queryPart := fmt.Sprintf("(%s?)", strings.Repeat("?, ", 5)) + query := fmt.Sprintf("INSERT INTO users (name, email, password, salt, created_at, updated_at) VALUES %s", strings.Repeat(queryPart+", ", 999)+queryPart) + + normalizedQuery := msb.normalizer(query) + + if normalizedQuery != "INSERT INTO users (name, email, password, salt, created_at, updated_at) VALUES ..., (..., ?)" { + t.Errorf("unexpected query: %s", normalizedQuery) + } +} diff --git a/db/sqlite3.go b/db/sqlite3.go index 3f4d0e7..3b4e0ec 100644 --- a/db/sqlite3.go +++ b/db/sqlite3.go @@ -2,6 +2,8 @@ package isudb import ( "database/sql" + "regexp" + "sync" "github.com/mattn/go-sqlite3" ) @@ -18,7 +20,51 @@ func (sqlite3SegmentBuilder) driver() string { func (ssb sqlite3SegmentBuilder) parseDSN(dsn string) *measureSegment { return &measureSegment{ - driver: ssb.driver(), - addr: dsn, + driver: ssb.driver(), + addr: dsn, + normalizer: ssb.normalizer, } } + +var ( + sqliteReList = []struct { + re *regexp.Regexp + to string + }{{ + re: regexp.MustCompile(`((?:\?(\d*)|[@:$][0-9A-Fa-f]+)\s*,\s*)+`), + to: "..., ", + }, { + re: regexp.MustCompile(`(\(\.\.\., ((\?[0-9]*)|[@:$][0-9A-Fa-f]+)\)\s*,\s*)+`), + to: "..., ", + }} + sqlite3NormalizeCacheLocker = &sync.RWMutex{} + sqlite3NormalizeCache = make(map[string]string, 50) +) + +func (sqlite3SegmentBuilder) normalizer(query string) string { + var ( + normalizedQuery string + ok bool + ) + func() { + sqlite3NormalizeCacheLocker.RLock() + defer sqlite3NormalizeCacheLocker.RUnlock() + normalizedQuery, ok = sqlite3NormalizeCache[query] + }() + if ok { + return normalizedQuery + } + + normalizedQuery = query + for _, re := range sqliteReList { + normalizedQuery = re.re.ReplaceAllString(normalizedQuery, re.to) + } + + func() { + sqlite3NormalizeCacheLocker.Lock() + defer sqlite3NormalizeCacheLocker.Unlock() + sqlite3NormalizeCache[query] = normalizedQuery + }() + + return normalizedQuery +} diff --git a/db/sqlite_test.go b/db/sqlite_test.go new file mode 100644 index 0000000..0c4a9ad --- /dev/null +++ b/db/sqlite_test.go @@ -0,0 +1,46 @@ +package isudb + +import ( + "fmt" + "strings" + "testing" +) + +func BenchmarkSQLite3Normalizer(b *testing.B) { + ssb := sqlite3SegmentBuilder{} + + queryPart := fmt.Sprintf("(%s?)", strings.Repeat("?, ", 5)) + query := fmt.Sprintf("INSERT INTO users (name, email, password, salt, created_at, updated_at) VALUES %s", strings.Repeat(queryPart+", ", 999)+queryPart) + + for i := 0; i < b.N; i++ { + ssb.normalizer(query) + } +} + +func TestSQLite3Normalizer(t *testing.T) { + ssb := sqlite3SegmentBuilder{} + + tests := []string{ + "?", + "?1234", + ":a1", + "@a1", + "$a1", + } + + for _, test := range tests { + test := test + t.Run(test, func(t *testing.T) { + t.Parallel() + + queryPart := fmt.Sprintf("(%s%s)", strings.Repeat(fmt.Sprintf("%s, ", test), 5), test) + query := fmt.Sprintf("INSERT INTO users (name, email, password, salt, created_at, updated_at) VALUES %s", strings.Repeat(queryPart+", ", 999)+queryPart) + + normalizedQuery := ssb.normalizer(query) + + if normalizedQuery != fmt.Sprintf("INSERT INTO users (name, email, password, salt, created_at, updated_at) VALUES ..., (..., %s)", test) { + t.Errorf("unexpected query: %s", normalizedQuery) + } + }) + } +} diff --git a/tools.go b/tools.go index 01421b1..6ac1563 100644 --- a/tools.go +++ b/tools.go @@ -1,3 +1,6 @@ +//go:build tools +// +build tools + package isutools import (