From db3400fec4f111d845f1c2b2ea98544f7dd87ea8 Mon Sep 17 00:00:00 2001 From: Abhinav Gupta Date: Fri, 20 Dec 2024 19:44:10 -0500 Subject: [PATCH] state: Add 'version' file There's a possibility that we make breaking changes to the state store in the future. To detect and upgrade it automatically, add a 'version' file to the state store. Its current value is "1", and that's what we'll assume when the file is absent. The binary will refuse to run against a store with a version newer than the current binary. --- .../unreleased/Added-20241220-194631.yaml | 6 + internal/spice/state/mocks_test.go | 271 ++++++++++++++++++ internal/spice/state/storage/mem.go | 13 + internal/spice/state/store.go | 26 +- internal/spice/state/store_test.go | 44 +++ internal/spice/state/version.go | 75 +++++ internal/spice/state/version_test.go | 124 ++++++++ 7 files changed, 555 insertions(+), 4 deletions(-) create mode 100644 .changes/unreleased/Added-20241220-194631.yaml create mode 100644 internal/spice/state/mocks_test.go create mode 100644 internal/spice/state/version.go create mode 100644 internal/spice/state/version_test.go diff --git a/.changes/unreleased/Added-20241220-194631.yaml b/.changes/unreleased/Added-20241220-194631.yaml new file mode 100644 index 00000000..de1c856d --- /dev/null +++ b/.changes/unreleased/Added-20241220-194631.yaml @@ -0,0 +1,6 @@ +kind: Added +body: >- + state: Track version of the state store layout in use. + This should be a no-op for users, but it guards against corruption + in case of future changes to the layout. +time: 2024-12-20T19:46:31.685098-05:00 diff --git a/internal/spice/state/mocks_test.go b/internal/spice/state/mocks_test.go new file mode 100644 index 00000000..9acbf8b4 --- /dev/null +++ b/internal/spice/state/mocks_test.go @@ -0,0 +1,271 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: go.abhg.dev/gs/internal/spice/state (interfaces: DB) +// +// Generated by this command: +// +// mockgen -destination mocks_test.go -package state -typed . DB +// + +// Package state is a generated GoMock package. +package state + +import ( + context "context" + reflect "reflect" + + storage "go.abhg.dev/gs/internal/spice/state/storage" + gomock "go.uber.org/mock/gomock" +) + +// MockDB is a mock of DB interface. +type MockDB struct { + ctrl *gomock.Controller + recorder *MockDBMockRecorder + isgomock struct{} +} + +// MockDBMockRecorder is the mock recorder for MockDB. +type MockDBMockRecorder struct { + mock *MockDB +} + +// NewMockDB creates a new mock instance. +func NewMockDB(ctrl *gomock.Controller) *MockDB { + mock := &MockDB{ctrl: ctrl} + mock.recorder = &MockDBMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDB) EXPECT() *MockDBMockRecorder { + return m.recorder +} + +// Clear mocks base method. +func (m *MockDB) Clear(ctx context.Context, msg string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Clear", ctx, msg) + ret0, _ := ret[0].(error) + return ret0 +} + +// Clear indicates an expected call of Clear. +func (mr *MockDBMockRecorder) Clear(ctx, msg any) *MockDBClearCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Clear", reflect.TypeOf((*MockDB)(nil).Clear), ctx, msg) + return &MockDBClearCall{Call: call} +} + +// MockDBClearCall wrap *gomock.Call +type MockDBClearCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockDBClearCall) Return(arg0 error) *MockDBClearCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockDBClearCall) Do(f func(context.Context, string) error) *MockDBClearCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockDBClearCall) DoAndReturn(f func(context.Context, string) error) *MockDBClearCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Delete mocks base method. +func (m *MockDB) Delete(ctx context.Context, k, msg string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Delete", ctx, k, msg) + ret0, _ := ret[0].(error) + return ret0 +} + +// Delete indicates an expected call of Delete. +func (mr *MockDBMockRecorder) Delete(ctx, k, msg any) *MockDBDeleteCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockDB)(nil).Delete), ctx, k, msg) + return &MockDBDeleteCall{Call: call} +} + +// MockDBDeleteCall wrap *gomock.Call +type MockDBDeleteCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockDBDeleteCall) Return(arg0 error) *MockDBDeleteCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockDBDeleteCall) Do(f func(context.Context, string, string) error) *MockDBDeleteCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockDBDeleteCall) DoAndReturn(f func(context.Context, string, string) error) *MockDBDeleteCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Get mocks base method. +func (m *MockDB) Get(ctx context.Context, k string, v any) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", ctx, k, v) + ret0, _ := ret[0].(error) + return ret0 +} + +// Get indicates an expected call of Get. +func (mr *MockDBMockRecorder) Get(ctx, k, v any) *MockDBGetCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockDB)(nil).Get), ctx, k, v) + return &MockDBGetCall{Call: call} +} + +// MockDBGetCall wrap *gomock.Call +type MockDBGetCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockDBGetCall) Return(arg0 error) *MockDBGetCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockDBGetCall) Do(f func(context.Context, string, any) error) *MockDBGetCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockDBGetCall) DoAndReturn(f func(context.Context, string, any) error) *MockDBGetCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Keys mocks base method. +func (m *MockDB) Keys(ctx context.Context, dir string) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Keys", ctx, dir) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Keys indicates an expected call of Keys. +func (mr *MockDBMockRecorder) Keys(ctx, dir any) *MockDBKeysCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Keys", reflect.TypeOf((*MockDB)(nil).Keys), ctx, dir) + return &MockDBKeysCall{Call: call} +} + +// MockDBKeysCall wrap *gomock.Call +type MockDBKeysCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockDBKeysCall) Return(arg0 []string, arg1 error) *MockDBKeysCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockDBKeysCall) Do(f func(context.Context, string) ([]string, error)) *MockDBKeysCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockDBKeysCall) DoAndReturn(f func(context.Context, string) ([]string, error)) *MockDBKeysCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Set mocks base method. +func (m *MockDB) Set(ctx context.Context, k string, v any, msg string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Set", ctx, k, v, msg) + ret0, _ := ret[0].(error) + return ret0 +} + +// Set indicates an expected call of Set. +func (mr *MockDBMockRecorder) Set(ctx, k, v, msg any) *MockDBSetCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockDB)(nil).Set), ctx, k, v, msg) + return &MockDBSetCall{Call: call} +} + +// MockDBSetCall wrap *gomock.Call +type MockDBSetCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockDBSetCall) Return(arg0 error) *MockDBSetCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockDBSetCall) Do(f func(context.Context, string, any, string) error) *MockDBSetCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockDBSetCall) DoAndReturn(f func(context.Context, string, any, string) error) *MockDBSetCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Update mocks base method. +func (m *MockDB) Update(ctx context.Context, req storage.UpdateRequest) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Update", ctx, req) + ret0, _ := ret[0].(error) + return ret0 +} + +// Update indicates an expected call of Update. +func (mr *MockDBMockRecorder) Update(ctx, req any) *MockDBUpdateCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockDB)(nil).Update), ctx, req) + return &MockDBUpdateCall{Call: call} +} + +// MockDBUpdateCall wrap *gomock.Call +type MockDBUpdateCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockDBUpdateCall) Return(arg0 error) *MockDBUpdateCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockDBUpdateCall) Do(f func(context.Context, storage.UpdateRequest) error) *MockDBUpdateCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockDBUpdateCall) DoAndReturn(f func(context.Context, storage.UpdateRequest) error) *MockDBUpdateCall { + c.Call = c.Call.DoAndReturn(f) + return c +} diff --git a/internal/spice/state/storage/mem.go b/internal/spice/state/storage/mem.go index 4622366c..d51b02ae 100644 --- a/internal/spice/state/storage/mem.go +++ b/internal/spice/state/storage/mem.go @@ -4,6 +4,8 @@ import ( "context" "encoding/json" "fmt" + "iter" + "slices" "sort" "strings" "sync" @@ -24,6 +26,17 @@ func NewMemBackend() *MemBackend { } } +// AddFiles adds the given files to the memory backend, +// overwriting similarly named files. +func (m *MemBackend) AddFiles(files iter.Seq2[string, []byte]) { + m.mu.Lock() + defer m.mu.Unlock() + + for name, body := range files { + m.items[name] = slices.Clone(body) + } +} + // Get retrieves a value from the store func (m *MemBackend) Get(ctx context.Context, key string, dst any) error { m.mu.RLock() diff --git a/internal/spice/state/store.go b/internal/spice/state/store.go index 439897af..3708d129 100644 --- a/internal/spice/state/store.go +++ b/internal/spice/state/store.go @@ -24,6 +24,8 @@ type DB interface { var _ DB = (*storage.DB)(nil) +//go:generate mockgen -destination mocks_test.go -package state -typed . DB + // Store implements storage for state tracked by gs. type Store struct { db DB @@ -105,11 +107,23 @@ func InitStore(ctx context.Context, req InitStoreRequest) (*Store, error) { } } - info := repoInfo{ - Trunk: req.Trunk, - Remote: req.Remote, + update := storage.UpdateRequest{ + Sets: []storage.SetRequest{ + { + Key: _repoJSON, + Value: repoInfo{ + Trunk: req.Trunk, + Remote: req.Remote, + }, + }, + { + Key: _versionFile, + Value: LatestVersion, + }, + }, + Message: "initialize store", } - if err := db.Set(ctx, _repoJSON, info, "initialize store"); err != nil { + if err := db.Update(ctx, update); err != nil { return nil, fmt.Errorf("put repo state: %w", err) } @@ -165,6 +179,10 @@ func OpenStore(ctx context.Context, db DB, logger *log.Logger) (*Store, error) { logger = log.New(io.Discard) } + if err := checkVersion(ctx, db); err != nil { + return nil, fmt.Errorf("check store layout: %w", err) + } + var info repoInfo if err := db.Get(ctx, _repoJSON, &info); err != nil { if errors.Is(err, ErrNotExist) { diff --git a/internal/spice/state/store_test.go b/internal/spice/state/store_test.go index 1040aacc..5c7d5604 100644 --- a/internal/spice/state/store_test.go +++ b/internal/spice/state/store_test.go @@ -3,6 +3,7 @@ package state_test import ( "context" "encoding/json" + "maps" "testing" "github.com/stretchr/testify/assert" @@ -124,3 +125,46 @@ func TestStore(t *testing.T) { assert.Equal(t, "remoteBranch", res.UpstreamBranch) }) } + +func TestOpenStore_errors(t *testing.T) { + t.Run("VersionMismatch", func(t *testing.T) { + mem := storage.NewMemBackend() + mem.AddFiles(maps.All(map[string][]byte{ + "version": []byte("500"), + })) + + _, err := state.OpenStore(context.Background(), storage.NewDB(mem), nil) + require.Error(t, err) + assert.ErrorContains(t, err, "check store layout:") + assert.ErrorAs(t, err, new(*state.VersionMismatchError)) + }) + + t.Run("NotInitialized", func(t *testing.T) { + mem := storage.NewMemBackend() + _, err := state.OpenStore(context.Background(), storage.NewDB(mem), nil) + require.Error(t, err) + assert.ErrorIs(t, err, state.ErrUninitialized) + }) + + t.Run("CorruptRepo/Unparseable", func(t *testing.T) { + mem := storage.NewMemBackend() + mem.AddFiles(maps.All(map[string][]byte{ + "repo": []byte(`{`), + })) + + _, err := state.OpenStore(context.Background(), storage.NewDB(mem), nil) + require.Error(t, err) + assert.ErrorContains(t, err, "get repo state:") + }) + + t.Run("CorruptRepo/Incomplete", func(t *testing.T) { + mem := storage.NewMemBackend() + mem.AddFiles(maps.All(map[string][]byte{ + "repo": []byte(`{}`), + })) + + _, err := state.OpenStore(context.Background(), storage.NewDB(mem), nil) + require.Error(t, err) + assert.ErrorContains(t, err, "corrupt state:") + }) +} diff --git a/internal/spice/state/version.go b/internal/spice/state/version.go new file mode 100644 index 00000000..e4b669b1 --- /dev/null +++ b/internal/spice/state/version.go @@ -0,0 +1,75 @@ +package state + +import ( + "context" + "errors" + "fmt" +) + +const _versionFile = "version" + +// Version specifies the version of the state store. +// +// It is stored in a 'version' file in the root of the store. +// Absence of this file indicates version 1. +type Version int + +// Supported versions of the storage layout. +const ( + VersionOne Version = 1 + + // LatestVersion refers to the latest supported version. + LatestVersion = VersionOne +) + +// checkVersion verifies that the given DB +// uses a supported version of the layout. +func checkVersion(ctx context.Context, db DB) error { + version, err := loadVersion(ctx, db) + if err != nil { + return fmt.Errorf("load store version: %w", err) + } + + // If/when we make a breaking change to the storage format, + // we'll add migration code here. + switch version { + case VersionOne: + // ok + + default: + return &VersionMismatchError{ + Want: LatestVersion, + Got: version, + } + } + + return nil +} + +// loadVersion loads the version of the storage layout used by the given [DB]. +func loadVersion(ctx context.Context, db DB) (Version, error) { + var version Version + + if err := db.Get(ctx, _versionFile, &version); err != nil { + if errors.Is(err, ErrNotExist) { + // Version file was added during storage version 1. + // If file does not exist, it's an old v1 store. + return VersionOne, nil + } + + return 0, err + } + + return version, nil +} + +// VersionMismatchError indicates that the data store we attempted to open +// is using a version older than this binary knows how to handle. +type VersionMismatchError struct { + Want Version + Got Version +} + +func (e *VersionMismatchError) Error() string { + return fmt.Sprintf("expected store version <= %d, got %d", e.Want, e.Got) +} diff --git a/internal/spice/state/version_test.go b/internal/spice/state/version_test.go new file mode 100644 index 00000000..2ec16d2d --- /dev/null +++ b/internal/spice/state/version_test.go @@ -0,0 +1,124 @@ +package state + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.abhg.dev/gs/internal/spice/state/storage" + "go.uber.org/mock/gomock" +) + +func TestLoadVersion(t *testing.T) { + tests := []struct { + name string + files map[string]string + want Version + }{ + { + name: "Empty", + want: VersionOne, + }, + { + name: "ExplicitV1", + files: map[string]string{ + "version": "1", + }, + want: VersionOne, + }, + { + name: "FutureVersion", + files: map[string]string{ + "version": "42", + }, + want: Version(42), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mem := storage.NewMemBackend() + mem.AddFiles(func(yield func(string, []byte) bool) { + for name, body := range tt.files { + if !yield(name, []byte(body)) { + return + } + } + }) + + db := storage.NewDB(mem) + got, err := loadVersion(context.Background(), db) + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestCheckVersion(t *testing.T) { + tests := []struct { + name string + files map[string]string + err bool + }{ + {name: "ImplicitV1"}, + { + name: "ExplicitV1", + files: map[string]string{ + "version": "1", + }, + }, + { + name: "UnsupportedVersion", + files: map[string]string{ + "version": "500", + }, + err: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mem := storage.NewMemBackend() + mem.AddFiles(func(yield func(string, []byte) bool) { + for name, body := range tt.files { + if !yield(name, []byte(body)) { + return + } + } + }) + + db := storage.NewDB(mem) + err := checkVersion(context.Background(), db) + if tt.err { + require.Error(t, err) + assert.ErrorAs(t, err, new(*VersionMismatchError)) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestCheckVersion_loadError(t *testing.T) { + ctrl := gomock.NewController(t) + mockDB := NewMockDB(ctrl) + + mockDB.EXPECT(). + Get(gomock.Any(), gomock.Any(), gomock.Any()). + Return(assert.AnError) + + err := checkVersion(context.Background(), mockDB) + require.Error(t, err) + assert.ErrorContains(t, err, "load store version:") + assert.ErrorIs(t, err, assert.AnError) +} + +func TestVersionMismatchError(t *testing.T) { + err := &VersionMismatchError{ + Want: 42, + Got: 43, + } + + assert.Equal(t, "expected store version <= 42, got 43", err.Error()) +}