diff --git a/go.mod b/go.mod index 05f92208f..b056b48b6 100644 --- a/go.mod +++ b/go.mod @@ -40,4 +40,7 @@ require ( go.uber.org/fx v1.16.0 ) -require github.com/go-logr/stdr v1.2.2 // indirect +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/go-logr/stdr v1.2.2 // indirect +) diff --git a/pkg/api/controllers/controllers.go b/pkg/api/controllers/controllers.go index b71883a07..00bf92f7d 100644 --- a/pkg/api/controllers/controllers.go +++ b/pkg/api/controllers/controllers.go @@ -36,4 +36,5 @@ var Module = fx.Options( fx.Provide(NewScriptController), fx.Provide(NewAccountController), fx.Provide(NewTransactionController), + fx.Provide(NewMappingController), ) diff --git a/pkg/api/controllers/mapping_controller.go b/pkg/api/controllers/mapping_controller.go new file mode 100644 index 000000000..b49054e7e --- /dev/null +++ b/pkg/api/controllers/mapping_controller.go @@ -0,0 +1,79 @@ +package controllers + +import ( + "github.com/gin-gonic/gin" + "github.com/numary/ledger/pkg/core" + "github.com/numary/ledger/pkg/ledger" + "net/http" +) + +type MappingController struct { + BaseController +} + +func NewMappingController() MappingController { + return MappingController{} +} + +// PutMapping godoc +// @Summary Put mapping +// @Description Update ledger mapping +// @Tags mapping +// @Schemes +// @Param ledger path string true "ledger" +// @Accept json +// @Produce json +// @Success 200 {object} controllers.BaseResponse +// @Failure 404 {object} controllers.BaseResponse +// @Router /{ledger}/mapping [put] +func (ctl *MappingController) PutMapping(c *gin.Context) { + l, _ := c.Get("ledger") + + mapping := &core.Mapping{} + err := c.ShouldBind(mapping) + if err != nil { + ctl.responseError(c, http.StatusBadRequest, err) + return + } + + err = l.(*ledger.Ledger).SaveMapping(c.Request.Context(), *mapping) + if err != nil { + ctl.responseError( + c, + http.StatusInternalServerError, + err, + ) + return + } + ctl.response( + c, + http.StatusOK, + mapping, + ) +} + +// GetMapping godoc +// @Summary Get mapping +// @Description Get ledger mapping +// @Tags contracts +// @Schemes +// @Param ledger path string true "ledger" +// @Accept json +// @Produce json +// @Success 200 {object} controllers.BaseResponse +// @Failure 404 {object} controllers.BaseResponse +// @Router /{ledger}/mapping [get] +func (ctl *MappingController) GetMapping(c *gin.Context) { + l, _ := c.Get("ledger") + + mapping, err := l.(*ledger.Ledger).LoadMapping(c.Request.Context()) + if err != nil { + ctl.responseError(c, http.StatusInternalServerError, err) + return + } + ctl.response( + c, + http.StatusOK, + mapping, + ) +} diff --git a/pkg/api/routes/routes.go b/pkg/api/routes/routes.go index 94dc7ac34..85680e5f1 100644 --- a/pkg/api/routes/routes.go +++ b/pkg/api/routes/routes.go @@ -43,6 +43,7 @@ type Routes struct { scriptController controllers.ScriptController accountController controllers.AccountController transactionController controllers.TransactionController + mappingController controllers.MappingController globalMiddlewares []gin.HandlerFunc perLedgerMiddlewares []gin.HandlerFunc } @@ -59,6 +60,7 @@ func NewRoutes( scriptController controllers.ScriptController, accountController controllers.AccountController, transactionController controllers.TransactionController, + mappingController controllers.MappingController, ) *Routes { return &Routes{ globalMiddlewares: globalMiddlewares, @@ -71,6 +73,7 @@ func NewRoutes( scriptController: scriptController, accountController: accountController, transactionController: transactionController, + mappingController: mappingController, } } @@ -109,6 +112,10 @@ func (r *Routes) Engine(cc cors.Config) *gin.Engine { ledger.GET("/accounts/:address", r.accountController.GetAccount) ledger.POST("/accounts/:address/metadata", r.accountController.PostAccountMetadata) + // MappingController + ledger.GET("/mapping", r.mappingController.GetMapping) + ledger.PUT("/mapping", r.mappingController.PutMapping) + // ScriptController ledger.POST("/script", r.scriptController.PostScript) } diff --git a/pkg/core/account.go b/pkg/core/account.go index 1afa9c3f9..bb741c5f2 100644 --- a/pkg/core/account.go +++ b/pkg/core/account.go @@ -6,7 +6,6 @@ const ( type Account struct { Address string `json:"address" example:"users:001"` - Contract string `json:"contract" example:"default"` Type string `json:"type,omitempty" example:"virtual"` Balances map[string]int64 `json:"balances,omitempty" example:"COIN:100"` Volumes map[string]map[string]int64 `json:"volumes,omitempty"` diff --git a/pkg/core/contract.go b/pkg/core/contract.go new file mode 100644 index 000000000..4d6bb4091 --- /dev/null +++ b/pkg/core/contract.go @@ -0,0 +1,39 @@ +package core + +import ( + "encoding/json" + "regexp" + "strings" +) + +type Contract struct { + Expr Expr `json:"expr"` + Account string `json:"account"` +} + +func (c *Contract) UnmarshalJSON(data []byte) error { + type AuxContract Contract + type Aux struct { + AuxContract + Expr map[string]interface{} `json:"expr"` + } + aux := Aux{} + err := json.Unmarshal(data, &aux) + if err != nil { + return err + } + expr, err := ParseRuleExpr(aux.Expr) + if err != nil { + return err + } + *c = Contract{ + Expr: expr, + Account: aux.Account, + } + return nil +} + +func (c Contract) Match(addr string) bool { + r := strings.ReplaceAll(c.Account, "*", ".*") + return regexp.MustCompile(r).Match([]byte(addr)) +} diff --git a/pkg/core/contract_test.go b/pkg/core/contract_test.go new file mode 100644 index 000000000..c7f53ad5e --- /dev/null +++ b/pkg/core/contract_test.go @@ -0,0 +1,14 @@ +package core + +import ( + "encoding/json" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestContract_UnmarshalJSON(t *testing.T) { + contract := &Contract{} + data := `{"id": "foo", "account": "order:*", "expr": { "$gte": ["$balance", 0] }}` + err := json.Unmarshal([]byte(data), contract) + assert.NoError(t, err) +} diff --git a/pkg/core/expr.go b/pkg/core/expr.go new file mode 100644 index 000000000..324181a00 --- /dev/null +++ b/pkg/core/expr.go @@ -0,0 +1,294 @@ +package core + +import ( + "encoding/json" + "errors" + "fmt" + "reflect" + "strings" +) + +type EvalContext struct { + Variables map[string]interface{} + Metadata Metadata + Asset string +} + +type Expr interface { + Eval(EvalContext) bool +} + +type Value interface { + eval(ctx EvalContext) interface{} +} + +type ExprOr []Expr + +func (o ExprOr) Eval(ctx EvalContext) bool { + for _, e := range o { + if e.Eval(ctx) { + return true + } + } + return false +} + +func (e ExprOr) MarshalJSON() ([]byte, error) { + return json.Marshal(map[string]interface{}{ + "$or": []Expr(e), + }) +} + +type ExprAnd []Expr + +func (o ExprAnd) Eval(ctx EvalContext) bool { + for _, e := range o { + if !e.Eval(ctx) { + return false + } + } + return true +} + +func (e ExprAnd) MarshalJSON() ([]byte, error) { + return json.Marshal(map[string]interface{}{ + "$and": []Expr(e), + }) +} + +type ExprEq struct { + Op1 Value + Op2 Value +} + +func (o *ExprEq) Eval(ctx EvalContext) bool { + return reflect.DeepEqual(o.Op1.eval(ctx), o.Op2.eval(ctx)) +} + +func (e ExprEq) MarshalJSON() ([]byte, error) { + return json.Marshal(map[string]interface{}{ + "$eq": []interface{}{e.Op1, e.Op2}, + }) +} + +type ExprGt struct { + Op1 Value + Op2 Value +} + +func (o *ExprGt) Eval(ctx EvalContext) bool { + return o.Op1.eval(ctx).(float64) > o.Op2.eval(ctx).(float64) +} + +func (e ExprGt) MarshalJSON() ([]byte, error) { + return json.Marshal(map[string]interface{}{ + "$gt": []interface{}{e.Op1, e.Op2}, + }) +} + +type ExprLt struct { + Op1 Value + Op2 Value +} + +func (o *ExprLt) Eval(ctx EvalContext) bool { + return o.Op1.eval(ctx).(float64) > o.Op2.eval(ctx).(float64) +} + +func (e ExprLt) MarshalJSON() ([]byte, error) { + return json.Marshal(map[string]interface{}{ + "$lt": []interface{}{e.Op1, e.Op2}, + }) +} + +type ExprGte struct { + Op1 Value + Op2 Value +} + +func (o *ExprGte) Eval(ctx EvalContext) bool { + return o.Op1.eval(ctx).(float64) >= o.Op2.eval(ctx).(float64) +} + +func (e ExprGte) MarshalJSON() ([]byte, error) { + return json.Marshal(map[string]interface{}{ + "$gte": []interface{}{e.Op1, e.Op2}, + }) +} + +type ExprLte struct { + Op1 Value + Op2 Value +} + +func (o *ExprLte) Eval(ctx EvalContext) bool { + return o.Op1.eval(ctx).(float64) >= o.Op2.eval(ctx).(float64) +} + +func (e ExprLte) MarshalJSON() ([]byte, error) { + return json.Marshal(map[string]interface{}{ + "$lte": []interface{}{e.Op1, e.Op2}, + }) +} + +type ConstantExpr struct { + Value interface{} +} + +func (e ConstantExpr) eval(ctx EvalContext) interface{} { + return e.Value +} + +func (e ConstantExpr) MarshalJSON() ([]byte, error) { + return json.Marshal(e.Value) +} + +type VariableExpr struct { + Name string +} + +func (e VariableExpr) eval(ctx EvalContext) interface{} { + return ctx.Variables[e.Name] +} + +func (e VariableExpr) MarshalJSON() ([]byte, error) { + return []byte(fmt.Sprintf(`"$%s"`, e.Name)), nil +} + +type MetaExpr struct { + Name string +} + +func (e MetaExpr) eval(ctx EvalContext) interface{} { + return string(ctx.Metadata[e.Name]) +} + +func (e MetaExpr) MarshalJSON() ([]byte, error) { + return json.Marshal(map[string]interface{}{ + "$meta": e.Name, + }) +} + +func parse(v interface{}) (expr interface{}, err error) { + switch vv := v.(type) { + case map[string]interface{}: + if len(vv) != 1 { + return nil, errors.New("malformed expression") + } + for key, vvv := range vv { + switch { + case strings.HasPrefix(key, "$"): + switch key { + case "$meta": + value, ok := vvv.(string) + if !ok { + return nil, errors.New("$meta operator invalid") + } + return &MetaExpr{Name: value}, nil + case "$or", "$and": + slice, ok := vvv.([]interface{}) + if !ok { + return nil, errors.New("Expected slice for operator " + key) + } + exprs := make([]Expr, 0) + for _, item := range slice { + r, err := parse(item) + if err != nil { + return nil, err + } + expr, ok := r.(Expr) + if !ok { + return nil, errors.New("unexpected value when parsing " + key) + } + exprs = append(exprs, expr) + } + switch key { + case "$and": + expr = ExprAnd(exprs) + case "$or": + expr = ExprOr(exprs) + } + case "$eq", "$gt", "$gte", "$lt", "$lte": + vv, ok := vvv.([]interface{}) + if !ok { + return nil, errors.New("expected array when using $eq") + } + if len(vv) != 2 { + return nil, errors.New("expected 2 items when using $eq") + } + op1, err := parse(vv[0]) + if err != nil { + return nil, err + } + op1Value, ok := op1.(Value) + if !ok { + return nil, errors.New("op1 must be valuable") + } + op2, err := parse(vv[1]) + if err != nil { + return nil, err + } + op2Value, ok := op2.(Value) + if !ok { + return nil, errors.New("op2 must be valuable") + } + switch key { + case "$eq": + expr = &ExprEq{ + Op1: op1Value, + Op2: op2Value, + } + case "$gt": + expr = &ExprGt{ + Op1: op1Value, + Op2: op2Value, + } + case "$gte": + expr = &ExprGte{ + Op1: op1Value, + Op2: op2Value, + } + case "$lt": + expr = &ExprLt{ + Op1: op1Value, + Op2: op2Value, + } + case "$lte": + expr = &ExprLte{ + Op1: op1Value, + Op2: op2Value, + } + } + default: + return nil, errors.New("unknown operator '" + key + "'") + } + } + } + case string: + if !strings.HasPrefix(vv, "$") { + return ConstantExpr{v}, nil + } + return VariableExpr{vv[1:]}, nil + default: + return ConstantExpr{v}, nil + } + + return expr, nil +} + +func ParseRuleExpr(v map[string]interface{}) (Expr, error) { + ret, err := parse(v) + if err != nil { + return nil, err + } + return ret.(Expr), nil +} + +func ParseRule(data string) (Expr, error) { + m := make(map[string]interface{}) + err := json.Unmarshal([]byte(data), &m) + if err != nil { + return nil, err + } + return ParseRuleExpr(m) +} diff --git a/pkg/core/expr_test.go b/pkg/core/expr_test.go new file mode 100644 index 000000000..2dacf7282 --- /dev/null +++ b/pkg/core/expr_test.go @@ -0,0 +1,85 @@ +package core + +import ( + "encoding/json" + "fmt" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestRules(t *testing.T) { + + type testCase struct { + rule map[string]interface{} + context EvalContext + shouldBeAccepted bool + } + + var tests = []testCase{ + { + rule: map[string]interface{}{ + "$or": []interface{}{ + map[string]interface{}{ + "$gt": []interface{}{ + "$balance", float64(0), + }, + }, + map[string]interface{}{ + "$eq": []interface{}{ + map[string]interface{}{ + "$meta": "approved", + }, + "yes", + }, + }, + }, + }, + context: EvalContext{ + Variables: map[string]interface{}{ + "balance": float64(-10), + }, + Metadata: map[string]json.RawMessage{ + "approved": json.RawMessage("yes"), + }, + }, + shouldBeAccepted: true, + }, + { + rule: map[string]interface{}{ + "$and": []interface{}{ + map[string]interface{}{ + "$gt": []interface{}{ + "$balance", float64(0), + }, + }, + map[string]interface{}{ + "$eq": []interface{}{ + map[string]interface{}{ + "$meta": "approved", + }, + "yes", + }, + }, + }, + }, + context: EvalContext{ + Variables: map[string]interface{}{ + "balance": float64(10), + }, + Metadata: map[string]json.RawMessage{ + "approved": json.RawMessage("no"), + }, + }, + shouldBeAccepted: false, + }, + } + + for i, test := range tests { + t.Run(fmt.Sprintf("test%d", i), func(t *testing.T) { + r, err := ParseRuleExpr(test.rule) + assert.NoError(t, err) + assert.Equal(t, test.shouldBeAccepted, r.Eval(test.context)) + }) + } + +} diff --git a/pkg/core/mapping.go b/pkg/core/mapping.go new file mode 100644 index 000000000..b2599e7b9 --- /dev/null +++ b/pkg/core/mapping.go @@ -0,0 +1,5 @@ +package core + +type Mapping struct { + Contracts []Contract `json:"contracts"` +} diff --git a/pkg/ledger/ledger.go b/pkg/ledger/ledger.go index c207cfc10..e9ca3f8c7 100644 --- a/pkg/ledger/ledger.go +++ b/pkg/ledger/ledger.go @@ -16,6 +16,20 @@ const ( targetTypeTransaction = "transaction" ) +var DefaultContracts = []core.Contract{ + { + Expr: &core.ExprGte{ + Op1: core.VariableExpr{ + Name: "balance", + }, + Op2: core.ConstantExpr{ + Value: float64(0), + }, + }, + Account: "*", // world still an exception + }, +} + type Ledger struct { locker Locker name string @@ -84,6 +98,17 @@ func (l *Ledger) Commit(ctx context.Context, ts []core.Transaction) ([]core.Tran } } + mapping, err := l.store.LoadMapping(ctx) + if err != nil { + return nil, err + } + + contracts := make([]core.Contract, 0) + if mapping != nil { + contracts = append(contracts, mapping.Contracts...) + } + contracts = append(contracts, DefaultContracts...) + for addr := range rf { if addr == "world" { continue @@ -109,10 +134,25 @@ func (l *Ledger) Commit(ctx context.Context, ts []core.Transaction) ([]core.Tran } for asset := range checks { - balance, ok := balances[asset] - - if !ok || balance < checks[asset] { - return ts, NewInsufficientFundError(asset) + expectedBalance := balances[asset] - checks[asset] + for _, contract := range contracts { + if contract.Match(addr) { + meta, err := l.store.GetMeta(ctx, "account", addr) + if err != nil { + return nil, err + } + ok := contract.Expr.Eval(core.EvalContext{ + Variables: map[string]interface{}{ + "balance": float64(expectedBalance), + }, + Metadata: meta, + Asset: asset, + }) + if !ok { + return nil, NewInsufficientFundError(asset) + } + break + } } } } @@ -161,6 +201,14 @@ func (l *Ledger) GetTransaction(ctx context.Context, id string) (core.Transactio return tx, err } +func (l *Ledger) SaveMapping(ctx context.Context, mapping core.Mapping) error { + return l.store.SaveMapping(ctx, mapping) +} + +func (l *Ledger) LoadMapping(ctx context.Context) (*core.Mapping, error) { + return l.store.LoadMapping(ctx) +} + func (l *Ledger) RevertTransaction(ctx context.Context, id string) error { tx, err := l.store.GetTransaction(ctx, id) if err != nil { @@ -190,8 +238,7 @@ func (l *Ledger) FindAccounts(ctx context.Context, m ...query.QueryModifier) (qu func (l *Ledger) GetAccount(ctx context.Context, address string) (core.Account, error) { account := core.Account{ - Address: address, - Contract: "default", + Address: address, } balances, err := l.store.AggregateBalances(ctx, address) diff --git a/pkg/opentelemetry/opentelemetrytraces/storage.go b/pkg/opentelemetry/opentelemetrytraces/storage.go index cae32bf4f..295438fd7 100644 --- a/pkg/opentelemetry/opentelemetrytraces/storage.go +++ b/pkg/opentelemetry/opentelemetrytraces/storage.go @@ -128,6 +128,20 @@ func (o *openTelemetryStorage) CountMeta(ctx context.Context) (count int64, err return } +func (o *openTelemetryStorage) LoadMapping(ctx context.Context) (m *core.Mapping, err error) { + o.handle(ctx, "FindContracts", func(ctx context.Context) error { + m, err = o.underlying.LoadMapping(ctx) + return err + }) + return +} + +func (o *openTelemetryStorage) SaveMapping(ctx context.Context, mapping core.Mapping) error { + return o.handle(ctx, "SaveMapping", func(ctx context.Context) error { + return o.underlying.SaveMapping(ctx, mapping) + }) +} + func (o *openTelemetryStorage) Initialize(ctx context.Context) error { return o.handle(ctx, "Initialize", func(ctx context.Context) error { return o.underlying.Initialize(ctx) diff --git a/pkg/storage/sqlstorage/accounts.go b/pkg/storage/sqlstorage/accounts.go index 06c596dac..995d39ec1 100644 --- a/pkg/storage/sqlstorage/accounts.go +++ b/pkg/storage/sqlstorage/accounts.go @@ -53,7 +53,6 @@ func (s *Store) FindAccounts(ctx context.Context, q query.Query) (query.Cursor, account := core.Account{ Address: address, - Contract: "default", } meta, err := s.GetMeta(ctx, "account", account.Address) diff --git a/pkg/storage/sqlstorage/mapping.go b/pkg/storage/sqlstorage/mapping.go new file mode 100644 index 000000000..46c19d735 --- /dev/null +++ b/pkg/storage/sqlstorage/mapping.go @@ -0,0 +1,92 @@ +package sqlstorage + +import ( + "context" + "encoding/json" + "github.com/huandu/go-sqlbuilder" + "github.com/numary/ledger/pkg/core" + "github.com/sirupsen/logrus" +) + +// We have only one mapping for a ledger, so hardcode the id +const mappingId = "0000" + +func (s *Store) LoadMapping(ctx context.Context) (*core.Mapping, error) { + + sb := sqlbuilder.NewSelectBuilder() + sb. + Select("mapping"). + From(s.table("mapping")) + + sqlq, args := sb.BuildWithFlavor(s.flavor) + logrus.Debugln(sqlq, args) + + rows, err := s.db.QueryContext( + ctx, + sqlq, + args..., + ) + if err != nil { + return nil, s.error(err) + } + if !rows.Next() { + return nil, nil + } + + var ( + mappingString string + ) + + err = rows.Scan(&mappingString) + if err != nil { + return nil, err + } + + m := &core.Mapping{} + err = json.Unmarshal([]byte(mappingString), m) + if err != nil { + return nil, err + } + + return m, nil +} + +func (s *Store) SaveMapping(ctx context.Context, mapping core.Mapping) error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return s.error(err) + } + + data, err := json.Marshal(mapping) + if err != nil { + return err + } + + ib := sqlbuilder.NewInsertBuilder() + ib.InsertInto(s.table("mapping")) + ib.Cols("mapping_id", "mapping") + ib.Values(mappingId, string(data)) + + var ( + sqlq string + args []interface{} + ) + switch s.flavor { + case sqlbuilder.Flavor(PostgreSQL): + sqlq, args = ib.BuildWithFlavor(s.flavor) + sqlq += " ON CONFLICT (mapping_id) DO UPDATE SET mapping = $2" + default: + ib.ReplaceInto(s.table("mapping")) + sqlq, args = ib.BuildWithFlavor(s.flavor) + } + + logrus.Debugln(sqlq, args) + + _, err = tx.ExecContext(ctx, sqlq, args...) + if err != nil { + tx.Rollback() + + return s.error(err) + } + return tx.Commit() +} diff --git a/pkg/storage/sqlstorage/migrations/postgresql/v001.sql b/pkg/storage/sqlstorage/migrations/postgresql/v001.sql index 15e54875f..3730b4d10 100644 --- a/pkg/storage/sqlstorage/migrations/postgresql/v001.sql +++ b/pkg/storage/sqlstorage/migrations/postgresql/v001.sql @@ -39,6 +39,21 @@ CREATE TABLE IF NOT EXISTS "VAR_LEDGER_NAME".metadata ( UNIQUE("meta_id") ); --statement +CREATE TABLE IF NOT EXISTS "VAR_LEDGER_NAME".contract ( + "contract_id" integer, + "contract_account" varchar, + "contract_expr" varchar, + + UNIQUE("contract_id") +) +--statement +CREATE TABLE IF NOT EXISTS "VAR_LEDGER_NAME".mapping ( + "mapping_id" varchar, + "mapping" varchar, + + UNIQUE("mapping_id") +) +--statement CREATE INDEX IF NOT EXISTS m_i0 ON "VAR_LEDGER_NAME".metadata ( "meta_target_type", "meta_target_id" diff --git a/pkg/storage/sqlstorage/migrations/sqlite/v001.sql b/pkg/storage/sqlstorage/migrations/sqlite/v001.sql index 9cd0de961..b5d294ee0 100644 --- a/pkg/storage/sqlstorage/migrations/sqlite/v001.sql +++ b/pkg/storage/sqlstorage/migrations/sqlite/v001.sql @@ -37,6 +37,13 @@ CREATE TABLE IF NOT EXISTS metadata ( UNIQUE("meta_id") ); --statement +CREATE TABLE IF NOT EXISTS mapping ( + "mapping_id" varchar, + "mapping" varchar, + + UNIQUE("mapping_id") +) +--statement CREATE INDEX IF NOT EXISTS 'm_i0' ON "metadata" ( "meta_target_type", "meta_target_id" @@ -46,4 +53,4 @@ CREATE VIEW IF NOT EXISTS addresses AS SELECT address FROM ( SELECT source as address FROM postings GROUP BY source UNION SELECT destination as address FROM postings GROUP BY destination -) GROUP BY address; +) GROUP BY address; \ No newline at end of file diff --git a/pkg/storage/sqlstorage/store_test.go b/pkg/storage/sqlstorage/store_test.go index d176109dd..1f03a6cc6 100644 --- a/pkg/storage/sqlstorage/store_test.go +++ b/pkg/storage/sqlstorage/store_test.go @@ -114,6 +114,10 @@ func TestStore(t *testing.T) { name: "GetTransaction", fn: testGetTransaction, }, + { + name: "Mapping", + fn: testMapping, + }, } { t.Run(fmt.Sprintf("%s/%s", driver.driver, tf.name), func(t *testing.T) { ledger := uuid.New() @@ -479,6 +483,38 @@ func testFindTransactions(t *testing.T, store storage.Store) { } +func testMapping(t *testing.T, store storage.Store) { + + m := core.Mapping{ + Contracts: []core.Contract{ + { + Expr: &core.ExprGt{ + Op1: core.VariableExpr{Name: "balance"}, + Op2: core.ConstantExpr{Value: float64(0)}, + }, + Account: "orders:*", + }, + }, + } + err := store.SaveMapping(context.Background(), m) + assert.NoError(t, err) + + mapping, err := store.LoadMapping(context.Background()) + assert.NoError(t, err) + assert.Len(t, mapping.Contracts, 1) + assert.EqualValues(t, m.Contracts[0], mapping.Contracts[0]) + + m2 := core.Mapping{ + Contracts: []core.Contract{}, + } + err = store.SaveMapping(context.Background(), m2) + assert.NoError(t, err) + + mapping, err = store.LoadMapping(context.Background()) + assert.NoError(t, err) + assert.Len(t, mapping.Contracts, 0) +} + func testGetTransaction(t *testing.T, store storage.Store) { txs := []core.Transaction{ { diff --git a/pkg/storage/storage.go b/pkg/storage/storage.go index 9fdcc188d..0e2ffd11c 100644 --- a/pkg/storage/storage.go +++ b/pkg/storage/storage.go @@ -43,6 +43,8 @@ type Store interface { SaveMeta(context.Context, int64, string, string, string, string, string) error GetMeta(context.Context, string, string) (core.Metadata, error) CountMeta(context.Context) (int64, error) + LoadMapping(ctx context.Context) (*core.Mapping, error) + SaveMapping(ctx context.Context, m core.Mapping) error Initialize(context.Context) error Name() string Close(context.Context) error @@ -107,6 +109,14 @@ func (n noOpStore) Initialize(ctx context.Context) error { return nil } +func (n noOpStore) LoadMapping(context.Context) (*core.Mapping, error) { + return nil, nil +} + +func (n noOpStore) SaveMapping(ctx context.Context, mapping core.Mapping) error { + return nil +} + func (n noOpStore) Name() string { return "noop" }