diff --git a/cmd/container.go b/cmd/container.go index 3a5abaf10..e95cdc49e 100644 --- a/cmd/container.go +++ b/cmd/container.go @@ -4,9 +4,8 @@ import ( "io" "github.com/formancehq/ledger/cmd/internal" - "github.com/formancehq/ledger/internal/api" "github.com/formancehq/ledger/internal/engine" - driver2 "github.com/formancehq/ledger/internal/storage/driver" + driver "github.com/formancehq/ledger/internal/storage/driver" "github.com/formancehq/stack/libs/go-libs/otlp/otlpmetrics" "github.com/formancehq/stack/libs/go-libs/otlp/otlptraces" "github.com/formancehq/stack/libs/go-libs/publish" @@ -24,21 +23,18 @@ func resolveOptions(output io.Writer, userOptions ...fx.Option) []fx.Option { v := viper.GetViper() debug := v.GetBool(service.DebugFlag) if debug { - driver2.InstrumentalizeSQLDriver() + driver.InstrumentalizeSQLDriver() } options = append(options, publish.CLIPublisherModule(v, ServiceName), otlptraces.CLITracesModule(v), otlpmetrics.CLIMetricsModule(v), - api.Module(api.Config{ - Version: Version, - }), - driver2.CLIModule(v, output, debug), + driver.CLIModule(v, output, debug), internal.NewAnalyticsModule(v, Version), engine.Module(engine.Configuration{ NumscriptCache: engine.NumscriptCacheConfiguration{ - MaxCount: v.GetInt(numscriptCacheMaxCount), + MaxCount: v.GetInt(numscriptCacheMaxCountFlag), }, }), ) diff --git a/cmd/serve.go b/cmd/serve.go index 629070549..aeb52f111 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -4,6 +4,10 @@ import ( "net/http" "time" + "github.com/formancehq/ledger/internal/storage/driver" + + "github.com/formancehq/ledger/internal/api" + ledger "github.com/formancehq/ledger/internal" "github.com/formancehq/stack/libs/go-libs/ballast" "github.com/formancehq/stack/libs/go-libs/httpserver" @@ -16,8 +20,10 @@ import ( ) const ( - ballastSizeInBytesFlag = "ballast-size" - numscriptCacheMaxCount = "numscript-cache-max-count" + ballastSizeInBytesFlag = "ballast-size" + numscriptCacheMaxCountFlag = "numscript-cache-max-count" + readOnlyFlag = "read-only" + autoUpgradeFlag = "auto-upgrade" ) func NewServe() *cobra.Command { @@ -27,6 +33,17 @@ func NewServe() *cobra.Command { return app.New(cmd.OutOrStdout(), resolveOptions( cmd.OutOrStdout(), ballast.Module(viper.GetUint(ballastSizeInBytesFlag)), + api.Module(api.Config{ + Version: Version, + ReadOnly: viper.GetBool(readOnlyFlag), + }), + fx.Invoke(func(lc fx.Lifecycle, driver *driver.Driver) { + if viper.GetBool(autoUpgradeFlag) { + lc.Append(fx.Hook{ + OnStart: driver.UpgradeAllLedgersSchemas, + }) + } + }), fx.Invoke(func(lc fx.Lifecycle, h chi.Router, logger logging.Logger) { wrappedRouter := chi.NewRouter() @@ -45,7 +62,9 @@ func NewServe() *cobra.Command { }, } cmd.Flags().Uint(ballastSizeInBytesFlag, 0, "Ballast size in bytes, default to 0") - cmd.Flags().Int(numscriptCacheMaxCount, 1024, "Numscript cache max count") + cmd.Flags().Int(numscriptCacheMaxCountFlag, 1024, "Numscript cache max count") + cmd.Flags().Bool(readOnlyFlag, false, "Read only mode") + cmd.Flags().Bool(autoUpgradeFlag, false, "Automatically upgrade all schemas") return cmd } diff --git a/cmd/storage.go b/cmd/storage.go index 83ad405ce..54cc5e249 100644 --- a/cmd/storage.go +++ b/cmd/storage.go @@ -156,31 +156,18 @@ func NewStorageUpgradeAll() *cobra.Command { if err != nil { return err } - defer sqlDB.Close() + defer func() { + if err := sqlDB.Close(); err != nil { + logger.Errorf("Error closing database: %s", err) + } + }() driver := driver.New(sqlDB) if err := driver.Initialize(ctx); err != nil { return err } - systemStore := driver.GetSystemStore() - ledgers, err := systemStore.ListLedgers(ctx) - if err != nil { - return err - } - - for _, ledger := range ledgers { - store, err := driver.GetLedgerStore(ctx, ledger) - if err != nil { - return err - } - logger.Infof("Upgrading storage '%s'", ledger) - if err := upgradeStore(ctx, store, ledger); err != nil { - return err - } - } - - return nil + return driver.UpgradeAllLedgersSchemas(ctx) }, } return cmd diff --git a/go.mod b/go.mod index f3477c2a6..cdcd2a1ce 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/formancehq/ledger -go 1.19 +go 1.20 require ( github.com/Masterminds/semver/v3 v3.2.0 @@ -17,7 +17,6 @@ require ( github.com/jackc/pgx/v5 v5.3.0 github.com/lib/pq v1.10.7 github.com/logrusorgru/aurora v2.0.3+incompatible - github.com/ory/dockertest/v3 v3.9.1 github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 github.com/pborman/uuid v1.2.1 github.com/pkg/errors v0.9.1 @@ -33,8 +32,8 @@ require ( go.opentelemetry.io/otel v1.16.0 go.opentelemetry.io/otel/metric v1.16.0 go.opentelemetry.io/otel/trace v1.16.0 - go.uber.org/atomic v1.10.0 go.uber.org/fx v1.19.2 + go.uber.org/mock v0.3.0 gopkg.in/segmentio/analytics-go.v3 v3.1.0 ) @@ -96,6 +95,7 @@ require ( github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.0.2 // indirect github.com/opencontainers/runc v1.1.3 // indirect + github.com/ory/dockertest/v3 v3.9.1 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect github.com/pierrec/lz4/v4 v4.1.17 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect @@ -144,11 +144,12 @@ require ( go.opentelemetry.io/otel/sdk v1.16.0 // indirect go.opentelemetry.io/otel/sdk/metric v0.39.0 // indirect go.opentelemetry.io/proto/otlp v0.19.0 // indirect + go.uber.org/atomic v1.10.0 // indirect go.uber.org/dig v1.16.1 // indirect go.uber.org/multierr v1.9.0 // indirect go.uber.org/zap v1.24.0 // indirect golang.org/x/crypto v0.9.0 // indirect - golang.org/x/mod v0.8.0 // indirect + golang.org/x/mod v0.11.0 // indirect golang.org/x/net v0.10.0 // indirect golang.org/x/sys v0.8.0 // indirect golang.org/x/text v0.9.0 // indirect diff --git a/go.sum b/go.sum index 025e432bd..ae34c81ca 100644 --- a/go.sum +++ b/go.sum @@ -552,6 +552,8 @@ go.uber.org/dig v1.16.1/go.mod h1:557JTAUZT5bUK0SvCwikmLPPtdQhfvLYtO5tJgQSbnk= go.uber.org/fx v1.19.2 h1:SyFgYQFr1Wl0AYstE8vyYIzP4bFz2URrScjwC4cwUvY= go.uber.org/fx v1.19.2/go.mod h1:43G1VcqSzbIv77y00p1DRAsyZS8WdzuYdhZXmEUkMyQ= go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A= +go.uber.org/mock v0.3.0 h1:3mUxI1No2/60yUYax92Pt8eNOEecx2D3lcXZh2NEZJo= +go.uber.org/mock v0.3.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= go.uber.org/zap v1.24.0 h1:FiJd5l1UOLj0wCgbSE0rwwXHzEdAZS6hiiSnxJN/D60= @@ -603,8 +605,8 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8= -golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU= +golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= diff --git a/internal/api/backend/backend.go b/internal/api/backend/backend.go index 806a33cce..16bbb0782 100644 --- a/internal/api/backend/backend.go +++ b/internal/api/backend/backend.go @@ -32,6 +32,8 @@ type Ledger interface { RevertTransaction(ctx context.Context, parameters command.Parameters, id *big.Int) (*ledger.Transaction, error) SaveMeta(ctx context.Context, parameters command.Parameters, targetType string, targetID any, m metadata.Metadata) error DeleteMetadata(ctx context.Context, parameters command.Parameters, targetType string, targetID any, key string) error + + IsDatabaseUpToDate(ctx context.Context) (bool, error) } type Backend interface { diff --git a/internal/api/backend/backend_generated.go b/internal/api/backend/backend_generated.go index a9f054f30..7744d4ca9 100644 --- a/internal/api/backend/backend_generated.go +++ b/internal/api/backend/backend_generated.go @@ -1,6 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: backend.go - +// +// Generated by this command: +// +// mockgen -source backend.go -destination backend_generated.go -package backend . Ledger +// // Package backend is a generated GoMock package. package backend @@ -9,14 +13,14 @@ import ( big "math/big" reflect "reflect" - internal "github.com/formancehq/ledger/internal" + ledger "github.com/formancehq/ledger/internal" engine "github.com/formancehq/ledger/internal/engine" command "github.com/formancehq/ledger/internal/engine/command" ledgerstore "github.com/formancehq/ledger/internal/storage/ledgerstore" api "github.com/formancehq/stack/libs/go-libs/api" metadata "github.com/formancehq/stack/libs/go-libs/metadata" migrations "github.com/formancehq/stack/libs/go-libs/migrations" - gomock "github.com/golang/mock/gomock" + gomock "go.uber.org/mock/gomock" ) // MockLedger is a mock of Ledger interface. @@ -52,7 +56,7 @@ func (m *MockLedger) CountAccounts(ctx context.Context, query *ledgerstore.GetAc } // CountAccounts indicates an expected call of CountAccounts. -func (mr *MockLedgerMockRecorder) CountAccounts(ctx, query interface{}) *gomock.Call { +func (mr *MockLedgerMockRecorder) CountAccounts(ctx, query any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAccounts", reflect.TypeOf((*MockLedger)(nil).CountAccounts), ctx, query) } @@ -67,22 +71,22 @@ func (m *MockLedger) CountTransactions(ctx context.Context, query *ledgerstore.G } // CountTransactions indicates an expected call of CountTransactions. -func (mr *MockLedgerMockRecorder) CountTransactions(ctx, query interface{}) *gomock.Call { +func (mr *MockLedgerMockRecorder) CountTransactions(ctx, query any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountTransactions", reflect.TypeOf((*MockLedger)(nil).CountTransactions), ctx, query) } // CreateTransaction mocks base method. -func (m *MockLedger) CreateTransaction(ctx context.Context, parameters command.Parameters, data internal.RunScript) (*internal.Transaction, error) { +func (m *MockLedger) CreateTransaction(ctx context.Context, parameters command.Parameters, data ledger.RunScript) (*ledger.Transaction, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CreateTransaction", ctx, parameters, data) - ret0, _ := ret[0].(*internal.Transaction) + ret0, _ := ret[0].(*ledger.Transaction) ret1, _ := ret[1].(error) return ret0, ret1 } // CreateTransaction indicates an expected call of CreateTransaction. -func (mr *MockLedgerMockRecorder) CreateTransaction(ctx, parameters, data interface{}) *gomock.Call { +func (mr *MockLedgerMockRecorder) CreateTransaction(ctx, parameters, data any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateTransaction", reflect.TypeOf((*MockLedger)(nil).CreateTransaction), ctx, parameters, data) } @@ -96,67 +100,67 @@ func (m *MockLedger) DeleteMetadata(ctx context.Context, parameters command.Para } // DeleteMetadata indicates an expected call of DeleteMetadata. -func (mr *MockLedgerMockRecorder) DeleteMetadata(ctx, parameters, targetType, targetID, key interface{}) *gomock.Call { +func (mr *MockLedgerMockRecorder) DeleteMetadata(ctx, parameters, targetType, targetID, key any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteMetadata", reflect.TypeOf((*MockLedger)(nil).DeleteMetadata), ctx, parameters, targetType, targetID, key) } // GetAccountWithVolumes mocks base method. -func (m *MockLedger) GetAccountWithVolumes(ctx context.Context, query ledgerstore.GetAccountQuery) (*internal.ExpandedAccount, error) { +func (m *MockLedger) GetAccountWithVolumes(ctx context.Context, query ledgerstore.GetAccountQuery) (*ledger.ExpandedAccount, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetAccountWithVolumes", ctx, query) - ret0, _ := ret[0].(*internal.ExpandedAccount) + ret0, _ := ret[0].(*ledger.ExpandedAccount) ret1, _ := ret[1].(error) return ret0, ret1 } // GetAccountWithVolumes indicates an expected call of GetAccountWithVolumes. -func (mr *MockLedgerMockRecorder) GetAccountWithVolumes(ctx, query interface{}) *gomock.Call { +func (mr *MockLedgerMockRecorder) GetAccountWithVolumes(ctx, query any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountWithVolumes", reflect.TypeOf((*MockLedger)(nil).GetAccountWithVolumes), ctx, query) } // GetAccountsWithVolumes mocks base method. -func (m *MockLedger) GetAccountsWithVolumes(ctx context.Context, query *ledgerstore.GetAccountsQuery) (*api.Cursor[internal.ExpandedAccount], error) { +func (m *MockLedger) GetAccountsWithVolumes(ctx context.Context, query *ledgerstore.GetAccountsQuery) (*api.Cursor[ledger.ExpandedAccount], error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetAccountsWithVolumes", ctx, query) - ret0, _ := ret[0].(*api.Cursor[internal.ExpandedAccount]) + ret0, _ := ret[0].(*api.Cursor[ledger.ExpandedAccount]) ret1, _ := ret[1].(error) return ret0, ret1 } // GetAccountsWithVolumes indicates an expected call of GetAccountsWithVolumes. -func (mr *MockLedgerMockRecorder) GetAccountsWithVolumes(ctx, query interface{}) *gomock.Call { +func (mr *MockLedgerMockRecorder) GetAccountsWithVolumes(ctx, query any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountsWithVolumes", reflect.TypeOf((*MockLedger)(nil).GetAccountsWithVolumes), ctx, query) } // GetAggregatedBalances mocks base method. -func (m *MockLedger) GetAggregatedBalances(ctx context.Context, q *ledgerstore.GetAggregatedBalanceQuery) (internal.BalancesByAssets, error) { +func (m *MockLedger) GetAggregatedBalances(ctx context.Context, q *ledgerstore.GetAggregatedBalanceQuery) (ledger.BalancesByAssets, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetAggregatedBalances", ctx, q) - ret0, _ := ret[0].(internal.BalancesByAssets) + ret0, _ := ret[0].(ledger.BalancesByAssets) ret1, _ := ret[1].(error) return ret0, ret1 } // GetAggregatedBalances indicates an expected call of GetAggregatedBalances. -func (mr *MockLedgerMockRecorder) GetAggregatedBalances(ctx, q interface{}) *gomock.Call { +func (mr *MockLedgerMockRecorder) GetAggregatedBalances(ctx, q any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAggregatedBalances", reflect.TypeOf((*MockLedger)(nil).GetAggregatedBalances), ctx, q) } // GetLogs mocks base method. -func (m *MockLedger) GetLogs(ctx context.Context, query *ledgerstore.GetLogsQuery) (*api.Cursor[internal.ChainedLog], error) { +func (m *MockLedger) GetLogs(ctx context.Context, query *ledgerstore.GetLogsQuery) (*api.Cursor[ledger.ChainedLog], error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetLogs", ctx, query) - ret0, _ := ret[0].(*api.Cursor[internal.ChainedLog]) + ret0, _ := ret[0].(*api.Cursor[ledger.ChainedLog]) ret1, _ := ret[1].(error) return ret0, ret1 } // GetLogs indicates an expected call of GetLogs. -func (mr *MockLedgerMockRecorder) GetLogs(ctx, query interface{}) *gomock.Call { +func (mr *MockLedgerMockRecorder) GetLogs(ctx, query any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLogs", reflect.TypeOf((*MockLedger)(nil).GetLogs), ctx, query) } @@ -171,52 +175,67 @@ func (m *MockLedger) GetMigrationsInfo(ctx context.Context) ([]migrations.Info, } // GetMigrationsInfo indicates an expected call of GetMigrationsInfo. -func (mr *MockLedgerMockRecorder) GetMigrationsInfo(ctx interface{}) *gomock.Call { +func (mr *MockLedgerMockRecorder) GetMigrationsInfo(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMigrationsInfo", reflect.TypeOf((*MockLedger)(nil).GetMigrationsInfo), ctx) } // GetTransactionWithVolumes mocks base method. -func (m *MockLedger) GetTransactionWithVolumes(ctx context.Context, query ledgerstore.GetTransactionQuery) (*internal.ExpandedTransaction, error) { +func (m *MockLedger) GetTransactionWithVolumes(ctx context.Context, query ledgerstore.GetTransactionQuery) (*ledger.ExpandedTransaction, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetTransactionWithVolumes", ctx, query) - ret0, _ := ret[0].(*internal.ExpandedTransaction) + ret0, _ := ret[0].(*ledger.ExpandedTransaction) ret1, _ := ret[1].(error) return ret0, ret1 } // GetTransactionWithVolumes indicates an expected call of GetTransactionWithVolumes. -func (mr *MockLedgerMockRecorder) GetTransactionWithVolumes(ctx, query interface{}) *gomock.Call { +func (mr *MockLedgerMockRecorder) GetTransactionWithVolumes(ctx, query any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTransactionWithVolumes", reflect.TypeOf((*MockLedger)(nil).GetTransactionWithVolumes), ctx, query) } // GetTransactions mocks base method. -func (m *MockLedger) GetTransactions(ctx context.Context, query *ledgerstore.GetTransactionsQuery) (*api.Cursor[internal.ExpandedTransaction], error) { +func (m *MockLedger) GetTransactions(ctx context.Context, query *ledgerstore.GetTransactionsQuery) (*api.Cursor[ledger.ExpandedTransaction], error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetTransactions", ctx, query) - ret0, _ := ret[0].(*api.Cursor[internal.ExpandedTransaction]) + ret0, _ := ret[0].(*api.Cursor[ledger.ExpandedTransaction]) ret1, _ := ret[1].(error) return ret0, ret1 } // GetTransactions indicates an expected call of GetTransactions. -func (mr *MockLedgerMockRecorder) GetTransactions(ctx, query interface{}) *gomock.Call { +func (mr *MockLedgerMockRecorder) GetTransactions(ctx, query any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTransactions", reflect.TypeOf((*MockLedger)(nil).GetTransactions), ctx, query) } +// IsDatabaseUpToDate mocks base method. +func (m *MockLedger) IsDatabaseUpToDate(ctx context.Context) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsDatabaseUpToDate", ctx) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// IsDatabaseUpToDate indicates an expected call of IsDatabaseUpToDate. +func (mr *MockLedgerMockRecorder) IsDatabaseUpToDate(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsDatabaseUpToDate", reflect.TypeOf((*MockLedger)(nil).IsDatabaseUpToDate), ctx) +} + // RevertTransaction mocks base method. -func (m *MockLedger) RevertTransaction(ctx context.Context, parameters command.Parameters, id *big.Int) (*internal.Transaction, error) { +func (m *MockLedger) RevertTransaction(ctx context.Context, parameters command.Parameters, id *big.Int) (*ledger.Transaction, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RevertTransaction", ctx, parameters, id) - ret0, _ := ret[0].(*internal.Transaction) + ret0, _ := ret[0].(*ledger.Transaction) ret1, _ := ret[1].(error) return ret0, ret1 } // RevertTransaction indicates an expected call of RevertTransaction. -func (mr *MockLedgerMockRecorder) RevertTransaction(ctx, parameters, id interface{}) *gomock.Call { +func (mr *MockLedgerMockRecorder) RevertTransaction(ctx, parameters, id any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RevertTransaction", reflect.TypeOf((*MockLedger)(nil).RevertTransaction), ctx, parameters, id) } @@ -230,7 +249,7 @@ func (m_2 *MockLedger) SaveMeta(ctx context.Context, parameters command.Paramete } // SaveMeta indicates an expected call of SaveMeta. -func (mr *MockLedgerMockRecorder) SaveMeta(ctx, parameters, targetType, targetID, m interface{}) *gomock.Call { +func (mr *MockLedgerMockRecorder) SaveMeta(ctx, parameters, targetType, targetID, m any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveMeta", reflect.TypeOf((*MockLedger)(nil).SaveMeta), ctx, parameters, targetType, targetID, m) } @@ -245,7 +264,7 @@ func (m *MockLedger) Stats(ctx context.Context) (engine.Stats, error) { } // Stats indicates an expected call of Stats. -func (mr *MockLedgerMockRecorder) Stats(ctx interface{}) *gomock.Call { +func (mr *MockLedgerMockRecorder) Stats(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stats", reflect.TypeOf((*MockLedger)(nil).Stats), ctx) } @@ -283,7 +302,7 @@ func (m *MockBackend) GetLedger(ctx context.Context, name string) (Ledger, error } // GetLedger indicates an expected call of GetLedger. -func (mr *MockBackendMockRecorder) GetLedger(ctx, name interface{}) *gomock.Call { +func (mr *MockBackendMockRecorder) GetLedger(ctx, name any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLedger", reflect.TypeOf((*MockBackend)(nil).GetLedger), ctx, name) } @@ -312,7 +331,7 @@ func (m *MockBackend) ListLedgers(ctx context.Context) ([]string, error) { } // ListLedgers indicates an expected call of ListLedgers. -func (mr *MockBackendMockRecorder) ListLedgers(ctx interface{}) *gomock.Call { +func (mr *MockBackendMockRecorder) ListLedgers(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListLedgers", reflect.TypeOf((*MockBackend)(nil).ListLedgers), ctx) } diff --git a/internal/api/module.go b/internal/api/module.go index 024f1ba0f..484a035c8 100644 --- a/internal/api/module.go +++ b/internal/api/module.go @@ -3,6 +3,8 @@ package api import ( _ "embed" + "github.com/go-chi/chi/v5" + "github.com/formancehq/ledger/internal/api/backend" "github.com/formancehq/ledger/internal/engine" "github.com/formancehq/ledger/internal/opentelemetry/metrics" @@ -14,12 +16,18 @@ import ( ) type Config struct { - Version string + Version string + ReadOnly bool } func Module(cfg Config) fx.Option { return fx.Options( - fx.Provide(NewRouter), + fx.Provide(func( + backend backend.Backend, + healthController *health.HealthController, + globalMetricsRegistry metrics.GlobalRegistry) chi.Router { + return NewRouter(backend, healthController, globalMetricsRegistry, cfg.ReadOnly) + }), fx.Provide(func(storageDriver *driver.Driver, resolver *engine.Resolver) backend.Backend { return backend.NewDefaultBackend(storageDriver, cfg.Version, resolver) }), diff --git a/internal/api/read_only.go b/internal/api/read_only.go new file mode 100644 index 000000000..d2e7ee458 --- /dev/null +++ b/internal/api/read_only.go @@ -0,0 +1,18 @@ +package api + +import ( + "net/http" + + "github.com/formancehq/stack/libs/go-libs/api" + "github.com/pkg/errors" +) + +func ReadOnly(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet && r.Method != http.MethodOptions && r.Method != http.MethodHead { + api.BadRequest(w, "READ_ONLY", errors.New("Read only mode")) + return + } + h.ServeHTTP(w, r) + }) +} diff --git a/internal/api/router.go b/internal/api/router.go index 9b1d22725..110fcbed8 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -15,8 +15,12 @@ func NewRouter( backend backend.Backend, healthController *health.HealthController, globalMetricsRegistry metrics.GlobalRegistry, + readOnly bool, ) chi.Router { mux := chi.NewRouter() + if readOnly { + mux.Use(ReadOnly) + } v2Router := v2.NewRouter(backend, healthController, globalMetricsRegistry) mux.Handle("/v2/*", http.StripPrefix("/v2", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { chi.RouteContext(r.Context()).Reset() diff --git a/internal/api/v1/context.go b/internal/api/shared/context.go similarity index 96% rename from internal/api/v1/context.go rename to internal/api/shared/context.go index 3b8b76937..a4a99e744 100644 --- a/internal/api/v1/context.go +++ b/internal/api/shared/context.go @@ -1,4 +1,4 @@ -package v1 +package shared import ( "context" diff --git a/internal/api/v2/errors.go b/internal/api/shared/errors.go similarity index 99% rename from internal/api/v2/errors.go rename to internal/api/shared/errors.go index 888749900..50b8be75f 100644 --- a/internal/api/v2/errors.go +++ b/internal/api/shared/errors.go @@ -1,4 +1,4 @@ -package v2 +package shared import ( "context" diff --git a/internal/api/v1/middlewares_resolver.go b/internal/api/shared/resolver.go similarity index 72% rename from internal/api/v1/middlewares_resolver.go rename to internal/api/shared/resolver.go index c0e986fb0..a0a8082d5 100644 --- a/internal/api/v1/middlewares_resolver.go +++ b/internal/api/shared/resolver.go @@ -1,11 +1,14 @@ -package v1 +package shared import ( "math/rand" "net/http" + "strings" "sync" "time" + "github.com/pkg/errors" + "github.com/formancehq/ledger/internal/api/backend" "github.com/formancehq/ledger/internal/opentelemetry/tracer" "github.com/formancehq/stack/libs/go-libs/logging" @@ -36,6 +39,7 @@ func randomTraceID(n int) string { func LedgerMiddleware( resolver backend.Backend, + excludePathFromSchemaCheck []string, ) func(handler http.Handler) http.Handler { return func(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -66,14 +70,28 @@ func LedgerMiddleware( ResponseError(w, r, err) return } - // TODO(polo/gfyrag): close ledger if not used for x minutes - // defer l.Close(context.Background()) - // When close, we have to decrease the active ledgers counter: - // globalMetricsRegistry.ActiveLedgers.Add(r.Context(), -1) - r = r.WithContext(ContextWithLedger(r.Context(), l)) + excluded := false + for _, path := range excludePathFromSchemaCheck { + if strings.HasSuffix(r.URL.Path, path) { + excluded = true + break + } + } + + if !excluded { + isUpToDate, err := l.IsDatabaseUpToDate(ctx) + if err != nil { + ResponseError(w, r, err) + return + } + if !isUpToDate { + ResponseError(w, r, errors.New("outdated schema")) + return + } + } - handler.ServeHTTP(w, r) + handler.ServeHTTP(w, r.WithContext(ContextWithLedger(r.Context(), l))) }) } } diff --git a/internal/api/v1/api_utils_test.go b/internal/api/v1/api_utils_test.go index bf8918f0e..77896d3c6 100644 --- a/internal/api/v1/api_utils_test.go +++ b/internal/api/v1/api_utils_test.go @@ -4,10 +4,10 @@ import ( "testing" "github.com/formancehq/ledger/internal/api/backend" - "github.com/golang/mock/gomock" + "go.uber.org/mock/gomock" ) -func newTestingBackend(t *testing.T) (*backend.MockBackend, *backend.MockLedger) { +func newTestingBackend(t *testing.T, expectedSchemaCheck bool) (*backend.MockBackend, *backend.MockLedger) { ctrl := gomock.NewController(t) mockLedger := backend.NewMockLedger(ctrl) backend := backend.NewMockBackend(ctrl) @@ -19,5 +19,10 @@ func newTestingBackend(t *testing.T) (*backend.MockBackend, *backend.MockLedger) t.Cleanup(func() { ctrl.Finish() }) + if expectedSchemaCheck { + mockLedger.EXPECT(). + IsDatabaseUpToDate(gomock.Any()). + Return(true, nil) + } return backend, mockLedger } diff --git a/internal/api/v1/controllers_accounts.go b/internal/api/v1/controllers_accounts.go index ff789aaf0..38fafcc80 100644 --- a/internal/api/v1/controllers_accounts.go +++ b/internal/api/v1/controllers_accounts.go @@ -7,6 +7,8 @@ import ( "strconv" "strings" + "github.com/formancehq/ledger/internal/api/shared" + ledger "github.com/formancehq/ledger/internal" "github.com/formancehq/ledger/internal/engine/command" "github.com/formancehq/ledger/internal/storage/ledgerstore" @@ -69,7 +71,7 @@ func buildAccountsFilterQuery(r *http.Request) (query.Builder, error) { } func countAccounts(w http.ResponseWriter, r *http.Request) { - l := LedgerFromContext(r.Context()) + l := shared.LedgerFromContext(r.Context()) options, err := getPaginatedQueryOptionsOfPITFilterWithVolumes(r) if err != nil { @@ -88,7 +90,7 @@ func countAccounts(w http.ResponseWriter, r *http.Request) { } func getAccounts(w http.ResponseWriter, r *http.Request) { - l := LedgerFromContext(r.Context()) + l := shared.LedgerFromContext(r.Context()) q := &ledgerstore.GetAccountsQuery{} @@ -119,7 +121,7 @@ func getAccounts(w http.ResponseWriter, r *http.Request) { } func getAccount(w http.ResponseWriter, r *http.Request) { - l := LedgerFromContext(r.Context()) + l := shared.LedgerFromContext(r.Context()) query := ledgerstore.NewGetAccountQuery(chi.URLParam(r, "address")) if collectionutils.Contains(r.URL.Query()["expand"], "volumes") { @@ -139,7 +141,7 @@ func getAccount(w http.ResponseWriter, r *http.Request) { } func postAccountMetadata(w http.ResponseWriter, r *http.Request) { - l := LedgerFromContext(r.Context()) + l := shared.LedgerFromContext(r.Context()) if !ledger.ValidateAddress(chi.URLParam(r, "address")) { ResponseError(w, r, errorsutil.NewError(command.ErrValidation, @@ -164,7 +166,7 @@ func postAccountMetadata(w http.ResponseWriter, r *http.Request) { } func deleteAccountMetadata(w http.ResponseWriter, r *http.Request) { - if err := LedgerFromContext(r.Context()). + if err := shared.LedgerFromContext(r.Context()). DeleteMetadata( r.Context(), getCommandParameters(r), diff --git a/internal/api/v1/controllers_accounts_test.go b/internal/api/v1/controllers_accounts_test.go index a234af769..a6ed93cdf 100644 --- a/internal/api/v1/controllers_accounts_test.go +++ b/internal/api/v1/controllers_accounts_test.go @@ -115,7 +115,7 @@ func TestGetAccounts(t *testing.T) { }, } - backend, mockLedger := newTestingBackend(t) + backend, mockLedger := newTestingBackend(t, true) if testCase.expectStatusCode < 300 && testCase.expectStatusCode >= 200 { mockLedger.EXPECT(). GetAccountsWithVolumes(gomock.Any(), ledgerstore.NewGetAccountsQuery(testCase.expectQuery)). @@ -153,7 +153,7 @@ func TestGetAccount(t *testing.T) { }, } - backend, mock := newTestingBackend(t) + backend, mock := newTestingBackend(t, true) mock.EXPECT(). GetAccountWithVolumes(gomock.Any(), ledgerstore.NewGetAccountQuery("foo")). Return(&account, nil) @@ -212,7 +212,7 @@ func TestPostAccountMetadata(t *testing.T) { testCase.expectStatusCode = http.StatusNoContent } - backend, mock := newTestingBackend(t) + backend, mock := newTestingBackend(t, true) if testCase.expectStatusCode == http.StatusNoContent { mock.EXPECT(). SaveMeta(gomock.Any(), command.Parameters{}, ledger.MetaTargetTypeAccount, testCase.account, testCase.body). diff --git a/internal/api/v1/controllers_balances.go b/internal/api/v1/controllers_balances.go index bbb0f7c34..5d5bea8cb 100644 --- a/internal/api/v1/controllers_balances.go +++ b/internal/api/v1/controllers_balances.go @@ -4,6 +4,8 @@ import ( "math/big" "net/http" + "github.com/formancehq/ledger/internal/api/shared" + "github.com/formancehq/ledger/internal/engine/command" "github.com/formancehq/ledger/internal/storage/ledgerstore" "github.com/formancehq/ledger/internal/storage/paginate" @@ -31,7 +33,7 @@ func getBalancesAggregated(w http.ResponseWriter, r *http.Request) { query := ledgerstore.NewGetAggregatedBalancesQuery(*options) query.Options.QueryBuilder, err = buildAggregatedBalancesQuery(r) - balances, err := LedgerFromContext(r.Context()).GetAggregatedBalances(r.Context(), query) + balances, err := shared.LedgerFromContext(r.Context()).GetAggregatedBalances(r.Context(), query) if err != nil { ResponseError(w, r, err) return @@ -41,7 +43,7 @@ func getBalancesAggregated(w http.ResponseWriter, r *http.Request) { } func getBalances(w http.ResponseWriter, r *http.Request) { - l := LedgerFromContext(r.Context()) + l := shared.LedgerFromContext(r.Context()) q := &ledgerstore.GetAccountsQuery{} diff --git a/internal/api/v1/controllers_balances_test.go b/internal/api/v1/controllers_balances_test.go index 97a9205bb..bab78c5a7 100644 --- a/internal/api/v1/controllers_balances_test.go +++ b/internal/api/v1/controllers_balances_test.go @@ -47,7 +47,7 @@ func TestGetBalancesAggregated(t *testing.T) { expectedBalances := ledger.BalancesByAssets{ "world": big.NewInt(-100), } - backend, mock := newTestingBackend(t) + backend, mock := newTestingBackend(t, true) mock.EXPECT(). GetAggregatedBalances(gomock.Any(), ledgerstore.NewGetAggregatedBalancesQuery(testCase.expectQuery)). Return(expectedBalances, nil) diff --git a/internal/api/v1/controllers_config_test.go b/internal/api/v1/controllers_config_test.go index 1b7fe085c..698bc1167 100644 --- a/internal/api/v1/controllers_config_test.go +++ b/internal/api/v1/controllers_config_test.go @@ -15,7 +15,7 @@ import ( func TestGetInfo(t *testing.T) { t.Parallel() - backend, _ := newTestingBackend(t) + backend, _ := newTestingBackend(t, false) router := v2.NewRouter(backend, nil, metrics.NewNoOpRegistry()) backend. diff --git a/internal/api/v1/controllers_info.go b/internal/api/v1/controllers_info.go index 2220eb46a..434da171a 100644 --- a/internal/api/v1/controllers_info.go +++ b/internal/api/v1/controllers_info.go @@ -3,6 +3,8 @@ package v1 import ( "net/http" + "github.com/formancehq/ledger/internal/api/shared" + "github.com/formancehq/ledger/internal/engine/command" "github.com/formancehq/ledger/internal/storage/ledgerstore" "github.com/formancehq/ledger/internal/storage/paginate" @@ -24,7 +26,7 @@ type StorageInfo struct { } func getLedgerInfo(w http.ResponseWriter, r *http.Request) { - ledger := LedgerFromContext(r.Context()) + ledger := shared.LedgerFromContext(r.Context()) var err error res := Info{ @@ -41,7 +43,7 @@ func getLedgerInfo(w http.ResponseWriter, r *http.Request) { } func getStats(w http.ResponseWriter, r *http.Request) { - l := LedgerFromContext(r.Context()) + l := shared.LedgerFromContext(r.Context()) stats, err := l.Stats(r.Context()) if err != nil { @@ -76,7 +78,7 @@ func buildGetLogsQuery(r *http.Request) (query.Builder, error) { } func getLogs(w http.ResponseWriter, r *http.Request) { - l := LedgerFromContext(r.Context()) + l := shared.LedgerFromContext(r.Context()) query := &ledgerstore.GetLogsQuery{} diff --git a/internal/api/v1/controllers_info_test.go b/internal/api/v1/controllers_info_test.go index af94bd844..ffd7f8b6c 100644 --- a/internal/api/v1/controllers_info_test.go +++ b/internal/api/v1/controllers_info_test.go @@ -25,7 +25,7 @@ import ( func TestGetLedgerInfo(t *testing.T) { t.Parallel() - backend, mock := newTestingBackend(t) + backend, mock := newTestingBackend(t, false) router := v1.NewRouter(backend, nil, metrics.NewNoOpRegistry()) migrationInfo := []migrations.Info{ @@ -33,13 +33,13 @@ func TestGetLedgerInfo(t *testing.T) { Version: "1", Name: "init", State: "ready", - Date: time.Now().Add(-2 * time.Minute).Round(time.Second), + Date: time.Now().Add(-2 * time.Minute).Round(time.Second).UTC(), }, { Version: "2", Name: "fix", State: "ready", - Date: time.Now().Add(-time.Minute).Round(time.Second), + Date: time.Now().Add(-time.Minute).Round(time.Second).UTC(), }, } @@ -68,7 +68,7 @@ func TestGetLedgerInfo(t *testing.T) { func TestGetStats(t *testing.T) { t.Parallel() - backend, mock := newTestingBackend(t) + backend, mock := newTestingBackend(t, true) router := v1.NewRouter(backend, nil, metrics.NewNoOpRegistry()) expectedStats := engine.Stats{ @@ -156,7 +156,7 @@ func TestGetLogs(t *testing.T) { }, } - backend, mockLedger := newTestingBackend(t) + backend, mockLedger := newTestingBackend(t, true) if testCase.expectStatusCode < 300 && testCase.expectStatusCode >= 200 { mockLedger.EXPECT(). GetLogs(gomock.Any(), ledgerstore.NewGetLogsQuery(testCase.expectQuery)). diff --git a/internal/api/v1/controllers_transactions.go b/internal/api/v1/controllers_transactions.go index ca538d94e..919c60cdb 100644 --- a/internal/api/v1/controllers_transactions.go +++ b/internal/api/v1/controllers_transactions.go @@ -8,6 +8,8 @@ import ( "strconv" "strings" + "github.com/formancehq/ledger/internal/api/shared" + ledger "github.com/formancehq/ledger/internal" "github.com/formancehq/ledger/internal/engine/command" "github.com/formancehq/ledger/internal/storage/ledgerstore" @@ -97,7 +99,7 @@ func countTransactions(w http.ResponseWriter, r *http.Request) { return } - count, err := LedgerFromContext(r.Context()). + count, err := shared.LedgerFromContext(r.Context()). CountTransactions(r.Context(), ledgerstore.NewGetTransactionsQuery(*options)) if err != nil { ResponseError(w, r, err) @@ -109,7 +111,7 @@ func countTransactions(w http.ResponseWriter, r *http.Request) { } func getTransactions(w http.ResponseWriter, r *http.Request) { - l := LedgerFromContext(r.Context()) + l := shared.LedgerFromContext(r.Context()) query := &ledgerstore.GetTransactionsQuery{} @@ -172,7 +174,7 @@ type PostTransactionRequest struct { } func postTransaction(w http.ResponseWriter, r *http.Request) { - l := LedgerFromContext(r.Context()) + l := shared.LedgerFromContext(r.Context()) payload := PostTransactionRequest{} if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { @@ -227,7 +229,7 @@ func postTransaction(w http.ResponseWriter, r *http.Request) { } func getTransaction(w http.ResponseWriter, r *http.Request) { - l := LedgerFromContext(r.Context()) + l := shared.LedgerFromContext(r.Context()) txId, ok := big.NewInt(0).SetString(chi.URLParam(r, "id"), 10) if !ok { @@ -254,7 +256,7 @@ func getTransaction(w http.ResponseWriter, r *http.Request) { } func revertTransaction(w http.ResponseWriter, r *http.Request) { - l := LedgerFromContext(r.Context()) + l := shared.LedgerFromContext(r.Context()) transactionID, ok := big.NewInt(0).SetString(chi.URLParam(r, "id"), 10) if !ok { @@ -272,7 +274,7 @@ func revertTransaction(w http.ResponseWriter, r *http.Request) { } func postTransactionMetadata(w http.ResponseWriter, r *http.Request) { - l := LedgerFromContext(r.Context()) + l := shared.LedgerFromContext(r.Context()) var m metadata.Metadata if err := json.NewDecoder(r.Body).Decode(&m); err != nil { @@ -296,7 +298,7 @@ func postTransactionMetadata(w http.ResponseWriter, r *http.Request) { } func deleteTransactionMetadata(w http.ResponseWriter, r *http.Request) { - l := LedgerFromContext(r.Context()) + l := shared.LedgerFromContext(r.Context()) transactionID, err := strconv.ParseUint(chi.URLParam(r, "id"), 10, 64) if err != nil { diff --git a/internal/api/v1/controllers_transactions_test.go b/internal/api/v1/controllers_transactions_test.go index 73f1c7d43..153c57195 100644 --- a/internal/api/v1/controllers_transactions_test.go +++ b/internal/api/v1/controllers_transactions_test.go @@ -222,7 +222,7 @@ func TestPostTransactions(t *testing.T) { ledger.NewPosting("world", "bank", "USD", big.NewInt(100)), ) - backend, mockLedger := newTestingBackend(t) + backend, mockLedger := newTestingBackend(t, true) if testCase.expectedStatusCode < 300 && testCase.expectedStatusCode >= 200 { mockLedger.EXPECT(). CreateTransaction(gomock.Any(), command.Parameters{ @@ -286,7 +286,7 @@ func TestPostTransactionMetadata(t *testing.T) { testCase.expectStatusCode = http.StatusNoContent } - backend, mock := newTestingBackend(t) + backend, mock := newTestingBackend(t, true) if testCase.expectStatusCode == http.StatusNoContent { mock.EXPECT(). SaveMeta(gomock.Any(), command.Parameters{}, ledger.MetaTargetTypeTransaction, big.NewInt(0), testCase.body). @@ -321,7 +321,7 @@ func TestGetTransaction(t *testing.T) { nil, ) - backend, mock := newTestingBackend(t) + backend, mock := newTestingBackend(t, true) mock.EXPECT(). GetTransactionWithVolumes(gomock.Any(), ledgerstore.NewGetTransactionQuery(big.NewInt(0))). Return(&tx, nil) @@ -462,7 +462,7 @@ func TestGetTransactions(t *testing.T) { }, } - backend, mockLedger := newTestingBackend(t) + backend, mockLedger := newTestingBackend(t, true) if testCase.expectStatusCode < 300 && testCase.expectStatusCode >= 200 { mockLedger.EXPECT(). GetTransactions(gomock.Any(), ledgerstore.NewGetTransactionsQuery(testCase.expectQuery)). @@ -572,7 +572,7 @@ func TestCountTransactions(t *testing.T) { testCase.expectStatusCode = http.StatusNoContent } - backend, mockLedger := newTestingBackend(t) + backend, mockLedger := newTestingBackend(t, true) if testCase.expectStatusCode < 300 && testCase.expectStatusCode >= 200 { mockLedger.EXPECT(). CountTransactions(gomock.Any(), ledgerstore.NewGetTransactionsQuery(testCase.expectQuery)). @@ -605,7 +605,7 @@ func TestRevertTransaction(t *testing.T) { ledger.NewPosting("world", "bank", "USD", big.NewInt(100)), ) - backend, mockLedger := newTestingBackend(t) + backend, mockLedger := newTestingBackend(t, true) mockLedger. EXPECT(). RevertTransaction(gomock.Any(), command.Parameters{}, big.NewInt(0)). diff --git a/internal/api/v1/routes.go b/internal/api/v1/routes.go index 98ebffa4e..4316a3e9b 100644 --- a/internal/api/v1/routes.go +++ b/internal/api/v1/routes.go @@ -3,6 +3,8 @@ package v1 import ( "net/http" + "github.com/formancehq/ledger/internal/api/shared" + "github.com/formancehq/ledger/internal/api/backend" "github.com/formancehq/ledger/internal/opentelemetry/metrics" "github.com/formancehq/stack/libs/go-libs/health" @@ -12,11 +14,7 @@ import ( "github.com/riandyrn/otelchi" ) -func NewRouter( - backend backend.Backend, - healthController *health.HealthController, - globalMetricsRegistry metrics.GlobalRegistry, -) chi.Router { +func NewRouter(backend backend.Backend, healthController *health.HealthController, globalMetricsRegistry metrics.GlobalRegistry) chi.Router { router := chi.NewMux() router.Use( @@ -42,7 +40,7 @@ func NewRouter( handler.ServeHTTP(w, r) }) }) - router.Use(LedgerMiddleware(backend)) + router.Use(shared.LedgerMiddleware(backend, []string{"/_info"})) // LedgerController router.Get("/_info", getLedgerInfo) diff --git a/internal/api/v2/api_utils_test.go b/internal/api/v2/api_utils_test.go index 8a42c93ac..81e6e6639 100644 --- a/internal/api/v2/api_utils_test.go +++ b/internal/api/v2/api_utils_test.go @@ -3,11 +3,12 @@ package v2_test import ( "testing" + "go.uber.org/mock/gomock" + "github.com/formancehq/ledger/internal/api/backend" - "github.com/golang/mock/gomock" ) -func newTestingBackend(t *testing.T) (*backend.MockBackend, *backend.MockLedger) { +func newTestingBackend(t *testing.T, expectedSchemaCheck bool) (*backend.MockBackend, *backend.MockLedger) { ctrl := gomock.NewController(t) mockLedger := backend.NewMockLedger(ctrl) backend := backend.NewMockBackend(ctrl) @@ -19,5 +20,10 @@ func newTestingBackend(t *testing.T) (*backend.MockBackend, *backend.MockLedger) t.Cleanup(func() { ctrl.Finish() }) + if expectedSchemaCheck { + mockLedger.EXPECT(). + IsDatabaseUpToDate(gomock.Any()). + Return(true, nil) + } return backend, mockLedger } diff --git a/internal/api/v2/context.go b/internal/api/v2/context.go deleted file mode 100644 index a2029ed9b..000000000 --- a/internal/api/v2/context.go +++ /dev/null @@ -1,19 +0,0 @@ -package v2 - -import ( - "context" - - "github.com/formancehq/ledger/internal/api/backend" -) - -type ledgerKey struct{} - -var _ledgerKey = ledgerKey{} - -func ContextWithLedger(ctx context.Context, ledger backend.Ledger) context.Context { - return context.WithValue(ctx, _ledgerKey, ledger) -} - -func LedgerFromContext(ctx context.Context) backend.Ledger { - return ctx.Value(_ledgerKey).(backend.Ledger) -} diff --git a/internal/api/v2/controllers_accounts.go b/internal/api/v2/controllers_accounts.go index fad853191..1f7101459 100644 --- a/internal/api/v2/controllers_accounts.go +++ b/internal/api/v2/controllers_accounts.go @@ -5,6 +5,8 @@ import ( "fmt" "net/http" + "github.com/formancehq/ledger/internal/api/shared" + ledger "github.com/formancehq/ledger/internal" "github.com/formancehq/ledger/internal/engine/command" "github.com/formancehq/ledger/internal/storage/ledgerstore" @@ -18,17 +20,17 @@ import ( ) func countAccounts(w http.ResponseWriter, r *http.Request) { - l := LedgerFromContext(r.Context()) + l := shared.LedgerFromContext(r.Context()) options, err := getPaginatedQueryOptionsOfPITFilterWithVolumes(r) if err != nil { - sharedapi.BadRequest(w, ErrValidation, err) + sharedapi.BadRequest(w, shared.ErrValidation, err) return } count, err := l.CountAccounts(r.Context(), ledgerstore.NewGetAccountsQuery(*options)) if err != nil { - ResponseError(w, r, err) + shared.ResponseError(w, r, err) return } @@ -37,21 +39,21 @@ func countAccounts(w http.ResponseWriter, r *http.Request) { } func getAccounts(w http.ResponseWriter, r *http.Request) { - l := LedgerFromContext(r.Context()) + l := shared.LedgerFromContext(r.Context()) query := &ledgerstore.GetAccountsQuery{} if r.URL.Query().Get(QueryKeyCursor) != "" { err := paginate.UnmarshalCursor(r.URL.Query().Get(QueryKeyCursor), query) if err != nil { - ResponseError(w, r, errorsutil.NewError(command.ErrValidation, + shared.ResponseError(w, r, errorsutil.NewError(command.ErrValidation, errors.Errorf("invalid '%s' query param", QueryKeyCursor))) return } } else { options, err := getPaginatedQueryOptionsOfPITFilterWithVolumes(r) if err != nil { - sharedapi.BadRequest(w, ErrValidation, err) + sharedapi.BadRequest(w, shared.ErrValidation, err) return } query = ledgerstore.NewGetAccountsQuery(*options) @@ -59,7 +61,7 @@ func getAccounts(w http.ResponseWriter, r *http.Request) { cursor, err := l.GetAccountsWithVolumes(r.Context(), query) if err != nil { - ResponseError(w, r, err) + shared.ResponseError(w, r, err) return } @@ -67,7 +69,7 @@ func getAccounts(w http.ResponseWriter, r *http.Request) { } func getAccount(w http.ResponseWriter, r *http.Request) { - l := LedgerFromContext(r.Context()) + l := shared.LedgerFromContext(r.Context()) query := ledgerstore.NewGetAccountQuery(chi.URLParam(r, "address")) if collectionutils.Contains(r.URL.Query()["expand"], "volumes") { @@ -78,14 +80,14 @@ func getAccount(w http.ResponseWriter, r *http.Request) { } pitFilter, err := getPITFilter(r) if err != nil { - sharedapi.BadRequest(w, ErrValidation, err) + sharedapi.BadRequest(w, shared.ErrValidation, err) return } query.PITFilter = *pitFilter acc, err := l.GetAccountWithVolumes(r.Context(), query) if err != nil { - ResponseError(w, r, err) + shared.ResponseError(w, r, err) return } @@ -93,24 +95,24 @@ func getAccount(w http.ResponseWriter, r *http.Request) { } func postAccountMetadata(w http.ResponseWriter, r *http.Request) { - l := LedgerFromContext(r.Context()) + l := shared.LedgerFromContext(r.Context()) if !ledger.ValidateAddress(chi.URLParam(r, "address")) { - ResponseError(w, r, errorsutil.NewError(command.ErrValidation, + shared.ResponseError(w, r, errorsutil.NewError(command.ErrValidation, errors.New("invalid account address format"))) return } var m metadata.Metadata if err := json.NewDecoder(r.Body).Decode(&m); err != nil { - ResponseError(w, r, errorsutil.NewError(command.ErrValidation, + shared.ResponseError(w, r, errorsutil.NewError(command.ErrValidation, errors.New("invalid metadata format"))) return } err := l.SaveMeta(r.Context(), getCommandParameters(r), ledger.MetaTargetTypeAccount, chi.URLParam(r, "address"), m) if err != nil { - ResponseError(w, r, err) + shared.ResponseError(w, r, err) return } @@ -118,7 +120,7 @@ func postAccountMetadata(w http.ResponseWriter, r *http.Request) { } func deleteAccountMetadata(w http.ResponseWriter, r *http.Request) { - if err := LedgerFromContext(r.Context()). + if err := shared.LedgerFromContext(r.Context()). DeleteMetadata( r.Context(), getCommandParameters(r), @@ -126,7 +128,7 @@ func deleteAccountMetadata(w http.ResponseWriter, r *http.Request) { chi.URLParam(r, "address"), chi.URLParam(r, "key"), ); err != nil { - ResponseError(w, r, err) + shared.ResponseError(w, r, err) return } diff --git a/internal/api/v2/controllers_accounts_test.go b/internal/api/v2/controllers_accounts_test.go index bc51bab78..ba294ab0d 100644 --- a/internal/api/v2/controllers_accounts_test.go +++ b/internal/api/v2/controllers_accounts_test.go @@ -7,6 +7,8 @@ import ( "net/url" "testing" + "github.com/formancehq/ledger/internal/api/shared" + ledger "github.com/formancehq/ledger/internal" v2 "github.com/formancehq/ledger/internal/api/v2" "github.com/formancehq/ledger/internal/engine/command" @@ -65,7 +67,7 @@ func TestGetAccounts(t *testing.T) { "cursor": []string{"XXX"}, }, expectStatusCode: http.StatusBadRequest, - expectedErrorCode: v2.ErrValidation, + expectedErrorCode: shared.ErrValidation, }, { name: "invalid page size", @@ -73,7 +75,7 @@ func TestGetAccounts(t *testing.T) { "pageSize": []string{"nan"}, }, expectStatusCode: http.StatusBadRequest, - expectedErrorCode: v2.ErrValidation, + expectedErrorCode: shared.ErrValidation, }, { name: "page size over maximum", @@ -110,7 +112,7 @@ func TestGetAccounts(t *testing.T) { }, } - backend, mockLedger := newTestingBackend(t) + backend, mockLedger := newTestingBackend(t, true) if testCase.expectStatusCode < 300 && testCase.expectStatusCode >= 200 { mockLedger.EXPECT(). GetAccountsWithVolumes(gomock.Any(), ledgerstore.NewGetAccountsQuery(testCase.expectQuery)). @@ -150,7 +152,7 @@ func TestGetAccount(t *testing.T) { }, } - backend, mock := newTestingBackend(t) + backend, mock := newTestingBackend(t, true) mock.EXPECT(). GetAccountWithVolumes(gomock.Any(), ledgerstore.NewGetAccountQuery("foo")). Return(&account, nil) @@ -191,14 +193,14 @@ func TestPostAccountMetadata(t *testing.T) { name: "invalid account address format", account: "invalid-acc", expectStatusCode: http.StatusBadRequest, - expectedErrorCode: v2.ErrValidation, + expectedErrorCode: shared.ErrValidation, }, { name: "invalid body", account: "world", body: "invalid - not an object", expectStatusCode: http.StatusBadRequest, - expectedErrorCode: v2.ErrValidation, + expectedErrorCode: shared.ErrValidation, }, } for _, testCase := range testCases { @@ -209,7 +211,7 @@ func TestPostAccountMetadata(t *testing.T) { testCase.expectStatusCode = http.StatusNoContent } - backend, mock := newTestingBackend(t) + backend, mock := newTestingBackend(t, true) if testCase.expectStatusCode == http.StatusNoContent { mock.EXPECT(). SaveMeta(gomock.Any(), command.Parameters{}, ledger.MetaTargetTypeAccount, testCase.account, testCase.body). diff --git a/internal/api/v2/controllers_balances.go b/internal/api/v2/controllers_balances.go index a66c35de8..d2063667e 100644 --- a/internal/api/v2/controllers_balances.go +++ b/internal/api/v2/controllers_balances.go @@ -3,6 +3,8 @@ package v2 import ( "net/http" + "github.com/formancehq/ledger/internal/api/shared" + "github.com/formancehq/ledger/internal/storage/ledgerstore" sharedapi "github.com/formancehq/stack/libs/go-libs/api" ) @@ -10,14 +12,14 @@ import ( func getBalancesAggregated(w http.ResponseWriter, r *http.Request) { options, err := getPaginatedQueryOptionsOfPITFilter(r) if err != nil { - sharedapi.BadRequest(w, ErrValidation, err) + sharedapi.BadRequest(w, shared.ErrValidation, err) return } - balances, err := LedgerFromContext(r.Context()). + balances, err := shared.LedgerFromContext(r.Context()). GetAggregatedBalances(r.Context(), ledgerstore.NewGetAggregatedBalancesQuery(*options)) if err != nil { - ResponseError(w, r, err) + shared.ResponseError(w, r, err) return } diff --git a/internal/api/v2/controllers_balances_test.go b/internal/api/v2/controllers_balances_test.go index 8842c91a4..3e5620a9d 100644 --- a/internal/api/v2/controllers_balances_test.go +++ b/internal/api/v2/controllers_balances_test.go @@ -47,7 +47,7 @@ func TestGetBalancesAggregated(t *testing.T) { expectedBalances := ledger.BalancesByAssets{ "world": big.NewInt(-100), } - backend, mock := newTestingBackend(t) + backend, mock := newTestingBackend(t, true) mock.EXPECT(). GetAggregatedBalances(gomock.Any(), ledgerstore.NewGetAggregatedBalancesQuery(testCase.expectQuery)). Return(expectedBalances, nil) diff --git a/internal/api/v2/controllers_config_test.go b/internal/api/v2/controllers_config_test.go index cf87e6435..c2fe56c3b 100644 --- a/internal/api/v2/controllers_config_test.go +++ b/internal/api/v2/controllers_config_test.go @@ -15,7 +15,7 @@ import ( func TestGetInfo(t *testing.T) { t.Parallel() - backend, _ := newTestingBackend(t) + backend, _ := newTestingBackend(t, false) router := v2.NewRouter(backend, nil, metrics.NewNoOpRegistry()) backend. diff --git a/internal/api/v2/controllers_info.go b/internal/api/v2/controllers_info.go index 18e17b381..d5801e0fd 100644 --- a/internal/api/v2/controllers_info.go +++ b/internal/api/v2/controllers_info.go @@ -3,6 +3,8 @@ package v2 import ( "net/http" + "github.com/formancehq/ledger/internal/api/shared" + "github.com/formancehq/ledger/internal/engine/command" "github.com/formancehq/ledger/internal/storage/ledgerstore" "github.com/formancehq/ledger/internal/storage/paginate" @@ -23,7 +25,7 @@ type StorageInfo struct { } func getLedgerInfo(w http.ResponseWriter, r *http.Request) { - ledger := LedgerFromContext(r.Context()) + ledger := shared.LedgerFromContext(r.Context()) var err error res := Info{ @@ -32,7 +34,7 @@ func getLedgerInfo(w http.ResponseWriter, r *http.Request) { } res.Storage.Migrations, err = ledger.GetMigrationsInfo(r.Context()) if err != nil { - ResponseError(w, r, err) + shared.ResponseError(w, r, err) return } @@ -40,11 +42,11 @@ func getLedgerInfo(w http.ResponseWriter, r *http.Request) { } func getStats(w http.ResponseWriter, r *http.Request) { - l := LedgerFromContext(r.Context()) + l := shared.LedgerFromContext(r.Context()) stats, err := l.Stats(r.Context()) if err != nil { - ResponseError(w, r, err) + shared.ResponseError(w, r, err) return } @@ -52,14 +54,14 @@ func getStats(w http.ResponseWriter, r *http.Request) { } func getLogs(w http.ResponseWriter, r *http.Request) { - l := LedgerFromContext(r.Context()) + l := shared.LedgerFromContext(r.Context()) query := &ledgerstore.GetLogsQuery{} if r.URL.Query().Get(QueryKeyCursor) != "" { err := paginate.UnmarshalCursor(r.URL.Query().Get(QueryKeyCursor), query) if err != nil { - ResponseError(w, r, errorsutil.NewError(command.ErrValidation, + shared.ResponseError(w, r, errorsutil.NewError(command.ErrValidation, errors.Errorf("invalid '%s' query param", QueryKeyCursor))) return } @@ -68,13 +70,13 @@ func getLogs(w http.ResponseWriter, r *http.Request) { pageSize, err := getPageSize(r) if err != nil { - ResponseError(w, r, err) + shared.ResponseError(w, r, err) return } qb, err := getQueryBuilder(r) if err != nil { - sharedapi.BadRequest(w, ErrValidation, err) + sharedapi.BadRequest(w, shared.ErrValidation, err) return } @@ -86,7 +88,7 @@ func getLogs(w http.ResponseWriter, r *http.Request) { cursor, err := l.GetLogs(r.Context(), query) if err != nil { - ResponseError(w, r, err) + shared.ResponseError(w, r, err) return } diff --git a/internal/api/v2/controllers_info_test.go b/internal/api/v2/controllers_info_test.go index d22bfa177..39c9fdde9 100644 --- a/internal/api/v2/controllers_info_test.go +++ b/internal/api/v2/controllers_info_test.go @@ -10,6 +10,8 @@ import ( "testing" "time" + "github.com/formancehq/ledger/internal/api/shared" + ledger "github.com/formancehq/ledger/internal" v2 "github.com/formancehq/ledger/internal/api/v2" "github.com/formancehq/ledger/internal/engine" @@ -27,7 +29,7 @@ import ( func TestGetLedgerInfo(t *testing.T) { t.Parallel() - backend, mock := newTestingBackend(t) + backend, mock := newTestingBackend(t, false) router := v2.NewRouter(backend, nil, metrics.NewNoOpRegistry()) migrationInfo := []migrations.Info{ @@ -35,13 +37,13 @@ func TestGetLedgerInfo(t *testing.T) { Version: "1", Name: "init", State: "ready", - Date: time.Now().Add(-2 * time.Minute).Round(time.Second), + Date: time.Now().Add(-2 * time.Minute).Round(time.Second).UTC(), }, { Version: "2", Name: "fix", State: "ready", - Date: time.Now().Add(-time.Minute).Round(time.Second), + Date: time.Now().Add(-time.Minute).Round(time.Second).UTC(), }, } @@ -70,7 +72,7 @@ func TestGetLedgerInfo(t *testing.T) { func TestGetStats(t *testing.T) { t.Parallel() - backend, mock := newTestingBackend(t) + backend, mock := newTestingBackend(t, true) router := v2.NewRouter(backend, nil, metrics.NewNoOpRegistry()) expectedStats := engine.Stats{ @@ -137,7 +139,7 @@ func TestGetLogs(t *testing.T) { "cursor": []string{"xxx"}, }, expectStatusCode: http.StatusBadRequest, - expectedErrorCode: v2.ErrValidation, + expectedErrorCode: shared.ErrValidation, }, } for _, testCase := range testCases { @@ -155,7 +157,7 @@ func TestGetLogs(t *testing.T) { }, } - backend, mockLedger := newTestingBackend(t) + backend, mockLedger := newTestingBackend(t, true) if testCase.expectStatusCode < 300 && testCase.expectStatusCode >= 200 { mockLedger.EXPECT(). GetLogs(gomock.Any(), ledgerstore.NewGetLogsQuery(testCase.expectQuery)). diff --git a/internal/api/v2/controllers_transactions.go b/internal/api/v2/controllers_transactions.go index cdb67f0d1..7b3c5c38b 100644 --- a/internal/api/v2/controllers_transactions.go +++ b/internal/api/v2/controllers_transactions.go @@ -5,7 +5,8 @@ import ( "fmt" "math/big" "net/http" - "strconv" + + "github.com/formancehq/ledger/internal/api/shared" ledger "github.com/formancehq/ledger/internal" "github.com/formancehq/ledger/internal/engine/command" @@ -23,14 +24,14 @@ func countTransactions(w http.ResponseWriter, r *http.Request) { options, err := getPaginatedQueryOptionsOfPITFilterWithVolumes(r) if err != nil { - sharedapi.BadRequest(w, ErrValidation, err) + sharedapi.BadRequest(w, shared.ErrValidation, err) return } - count, err := LedgerFromContext(r.Context()). + count, err := shared.LedgerFromContext(r.Context()). CountTransactions(r.Context(), ledgerstore.NewGetTransactionsQuery(*options)) if err != nil { - ResponseError(w, r, err) + shared.ResponseError(w, r, err) return } @@ -39,21 +40,21 @@ func countTransactions(w http.ResponseWriter, r *http.Request) { } func getTransactions(w http.ResponseWriter, r *http.Request) { - l := LedgerFromContext(r.Context()) + l := shared.LedgerFromContext(r.Context()) query := &ledgerstore.GetTransactionsQuery{} if r.URL.Query().Get(QueryKeyCursor) != "" { err := paginate.UnmarshalCursor(r.URL.Query().Get(QueryKeyCursor), &query) if err != nil { - ResponseError(w, r, errorsutil.NewError(command.ErrValidation, + shared.ResponseError(w, r, errorsutil.NewError(command.ErrValidation, errors.Errorf("invalid '%s' query param", QueryKeyCursor))) return } } else { options, err := getPaginatedQueryOptionsOfPITFilterWithVolumes(r) if err != nil { - sharedapi.BadRequest(w, ErrValidation, err) + sharedapi.BadRequest(w, shared.ErrValidation, err) return } query = ledgerstore.NewGetTransactionsQuery(*options) @@ -61,7 +62,7 @@ func getTransactions(w http.ResponseWriter, r *http.Request) { cursor, err := l.GetTransactions(r.Context(), query) if err != nil { - ResponseError(w, r, err) + shared.ResponseError(w, r, err) return } @@ -97,11 +98,11 @@ type PostTransactionRequest struct { } func postTransaction(w http.ResponseWriter, r *http.Request) { - l := LedgerFromContext(r.Context()) + l := shared.LedgerFromContext(r.Context()) payload := PostTransactionRequest{} if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { - ResponseError(w, r, + shared.ResponseError(w, r, errorsutil.NewError(command.ErrValidation, errors.New("invalid transaction format"))) return @@ -109,12 +110,12 @@ func postTransaction(w http.ResponseWriter, r *http.Request) { if len(payload.Postings) > 0 && payload.Script.Plain != "" || len(payload.Postings) == 0 && payload.Script.Plain == "" { - ResponseError(w, r, errorsutil.NewError(command.ErrValidation, + shared.ResponseError(w, r, errorsutil.NewError(command.ErrValidation, errors.New("invalid payload: should contain either postings or script"))) return } else if len(payload.Postings) > 0 { if i, err := payload.Postings.Validate(); err != nil { - ResponseError(w, r, errorsutil.NewError(command.ErrValidation, errors.Wrap(err, + shared.ResponseError(w, r, errorsutil.NewError(command.ErrValidation, errors.Wrap(err, fmt.Sprintf("invalid posting %d", i)))) return } @@ -127,7 +128,7 @@ func postTransaction(w http.ResponseWriter, r *http.Request) { res, err := l.CreateTransaction(r.Context(), getCommandParameters(r), ledger.TxToScriptData(txData)) if err != nil { - ResponseError(w, r, err) + shared.ResponseError(w, r, err) return } @@ -144,7 +145,7 @@ func postTransaction(w http.ResponseWriter, r *http.Request) { res, err := l.CreateTransaction(r.Context(), getCommandParameters(r), script) if err != nil { - ResponseError(w, r, err) + shared.ResponseError(w, r, err) return } @@ -152,11 +153,11 @@ func postTransaction(w http.ResponseWriter, r *http.Request) { } func getTransaction(w http.ResponseWriter, r *http.Request) { - l := LedgerFromContext(r.Context()) + l := shared.LedgerFromContext(r.Context()) txId, ok := big.NewInt(0).SetString(chi.URLParam(r, "id"), 10) if !ok { - ResponseError(w, r, errorsutil.NewError(command.ErrValidation, + shared.ResponseError(w, r, errorsutil.NewError(command.ErrValidation, errors.New("invalid transaction ID"))) return } @@ -171,14 +172,14 @@ func getTransaction(w http.ResponseWriter, r *http.Request) { pitFilter, err := getPITFilter(r) if err != nil { - sharedapi.BadRequest(w, ErrValidation, err) + sharedapi.BadRequest(w, shared.ErrValidation, err) return } query.PITFilter = *pitFilter tx, err := l.GetTransactionWithVolumes(r.Context(), query) if err != nil { - ResponseError(w, r, err) + shared.ResponseError(w, r, err) return } @@ -186,7 +187,7 @@ func getTransaction(w http.ResponseWriter, r *http.Request) { } func revertTransaction(w http.ResponseWriter, r *http.Request) { - l := LedgerFromContext(r.Context()) + l := shared.LedgerFromContext(r.Context()) transactionID, ok := big.NewInt(0).SetString(chi.URLParam(r, "id"), 10) if !ok { @@ -196,7 +197,7 @@ func revertTransaction(w http.ResponseWriter, r *http.Request) { tx, err := l.RevertTransaction(r.Context(), getCommandParameters(r), transactionID) if err != nil { - ResponseError(w, r, err) + shared.ResponseError(w, r, err) return } @@ -204,11 +205,11 @@ func revertTransaction(w http.ResponseWriter, r *http.Request) { } func postTransactionMetadata(w http.ResponseWriter, r *http.Request) { - l := LedgerFromContext(r.Context()) + l := shared.LedgerFromContext(r.Context()) var m metadata.Metadata if err := json.NewDecoder(r.Body).Decode(&m); err != nil { - ResponseError(w, r, errorsutil.NewError(command.ErrValidation, + shared.ResponseError(w, r, errorsutil.NewError(command.ErrValidation, errors.New("invalid metadata format"))) return } @@ -220,7 +221,7 @@ func postTransactionMetadata(w http.ResponseWriter, r *http.Request) { } if err := l.SaveMeta(r.Context(), getCommandParameters(r), ledger.MetaTargetTypeTransaction, txID, m); err != nil { - ResponseError(w, r, err) + shared.ResponseError(w, r, err) return } @@ -228,11 +229,11 @@ func postTransactionMetadata(w http.ResponseWriter, r *http.Request) { } func deleteTransactionMetadata(w http.ResponseWriter, r *http.Request) { - l := LedgerFromContext(r.Context()) + l := shared.LedgerFromContext(r.Context()) - transactionID, err := strconv.ParseUint(chi.URLParam(r, "id"), 10, 64) - if err != nil { - ResponseError(w, r, errorsutil.NewError(command.ErrValidation, + transactionID, ok := big.NewInt(0).SetString(chi.URLParam(r, "id"), 10) + if !ok { + shared.ResponseError(w, r, errorsutil.NewError(command.ErrValidation, errors.New("invalid transaction ID"))) return } @@ -240,7 +241,7 @@ func deleteTransactionMetadata(w http.ResponseWriter, r *http.Request) { metadataKey := chi.URLParam(r, "key") if err := l.DeleteMetadata(r.Context(), getCommandParameters(r), ledger.MetaTargetTypeTransaction, transactionID, metadataKey); err != nil { - ResponseError(w, r, err) + shared.ResponseError(w, r, err) return } diff --git a/internal/api/v2/controllers_transactions_test.go b/internal/api/v2/controllers_transactions_test.go index 5c5d3aee1..2a49490e7 100644 --- a/internal/api/v2/controllers_transactions_test.go +++ b/internal/api/v2/controllers_transactions_test.go @@ -9,6 +9,8 @@ import ( "net/url" "testing" + "github.com/formancehq/ledger/internal/api/shared" + ledger "github.com/formancehq/ledger/internal" v2 "github.com/formancehq/ledger/internal/api/v2" "github.com/formancehq/ledger/internal/engine/command" @@ -179,7 +181,7 @@ func TestPostTransactions(t *testing.T) { name: "no postings or script", payload: v2.PostTransactionRequest{}, expectedStatusCode: http.StatusBadRequest, - expectedErrorCode: v2.ErrValidation, + expectedErrorCode: shared.ErrValidation, }, { name: "postings and script", @@ -203,13 +205,13 @@ func TestPostTransactions(t *testing.T) { }, }, expectedStatusCode: http.StatusBadRequest, - expectedErrorCode: v2.ErrValidation, + expectedErrorCode: shared.ErrValidation, }, { name: "using invalid body", payload: "not a valid payload", expectedStatusCode: http.StatusBadRequest, - expectedErrorCode: v2.ErrValidation, + expectedErrorCode: shared.ErrValidation, }, } @@ -224,7 +226,7 @@ func TestPostTransactions(t *testing.T) { ledger.NewPosting("world", "bank", "USD", big.NewInt(100)), ) - backend, mockLedger := newTestingBackend(t) + backend, mockLedger := newTestingBackend(t, true) if testCase.expectedStatusCode < 300 && testCase.expectedStatusCode >= 200 { mockLedger.EXPECT(). CreateTransaction(gomock.Any(), command.Parameters{ @@ -277,7 +279,7 @@ func TestPostTransactionMetadata(t *testing.T) { name: "invalid body", body: "invalid - not an object", expectStatusCode: http.StatusBadRequest, - expectedErrorCode: v2.ErrValidation, + expectedErrorCode: shared.ErrValidation, }, } for _, testCase := range testCases { @@ -288,7 +290,7 @@ func TestPostTransactionMetadata(t *testing.T) { testCase.expectStatusCode = http.StatusNoContent } - backend, mock := newTestingBackend(t) + backend, mock := newTestingBackend(t, true) if testCase.expectStatusCode == http.StatusNoContent { mock.EXPECT(). SaveMeta(gomock.Any(), command.Parameters{}, ledger.MetaTargetTypeTransaction, big.NewInt(0), testCase.body). @@ -323,7 +325,7 @@ func TestGetTransaction(t *testing.T) { nil, ) - backend, mock := newTestingBackend(t) + backend, mock := newTestingBackend(t, true) mock.EXPECT(). GetTransactionWithVolumes(gomock.Any(), ledgerstore.NewGetTransactionQuery(big.NewInt(0))). Return(&tx, nil) @@ -413,7 +415,7 @@ func TestGetTransactions(t *testing.T) { "cursor": []string{"XXX"}, }, expectStatusCode: http.StatusBadRequest, - expectedErrorCode: v2.ErrValidation, + expectedErrorCode: shared.ErrValidation, }, { name: "invalid page size", @@ -421,7 +423,7 @@ func TestGetTransactions(t *testing.T) { "pageSize": []string{"nan"}, }, expectStatusCode: http.StatusBadRequest, - expectedErrorCode: v2.ErrValidation, + expectedErrorCode: shared.ErrValidation, }, { name: "page size over maximum", @@ -451,7 +453,7 @@ func TestGetTransactions(t *testing.T) { }, } - backend, mockLedger := newTestingBackend(t) + backend, mockLedger := newTestingBackend(t, true) if testCase.expectStatusCode < 300 && testCase.expectStatusCode >= 200 { mockLedger.EXPECT(). GetTransactions(gomock.Any(), ledgerstore.NewGetTransactionsQuery(testCase.expectQuery)). @@ -550,7 +552,7 @@ func TestCountTransactions(t *testing.T) { testCase.expectStatusCode = http.StatusNoContent } - backend, mockLedger := newTestingBackend(t) + backend, mockLedger := newTestingBackend(t, true) if testCase.expectStatusCode < 300 && testCase.expectStatusCode >= 200 { mockLedger.EXPECT(). CountTransactions(gomock.Any(), ledgerstore.NewGetTransactionsQuery(testCase.expectQuery)). @@ -585,7 +587,7 @@ func TestRevertTransaction(t *testing.T) { ledger.NewPosting("world", "bank", "USD", big.NewInt(100)), ) - backend, mockLedger := newTestingBackend(t) + backend, mockLedger := newTestingBackend(t, true) mockLedger. EXPECT(). RevertTransaction(gomock.Any(), command.Parameters{}, big.NewInt(0)). diff --git a/internal/api/v2/middlewares_resolver.go b/internal/api/v2/middlewares_resolver.go deleted file mode 100644 index ee5548bcd..000000000 --- a/internal/api/v2/middlewares_resolver.go +++ /dev/null @@ -1,79 +0,0 @@ -package v2 - -import ( - "math/rand" - "net/http" - "sync" - "time" - - "github.com/formancehq/ledger/internal/api/backend" - "github.com/formancehq/ledger/internal/opentelemetry/tracer" - "github.com/formancehq/stack/libs/go-libs/logging" - "github.com/go-chi/chi/v5" -) - -var ( - r *rand.Rand - mu sync.Mutex -) - -func init() { - r = rand.New(rand.NewSource(time.Now().UnixNano())) -} - -var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") - -func randomTraceID(n int) string { - mu.Lock() - defer mu.Unlock() - - b := make([]rune, n) - for i := range b { - b[i] = letterRunes[r.Intn(len(letterRunes))] - } - return string(b) -} - -func LedgerMiddleware( - resolver backend.Backend, -) func(handler http.Handler) http.Handler { - return func(handler http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - name := chi.URLParam(r, "ledger") - if name == "" { - w.WriteHeader(http.StatusNotFound) - return - } - - ctx, span := tracer.Start(r.Context(), name) - defer span.End() - - r = r.WithContext(ctx) - - loggerFields := map[string]any{ - "ledger": name, - } - if span.SpanContext().TraceID().IsValid() { - loggerFields["trace-id"] = span.SpanContext().TraceID().String() - } else { - loggerFields["trace-id"] = randomTraceID(10) - } - - r = r.WithContext(logging.ContextWithFields(r.Context(), loggerFields)) - - l, err := resolver.GetLedger(r.Context(), name) - if err != nil { - ResponseError(w, r, err) - return - } - // TODO(polo/gfyrag): close ledger if not used for x minutes - // defer l.Close(context.Background()) - // When close, we have to decrease the active ledgers counter: - // globalMetricsRegistry.ActiveLedgers.Add(r.Context(), -1) - - r = r.WithContext(ContextWithLedger(r.Context(), l)) - - handler.ServeHTTP(w, r) - }) - } -} diff --git a/internal/api/v2/routes.go b/internal/api/v2/routes.go index 30a729bee..6bc5a652d 100644 --- a/internal/api/v2/routes.go +++ b/internal/api/v2/routes.go @@ -3,6 +3,8 @@ package v2 import ( "net/http" + "github.com/formancehq/ledger/internal/api/shared" + "github.com/formancehq/ledger/internal/api/backend" "github.com/formancehq/ledger/internal/opentelemetry/metrics" "github.com/formancehq/stack/libs/go-libs/health" @@ -37,7 +39,7 @@ func NewRouter( router.Get("/_info", getInfo(backend)) router.Route("/{ledger}", func(router chi.Router) { - router.Use(LedgerMiddleware(backend)) + router.Use(shared.LedgerMiddleware(backend, []string{"/_info"})) // LedgerController router.Get("/_info", getLedgerInfo) diff --git a/internal/bus/message.go b/internal/bus/message.go index 87e0fcd58..883f7822b 100644 --- a/internal/bus/message.go +++ b/internal/bus/message.go @@ -16,7 +16,7 @@ type EventMessage struct { type CommittedTransactions struct { Ledger string `json:"ledger"` - Transaction ledger.Transaction `json:"transaction"` + Transactions []ledger.Transaction `json:"transactions"` AccountMetadata map[string]metadata.Metadata `json:"accountMetadata"` } @@ -66,7 +66,7 @@ func newEventRevertedTransaction(tx RevertedTransaction) EventMessage { type DeletedMetadata struct { Ledger string `json:"ledger"` TargetType string `json:"targetType"` - TargetID any `json:"targetID"` + TargetID any `json:"targetId"` Key string `json:"key"` } diff --git a/internal/bus/monitor.go b/internal/bus/monitor.go index 926460aa7..88a6cf974 100644 --- a/internal/bus/monitor.go +++ b/internal/bus/monitor.go @@ -55,7 +55,7 @@ func (l *ledgerMonitor) CommittedTransactions(ctx context.Context, txs ledger.Tr l.publish(ctx, events.EventTypeCommittedTransactions, newEventCommittedTransactions(CommittedTransactions{ Ledger: l.ledgerName, - Transaction: txs, + Transactions: []ledger.Transaction{txs}, AccountMetadata: accountMetadata, })) } diff --git a/internal/bus/monitor_test.go b/internal/bus/monitor_test.go index dc76884b1..c1fa881dd 100644 --- a/internal/bus/monitor_test.go +++ b/internal/bus/monitor_test.go @@ -2,10 +2,11 @@ package bus import ( "context" - ledger "github.com/formancehq/ledger/internal" "testing" "time" + ledger "github.com/formancehq/ledger/internal" + "github.com/ThreeDotsLabs/watermill" "github.com/ThreeDotsLabs/watermill/pubsub/gochannel" "github.com/formancehq/stack/libs/go-libs/publish" diff --git a/internal/engine/command/commander.go b/internal/engine/command/commander.go index a204aaee7..f7bc459f2 100644 --- a/internal/engine/command/commander.go +++ b/internal/engine/command/commander.go @@ -309,7 +309,7 @@ func (commander *Commander) DeleteMetadata(ctx context.Context, parameters Param } log = ledger.NewDeleteMetadataLog(at, ledger.DeleteMetadataLogPayload{ TargetType: ledger.MetaTargetTypeTransaction, - TargetID: targetID.(uint64), + TargetID: targetID.(*big.Int), Key: key, }) case ledger.MetaTargetTypeAccount: diff --git a/internal/engine/ledger.go b/internal/engine/ledger.go index 111cd73ee..75abda917 100644 --- a/internal/engine/ledger.go +++ b/internal/engine/ledger.go @@ -108,3 +108,7 @@ func (l *Ledger) SaveMeta(ctx context.Context, parameters command.Parameters, ta func (l *Ledger) DeleteMetadata(ctx context.Context, parameters command.Parameters, targetType string, targetID any, key string) error { return l.commander.DeleteMetadata(ctx, parameters, targetType, targetID, key) } + +func (l *Ledger) IsDatabaseUpToDate(ctx context.Context) (bool, error) { + return l.store.IsSchemaUpToDate(ctx) +} diff --git a/internal/engine/resolver.go b/internal/engine/resolver.go index 7f2f25830..70cfc1fbe 100644 --- a/internal/engine/resolver.go +++ b/internal/engine/resolver.go @@ -8,9 +8,7 @@ import ( "github.com/formancehq/ledger/internal/engine/command" "github.com/formancehq/ledger/internal/opentelemetry/metrics" "github.com/formancehq/ledger/internal/storage/driver" - "github.com/formancehq/ledger/internal/storage/ledgerstore" "github.com/formancehq/stack/libs/go-libs/logging" - "github.com/pkg/errors" "github.com/sirupsen/logrus" ) @@ -73,6 +71,7 @@ func (r *Resolver) GetLedger(ctx context.Context, name string) (*Ledger, error) r.lock.RLock() ledger, ok := r.ledgers[name] r.lock.RUnlock() + if !ok { r.lock.Lock() defer r.lock.Unlock() @@ -84,27 +83,11 @@ func (r *Resolver) GetLedger(ctx context.Context, name string) (*Ledger, error) return ledger, nil } - exists, err := r.storageDriver.GetSystemStore().Exists(ctx, name) - if err != nil { - return nil, err - } - - var store *ledgerstore.Store - if !exists { - store, err = r.storageDriver.CreateLedgerStore(ctx, name) - } else { - store, err = r.storageDriver.GetLedgerStore(ctx, name) - } + store, err := r.storageDriver.GetLedgerStore(ctx, name) if err != nil { return nil, err } - if !store.IsInitialized() { - if _, err := store.Migrate(ctx); err != nil { - return nil, errors.Wrap(err, "initializing ledger store") - } - } - ledger = New(store, r.publisher, r.compiler) ledger.Start(logging.ContextWithLogger(context.Background(), r.logger)) r.ledgers[name] = ledger diff --git a/internal/machine/vm/machine.go b/internal/machine/vm/machine.go index f53e4500e..f8b987b99 100644 --- a/internal/machine/vm/machine.go +++ b/internal/machine/vm/machine.go @@ -43,7 +43,7 @@ type Machine struct { } type Posting struct { - Source string `json:"source"` + Source string `json:"source"` Destination string `json:"destination"` Amount *internal.MonetaryInt `json:"amount"` Asset string `json:"asset"` diff --git a/internal/storage/driver/cli.go b/internal/storage/driver/cli.go index 424f5392d..5290220d8 100644 --- a/internal/storage/driver/cli.go +++ b/internal/storage/driver/cli.go @@ -5,7 +5,7 @@ import ( "io" "time" - storage2 "github.com/formancehq/ledger/internal/storage" + storage "github.com/formancehq/ledger/internal/storage" "github.com/formancehq/stack/libs/go-libs/health" "github.com/formancehq/stack/libs/go-libs/logging" "github.com/spf13/cobra" @@ -18,12 +18,12 @@ import ( // Or make the inverse (move analytics flags to pkg/analytics) // IMO, flags are more easily discoverable if located inside cmd/ func InitCLIFlags(cmd *cobra.Command) { - cmd.PersistentFlags().Int(storage2.StoreWorkerMaxPendingSize, 0, "Max pending size for store worker") - cmd.PersistentFlags().Int(storage2.StoreWorkerMaxWriteChanSize, 1024, "Max write channel size for store worker") - cmd.PersistentFlags().String(storage2.StoragePostgresConnectionStringFlag, "postgresql://localhost/postgres", "Postgres connection string") - cmd.PersistentFlags().Int(storage2.StoragePostgresMaxIdleConnsFlag, 20, "Max idle connections to database") - cmd.PersistentFlags().Duration(storage2.StoragePostgresConnMaxIdleTimeFlag, time.Minute, "Max idle time of idle connections") - cmd.PersistentFlags().Int(storage2.StoragePostgresMaxOpenConns, 20, "Max open connections") + cmd.PersistentFlags().Int(storage.StoreWorkerMaxPendingSize, 0, "Max pending size for store worker") + cmd.PersistentFlags().Int(storage.StoreWorkerMaxWriteChanSize, 1024, "Max write channel size for store worker") + cmd.PersistentFlags().String(storage.StoragePostgresConnectionStringFlag, "postgresql://localhost/postgres", "Postgres connection string") + cmd.PersistentFlags().Int(storage.StoragePostgresMaxIdleConnsFlag, 20, "Max idle connections to database") + cmd.PersistentFlags().Duration(storage.StoragePostgresConnMaxIdleTimeFlag, time.Minute, "Max idle time of idle connections") + cmd.PersistentFlags().Int(storage.StoragePostgresMaxOpenConns, 20, "Max open connections") } type PostgresConfig struct { @@ -31,7 +31,7 @@ type PostgresConfig struct { } type ModuleConfig struct { - PostgresConnectionOptions storage2.ConnectionOptions + PostgresConnectionOptions storage.ConnectionOptions Debug bool } @@ -39,9 +39,9 @@ func CLIModule(v *viper.Viper, output io.Writer, debug bool) fx.Option { options := make([]fx.Option, 0) options = append(options, fx.Provide(func(logger logging.Logger) (*bun.DB, error) { - configuration := storage2.ConnectionOptionsFromFlags(v, output, debug) + configuration := storage.ConnectionOptionsFromFlags(v, output, debug) logger.WithField("config", configuration).Infof("Opening connection to database...") - return storage2.OpenSQLDB(configuration) + return storage.OpenSQLDB(configuration) })) options = append(options, fx.Provide(func(db *bun.DB) (*Driver, error) { return New(db), nil diff --git a/internal/storage/driver/driver.go b/internal/storage/driver/driver.go index 701df0df0..6cdcf801a 100644 --- a/internal/storage/driver/driver.go +++ b/internal/storage/driver/driver.go @@ -77,22 +77,15 @@ func (d *Driver) GetSystemStore() *systemstore.Store { } func (d *Driver) newStore(name string) (*ledgerstore.Store, error) { - store, err := ledgerstore.New(d.db, name, func(ctx context.Context) error { + return ledgerstore.New(d.db, name, func(ctx context.Context) error { return d.GetSystemStore().DeleteLedger(ctx, name) }) - if err != nil { - return nil, err - } - - return store, nil } -func (d *Driver) CreateLedgerStore(ctx context.Context, name string) (*ledgerstore.Store, error) { +func (d *Driver) createLedgerStore(ctx context.Context, name string) (*ledgerstore.Store, error) { if name == SystemSchema { return nil, errors.New("reserved name") } - d.lock.Lock() - defer d.lock.Unlock() exists, err := d.systemStore.Exists(ctx, name) if err != nil { @@ -117,19 +110,33 @@ func (d *Driver) CreateLedgerStore(ctx context.Context, name string) (*ledgersto return store, err } +func (d *Driver) CreateLedgerStore(ctx context.Context, name string) (*ledgerstore.Store, error) { + d.lock.Lock() + defer d.lock.Unlock() + + return d.createLedgerStore(ctx, name) +} + func (d *Driver) GetLedgerStore(ctx context.Context, name string) (*ledgerstore.Store, error) { d.lock.Lock() defer d.lock.Unlock() exists, err := d.systemStore.Exists(ctx, name) if err != nil { - return nil, errors.Wrap(err, "checking ledger existence") + return nil, err } + + var store *ledgerstore.Store if !exists { - return nil, storage.ErrStoreNotFound + store, err = d.createLedgerStore(ctx, name) + } else { + store, err = d.newStore(name) + } + if err != nil { + return nil, err } - return d.newStore(name) + return store, nil } func (d *Driver) Initialize(ctx context.Context) error { @@ -154,6 +161,28 @@ func (d *Driver) Initialize(ctx context.Context) error { return nil } +func (d *Driver) UpgradeAllLedgersSchemas(ctx context.Context) error { + systemStore := d.GetSystemStore() + ledgers, err := systemStore.ListLedgers(ctx) + if err != nil { + return err + } + + for _, ledger := range ledgers { + store, err := d.GetLedgerStore(ctx, ledger) + if err != nil { + return err + } + + logging.FromContext(ctx).Infof("Upgrading storage '%s'", ledger) + if _, err := store.Migrate(ctx); err != nil { + return err + } + } + + return nil +} + func New(db *bun.DB) *Driver { return &Driver{ db: db, diff --git a/internal/storage/driver/driver_test.go b/internal/storage/driver/driver_test.go index 03adbdac4..58d1146f5 100644 --- a/internal/storage/driver/driver_test.go +++ b/internal/storage/driver/driver_test.go @@ -2,28 +2,16 @@ package driver_test import ( "context" - "os" "testing" + "github.com/formancehq/stack/libs/go-libs/logging" + "github.com/google/uuid" + "github.com/formancehq/ledger/internal/storage" "github.com/formancehq/ledger/internal/storage/storagetesting" - "github.com/formancehq/stack/libs/go-libs/logging" - "github.com/formancehq/stack/libs/go-libs/pgtesting" "github.com/stretchr/testify/require" ) -func TestMain(t *testing.M) { - if err := pgtesting.CreatePostgresServer(); err != nil { - logging.Error(err) - os.Exit(1) - } - code := t.Run() - if err := pgtesting.DestroyPostgresServer(); err != nil { - logging.Error(err) - } - os.Exit(code) -} - func TestConfiguration(t *testing.T) { d := storagetesting.StorageDriver(t) @@ -40,3 +28,19 @@ func TestConfigurationError(t *testing.T) { require.Error(t, err) require.True(t, storage.IsNotFoundError(err)) } + +func TestErrorOnOutdatedSchema(t *testing.T) { + d := storagetesting.StorageDriver(t) + ctx := logging.TestingContext() + + name := uuid.NewString() + _, err := d.GetSystemStore().Register(ctx, name) + require.NoError(t, err) + + store, err := d.GetLedgerStore(ctx, name) + require.NoError(t, err) + + upToDate, err := store.IsSchemaUpToDate(ctx) + require.NoError(t, err) + require.False(t, upToDate) +} diff --git a/internal/storage/driver/main_test.go b/internal/storage/driver/main_test.go new file mode 100644 index 000000000..70c6d21da --- /dev/null +++ b/internal/storage/driver/main_test.go @@ -0,0 +1,21 @@ +package driver + +import ( + "os" + "testing" + + "github.com/formancehq/stack/libs/go-libs/logging" + "github.com/formancehq/stack/libs/go-libs/pgtesting" +) + +func TestMain(t *testing.M) { + if err := pgtesting.CreatePostgresServer(); err != nil { + logging.Error(err) + os.Exit(1) + } + code := t.Run() + if err := pgtesting.DestroyPostgresServer(); err != nil { + logging.Error(err) + } + os.Exit(code) +} diff --git a/internal/storage/ledgerstore/store.go b/internal/storage/ledgerstore/store.go index 9cc0f3583..935ef572d 100644 --- a/internal/storage/ledgerstore/store.go +++ b/internal/storage/ledgerstore/store.go @@ -15,8 +15,7 @@ type Store struct { db *bun.DB onDelete func(ctx context.Context) error - isInitialized bool - name string + name string } func (store *Store) Name() string { @@ -35,10 +34,6 @@ func (store *Store) Delete(ctx context.Context) error { return errors.Wrap(store.onDelete(ctx), "deleting ledger store") } -func (store *Store) IsInitialized() bool { - return store.isInitialized -} - func (store *Store) prepareTransaction(ctx context.Context) (bun.Tx, error) { txOptions := &sql.TxOptions{} @@ -64,6 +59,10 @@ func (store *Store) withTransaction(ctx context.Context, callback func(tx bun.Tx return tx.Commit() } +func (store *Store) IsSchemaUpToDate(ctx context.Context) (bool, error) { + return store.getMigrator().IsUpToDate(ctx, store.db) +} + func New( db *bun.DB, name string, diff --git a/internal/storage/ledgerstore/transactions_test.go b/internal/storage/ledgerstore/transactions_test.go index e6cafa4a6..f1fd6d358 100644 --- a/internal/storage/ledgerstore/transactions_test.go +++ b/internal/storage/ledgerstore/transactions_test.go @@ -2,12 +2,13 @@ package ledgerstore_test import ( "context" - "github.com/formancehq/stack/libs/go-libs/logging" - "github.com/formancehq/stack/libs/go-libs/pointer" "math/big" "testing" "time" + "github.com/formancehq/stack/libs/go-libs/logging" + "github.com/formancehq/stack/libs/go-libs/pointer" + ledger "github.com/formancehq/ledger/internal" "github.com/formancehq/ledger/internal/storage/ledgerstore" "github.com/formancehq/ledger/internal/storage/query" diff --git a/internal/storage/storagetesting/storage.go b/internal/storage/storagetesting/storage.go index e9fe2fafb..8edca816c 100644 --- a/internal/storage/storagetesting/storage.go +++ b/internal/storage/storagetesting/storage.go @@ -24,7 +24,7 @@ func StorageDriver(t pgtesting.TestingT) *driver.Driver { require.NoError(t, err) t.Cleanup(func() { - db.Close() + _ = db.Close() }) d := driver.New(db) diff --git a/internal/storage/systemstore/configuration.go b/internal/storage/systemstore/configuration.go index def2efe22..4a54cd6c7 100644 --- a/internal/storage/systemstore/configuration.go +++ b/internal/storage/systemstore/configuration.go @@ -16,15 +16,6 @@ type configuration struct { AddedAt ledger.Time `bun:"addedAt,type:timestamp"` } -func (s *Store) CreateConfigurationTable(ctx context.Context) error { - _, err := s.db.NewCreateTable(). - Model((*configuration)(nil)). - IfNotExists(). - Exec(ctx) - - return storageerrors.PostgresError(err) -} - func (s *Store) GetConfiguration(ctx context.Context, key string) (string, error) { query := s.db.NewSelect(). Model((*configuration)(nil)). diff --git a/internal/storage/systemstore/ledgers.go b/internal/storage/systemstore/ledgers.go index a2f3b2582..942df66a6 100644 --- a/internal/storage/systemstore/ledgers.go +++ b/internal/storage/systemstore/ledgers.go @@ -16,15 +16,6 @@ type Ledgers struct { AddedAt ledger.Time `bun:"addedat,type:timestamp"` } -func (s *Store) CreateLedgersTable(ctx context.Context) error { - _, err := s.db.NewCreateTable(). - Model((*Ledgers)(nil)). - IfNotExists(). - Exec(ctx) - - return storageerrors.PostgresError(err) -} - func (s *Store) ListLedgers(ctx context.Context) ([]string, error) { query := s.db.NewSelect(). Model((*Ledgers)(nil)). diff --git a/internal/storage/systemstore/migrations.go b/internal/storage/systemstore/migrations.go new file mode 100644 index 000000000..e572a38ee --- /dev/null +++ b/internal/storage/systemstore/migrations.go @@ -0,0 +1,34 @@ +package systemstore + +import ( + "context" + + "github.com/formancehq/ledger/internal/storage" + "github.com/formancehq/stack/libs/go-libs/migrations" + "github.com/uptrace/bun" +) + +func (s *Store) getMigrator() *migrations.Migrator { + migrator := migrations.NewMigrator(migrations.WithSchema("_system", true)) + migrator.RegisterMigrations( + migrations.Migration{ + Name: "Init schema", + UpWithContext: func(ctx context.Context, tx bun.Tx) error { + _, err := tx.NewCreateTable(). + Model((*Ledgers)(nil)). + IfNotExists(). + Exec(ctx) + if err != nil { + return storage.PostgresError(err) + } + + _, err = s.db.NewCreateTable(). + Model((*configuration)(nil)). + IfNotExists(). + Exec(ctx) + return storage.PostgresError(err) + }, + }, + ) + return migrator +} diff --git a/internal/storage/systemstore/store.go b/internal/storage/systemstore/store.go index 2aaae497a..124f8824a 100644 --- a/internal/storage/systemstore/store.go +++ b/internal/storage/systemstore/store.go @@ -16,9 +16,5 @@ func NewStore(db *bun.DB) *Store { } func (s *Store) Initialize(ctx context.Context) error { - if err := s.CreateLedgersTable(ctx); err != nil { - return storage.PostgresError(err) - } - - return storage.PostgresError(s.CreateConfigurationTable(ctx)) + return storage.PostgresError(s.getMigrator().Up(ctx, s.db)) } diff --git a/libs/.github/dependabot.yml b/libs/.github/dependabot.yml deleted file mode 100644 index e69de29bb..000000000 diff --git a/libs/.github/workflows/codeql.yml b/libs/.github/workflows/codeql.yml deleted file mode 100644 index 773d9fba0..000000000 --- a/libs/.github/workflows/codeql.yml +++ /dev/null @@ -1,33 +0,0 @@ -name: "CodeQL" - -on: - push: - branches: [ main ] - pull_request: - branches: [ main ] - schedule: - - cron: '35 21 * * 5' - -jobs: - analyze: - name: Analyze - runs-on: ubuntu-latest - permissions: - actions: read - contents: read - security-events: write - strategy: - fail-fast: false - matrix: - language: [ 'go' ] - steps: - - name: Checkout repository - uses: actions/checkout@v3 - - name: Initialize CodeQL - uses: github/codeql-action/init@v2 - with: - languages: ${{ matrix.language }} - - name: Autobuild - uses: github/codeql-action/autobuild@v2 - - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v2 diff --git a/libs/.github/workflows/main.yml b/libs/.github/workflows/main.yml deleted file mode 100644 index e69de29bb..000000000 diff --git a/libs/.github/workflows/pr-open.yml b/libs/.github/workflows/pr-open.yml deleted file mode 100644 index e69de29bb..000000000 diff --git a/libs/.gitignore b/libs/.gitignore index e69de29bb..2a9b1e54c 100644 --- a/libs/.gitignore +++ b/libs/.gitignore @@ -0,0 +1,3 @@ +.idea +vendor +coverage.* diff --git a/libs/.golangci.yml b/libs/.golangci.yml deleted file mode 100644 index e6e416e7a..000000000 --- a/libs/.golangci.yml +++ /dev/null @@ -1,8 +0,0 @@ -linters: - enable: - - gofmt - - gci - - goimports - -run: - timeout: 5m diff --git a/libs/.pre-commit-config.yaml b/libs/.pre-commit-config.yaml index e69de29bb..a4c584c91 100644 --- a/libs/.pre-commit-config.yaml +++ b/libs/.pre-commit-config.yaml @@ -0,0 +1,22 @@ +exclude: client +fail_fast: true +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.3.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + exclude: .cloud + - id: check-added-large-files +- repo: https://github.com/formancehq/pre-commit-hooks + rev: dd079f7c30ad72446d615f55a000d4f875e79633 + hooks: + - id: gogenerate + files: swagger.yaml + - id: gomodtidy + - id: goimports + - id: gofmt + - id: golangci-lint + - id: gotests + - id: commitlint diff --git a/libs/collectionutils/slice.go b/libs/collectionutils/slice.go index be5f4a55d..58d2faead 100644 --- a/libs/collectionutils/slice.go +++ b/libs/collectionutils/slice.go @@ -12,6 +12,14 @@ func Map[FROM any, TO any](input []FROM, mapper func(FROM) TO) []TO { return ret } +func CopyMap[KEY comparable, VALUE any](m map[KEY]VALUE) map[KEY]VALUE { + ret := make(map[KEY]VALUE) + for k, v := range m { + ret[k] = v + } + return ret +} + func Filter[TYPE any](input []TYPE, filter func(TYPE) bool) []TYPE { ret := make([]TYPE, 0) for _, i := range input { @@ -59,3 +67,18 @@ func Contains[T any](slice []T, t T) bool { } return false } + +type Set[T comparable] map[T]struct{} + +func (s Set[T]) Put(t T) { + s[t] = struct{}{} +} + +func (s Set[T]) Contains(t T) bool { + _, ok := s[t] + return ok +} + +func NewSet[T comparable]() Set[T] { + return make(Set[T], 0) +} diff --git a/libs/contextutil/contextutil.go b/libs/contextutil/contextutil.go new file mode 100644 index 000000000..7aaa323a4 --- /dev/null +++ b/libs/contextutil/contextutil.go @@ -0,0 +1,40 @@ +package contextutil + +import ( + "context" + "time" +) + +type detachedContext struct { + parent context.Context +} + +var _ context.Context = (*detachedContext)(nil) + +func (c *detachedContext) Done() <-chan struct{} { + return nil +} + +func (c *detachedContext) Deadline() (deadline time.Time, ok bool) { + return c.parent.Deadline() +} + +func (c *detachedContext) Err() error { + return c.parent.Err() +} + +func (c *detachedContext) Value(key interface{}) interface{} { + return c.parent.Value(key) +} + +func Detached(parent context.Context) (context.Context, context.CancelFunc) { + c := &detachedContext{parent: parent} + if deadline, ok := parent.Deadline(); ok { + return context.WithDeadline(c, deadline) + } + return context.WithCancel(c) +} + +func DetachedWithTimeout(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { + return context.WithTimeout(&detachedContext{parent: parent}, timeout) +} diff --git a/libs/migrations/migrator.go b/libs/migrations/migrator.go index 858f12d61..77fb68d1b 100644 --- a/libs/migrations/migrator.go +++ b/libs/migrations/migrator.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "github.com/lib/pq" "time" "github.com/pkg/errors" @@ -74,9 +75,19 @@ func (m *Migrator) getLastVersion(ctx context.Context, querier interface { }) (int64, error) { row := querier.QueryRowContext(ctx, fmt.Sprintf(`select max(version_id) from "%s";`, migrationTable)) if err := row.Err(); err != nil { - if err == sql.ErrNoRows { + switch { + case err == sql.ErrNoRows: return -1, nil + default: + switch err := err.(type) { + case *pq.Error: + switch err.Code { + case "42P01": // Table not exists + return -1, nil + } + } } + return -1, errors.Wrap(err, "selecting max id from version table") } var number sql.NullInt64 @@ -99,11 +110,35 @@ func (m *Migrator) insertVersion(ctx context.Context, tx bun.Tx, version int) er } func (m *Migrator) GetDBVersion(ctx context.Context, db *bun.DB) (int64, error) { - return m.getLastVersion(ctx, db) + tx, err := m.newTx(ctx, db) + if err != nil { + return -1, err + } + defer func() { + _ = tx.Rollback() + }() + + return m.getLastVersion(ctx, tx) } -func (m *Migrator) Up(ctx context.Context, db bun.IDB) error { +func (m *Migrator) newTx(ctx context.Context, db bun.IDB) (bun.Tx, error) { tx, err := db.BeginTx(ctx, &sql.TxOptions{}) + if err != nil { + return bun.Tx{}, err + } + + if m.schema != "" { + _, err := tx.ExecContext(ctx, fmt.Sprintf(`set search_path = "%s"`, m.schema)) + if err != nil { + return bun.Tx{}, err + } + } + + return tx, err +} + +func (m *Migrator) Up(ctx context.Context, db bun.IDB) error { + tx, err := m.newTx(ctx, db) if err != nil { return err } @@ -111,14 +146,8 @@ func (m *Migrator) Up(ctx context.Context, db bun.IDB) error { _ = tx.Rollback() }() - if m.schema != "" { - if m.createSchema { - _, err := tx.ExecContext(ctx, fmt.Sprintf(`create schema if not exists "%s"`, m.schema)) - if err != nil { - return err - } - } - _, err := tx.ExecContext(ctx, fmt.Sprintf(`set search_path = "%s"`, m.schema)) + if m.schema != "" && m.createSchema { + _, err := tx.ExecContext(ctx, fmt.Sprintf(`create schema if not exists "%s"`, m.schema)) if err != nil { return err } @@ -136,16 +165,16 @@ func (m *Migrator) Up(ctx context.Context, db bun.IDB) error { if len(m.migrations) > int(lastMigration)-1 { for ind, migration := range m.migrations[lastMigration:] { if migration.UpWithContext != nil { - if err := migration.UpWithContext(ctx, tx); err != nil { - return err + if err := migration.UpWithContext(ctx, tx); err != nil { + return err + } + } else if migration.Up != nil { + if err := migration.Up(tx); err != nil { + return err + } + } else { + return errors.New("no code defined for migration") } - } else if migration.Up != nil { - if err := migration.Up(tx); err != nil { - return err - } - } else { - return errors.New("no code defined for migration") - } if err := m.insertVersion(ctx, tx, int(lastMigration)+ind+1); err != nil { return err @@ -157,13 +186,21 @@ func (m *Migrator) Up(ctx context.Context, db bun.IDB) error { } func (m *Migrator) GetMigrations(ctx context.Context, db bun.IDB) ([]Info, error) { + tx, err := m.newTx(ctx, db) + if err != nil { + return nil, err + } + defer func() { + _ = tx.Rollback() + }() + migrationTableName := migrationTable if m.schema != "" { migrationTableName = fmt.Sprintf(`"%s".%s`, m.schema, migrationTableName) } ret := make([]Info, 0) - if err := db.NewSelect(). + if err := tx.NewSelect(). TableExpr(migrationTableName). Order("version_id"). Where("version_id >= 1"). @@ -188,6 +225,22 @@ func (m *Migrator) GetMigrations(ctx context.Context, db bun.IDB) ([]Info, error return ret, nil } +func (m *Migrator) IsUpToDate(ctx context.Context, db *bun.DB) (bool, error) { + tx, err := m.newTx(ctx, db) + if err != nil { + return false, err + } + defer func() { + _ = tx.Rollback() + }() + version, err := m.getLastVersion(ctx, tx) + if err != nil { + return false, err + } + + return int(version) == len(m.migrations), nil +} + func NewMigrator(opts ...option) *Migrator { ret := &Migrator{} for _, opt := range opts { diff --git a/libs/pgtesting/postgres.go b/libs/pgtesting/postgres.go index 372001e6b..c77151916 100644 --- a/libs/pgtesting/postgres.go +++ b/libs/pgtesting/postgres.go @@ -189,7 +189,7 @@ func WithDockerHostConfigOption(opt func(hostConfig *docker.HostConfig)) option var defaultOptions = []option{ WithStatusCheckInterval(200 * time.Millisecond), WithInitialUser("root", "root"), - WithMaximumWaitingTime(5 * time.Second), + WithMaximumWaitingTime(15 * time.Second), WithInitialDatabaseName("formance"), WithContext(context.Background()), } diff --git a/openapi.yaml b/openapi.yaml index 2e28c49a4..c43b1a760 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -271,6 +271,40 @@ paths: schema: $ref: '#/components/schemas/ErrorResponse' + /v2/{ledger}/accounts/{address}/metadata/{key}: + delete: + description: Delete metadata by key + operationId: deleteAccountMetadata + tags: + - Ledger + - Transactions + summary: Delete metadata by key + parameters: + - name: ledger + in: path + description: Name of the ledger. + required: true + schema: + type: string + example: ledger001 + - name: address + in: path + description: Account address + required: true + schema: + type: string + - name: key + in: path + description: The key to remove. + required: true + schema: + type: string + example: foo + responses: + 2XX: + description: Key deleted + content: {} + /v2/{ledger}/stats: get: tags: @@ -563,6 +597,43 @@ paths: schema: $ref: '#/components/schemas/ErrorResponse' + /v2/{ledger}/transactions/{id}/metadata/{key}: + delete: + description: Delete metadata by key + operationId: deleteTransactionMetadata + summary: Delete metadata by key + tags: + - Ledger + - Transactions + parameters: + - name: ledger + in: path + description: Name of the ledger. + required: true + schema: + type: string + example: ledger001 + - name: id + in: path + description: Transaction ID. + required: true + schema: + type: integer + format: int64 + minimum: 0 + example: 1234 + - name: key + in: path + required: true + description: The key to remove. + schema: + type: string + example: foo + responses: + 2XX: + description: Key deleted + content: { } + /v2/{ledger}/transactions/{id}/revert: post: tags: @@ -1145,9 +1216,7 @@ components: - data ConfigInfoResponse: - properties: - data: - $ref: '#/components/schemas/ConfigInfo' + $ref: '#/components/schemas/ConfigInfo' Volume: type: object