diff --git a/src/go.mod b/src/go.mod index e9a5687..83c40c2 100644 --- a/src/go.mod +++ b/src/go.mod @@ -18,6 +18,7 @@ require ( github.com/andybalholm/brotli v1.1.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/google/uuid v1.6.0 // indirect + github.com/iancoleman/orderedmap v0.3.0 // indirect github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect github.com/klauspost/compress v1.17.6 // indirect github.com/mattn/go-colorable v0.1.13 // indirect diff --git a/src/go.sum b/src/go.sum index 17b0c7d..61d4295 100644 --- a/src/go.sum +++ b/src/go.sum @@ -8,6 +8,8 @@ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/iancoleman/orderedmap v0.3.0 h1:5cbR2grmZR/DiVt+VJopEhtVs9YGInGIxAoMJn+Ichc= +github.com/iancoleman/orderedmap v0.3.0/go.mod h1:XuLcCUkdL5owUCQeF2Ue9uuw1EptkJDkXXS7VoV7XGE= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= diff --git a/src/sched_tasks_test.go b/src/sched_tasks_test.go index 439ffa7..21156dc 100644 --- a/src/sched_tasks_test.go +++ b/src/sched_tasks_test.go @@ -236,7 +236,7 @@ func TestSchedTasksWithStatement(t *testing.T) { t.Error("did not succeed, but should have") } - if fmt.Sprint(res.Results[0].ResultSet[0]["num"]) != "17" { + if fmt.Sprint(getDefault[float64](res.Results[0].ResultSet[0], "num")) != "17" { t.Error("scheduled statement probably didn't execute") } } diff --git a/src/structs.go b/src/structs.go index 88b108c..ccf9804 100644 --- a/src/structs.go +++ b/src/structs.go @@ -20,6 +20,7 @@ import ( "database/sql" "encoding/json" "fmt" + "github.com/iancoleman/orderedmap" "sync" ) @@ -129,11 +130,11 @@ type request struct { // These are for generating the response type responseItem struct { - Success bool `json:"success"` - RowsUpdated *int64 `json:"rowsUpdated,omitempty"` - RowsUpdatedBatch []int64 `json:"rowsUpdatedBatch,omitempty"` - ResultSet []map[string]interface{} `json:"resultSet,omitnil"` // omitnil is used by jettison - Error string `json:"error,omitempty"` + Success bool `json:"success"` + RowsUpdated *int64 `json:"rowsUpdated,omitempty"` + RowsUpdatedBatch []int64 `json:"rowsUpdatedBatch,omitempty"` + ResultSet []orderedmap.OrderedMap `json:"resultSet,omitnil"` // omitnil is used by jettison + Error string `json:"error,omitempty"` } type response struct { diff --git a/src/utils.go b/src/utils.go index 01096bb..da4db72 100644 --- a/src/utils.go +++ b/src/utils.go @@ -20,6 +20,7 @@ import ( "bytes" "database/sql" "encoding/json" + "github.com/iancoleman/orderedmap" "github.com/mitchellh/go-homedir" mllog "github.com/proofrock/go-mylittlelogger" "os" @@ -121,3 +122,13 @@ func splitOnColon(toSplit string) (string, string) { } return toSplit, "" } + +func getDefault[T any](m orderedmap.OrderedMap, key string) T { + value, ok := m.Get(key) + if !ok { + var t T + return t + } + + return value.(T) +} diff --git a/src/web_service.go b/src/web_service.go index 2b188da..1fd5123 100644 --- a/src/web_service.go +++ b/src/web_service.go @@ -20,6 +20,7 @@ import ( "context" "database/sql" "errors" + "github.com/iancoleman/orderedmap" "strings" "time" @@ -74,20 +75,21 @@ func encrypt(encoder requestItemCrypto, values map[string]interface{}) error { } // Scans the results from a db request and decrypts them as needed -func decrypt(decoder requestItemCrypto, results map[string]interface{}) error { +func decrypt(decoder requestItemCrypto, results *orderedmap.OrderedMap) error { if decoder.CompressionLevel > 0 { return errors.New("cannot specify compression level for decryption") } for i := range decoder.Fields { - sval, ok := results[decoder.Fields[i]].(string) + // sval, ok := results[decoder.Fields[i]].(string) + sval, ok := results.Get(decoder.Fields[i]) if !ok { return errors.New("attempting to decrypt a non-string field") } - dval, err := crypgo.Decrypt(decoder.Password, sval) + dval, err := crypgo.Decrypt(decoder.Password, sval.(string)) if err != nil { return err } - results[decoder.Fields[i]] = dval + results.Set(decoder.Fields[i], dval) } return nil } @@ -106,7 +108,7 @@ func reportError(err error, code int, reqIdx int, noFail bool, results []respons // // This method is needed to execute properly the defers. func processWithResultSet(tx *sql.Tx, query string, decoder *requestItemCrypto, values map[string]interface{}) (*responseItem, error) { - resultSet := make([]map[string]interface{}, 0) + resultSet := make([]orderedmap.OrderedMap, 0) rows, err := tx.Query(query, vals2nameds(values)...) if err != nil { @@ -125,9 +127,9 @@ func processWithResultSet(tx *sql.Tx, query string, decoder *requestItemCrypto, return nil, err } - toAdd := make(map[string]interface{}) + toAdd := orderedmap.New() for i := range values { - toAdd[fields[i]] = values[i] + toAdd.Set(fields[i], values[i]) } if decoder != nil { @@ -135,7 +137,7 @@ func processWithResultSet(tx *sql.Tx, query string, decoder *requestItemCrypto, return nil, err } } - resultSet = append(resultSet, toAdd) + resultSet = append(resultSet, *toAdd) } if err = rows.Err(); err != nil { diff --git a/src/ws4sqlite_test.go b/src/ws4sqlite_test.go index cfbfd81..3561f05 100644 --- a/src/ws4sqlite_test.go +++ b/src/ws4sqlite_test.go @@ -20,6 +20,7 @@ import ( "encoding/json" "github.com/gofiber/fiber/v2" "os" + "slices" "sync" "testing" "time" @@ -111,7 +112,7 @@ func TestSetup(t *testing.T) { { Id: "test", Path: "../test/test.db", - //DisableWALMode: true, + // DisableWALMode: true, StoredStatement: []storedStatement{ { Id: "Q", @@ -224,7 +225,7 @@ func TestTx(t *testing.T) { t.Error("req 1 inconsistent") } - if !res.Results[2].Success || res.Results[2].ResultSet[0]["VAL"] != "ONE" { + if !res.Results[2].Success || getDefault[string](res.Results[2].ResultSet[0], "VAL") != "ONE" { t.Error("req 2 inconsistent") } @@ -367,7 +368,7 @@ func TestConcurrent(t *testing.T) { t.Error("req 1 inconsistent") } - if !res.Results[2].Success || res.Results[2].ResultSet[0]["VAL"] != "ONE" { + if !res.Results[2].Success || getDefault[string](res.Results[2].ResultSet[0], "VAL") != "ONE" { t.Error("req 2 inconsistent") } @@ -387,6 +388,60 @@ func TestConcurrent(t *testing.T) { wg.Wait() } +func TestResultSetOrder(t *testing.T) { + // See this issue for more context: https://github.com/proofrock/sqliterg/issues/5 + req := request{ + Transaction: []requestItem{ + { + Query: "CREATE TABLE table_with_many_columns (d INT, c INT, b INT, a INT)", + }, + { + Query: "INSERT INTO table_with_many_columns VALUES (4, 3, 2, 1)", + }, + { + Query: "SELECT * FROM table_with_many_columns", + }, + }, + } + code, _, res := call("test", req, t) + + if code != 200 { + t.Error("did not succeed") + return + } + + if !res.Results[0].Success || + !res.Results[1].Success || + !res.Results[2].Success { + t.Error("did not succeed") + return + } + + queryResult := res.Results[2].ResultSet[0] + expectedKeys := []string{"d", "c", "b", "a"} + if !slices.Equal( + queryResult.Keys(), + expectedKeys, + ) { + t.Error("should have the right order") + return + } + + expectedValues := []float64{4, 3, 2, 1} + for i, key := range expectedKeys { + value, ok := queryResult.Get(key) + if !ok { + t.Error("unreachable code") + return + } + expectedValue := expectedValues[i] + if value != expectedValue { + t.Error("wrong value") + return + } + } +} + // don't remove the file, we'll use it for the next tests for read-only func TestTeardown(t *testing.T) { time.Sleep(time.Second) @@ -403,7 +458,7 @@ func TestSetupRO(t *testing.T) { { Id: "test", Path: "../test/test.db", - //DisableWALMode: true, + // DisableWALMode: true, ReadOnly: true, StoredStatement: []storedStatement{ { @@ -451,7 +506,7 @@ func TestOkRO(t *testing.T) { return } - if !res.Results[0].Success || res.Results[0].ResultSet[3]["VAL"] != "FOUR" { + if !res.Results[0].Success || getDefault[string](res.Results[0].ResultSet[3], "VAL") != "FOUR" { t.Error("req is inconsistent") } } @@ -478,7 +533,7 @@ func TestConcurrentRO(t *testing.T) { return } - if !res.Results[0].Success || res.Results[0].ResultSet[3]["VAL"] != "FOUR" { + if !res.Results[0].Success || getDefault[string](res.Results[0].ResultSet[3], "VAL") != "FOUR" { t.Error("req is inconsistent") } }(t) @@ -501,7 +556,7 @@ func TestSetupSQO(t *testing.T) { { Id: "test", Path: "../test/test.db", - //DisableWALMode: true, + // DisableWALMode: true, ReadOnly: true, UseOnlyStoredStatements: true, StoredStatement: []storedStatement{ @@ -569,7 +624,7 @@ func TestSetupMEM(t *testing.T) { { Id: "test", Path: ":memory:", - //DisableWALMode: true, + // DisableWALMode: true, StoredStatement: []storedStatement{ { Id: "Q", @@ -640,7 +695,7 @@ func TestSetupMEM_RO(t *testing.T) { Id: "test", Path: ":memory:", ReadOnly: true, - //DisableWALMode: true, + // DisableWALMode: true, StoredStatement: []storedStatement{ { Id: "Q", @@ -689,7 +744,7 @@ func TestSetupWITH_ADD_PROPS(t *testing.T) { { Id: "test", Path: "file::memory:", - //DisableWALMode: true, + // DisableWALMode: true, StoredStatement: []storedStatement{ { Id: "Q", @@ -740,7 +795,7 @@ func TestRO_MEM_IS(t *testing.T) { Id: "test", Path: ":memory:", ReadOnly: true, - //DisableWALMode: true, + // DisableWALMode: true, InitStatements: []string{ "CREATE TABLE T1 (ID INT)", }, @@ -767,7 +822,7 @@ func Test_IS_Err(t *testing.T) { { Id: "test", Path: ":memory:", - //DisableWALMode: true, + // DisableWALMode: true, InitStatements: []string{ "CREATE TABLE T1 (ID INT)", "CREATE TABLE T1 (ID INT)", @@ -1213,7 +1268,7 @@ func TestUnicode(t *testing.T) { if code != 200 { t.Error("SELECT failed", body) } - if res.Results[0].ResultSet[0]["TXT"] != "世界" { + if getDefault[string](res.Results[0].ResultSet[0], "TXT") != "世界" { t.Error("Unicode extraction failed", body) }