Skip to content

Commit

Permalink
feat: added new parameters for parameterized query
Browse files Browse the repository at this point in the history
1. added new `whereSQLStmt` for parameterized query statement
2. added new `whereSQLParams` for parameterized query parameters
3. added new `whereSQLJSONParams` for parameterized query parameters in JSON format with base64 encoding

Signed-off-by: Neko Ayaka <[email protected]>
  • Loading branch information
nekomeowww committed Nov 27, 2023
1 parent e89faaf commit 8f10de8
Show file tree
Hide file tree
Showing 11 changed files with 5,028 additions and 27 deletions.
56 changes: 44 additions & 12 deletions pkg/storage/internalstorage/storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package internalstorage

import (
"fmt"
"os"
"testing"

"github.com/DATA-DOG/go-sqlmock"
gmysql "gorm.io/driver/mysql"
Expand All @@ -10,34 +12,64 @@ import (
)

var (
postgresDB *gorm.DB
postgresDB *gorm.DB
postgresDBMock sqlmock.Sqlmock

mysqlVersions = []string{"8.0.27", "5.7.22"}
mysqlDBs = make(map[string]*gorm.DB, 2)
mysqlDBMocks = make(map[string]sqlmock.Sqlmock, 2)
)

func init() {
db, _, err := sqlmock.New()
func newMockedPostgresDB() (*gorm.DB, sqlmock.Sqlmock, error) {
mockedDB, mock, err := sqlmock.New()
if err != nil {
panic(fmt.Sprintf("sqlmock.New() failed: %v", err))
return nil, nil, fmt.Errorf("sqlmock.New() failed: %w", err)
}

postgresDB, err = gorm.Open(gpostgres.New(gpostgres.Config{Conn: db}))
gormDB, err := gorm.Open(gpostgres.New(gpostgres.Config{Conn: mockedDB}))
if err != nil {
panic(fmt.Sprintf("init postgresDB failed: %v", err))
return nil, nil, fmt.Errorf("init postgresDB failed: %w", err)
}

for _, version := range mysqlVersions {
db, mock, err := sqlmock.New()
return gormDB, mock, nil
}

func newMockedMySQLDB(version string) (*gorm.DB, sqlmock.Sqlmock, error) {
mockedDB, mock, err := sqlmock.New()
if err != nil {
return nil, nil, fmt.Errorf("sqlmock.New() failed: %w", err)
}

mock.ExpectQuery("SELECT VERSION()").WillReturnRows(sqlmock.NewRows([]string{"VERSION()"}).AddRow(version))

mysqlDB, err := gorm.Open(gmysql.New(gmysql.Config{Conn: mockedDB}))
if err != nil {
return nil, nil, fmt.Errorf("init mysqlDB(%s) failed: %w", version, err)
}

return mysqlDB, mock, nil
}

func TestMain(m *testing.M) {
{
mockedDB, mock, err := newMockedPostgresDB()
if err != nil {
panic(fmt.Sprintf("sqlmock.New() failed: %v", err))
panic(err)
}
mock.ExpectQuery("SELECT VERSION()").WillReturnRows(sqlmock.NewRows([]string{"VERSION()"}).AddRow(version))

mysqlDB, err := gorm.Open(gmysql.New(gmysql.Config{Conn: db}))
postgresDB = mockedDB
postgresDBMock = mock
}

for _, version := range mysqlVersions {
mysqlDB, mock, err := newMockedMySQLDB(version)
if err != nil {
panic(fmt.Sprintf("init mysqlDB(%s) failed: %v", version, err))
panic(err)
}

mysqlDBs[version] = mysqlDB
mysqlDBMocks[version] = mock
}

os.Exit(m.Run())
}
163 changes: 156 additions & 7 deletions pkg/storage/internalstorage/util.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package internalstorage

import (
"encoding/base64"
"fmt"
"net/url"
"strconv"
"strings"

Expand All @@ -10,6 +12,7 @@ import (
apierrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/selection"
"k8s.io/apimachinery/pkg/util/json"
"k8s.io/apimachinery/pkg/util/validation/field"
utilfeature "k8s.io/apiserver/pkg/util/feature"

Expand All @@ -19,9 +22,154 @@ import (
const (
SearchLabelFuzzyName = "internalstorage.clusterpedia.io/fuzzy-name"

// Raw query
URLQueryWhereSQL = "whereSQL"
// Parameterized query
URLQueryFieldWhereSQLStatement = "whereSQLStatement"
URLQueryFieldWhereSQLParam = "whereSQLParam"
URLQueryFieldWhereSQLJSONParams = "whereSQLJSONParams"
)

type URLQueryWhereSQLParams struct {
// Raw query
WhereSQL string
// Parameterized query
WhereSQLStatement string
WhereSQLParams []string
WhereSQLJSONParams []any
}

// NewURLQueryWhereSQLParamsFromURLValues resolves parameters from passed in url.Values.
// A k8s.io/apimachinery/pkg/api/errors.StatusError will be returned if decoding or unmarshalling failed
// only when the value of "whereSQLJSONParams" is present.
//
// It recognizes the following query fields for parameters:
//
// "whereSQL"
// "whereSQLStatement"
// "whereSQLParam"
// "whereSQLJSONParams"
func NewURLQueryWhereSQLParamsFromURLValues(urlQuery url.Values) (URLQueryWhereSQLParams, error) {
var params URLQueryWhereSQLParams

whereClause, ok := urlQuery[URLQueryWhereSQL]
if ok && len(whereClause) > 0 {
params.WhereSQL = whereClause[0]
}

whereClauseStatement, ok := urlQuery[URLQueryFieldWhereSQLStatement]
if ok && len(whereClauseStatement) > 0 {
params.WhereSQLStatement = whereClauseStatement[0]
}

whereClauseParams, ok := urlQuery[URLQueryFieldWhereSQLParam]
if ok {
params.WhereSQLParams = whereClauseParams
}

whereClauseJSONParams, ok := urlQuery[URLQueryFieldWhereSQLJSONParams]
if ok && len(whereClauseJSONParams) > 0 {
decodedBytesContent, err := base64.StdEncoding.DecodeString(whereClauseJSONParams[0])
if err != nil {
return URLQueryWhereSQLParams{}, apierrors.NewInvalid(
schema.GroupKind{Group: internal.GroupName, Kind: "ListOptions"},
"urlQuery",
field.ErrorList{
field.Invalid(
field.NewPath(URLQueryFieldWhereSQLJSONParams),
whereClauseJSONParams[0],
fmt.Sprintf("failed to decode base64 string: %v", err),
),
},
)
}

params.WhereSQLJSONParams = make([]any, 0)
err = json.Unmarshal(decodedBytesContent, &params.WhereSQLJSONParams)
if err != nil {
return URLQueryWhereSQLParams{}, apierrors.NewInvalid(
schema.GroupKind{Group: internal.GroupName, Kind: "ListOptions"},
"urlQuery",
field.ErrorList{
field.Invalid(
field.NewPath(URLQueryFieldWhereSQLJSONParams),
whereClauseJSONParams[0],
fmt.Sprintf("failed to unmarshal decoded base64 string to JSON array: %v", err),
),
},
)
}
}

if (len(params.WhereSQLParams) > 0 || len(params.WhereSQLJSONParams) > 0) && params.WhereSQLStatement == "" {
return URLQueryWhereSQLParams{}, apierrors.NewInvalid(
schema.GroupKind{Group: internal.GroupName, Kind: "ListOptions"},
"urlQuery",
field.ErrorList{
field.Invalid(
field.NewPath(URLQueryFieldWhereSQLStatement),
whereClauseStatement,
fmt.Sprintf("required when either %s or %s was provided", URLQueryFieldWhereSQLParam, URLQueryFieldWhereSQLJSONParams),
),
},
)
}

return params, nil
}

func applyListOptionsURLQueryParameterizedQueryToWhereClause(query *gorm.DB, params URLQueryWhereSQLParams) *gorm.DB {
if params.WhereSQLStatement == "" {
return query
}

// If a string of numbers is passed in from SQL, the query will be taken as ID by default.
// If the SQL contains English letter, it will be passed in as column.

if len(params.WhereSQLJSONParams) > 0 {
return query.Where(params.WhereSQLStatement, params.WhereSQLJSONParams...)
}
if len(params.WhereSQLParams) > 0 {
anyParameters := make([]any, len(params.WhereSQLParams))

for i := range params.WhereSQLParams {
anyParameters[i] = params.WhereSQLParams[i]
}

return query.Where(params.WhereSQLStatement, anyParameters...)
}

return query.Where(params.WhereSQLStatement)
}

func applyListOptionsURLQueryToWhereClause(allowRawSQLQueryEnabled bool, allowParameterizedSQLQueryEnabled bool, query *gorm.DB, urlValues url.Values) (*gorm.DB, error) {
if !allowRawSQLQueryEnabled && !allowParameterizedSQLQueryEnabled {
return query, nil
}

urlQueryParams, err := NewURLQueryWhereSQLParamsFromURLValues(urlValues)
if err != nil {
return query, err
}

if allowRawSQLQueryEnabled {
// use parameterized query first if statement was provided
if urlQueryParams.WhereSQLStatement != "" {
return applyListOptionsURLQueryParameterizedQueryToWhereClause(query, urlQueryParams), nil
}
// otherwise, fallbacks to raw query
if urlQueryParams.WhereSQL != "" {
return query.Where(urlQueryParams.WhereSQL), nil
}
}

if allowParameterizedSQLQueryEnabled && urlQueryParams.WhereSQLStatement != "" {
return applyListOptionsURLQueryParameterizedQueryToWhereClause(query, urlQueryParams), nil
}

return query, nil
}

func applyListOptionsToQuery(query *gorm.DB, opts *internal.ListOptions, applyFn func(query *gorm.DB, opts *internal.ListOptions) (*gorm.DB, error)) (int64, *int64, *gorm.DB, error) {
switch len(opts.ClusterNames) {
case 0:
Expand Down Expand Up @@ -55,13 +203,14 @@ func applyListOptionsToQuery(query *gorm.DB, opts *internal.ListOptions, applyFn
query = query.Where("created_at < ?", opts.Before.Time.UTC())
}

if utilfeature.DefaultMutableFeatureGate.Enabled(AllowRawSQLQuery) {
if len(opts.URLQuery[URLQueryWhereSQL]) > 0 {
// TODO: prevent SQL injection.
// If a string of numbers is passed in from SQL, the query will be taken as ID by default.
// If the SQL contains English letter, it will be passed in as column.
query = query.Where(opts.URLQuery[URLQueryWhereSQL][0])
}
query, err := applyListOptionsURLQueryToWhereClause(
utilfeature.DefaultMutableFeatureGate.Enabled(AllowRawSQLQuery),
false,
query,
opts.URLQuery,
)
if err != nil {
return 0, nil, nil, err
}

if opts.LabelSelector != nil {
Expand Down
Loading

0 comments on commit 8f10de8

Please sign in to comment.