Skip to content

Commit 863a7cd

Browse files
Allow array query params
1 parent cd823ae commit 863a7cd

File tree

4 files changed

+134
-30
lines changed

4 files changed

+134
-30
lines changed

src/structs.go

+10-5
Original file line numberDiff line numberDiff line change
@@ -108,18 +108,23 @@ type credentials struct {
108108
}
109109

110110
type requestItem struct {
111-
Query string `json:"query"`
112-
Statement string `json:"statement"`
113-
NoFail bool `json:"noFail"`
114-
Values map[string]json.RawMessage `json:"values"`
115-
ValuesBatch []map[string]json.RawMessage `json:"valuesBatch"`
111+
Query string `json:"query"`
112+
Statement string `json:"statement"`
113+
NoFail bool `json:"noFail"`
114+
Values json.RawMessage `json:"values"`
115+
ValuesBatch []json.RawMessage `json:"valuesBatch"`
116116
}
117117

118118
type request struct {
119119
Credentials *credentials `json:"credentials"`
120120
Transaction []requestItem `json:"transaction"`
121121
}
122122

123+
type requestParams struct {
124+
UnmarshalledDict map[string]any
125+
UnmarshalledArray []any
126+
}
127+
123128
// These are for generating the response
124129

125130
type responseItem struct {

src/utils.go

+35
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@ import (
2020
"bytes"
2121
"database/sql"
2222
"encoding/json"
23+
"errors"
2324
"github.com/iancoleman/orderedmap"
2425
"github.com/mitchellh/go-homedir"
2526
mllog "github.com/proofrock/go-mylittlelogger"
2627
"os"
2728
"path/filepath"
29+
"slices"
2830
"strings"
2931
)
3032

@@ -78,6 +80,39 @@ func vals2nameds(vals map[string]interface{}) []interface{} {
7880
return nameds
7981
}
8082

83+
func isEmptyRaw(raw json.RawMessage) bool {
84+
// the last check is for `null`
85+
return raw == nil || len(raw) == 0 || slices.Equal(raw, []byte{110, 117, 108, 108})
86+
}
87+
88+
func raw2params(raw json.RawMessage) (*requestParams, error) {
89+
params := requestParams{}
90+
if isEmptyRaw(raw) {
91+
params.UnmarshalledArray = []any{}
92+
return &params, nil
93+
}
94+
switch raw[0] {
95+
case '[':
96+
values := make([]any, 0)
97+
err := json.Unmarshal(raw, &values)
98+
if err != nil {
99+
return nil, err
100+
}
101+
params.UnmarshalledArray = values
102+
case '{':
103+
values := make(map[string]interface{})
104+
err := json.Unmarshal(raw, &values)
105+
if err != nil {
106+
return nil, err
107+
}
108+
params.UnmarshalledDict = values
109+
default:
110+
return nil, errors.New("values should be an array or an object")
111+
}
112+
113+
return &params, nil
114+
}
115+
81116
// Processes paths with home (tilde) expansion. Fails if not valid
82117
func expandHomeDir(path string, desc string) string {
83118
ePath, err := homedir.Expand(path)

src/web_service.go

+39-15
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,18 @@ func reportError(err error, code int, reqIdx int, noFail bool, results []respons
6363
// Processes a query, and returns a suitable responseItem
6464
//
6565
// This method is needed to execute properly the defers.
66-
func processWithResultSet(tx *sql.Tx, query string, values map[string]interface{}) (*responseItem, error) {
66+
func processWithResultSet(tx *sql.Tx, query string, params requestParams) (*responseItem, error) {
6767
resultSet := make([]orderedmap.OrderedMap, 0)
6868

69-
rows, err := tx.Query(query, vals2nameds(values)...)
69+
rows := (*sql.Rows)(nil)
70+
err := (error)(nil)
71+
if params.UnmarshalledDict == nil && params.UnmarshalledArray == nil {
72+
rows, err = nil, errors.New("processWithResultSet unreachable code")
73+
} else if params.UnmarshalledDict != nil {
74+
rows, err = tx.Query(query, vals2nameds(params.UnmarshalledDict)...)
75+
} else {
76+
rows, err = tx.Query(query, params.UnmarshalledArray...)
77+
}
7078
if err != nil {
7179
return nil, err
7280
}
@@ -99,8 +107,16 @@ func processWithResultSet(tx *sql.Tx, query string, values map[string]interface{
99107
}
100108

101109
// Process a single statement, and returns a suitable responseItem
102-
func processForExec(tx *sql.Tx, statement string, values map[string]interface{}) (*responseItem, error) {
103-
res, err := tx.Exec(statement, vals2nameds(values)...)
110+
func processForExec(tx *sql.Tx, statement string, params requestParams) (*responseItem, error) {
111+
res := (sql.Result)(nil)
112+
err := (error)(nil)
113+
if params.UnmarshalledDict == nil && params.UnmarshalledArray == nil {
114+
res, err = nil, errors.New("processWithResultSet unreachable code")
115+
} else if params.UnmarshalledDict != nil {
116+
res, err = tx.Exec(statement, vals2nameds(params.UnmarshalledDict)...)
117+
} else {
118+
res, err = tx.Exec(statement, params.UnmarshalledArray...)
119+
}
104120
if err != nil {
105121
return nil, err
106122
}
@@ -115,16 +131,24 @@ func processForExec(tx *sql.Tx, statement string, values map[string]interface{})
115131

116132
// Process a batch statement, and returns a suitable responseItem.
117133
// It prepares the statement, then executes it for each of the values' sets.
118-
func processForExecBatch(tx *sql.Tx, q string, valuesBatch []map[string]interface{}) (*responseItem, error) {
134+
func processForExecBatch(tx *sql.Tx, q string, paramsBatch []requestParams) (*responseItem, error) {
119135
ps, err := tx.Prepare(q)
120136
if err != nil {
121137
return nil, err
122138
}
123139
defer ps.Close()
124140

125141
var rowsUpdatedBatch []int64
126-
for i := range valuesBatch {
127-
res, err := ps.Exec(vals2nameds(valuesBatch[i])...)
142+
for _, params := range paramsBatch {
143+
res := (sql.Result)(nil)
144+
err := (error)(nil)
145+
if params.UnmarshalledDict == nil && params.UnmarshalledArray == nil {
146+
res, err = nil, errors.New("processWithResultSet unreachable code")
147+
} else if params.UnmarshalledDict != nil {
148+
res, err = tx.Exec(q, vals2nameds(params.UnmarshalledDict)...)
149+
} else {
150+
res, err = tx.Exec(q, params.UnmarshalledArray...)
151+
}
128152
if err != nil {
129153
return nil, err
130154
}
@@ -215,7 +239,7 @@ func handler(databaseId string) func(c *fiber.Ctx) error {
215239

216240
hasResultSet := txItem.Query != ""
217241

218-
if len(txItem.Values) != 0 && len(txItem.ValuesBatch) != 0 {
242+
if !isEmptyRaw(txItem.Values) && len(txItem.ValuesBatch) != 0 {
219243
reportError(errors.New("cannot specify both values and valuesBatch"), fiber.StatusBadRequest, i, txItem.NoFail, ret.Results)
220244
continue
221245
}
@@ -256,18 +280,18 @@ func handler(databaseId string) func(c *fiber.Ctx) error {
256280

257281
if len(txItem.ValuesBatch) > 0 {
258282
// Process a batch statement (multiple values)
259-
var valuesBatch []map[string]interface{}
283+
var paramsBatch []requestParams
260284
for i2 := range txItem.ValuesBatch {
261-
values, err := raw2vals(txItem.ValuesBatch[i2])
285+
params, err := raw2params(txItem.ValuesBatch[i2])
262286
if err != nil {
263287
reportError(err, fiber.StatusInternalServerError, i, txItem.NoFail, ret.Results)
264288
continue
265289
}
266290

267-
valuesBatch = append(valuesBatch, values)
291+
paramsBatch = append(paramsBatch, *params)
268292
}
269293

270-
retE, err := processForExecBatch(tx, sqll, valuesBatch)
294+
retE, err := processForExecBatch(tx, sqll, paramsBatch)
271295
if err != nil {
272296
reportError(err, fiber.StatusInternalServerError, i, txItem.NoFail, ret.Results)
273297
continue
@@ -276,7 +300,7 @@ func handler(databaseId string) func(c *fiber.Ctx) error {
276300
ret.Results[i] = *retE
277301
} else {
278302
// At most one values set (be it query or statement)
279-
values, err := raw2vals(txItem.Values)
303+
params, err := raw2params(txItem.Values)
280304
if err != nil {
281305
reportError(err, fiber.StatusInternalServerError, i, txItem.NoFail, ret.Results)
282306
continue
@@ -285,7 +309,7 @@ func handler(databaseId string) func(c *fiber.Ctx) error {
285309
if hasResultSet {
286310
// Query
287311
// Externalized in a func so that defer rows.Close() actually runs
288-
retWR, err := processWithResultSet(tx, sqll, values)
312+
retWR, err := processWithResultSet(tx, sqll, *params)
289313
if err != nil {
290314
reportError(err, fiber.StatusInternalServerError, i, txItem.NoFail, ret.Results)
291315
continue
@@ -294,7 +318,7 @@ func handler(databaseId string) func(c *fiber.Ctx) error {
294318
ret.Results[i] = *retWR
295319
} else {
296320
// Statement
297-
retE, err := processForExec(tx, sqll, values)
321+
retE, err := processForExec(tx, sqll, *params)
298322
if err != nil {
299323
reportError(err, fiber.StatusInternalServerError, i, txItem.NoFail, ret.Results)
300324
continue

src/ws4sqlite_test.go

+50-10
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,9 @@ func call(databaseId string, req request, t *testing.T) (int, string, response)
9393
return callBA(databaseId, req, "", "", t)
9494
}
9595

96-
func mkRaw(mapp map[string]interface{}) map[string]json.RawMessage {
97-
ret := make(map[string]json.RawMessage)
98-
for k, v := range mapp {
99-
bytes, _ := json.Marshal(v)
100-
ret[k] = bytes
101-
}
102-
return ret
96+
func mkRaw(mapp any) json.RawMessage {
97+
bs, _ := json.Marshal(mapp)
98+
return bs
10399
}
104100

105101
func TestSetup(t *testing.T) {
@@ -191,7 +187,7 @@ func TestTx(t *testing.T) {
191187
},
192188
{
193189
Statement: "INSERT INTO T1 (ID, VAL) VALUES (:ID, :VAL)",
194-
ValuesBatch: []map[string]json.RawMessage{
190+
ValuesBatch: []json.RawMessage{
195191
mkRaw(map[string]interface{}{
196192
"ID": 3,
197193
"VAL": "THREE",
@@ -328,7 +324,7 @@ func TestConcurrent(t *testing.T) {
328324
},
329325
{
330326
Statement: "INSERT INTO T1 (ID, VAL) VALUES (:ID, :VAL)",
331-
ValuesBatch: []map[string]json.RawMessage{
327+
ValuesBatch: []json.RawMessage{
332328
mkRaw(map[string]interface{}{
333329
"ID": 3,
334330
"VAL": "THREE",
@@ -442,6 +438,50 @@ func TestResultSetOrder(t *testing.T) {
442438
}
443439
}
444440

441+
func TestArrayParams(t *testing.T) {
442+
req := request{
443+
Transaction: []requestItem{
444+
{
445+
Statement: "CREATE TABLE table_with_many_columns (d INT, c INT, b INT, a INT)",
446+
},
447+
{
448+
Statement: "INSERT INTO table_with_many_columns VALUES (?, ?, ?, ?)",
449+
Values: mkRaw([]int{1, 1, 1, 1}),
450+
},
451+
{
452+
Statement: "INSERT INTO table_with_many_columns VALUES (?, ?, ?, ?)",
453+
ValuesBatch: []json.RawMessage{
454+
mkRaw([]int{2, 2, 2, 2}),
455+
mkRaw([]int{3, 3, 3, 3}),
456+
mkRaw([]int{4, 4, 4, 4}),
457+
},
458+
},
459+
{
460+
Query: "SELECT * FROM table_with_many_columns",
461+
},
462+
{
463+
Statement: "DROP TABLE table_with_many_columns",
464+
},
465+
},
466+
}
467+
code, _, resp := call("test", req, t)
468+
469+
if code != 200 {
470+
t.Error("did not succeed")
471+
return
472+
}
473+
queryResult := resp.Results[3]
474+
if !queryResult.Success {
475+
t.Error("could not query")
476+
return
477+
}
478+
records := queryResult.ResultSet
479+
if len(records) != 4 {
480+
t.Error("expected 4 records")
481+
return
482+
}
483+
}
484+
445485
// don't remove the file, we'll use it for the next tests for read-only
446486
func TestTeardown(t *testing.T) {
447487
time.Sleep(time.Second)
@@ -1137,7 +1177,7 @@ func TestItemFieldsInsertBatch(t *testing.T) {
11371177
Transaction: []requestItem{
11381178
{
11391179
Statement: "INSERT INTO T1 VALUES (:ID, :VAL)",
1140-
ValuesBatch: []map[string]json.RawMessage{
1180+
ValuesBatch: []json.RawMessage{
11411181
mkRaw(map[string]interface{}{
11421182
"ID": 3,
11431183
"VAL": "THREE",

0 commit comments

Comments
 (0)