diff --git a/cmd/root_test.go b/cmd/root_test.go index fb346569b..91c643af7 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -92,14 +92,16 @@ func TestServer(t *testing.T) { timeout := 5 * time.Second delay := 200 * time.Millisecond for { - _, err := http.DefaultClient.Get("http://localhost:3068/_info") - if err != nil { + rsp, err := http.DefaultClient.Get("http://localhost:3068/_info") + if err != nil || rsp.StatusCode != http.StatusOK { if counter*delay < timeout { counter++ <-time.After(delay) continue } - assert.FailNow(t, err.Error()) + if assert.FailNow(t, err.Error()) { + return + } } break } diff --git a/cmd/server_start.go b/cmd/server_start.go index 4000af6f6..5366b0335 100644 --- a/cmd/server_start.go +++ b/cmd/server_start.go @@ -16,24 +16,24 @@ func NewServerStart() *cobra.Command { RunE: func(cmd *cobra.Command, args []string) error { app := NewContainer( viper.GetViper(), - fx.Invoke(func(h *api.API) error { - listener, err := net.Listen("tcp", viper.GetString(serverHttpBindAddressFlag)) - if err != nil { - return err - } - - go http.Serve(listener, h) - go func() { - select { - case <-cmd.Context().Done(): - } - err := listener.Close() - if err != nil { - panic(err) - } - }() - - return nil + fx.Invoke(func(lc fx.Lifecycle, h *api.API) { + var ( + err error + listener net.Listener + ) + lc.Append(fx.Hook{ + OnStart: func(ctx context.Context) error { + listener, err = net.Listen("tcp", viper.GetString(serverHttpBindAddressFlag)) + if err != nil { + return err + } + go http.Serve(listener, h) + return nil + }, + OnStop: func(ctx context.Context) error { + return listener.Close() + }, + }) }), ) errCh := make(chan error, 1) diff --git a/go.mod b/go.mod index 8acdbe471..41877581a 100644 --- a/go.mod +++ b/go.mod @@ -57,6 +57,7 @@ require ( github.com/ThreeDotsLabs/watermill-http v1.1.4 // indirect github.com/ThreeDotsLabs/watermill-kafka/v2 v2.2.1 // indirect github.com/antlr/antlr4/runtime/Go/antlr v0.0.0-20211226235153-13a0add2f557 // indirect + github.com/buger/jsonparser v1.1.1 // indirect github.com/cenkalti/backoff/v4 v4.1.2 // indirect github.com/census-instrumentation/opencensus-proto v0.3.0 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect diff --git a/go.sum b/go.sum index 74072b0aa..f1d8d8e10 100644 --- a/go.sum +++ b/go.sum @@ -120,6 +120,8 @@ github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kB github.com/bits-and-blooms/bitset v1.2.0/go.mod h1:gIdJ4wp64HaoK2YrL1Q5/N7Y16edYb8uY+O0FJTyyDA= github.com/bketelsen/crypt v0.0.3-0.20200106085610-5cbc8cc4026c/go.mod h1:MKsuJmJgSg28kpZDP6UIiPt0e0Oz0kqKNGyRaWEPv84= github.com/bketelsen/crypt v0.0.4/go.mod h1:aI6NrJ0pMGgvZKL1iVgXLnfIFJtfV+bKCoqOes/6LfM= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/casbin/casbin/v2 v2.1.2/go.mod h1:YcPU1XXisHhLzuxH9coDNf2FbKpjGlbCg3n9yuLkIJQ= github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4= github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM= diff --git a/internal/pgtesting/testing.go b/internal/pgtesting/testing.go index 201349a6c..ba3c1931c 100644 --- a/internal/pgtesting/testing.go +++ b/internal/pgtesting/testing.go @@ -25,7 +25,7 @@ func (s *PGServer) Close() error { return s.close() } -const MaxConnections = 2 +const MaxConnections = 3 func PostgresServer() (*PGServer, error) { diff --git a/pkg/api/controllers/account_controller.go b/pkg/api/controllers/account_controller.go index a75cec6e7..6cacf216f 100644 --- a/pkg/api/controllers/account_controller.go +++ b/pkg/api/controllers/account_controller.go @@ -22,9 +22,14 @@ func NewAccountController() AccountController { func (ctl *AccountController) GetAccounts(c *gin.Context) { l, _ := c.Get("ledger") + cursor, err := l.(*ledger.Ledger).FindAccounts( c.Request.Context(), query.After(c.Query("after")), + query.Address(c.Query("address")), + func(q *query.Query) { + q.Params["metadata"] = c.QueryMap("metadata") + }, ) if err != nil { ResponseError(c, err) diff --git a/pkg/api/controllers/account_controller_test.go b/pkg/api/controllers/account_controller_test.go index 0cf60c68d..8f2d95736 100644 --- a/pkg/api/controllers/account_controller_test.go +++ b/pkg/api/controllers/account_controller_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/assert" "go.uber.org/fx" "net/http" + "net/url" "testing" ) @@ -27,7 +28,9 @@ func TestGetAccounts(t *testing.T) { }, }, }) - assert.Equal(t, http.StatusOK, rsp.Result().StatusCode) + if !assert.Equal(t, http.StatusOK, rsp.Result().StatusCode) { + return nil + } rsp = internal.PostTransaction(t, h, core.TransactionData{ Postings: core.Postings{ @@ -39,14 +42,104 @@ func TestGetAccounts(t *testing.T) { }, }, }) - assert.Equal(t, http.StatusOK, rsp.Result().StatusCode) + if !assert.Equal(t, http.StatusOK, rsp.Result().StatusCode) { + return nil + } - rsp = internal.GetAccounts(h) - assert.Equal(t, http.StatusOK, rsp.Result().StatusCode) + rsp = internal.PostAccountMetadata(t, h, "bob", core.Metadata{ + "roles": json.RawMessage(`"admin"`), + "accountId": json.RawMessage("3"), + "enabled": json.RawMessage(`"true"`), + "a": json.RawMessage(`{"nested": {"key": "hello"}}`), + }) + if !assert.Equal(t, http.StatusNoContent, rsp.Result().StatusCode) { + return nil + } + + rsp = internal.GetAccounts(h, url.Values{}) + if !assert.Equal(t, http.StatusOK, rsp.Result().StatusCode) { + return nil + } cursor := internal.DecodeCursorResponse(t, rsp.Body, core.Account{}) - assert.EqualValues(t, 3, cursor.Total) - assert.Len(t, cursor.Data, 3) + if !assert.EqualValues(t, 3, cursor.Total) { + return nil + } + if !assert.Len(t, cursor.Data, 3) { + return nil + } + + rsp = internal.GetAccounts(h, url.Values{ + "metadata[roles]": []string{"admin"}, + }) + if !assert.Equal(t, http.StatusOK, rsp.Result().StatusCode) { + return nil + } + + cursor = internal.DecodeCursorResponse(t, rsp.Body, core.Account{}) + if !assert.EqualValues(t, 1, cursor.Total) { + return nil + } + if !assert.Len(t, cursor.Data, 1) { + return nil + } + + rsp = internal.GetAccounts(h, url.Values{ + "metadata[accountId]": []string{"3"}, + }) + if !assert.Equal(t, http.StatusOK, rsp.Result().StatusCode) { + return nil + } + + cursor = internal.DecodeCursorResponse(t, rsp.Body, core.Account{}) + if !assert.EqualValues(t, 1, cursor.Total) { + return nil + } + if !assert.Len(t, cursor.Data, 1) { + return nil + } + + rsp = internal.GetAccounts(h, url.Values{ + "metadata[enabled]": []string{"true"}, + }) + if !assert.Equal(t, http.StatusOK, rsp.Result().StatusCode) { + return nil + } + + cursor = internal.DecodeCursorResponse(t, rsp.Body, core.Account{}) + if !assert.EqualValues(t, 1, cursor.Total) { + return nil + } + if !assert.Len(t, cursor.Data, 1) { + return nil + } + + rsp = internal.GetAccounts(h, url.Values{ + "metadata[a.nested.key]": []string{"hello"}, + }) + if !assert.Equal(t, http.StatusOK, rsp.Result().StatusCode) { + return nil + } + + cursor = internal.DecodeCursorResponse(t, rsp.Body, core.Account{}) + if !assert.EqualValues(t, 1, cursor.Total) { + return nil + } + if !assert.Len(t, cursor.Data, 1) { + return nil + } + + rsp = internal.GetAccounts(h, url.Values{ + "metadata[unknown]": []string{"key"}, + }) + if !assert.Equal(t, http.StatusOK, rsp.Result().StatusCode) { + return nil + } + + cursor = internal.DecodeCursorResponse(t, rsp.Body, core.Account{}) + if !assert.EqualValues(t, 0, cursor.Total) { + return nil + } return nil }, }) @@ -88,7 +181,7 @@ func TestGetAccount(t *testing.T) { }, Volumes: map[string]map[string]int64{ "USD": { - "input": 100, + "input": 100, "output": 0, }, }, diff --git a/pkg/api/controllers/config_controller_test.go b/pkg/api/controllers/config_controller_test.go index 0716cefc7..84706378b 100644 --- a/pkg/api/controllers/config_controller_test.go +++ b/pkg/api/controllers/config_controller_test.go @@ -5,6 +5,7 @@ import ( "github.com/numary/ledger/pkg/api" "github.com/numary/ledger/pkg/api/controllers" "github.com/numary/ledger/pkg/api/internal" + "github.com/numary/ledger/pkg/storage" "github.com/stretchr/testify/assert" "go.uber.org/fx" "net/http" @@ -12,7 +13,7 @@ import ( ) func TestGetInfo(t *testing.T) { - internal.RunTest(t, fx.Invoke(func(lc fx.Lifecycle, h *api.API) { + internal.RunTest(t, fx.Invoke(func(lc fx.Lifecycle, h *api.API, driver storage.Driver) { lc.Append(fx.Hook{ OnStart: func(ctx context.Context) error { rsp := internal.GetInfo(h) @@ -20,13 +21,14 @@ func TestGetInfo(t *testing.T) { info := controllers.ConfigInfo{} internal.DecodeSingleResponse(t, rsp.Body, &info) + info.Config.LedgerStorage.Ledgers = []string{} assert.EqualValues(t, controllers.ConfigInfo{ Server: "numary-ledger", Version: "latest", Config: &controllers.Config{ LedgerStorage: &controllers.LedgerStorage{ - Driver: "sqlite", - Ledgers: []string{"quickstart"}, + Driver: driver.Name(), + Ledgers: []string{}, }, }, }, info) diff --git a/pkg/api/controllers/mapping_controller_test.go b/pkg/api/controllers/mapping_controller_test.go index 8f4832c44..6301e0596 100644 --- a/pkg/api/controllers/mapping_controller_test.go +++ b/pkg/api/controllers/mapping_controller_test.go @@ -27,14 +27,20 @@ func TestMapping(t *testing.T) { }, } rsp := internal.SaveMapping(t, h, m) - assert.Equal(t, http.StatusOK, rsp.Result().StatusCode) + if !assert.Equal(t, http.StatusOK, rsp.Result().StatusCode) { + return nil + } rsp = internal.LoadMapping(h) - assert.Equal(t, http.StatusOK, rsp.Result().StatusCode) + if !assert.Equal(t, http.StatusOK, rsp.Result().StatusCode) { + return nil + } m2 := core.Mapping{} internal.DecodeSingleResponse(t, rsp.Body, &m2) - assert.EqualValues(t, m, m2) + if !assert.EqualValues(t, m, m2) { + return nil + } return nil }, }) diff --git a/pkg/api/controllers/swagger.yaml b/pkg/api/controllers/swagger.yaml index bae6bfec6..27415164a 100644 --- a/pkg/api/controllers/swagger.yaml +++ b/pkg/api/controllers/swagger.yaml @@ -37,6 +37,21 @@ paths: description: pagination cursor, will return accounts after given address (in descending order) schema: type: string + - name: address + in: query + description: account address + required: false + schema: + type: string + - name: metadata + in: query + description: account address + required: false + style: deepObject + schema: + type: object + additionalProperties: + type: string responses: "200": description: OK diff --git a/pkg/api/controllers/transaction_controller_test.go b/pkg/api/controllers/transaction_controller_test.go index cd8b2623f..4104ccaad 100644 --- a/pkg/api/controllers/transaction_controller_test.go +++ b/pkg/api/controllers/transaction_controller_test.go @@ -309,29 +309,39 @@ func TestPostTransactionMetadata(t *testing.T) { }, }, }) - assert.Equal(t, http.StatusOK, rsp.Result().StatusCode) + if !assert.Equal(t, http.StatusOK, rsp.Result().StatusCode) { + return nil + } tx := make([]core.Transaction, 0) internal.DecodeSingleResponse(t, rsp.Body, &tx) rsp = internal.PostTransactionMetadata(t, api, tx[0].ID, core.Metadata{ "foo": json.RawMessage(`"bar"`), }) - assert.Equal(t, http.StatusNoContent, rsp.Result().StatusCode) + if !assert.Equal(t, http.StatusNoContent, rsp.Result().StatusCode) { + return nil + } rsp = internal.GetTransaction(api, tx[0].ID) - assert.Equal(t, http.StatusOK, rsp.Result().StatusCode) + if !assert.Equal(t, http.StatusOK, rsp.Result().StatusCode) { + return nil + } ret := core.Transaction{} internal.DecodeSingleResponse(t, rsp.Body, &ret) - assert.EqualValues(t, core.Metadata{ + if !assert.EqualValues(t, core.Metadata{ "foo": json.RawMessage(`"bar"`), - }, ret.Metadata) + }, ret.Metadata) { + return nil + } rsp = internal.PostTransactionMetadata(t, api, tx[0].ID, core.Metadata{ "foo": json.RawMessage(`"baz"`), }) - assert.Equal(t, http.StatusNoContent, rsp.Result().StatusCode) + if !assert.Equal(t, http.StatusNoContent, rsp.Result().StatusCode) { + return nil + } return nil }, }) diff --git a/pkg/api/internal/testing.go b/pkg/api/internal/testing.go index 29d98b1e7..e3b17741a 100644 --- a/pkg/api/internal/testing.go +++ b/pkg/api/internal/testing.go @@ -2,19 +2,24 @@ package internal import ( "bytes" + "context" "encoding/json" "fmt" "github.com/numary/go-libs/sharedapi" + "github.com/numary/go-libs/sharedlogging" + "github.com/numary/go-libs/sharedlogging/sharedlogginglogrus" "github.com/numary/ledger/pkg/api" "github.com/numary/ledger/pkg/core" "github.com/numary/ledger/pkg/ledger" "github.com/numary/ledger/pkg/ledgertesting" "github.com/pborman/uuid" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "go.uber.org/fx" "io" "net/http" "net/http/httptest" + "net/url" "reflect" "testing" ) @@ -112,8 +117,9 @@ func GetTransaction(handler http.Handler, id uint64) *httptest.ResponseRecorder return rec } -func GetAccounts(handler http.Handler) *httptest.ResponseRecorder { +func GetAccounts(handler http.Handler, query url.Values) *httptest.ResponseRecorder { req, rec := NewRequest(http.MethodGet, "/"+testingLedger+"/accounts", nil) + req.URL.RawQuery = query.Encode() handler.ServeHTTP(rec, req) return rec } @@ -155,6 +161,13 @@ func GetInfo(handler http.Handler) *httptest.ResponseRecorder { } func WithNewModule(t *testing.T, options ...fx.Option) { + + l := logrus.New() + if testing.Verbose() { + l.Level = logrus.DebugLevel + } + sharedlogging.SetFactory(sharedlogging.StaticLoggerFactory(sharedlogginglogrus.New(l))) + testingLedger = uuid.New() module := api.Module(api.Config{ StorageDriver: "sqlite", @@ -167,16 +180,26 @@ func WithNewModule(t *testing.T, options ...fx.Option) { ledgertesting.StorageModule(), fx.NopLogger, }, options...) - options = append(options, fx.Invoke(func() { - close(ch) + options = append(options, fx.Invoke(func(lc fx.Lifecycle) { + lc.Append(fx.Hook{ + OnStop: func(ctx context.Context) error { + close(ch) + return nil + }, + }) })) app := fx.New(options...) + if !assert.NoError(t, app.Start(context.Background())) { + return + } select { case <-ch: default: - assert.Fail(t, app.Err().Error()) + if app.Err() != nil { + assert.Fail(t, app.Err().Error()) + } } } diff --git a/pkg/ledger/ledger.go b/pkg/ledger/ledger.go index 60fcae4d7..c3e20c60f 100644 --- a/pkg/ledger/ledger.go +++ b/pkg/ledger/ledger.go @@ -299,7 +299,7 @@ func (l *Ledger) GetLastTransaction(ctx context.Context) (core.Transaction, erro return tx, nil } -func (l *Ledger) FindTransactions(ctx context.Context, m ...query.QueryModifier) (sharedapi.Cursor, error) { +func (l *Ledger) FindTransactions(ctx context.Context, m ...query.Modifier) (sharedapi.Cursor, error) { q := query.New(m) c, err := l.store.FindTransactions(ctx, q) @@ -369,7 +369,7 @@ func (l *Ledger) RevertTransaction(ctx context.Context, id uint64) (*core.Transa } } -func (l *Ledger) FindAccounts(ctx context.Context, m ...query.QueryModifier) (sharedapi.Cursor, error) { +func (l *Ledger) FindAccounts(ctx context.Context, m ...query.Modifier) (sharedapi.Cursor, error) { q := query.New(m) c, err := l.store.FindAccounts(ctx, q) diff --git a/pkg/ledger/query/query.go b/pkg/ledger/query/query.go index f285d23ae..069fc4c9e 100644 --- a/pkg/ledger/query/query.go +++ b/pkg/ledger/query/query.go @@ -1,7 +1,7 @@ package query const ( - DEFAULT_LIMIT = 15 + DefaultLimit = 15 ) type Query struct { @@ -10,11 +10,11 @@ type Query struct { Params map[string]interface{} } -type QueryModifier func(*Query) +type Modifier func(*Query) -func New(qms ...[]QueryModifier) Query { +func New(qms ...[]Modifier) Query { q := Query{ - Limit: DEFAULT_LIMIT, + Limit: DefaultLimit, Params: map[string]interface{}{}, } @@ -25,13 +25,13 @@ func New(qms ...[]QueryModifier) Query { return q } -func (q *Query) Apply(modifiers []QueryModifier) { +func (q *Query) Apply(modifiers []Modifier) { for _, m := range modifiers { m(q) } } -func (q *Query) Modify(modifier QueryModifier) { +func (q *Query) Modify(modifier Modifier) { modifier(q) } @@ -57,6 +57,12 @@ func After(v string) func(*Query) { } } +func Address(v string) func(*Query) { + return func(q *Query) { + q.Params["address"] = v + } +} + func Account(v string) func(*Query) { return func(q *Query) { q.Params["account"] = v diff --git a/pkg/storage/sqlstorage/accounts.go b/pkg/storage/sqlstorage/accounts.go index f99f45e4c..173631166 100644 --- a/pkg/storage/sqlstorage/accounts.go +++ b/pkg/storage/sqlstorage/accounts.go @@ -3,14 +3,40 @@ package sqlstorage import ( "context" "database/sql" - "github.com/numary/go-libs/sharedapi" - "math" - "github.com/huandu/go-sqlbuilder" + "github.com/numary/go-libs/sharedapi" "github.com/numary/ledger/pkg/core" "github.com/numary/ledger/pkg/ledger/query" + "math" + "strings" ) +func (s *Store) accountsQuery(p map[string]interface{}) *sqlbuilder.SelectBuilder { + + sb := sqlbuilder.NewSelectBuilder() + sb. + From(s.schema.Table("accounts")) + + if metadata, ok := p["metadata"]; ok { + for k, metaValue := range metadata.(map[string]string) { + arg := sb.Args.Add(metaValue) + // TODO: Need to find another way to specify the prefix since Table() methods does not make sense for functions and procedures + sb.Where(s.schema.Table("meta_compare(metadata, " + arg + ", '" + strings.Join(strings.Split(k, "."), "', '") + "')")) + } + } + if address, ok := p["address"]; ok && address.(string) != "" { + arg := sb.Args.Add("^" + address.(string) + "$") + switch s.Schema().Flavor() { + case sqlbuilder.PostgreSQL: + sb.Where("address ~* " + arg) + case sqlbuilder.SQLite: + sb.Where("address REGEXP " + arg) + } + } + + return sb +} + func (s *Store) findAccounts(ctx context.Context, exec executor, q query.Query) (sharedapi.Cursor, error) { // We fetch an additional account to know if we have more documents q.Limit = int(math.Max(-1, math.Min(float64(q.Limit), 100))) + 1 @@ -18,12 +44,10 @@ func (s *Store) findAccounts(ctx context.Context, exec executor, q query.Query) c := sharedapi.Cursor{} results := make([]core.Account, 0) - sb := sqlbuilder.NewSelectBuilder() - sb. + sb := s.accountsQuery(q.Params). Select("address", "metadata"). - From(s.schema.Table("accounts")). - OrderBy("address desc"). - Limit(q.Limit) + Limit(q.Limit). + OrderBy("address desc") if q.After != "" { sb.Where(sb.LessThan("address", q.After)) @@ -65,7 +89,7 @@ func (s *Store) findAccounts(ctx context.Context, exec executor, q query.Query) } c.Data = results - total, _ := s.CountAccounts(ctx) + total, _ := s.countAccounts(ctx, exec, q.Params) c.Total = total return c, nil diff --git a/pkg/storage/sqlstorage/aggregations.go b/pkg/storage/sqlstorage/aggregations.go index 8a994f7ca..170bff6d8 100644 --- a/pkg/storage/sqlstorage/aggregations.go +++ b/pkg/storage/sqlstorage/aggregations.go @@ -2,46 +2,38 @@ package sqlstorage import ( "context" + "fmt" "github.com/huandu/go-sqlbuilder" "github.com/numary/ledger/pkg/core" ) -func (s *Store) countTransactions(ctx context.Context, exec executor) (int64, error) { +func (s *Store) countTransactions(ctx context.Context, exec executor, params map[string]interface{}) (int64, error) { var count int64 - sb := sqlbuilder.NewSelectBuilder() - sb.Select("count(*)") - sb.From(s.schema.Table("transactions")) - - sqlq, args := sb.Build() + tq := s.transactionsQuery(params) + sqlq, args := tq.BuildWithFlavor(s.schema.Flavor()) + query := fmt.Sprintf(`SELECT count(*) FROM (%s) AS t`, sqlq) - err := exec.QueryRowContext(ctx, sqlq, args...).Scan(&count) + err := exec.QueryRowContext(ctx, query, args...).Scan(&count) return count, s.error(err) } func (s *Store) CountTransactions(ctx context.Context) (int64, error) { - return s.countTransactions(ctx, s.schema) + return s.countTransactions(ctx, s.schema, map[string]interface{}{}) } -func (s *Store) countAccounts(ctx context.Context, exec executor) (int64, error) { +func (s *Store) countAccounts(ctx context.Context, exec executor, p map[string]interface{}) (int64, error) { var count int64 - sb := sqlbuilder.NewSelectBuilder() - sb. - Select("count(*)"). - From(s.schema.Table("accounts")). - BuildWithFlavor(s.schema.Flavor()) - - sqlq, args := sb.Build() - + sqlq, args := s.accountsQuery(p).Select("count(*)").BuildWithFlavor(s.schema.Flavor()) err := exec.QueryRowContext(ctx, sqlq, args...).Scan(&count) return count, s.error(err) } func (s *Store) CountAccounts(ctx context.Context) (int64, error) { - return s.countAccounts(ctx, s.schema) + return s.countAccounts(ctx, s.schema, map[string]interface{}{}) } func (s *Store) aggregateVolumes(ctx context.Context, exec executor, address string) (core.Volumes, error) { @@ -55,6 +47,7 @@ func (s *Store) aggregateVolumes(ctx context.Context, exec executor, address str if err != nil { return nil, s.error(err) } + defer rows.Close() volumes := make(map[string]map[string]int64) for rows.Next() { diff --git a/pkg/storage/sqlstorage/driver.go b/pkg/storage/sqlstorage/driver.go index abea8529e..b36c201b9 100644 --- a/pkg/storage/sqlstorage/driver.go +++ b/pkg/storage/sqlstorage/driver.go @@ -2,9 +2,10 @@ package sqlstorage import ( "context" - "errors" "github.com/huandu/go-sqlbuilder" + "github.com/numary/go-libs/sharedlogging" "github.com/numary/ledger/pkg/storage" + "github.com/pkg/errors" "time" ) @@ -66,6 +67,8 @@ func (d *Driver) exists(ctx context.Context, ledger string) (bool, error) { if ret.Err() != nil { return false, nil } + var t string + _ = ret.Scan(&t) // Trigger close return true, nil } @@ -73,7 +76,7 @@ func (d *Driver) List(ctx context.Context) ([]string, error) { q, args := sqlbuilder. Select("ledger"). From(d.systemSchema.Table("ledgers")). - BuildWithFlavor(sqlbuilder.Flavor(d.systemSchema.Flavor())) + BuildWithFlavor(d.systemSchema.Flavor()) rows, err := d.systemSchema.QueryContext(ctx, q, args...) if err != nil { return nil, err @@ -97,6 +100,11 @@ func (s *Driver) Name() string { } func (s *Driver) Initialize(ctx context.Context) error { + + sharedlogging.GetLogger(ctx).Debugf("Initialize driver %s", s.name) + + <-time.After(2 * time.Second) + err := s.db.Initialize(ctx) if err != nil { return err @@ -152,7 +160,7 @@ func (s *Driver) GetStore(ctx context.Context, name string, create bool) (storag exists, err := s.exists(ctx, name) if err != nil { - return nil, false, err + return nil, false, errors.Wrap(err, "checking ledger existence") } if !exists && !create { return nil, false, errors.New("not exists") @@ -160,12 +168,12 @@ func (s *Driver) GetStore(ctx context.Context, name string, create bool) (storag schema, err := s.db.Schema(ctx, name) if err != nil { - return nil, false, err + return nil, false, errors.Wrap(err, "opening schema") } created, err := s.Register(ctx, name) if err != nil { - return nil, false, err + return nil, false, errors.Wrap(err, "registering ledger") } err = schema.Initialize(ctx) diff --git a/pkg/storage/sqlstorage/driver_test.go b/pkg/storage/sqlstorage/driver_test.go index 8ee5cfd3e..24e2edb0c 100644 --- a/pkg/storage/sqlstorage/driver_test.go +++ b/pkg/storage/sqlstorage/driver_test.go @@ -35,5 +35,5 @@ func TestNewDriver(t *testing.T) { if !assert.NotNil(t, err) { return } - assert.Equal(t, "sql: database is closed", err.Error()) + assert.Equal(t, "sql: database is closed [UNKNOWN]", err.Error()) } diff --git a/pkg/storage/sqlstorage/mapping.go b/pkg/storage/sqlstorage/mapping.go index 4165d5ee7..46219a74f 100644 --- a/pkg/storage/sqlstorage/mapping.go +++ b/pkg/storage/sqlstorage/mapping.go @@ -3,7 +3,6 @@ package sqlstorage import ( "context" "encoding/json" - "fmt" "github.com/huandu/go-sqlbuilder" "github.com/numary/ledger/pkg/core" ) @@ -79,9 +78,7 @@ func (s *Store) saveMapping(ctx context.Context, exec executor, mapping core.Map sqlq, args = ib.BuildWithFlavor(s.schema.Flavor()) } - fmt.Println("exec") _, err = exec.ExecContext(ctx, sqlq, args...) - fmt.Println("exec ok") return s.error(err) } diff --git a/pkg/storage/sqlstorage/migrations/postgresql/2.sql b/pkg/storage/sqlstorage/migrations/postgresql/2.sql new file mode 100644 index 000000000..18571d69a --- /dev/null +++ b/pkg/storage/sqlstorage/migrations/postgresql/2.sql @@ -0,0 +1,47 @@ +CREATE OR REPLACE FUNCTION "VAR_LEDGER_NAME".meta_compare(metadata jsonb, value varchar, variadic path TEXT[]) + RETURNS BOOLEAN +AS +$$ +BEGIN + return jsonb_extract_path_text(metadata, variadic path)::varchar = value::varchar; +EXCEPTION + WHEN others THEN + RAISE INFO 'Error Name: %', SQLERRM; + RAISE INFO 'Error State: %', SQLSTATE; + RETURN false; +END +$$ + LANGUAGE plpgsql + IMMUTABLE; +--statement +CREATE OR REPLACE FUNCTION meta_compare(metadata jsonb, value bool, variadic path TEXT[]) + RETURNS BOOLEAN +AS +$$ +BEGIN + return jsonb_extract_path(metadata, variadic path)::bool = value::bool; +EXCEPTION + WHEN others THEN + RAISE INFO 'Error Name: %', SQLERRM; + RAISE INFO 'Error State: %', SQLSTATE; + RETURN false; +END +$$ + LANGUAGE plpgsql + IMMUTABLE; +--statement +CREATE OR REPLACE FUNCTION "VAR_LEDGER_NAME".meta_compare(metadata jsonb, value numeric, variadic path TEXT[]) + RETURNS BOOLEAN +AS +$$ +BEGIN + return jsonb_extract_path(metadata, variadic path)::numeric = value::numeric; +EXCEPTION + WHEN others THEN + RAISE INFO 'Error Name: %', SQLERRM; + RAISE INFO 'Error State: %', SQLSTATE; + RETURN false; +END +$$ + LANGUAGE plpgsql + IMMUTABLE; \ No newline at end of file diff --git a/pkg/storage/sqlstorage/schema.go b/pkg/storage/sqlstorage/schema.go index b0ed08a40..931ed2665 100644 --- a/pkg/storage/sqlstorage/schema.go +++ b/pkg/storage/sqlstorage/schema.go @@ -81,6 +81,22 @@ func (s *PGSchema) Delete(ctx context.Context) error { return err } +func (s *PGSchema) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + rows, err := s.baseSchema.QueryContext(ctx, query, args...) + if err != nil { + return nil, errorFromFlavor(PostgreSQL, err) + } + return rows, nil +} + +func (s *PGSchema) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + ret, err := s.baseSchema.ExecContext(ctx, query, args...) + if err != nil { + return nil, errorFromFlavor(PostgreSQL, err) + } + return ret, nil +} + type SQLiteSchema struct { baseSchema file string @@ -98,6 +114,22 @@ func (s SQLiteSchema) Delete(ctx context.Context) error { return os.Remove(s.file) } +func (s *SQLiteSchema) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + rows, err := s.baseSchema.QueryContext(ctx, query, args...) + if err != nil { + return nil, errorFromFlavor(SQLite, err) + } + return rows, nil +} + +func (s *SQLiteSchema) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + ret, err := s.baseSchema.ExecContext(ctx, query, args...) + if err != nil { + return nil, errorFromFlavor(SQLite, err) + } + return ret, nil +} + type DB interface { Initialize(ctx context.Context) error Schema(ctx context.Context, name string) (Schema, error) diff --git a/pkg/storage/sqlstorage/sqlite.go b/pkg/storage/sqlstorage/sqlite.go index 17ffa54ed..6b7ac9015 100644 --- a/pkg/storage/sqlstorage/sqlite.go +++ b/pkg/storage/sqlstorage/sqlite.go @@ -10,9 +10,13 @@ package sqlstorage import ( "database/sql" "encoding/json" + "github.com/buger/jsonparser" + _ "github.com/buger/jsonparser" "github.com/mattn/go-sqlite3" "github.com/numary/ledger/pkg/core" "github.com/numary/ledger/pkg/storage" + "regexp" + "strconv" ) func init() { @@ -28,7 +32,7 @@ func init() { } sql.Register("sqlite3-custom", &sqlite3.SQLiteDriver{ ConnectHook: func(conn *sqlite3.SQLiteConn) error { - return conn.RegisterFunc("hash_log", func(v1, v2 string) string { + err := conn.RegisterFunc("hash_log", func(v1, v2 string) string { m1 := make(map[string]interface{}) m2 := make(map[string]interface{}) err := json.Unmarshal([]byte(v1), &m1) @@ -41,6 +45,56 @@ func init() { } return core.Hash(m1, m2) }, true) + if err != nil { + return err + } + err = conn.RegisterFunc("regexp", func(re, s string) (bool, error) { + b, e := regexp.MatchString(re, s) + return b, e + }, true) + if err != nil { + return err + } + err = conn.RegisterFunc("meta_compare", func(metadata string, value string, key ...string) bool { + bytes, dataType, _, err := jsonparser.Get([]byte(metadata), key...) + if err != nil { + return false + } + switch dataType { + case jsonparser.String: + str, err := jsonparser.ParseString(bytes) + if err != nil { + return false + } + return value == str + case jsonparser.Boolean: + b, err := jsonparser.ParseBoolean(bytes) + if err != nil { + return false + } + switch value { + case "true": + return b + case "false": + return !b + } + return false + case jsonparser.Number: + i, err := jsonparser.ParseInt(bytes) + if err != nil { + return false + } + vi, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return false + } + return i == vi + default: + return false + } + return false + }, true) + return err }, }) UpdateSQLDriverMapping(SQLite, "sqlite3-custom") diff --git a/pkg/storage/sqlstorage/store_test.go b/pkg/storage/sqlstorage/store_test.go index 3dc919386..e62442073 100644 --- a/pkg/storage/sqlstorage/store_test.go +++ b/pkg/storage/sqlstorage/store_test.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "github.com/davecgh/go-spew/spew" "github.com/numary/go-libs/sharedlogging" "github.com/numary/go-libs/sharedlogging/sharedlogginglogrus" "github.com/numary/ledger/internal/pgtesting" @@ -24,9 +23,11 @@ import ( func TestStore(t *testing.T) { + l := logrus.New() if testing.Verbose() { - logrus.StandardLogger().Level = logrus.DebugLevel + l.Level = logrus.DebugLevel } + sharedlogging.SetFactory(sharedlogging.StaticLoggerFactory(sharedlogginglogrus.New(l))) type testingFunction struct { name string @@ -234,20 +235,38 @@ func testAggregateVolumes(t *testing.T, store *sqlstorage.Store) { } func testFindAccounts(t *testing.T, store *sqlstorage.Store) { - tx := core.Transaction{ - TransactionData: core.TransactionData{ - Postings: []core.Posting{ - { - Source: "world", - Destination: "central_bank", - Amount: 100, - Asset: "USD", - }, - }, + account1 := core.NewSetMetadataLog(nil, core.SetMetadata{ + TargetType: core.MetaTargetTypeAccount, + TargetID: "world", + Metadata: core.Metadata{ + "foo": json.RawMessage(`"bar"`), }, - Timestamp: time.Now().Round(time.Second).Format(time.RFC3339), - } - _, err := store.AppendLog(context.Background(), core.NewTransactionLog(nil, tx)) + }) + account2 := core.NewSetMetadataLog(&account1, core.SetMetadata{ + TargetType: core.MetaTargetTypeAccount, + TargetID: "bank", + Metadata: core.Metadata{ + "hello": json.RawMessage(`"world"`), + }, + }) + account3 := core.NewSetMetadataLog(&account2, core.SetMetadata{ + TargetType: core.MetaTargetTypeAccount, + TargetID: "order:1", + Metadata: core.Metadata{ + "hello": json.RawMessage(`"world"`), + }, + }) + account4 := core.NewSetMetadataLog(&account3, core.SetMetadata{ + TargetType: core.MetaTargetTypeAccount, + TargetID: "order:2", + Metadata: core.Metadata{ + "number": json.RawMessage(`3`), + "boolean": json.RawMessage(`true`), + "a": json.RawMessage(`{"super": {"nested": {"key": "hello"}}}`), + }, + }) + + _, err := store.AppendLog(context.Background(), account1, account2, account3, account4) if !assert.NoError(t, err) { return } @@ -258,7 +277,7 @@ func testFindAccounts(t *testing.T, store *sqlstorage.Store) { if !assert.NoError(t, err) { return } - if !assert.EqualValues(t, 2, accounts.Total) { + if !assert.EqualValues(t, 4, accounts.Total) { return } if !assert.True(t, accounts.HasMore) { @@ -275,13 +294,119 @@ func testFindAccounts(t *testing.T, store *sqlstorage.Store) { if !assert.NoError(t, err) { return } + if !assert.EqualValues(t, 4, accounts.Total) { + return + } + if !assert.True(t, accounts.HasMore) { + return + } + if !assert.Equal(t, 1, accounts.PageSize) { + return + } + + accounts, err = store.FindAccounts(context.Background(), query.Query{ + Limit: 10, + Params: map[string]interface{}{ + "address": ".*der.*", + }, + }) + if !assert.NoError(t, err) { + return + } if !assert.EqualValues(t, 2, accounts.Total) { return } if !assert.False(t, accounts.HasMore) { return } - if !assert.Equal(t, 1, accounts.PageSize) { + if !assert.Len(t, accounts.Data, 2) { + return + } + if !assert.Equal(t, 10, accounts.PageSize) { + return + } + + accounts, err = store.FindAccounts(context.Background(), query.Query{ + Limit: 10, + Params: map[string]interface{}{ + "metadata": map[string]string{ + "foo": "bar", + }, + }, + }) + if !assert.NoError(t, err) { + return + } + if !assert.EqualValues(t, 1, accounts.Total) { + return + } + if !assert.False(t, accounts.HasMore) { + return + } + if !assert.Len(t, accounts.Data, 1) { + return + } + + accounts, err = store.FindAccounts(context.Background(), query.Query{ + Limit: 10, + Params: map[string]interface{}{ + "metadata": map[string]string{ + "number": "3", + }, + }, + }) + if !assert.NoError(t, err) { + return + } + if !assert.EqualValues(t, 1, accounts.Total) { + return + } + if !assert.False(t, accounts.HasMore) { + return + } + if !assert.Len(t, accounts.Data, 1) { + return + } + + accounts, err = store.FindAccounts(context.Background(), query.Query{ + Limit: 10, + Params: map[string]interface{}{ + "metadata": map[string]string{ + "boolean": "true", + }, + }, + }) + if !assert.NoError(t, err) { + return + } + if !assert.EqualValues(t, 1, accounts.Total) { + return + } + if !assert.False(t, accounts.HasMore) { + return + } + if !assert.Len(t, accounts.Data, 1) { + return + } + + accounts, err = store.FindAccounts(context.Background(), query.Query{ + Limit: 10, + Params: map[string]interface{}{ + "metadata": map[string]string{ + "a.super.nested.key": "hello", + }, + }, + }) + if !assert.NoError(t, err) { + return + } + if !assert.EqualValues(t, 1, accounts.Total) { + return + } + if !assert.False(t, accounts.HasMore) { + return + } + if !assert.Len(t, accounts.Data, 1) { return } } @@ -303,9 +428,8 @@ func testCountTransactions(t *testing.T, store *sqlstorage.Store) { }, Timestamp: time.Now().Round(time.Second).Format(time.RFC3339), } - ret, err := store.AppendLog(context.Background(), core.NewTransactionLog(nil, tx)) + _, err := store.AppendLog(context.Background(), core.NewTransactionLog(nil, tx)) if !assert.NoError(t, err) { - spew.Dump(ret) return } diff --git a/pkg/storage/sqlstorage/transactions.go b/pkg/storage/sqlstorage/transactions.go index 8b8bc4b87..17157be6b 100644 --- a/pkg/storage/sqlstorage/transactions.go +++ b/pkg/storage/sqlstorage/transactions.go @@ -12,14 +12,10 @@ import ( "github.com/numary/ledger/pkg/ledger/query" ) -func (s *Store) findTransactions(ctx context.Context, exec executor, q query.Query) (sharedapi.Cursor, error) { - q.Limit = int(math.Max(-1, math.Min(float64(q.Limit), 100))) + 1 - - c := sharedapi.Cursor{} +func (s *Store) transactionsQuery(p map[string]interface{}) *sqlbuilder.SelectBuilder { sb := sqlbuilder.NewSelectBuilder() - sb.Distinct() - sb.OrderBy("t.id desc") + sb.GroupBy("t.id", "t.postings", "t.metadata", "t.timestamp", "t.reference") sb.Select("t.id", "t.timestamp", "t.reference", "t.metadata", "t.postings") switch s.schema.Flavor() { case sqlbuilder.PostgreSQL: @@ -27,28 +23,38 @@ func (s *Store) findTransactions(ctx context.Context, exec executor, q query.Que case sqlbuilder.SQLite: sb.From("transactions t", "json_each(postings)") } - if q.After != "" { - sb.Where(sb.LessThan("t.id", q.After)) - } - sb.Limit(q.Limit) - if q.HasParam("account") { + if account, ok := p["account"]; ok && account.(string) != "" { switch s.schema.Flavor() { case sqlbuilder.PostgreSQL: sb.Where(sb.Or( - sb.Equal("source", q.Params["account"]), - sb.Equal("destination", q.Params["account"]), + sb.Equal("source", account.(string)), + sb.Equal("destination", account.(string)), )) case sqlbuilder.SQLite: sb.Where(sb.Or( - sb.Equal("json_extract(json_each.value, '$.source')", q.Params["account"]), - sb.Equal("json_extract(json_each.value, '$.destination')", q.Params["account"]), + sb.Equal("json_extract(json_each.value, '$.source')", account.(string)), + sb.Equal("json_extract(json_each.value, '$.destination')", account.(string)), )) } - } - if q.HasParam("reference") { - sb.Where(sb.E("reference", q.Params["reference"])) + if ref, ok := p["reference"]; ok && p["reference"].(string) != "" { + sb.Where(sb.E("reference", ref.(string))) } + return sb +} + +func (s *Store) findTransactions(ctx context.Context, exec executor, q query.Query) (sharedapi.Cursor, error) { + q.Limit = int(math.Max(-1, math.Min(float64(q.Limit), 100))) + 1 + + c := sharedapi.Cursor{} + + sb := s.transactionsQuery(q.Params) + sb.OrderBy("t.id desc") + if q.After != "" { + sb.Where(sb.LessThan("t.id", q.After)) + } + sb.Limit(q.Limit) + sqlq, args := sb.BuildWithFlavor(s.schema.Flavor()) rows, err := exec.QueryContext(ctx, sqlq, args...) if err != nil { @@ -97,8 +103,10 @@ func (s *Store) findTransactions(ctx context.Context, exec executor, q query.Que } c.Data = transactions - // TODO: The count should match the query - total, _ := s.countTransactions(ctx, exec) + total, err := s.countTransactions(ctx, exec, q.Params) + if err != nil { + return c, err + } c.Total = total return c, nil