From 88dbc57db6d32341876cdf9c005a425bca6c1e74 Mon Sep 17 00:00:00 2001 From: zhi Date: Sat, 19 Apr 2025 14:24:59 +0800 Subject: [PATCH] unit test for view --- action/protocol/staking/protocol_test.go | 2 +- action/protocol/staking/viewdata_test.go | 88 +++++++++++++ state/factory/factory_test.go | 156 +++++++++++++++++++++++ state/factory/statedb.go | 16 +-- 4 files changed, 253 insertions(+), 9 deletions(-) create mode 100644 action/protocol/staking/viewdata_test.go diff --git a/action/protocol/staking/protocol_test.go b/action/protocol/staking/protocol_test.go index 2c38ede197..2134e66493 100644 --- a/action/protocol/staking/protocol_test.go +++ b/action/protocol/staking/protocol_test.go @@ -127,8 +127,8 @@ func TestProtocol(t *testing.T) { ctx = protocol.WithBlockCtx(ctx, protocol.BlockCtx{BlockHeight: 10}) ctx = protocol.WithFeatureCtx(ctx) v, err := stk.Start(ctx, sm) - sm.WriteView(_protocolID, v) r.NoError(err) + r.NoError(sm.WriteView(_protocolID, v)) _, ok := v.(*ViewData) r.True(ok) diff --git a/action/protocol/staking/viewdata_test.go b/action/protocol/staking/viewdata_test.go new file mode 100644 index 0000000000..17517d319a --- /dev/null +++ b/action/protocol/staking/viewdata_test.go @@ -0,0 +1,88 @@ +package staking + +import ( + "context" + "math/big" + "testing" + + "github.com/golang/mock/gomock" + "github.com/iotexproject/iotex-core/v2/test/identityset" + "github.com/iotexproject/iotex-core/v2/test/mock/mock_chainmanager" + "github.com/stretchr/testify/require" +) + +func TestViewData_Clone(t *testing.T) { + viewData, _ := prepareViewData(t) + clone, ok := viewData.Clone().(*ViewData) + require.True(t, ok) + require.NotNil(t, clone) + + require.Equal(t, viewData.candCenter.size, clone.candCenter.size) + require.Equal(t, viewData.candCenter.base, clone.candCenter.base) + require.Equal(t, viewData.candCenter.change, clone.candCenter.change) + require.NotSame(t, viewData.bucketPool, clone.bucketPool) + require.Equal(t, viewData.snapshots, clone.snapshots) + + sr := mock_chainmanager.NewMockStateReader(gomock.NewController(t)) + sr.EXPECT().Height().Return(uint64(100), nil).Times(1) + require.NoError(t, viewData.Commit(context.Background(), sr)) + + clone, ok = viewData.Clone().(*ViewData) + require.True(t, ok) + require.NotNil(t, clone) + require.Equal(t, viewData.candCenter.size, clone.candCenter.size) + require.Equal(t, viewData.candCenter.base, clone.candCenter.base) + require.Equal(t, viewData.candCenter.change, clone.candCenter.change) + require.Equal(t, viewData.bucketPool, clone.bucketPool) + require.Equal(t, viewData.snapshots, clone.snapshots) +} + +func prepareViewData(t *testing.T) (*ViewData, int) { + owner := identityset.Address(0) + cand := &Candidate{ + Owner: owner, + Operator: owner, + Reward: owner, + Identifier: owner, + Name: "name", + Votes: big.NewInt(100), + SelfStakeBucketIdx: 0, + SelfStake: big.NewInt(0), + } + candCenter, err := NewCandidateCenter([]*Candidate{cand}) + require.NoError(t, err) + require.NoError(t, candCenter.Upsert(cand)) + bucketPool := &BucketPool{ + enableSMStorage: false, + dirty: true, + total: &totalAmount{ + amount: big.NewInt(100), + count: 1, + }, + } + viewData := &ViewData{ + candCenter: candCenter, + bucketPool: bucketPool, + snapshots: []Snapshot{}, + } + return viewData, viewData.Snapshot() +} + +func TestViewData_Commit(t *testing.T) { + viewData, _ := prepareViewData(t) + require.True(t, viewData.IsDirty()) + mockStateReader := mock_chainmanager.NewMockStateReader(gomock.NewController(t)) + mockStateReader.EXPECT().Height().Return(uint64(100), nil).Times(1) + require.NoError(t, viewData.Commit(context.Background(), mockStateReader)) + require.False(t, viewData.IsDirty()) + require.Empty(t, viewData.candCenter.change.dirty) + require.False(t, viewData.bucketPool.dirty) + require.Empty(t, viewData.snapshots) +} + +func TestViewData_Snapshot_Revert(t *testing.T) { + viewData, ss := prepareViewData(t) + require.Equal(t, 1, len(viewData.snapshots)) + require.NoError(t, viewData.Revert(ss)) + require.Equal(t, 0, len(viewData.snapshots)) +} diff --git a/state/factory/factory_test.go b/state/factory/factory_test.go index 71e915972f..4c02329997 100644 --- a/state/factory/factory_test.go +++ b/state/factory/factory_test.go @@ -9,6 +9,7 @@ import ( "context" "encoding/csv" "encoding/hex" + "math" "math/big" "math/rand" "os" @@ -34,12 +35,14 @@ import ( "github.com/iotexproject/iotex-core/v2/action/protocol/poll" "github.com/iotexproject/iotex-core/v2/action/protocol/rewarding" "github.com/iotexproject/iotex-core/v2/action/protocol/rolldpos" + "github.com/iotexproject/iotex-core/v2/action/protocol/staking" "github.com/iotexproject/iotex-core/v2/action/protocol/vote/candidatesutil" "github.com/iotexproject/iotex-core/v2/blockchain" "github.com/iotexproject/iotex-core/v2/blockchain/block" "github.com/iotexproject/iotex-core/v2/blockchain/genesis" "github.com/iotexproject/iotex-core/v2/db" "github.com/iotexproject/iotex-core/v2/pkg/enc" + "github.com/iotexproject/iotex-core/v2/pkg/unit" "github.com/iotexproject/iotex-core/v2/pkg/util/fileutil" "github.com/iotexproject/iotex-core/v2/state" "github.com/iotexproject/iotex-core/v2/test/identityset" @@ -1203,6 +1206,159 @@ func TestDeleteAndPutSameKey(t *testing.T) { }) } +func TestMintBlocksWithCandidateUpdate(t *testing.T) { + require := require.New(t) + testStateDBPath, err := testutil.PathOfTempFile(_stateDBPath) + require.NoError(err) + defer testutil.CleanupPath(testStateDBPath) + a := identityset.Address(28) + b := identityset.Address(29) + priKeyA := identityset.PrivateKey(28) + priKeyB := identityset.PrivateKey(29) + + cfg := DefaultConfig + cfg.Chain.TrieDBPath = testStateDBPath + cfg.Genesis.InitBalanceMap[a.String()] = unit.ConvertIotxToRau(10000000).String() + cfg.Genesis.InitBalanceMap[b.String()] = unit.ConvertIotxToRau(5000000).String() + + registry := protocol.NewRegistry() + require.NoError(account.NewProtocol(rewarding.DepositGas).Register(registry)) + sp, err := staking.NewProtocol( + staking.HelperCtx{ + DepositGas: rewarding.DepositGas, + BlockInterval: func(u uint64) time.Duration { + return time.Second + }, + }, + &staking.BuilderConfig{ + Staking: genesis.TestDefault().Staking, + PersistStakingPatchBlock: math.MaxUint64, + }, + nil, + nil, + nil, + ) + require.NoError(err) + require.NoError(sp.Register(registry)) + + db2, err := db.CreateKVStoreWithCache(db.DefaultConfig, cfg.Chain.TrieDBPath, cfg.Chain.StateDBCacheSize) + require.NoError(err) + sdb, err := NewStateDB(cfg, db2, SkipBlockValidationStateDBOption(), RegistryStateDBOption(registry)) + require.NoError(err) + + ctx := protocol.WithBlockCtx( + genesis.WithGenesisContext(context.Background(), cfg.Genesis), + protocol.BlockCtx{}, + ) + ctx = protocol.WithFeatureWithHeightCtx(ctx) + require.NoError(sdb.Start(ctx)) + defer func() { + require.NoError(sdb.Stop(ctx)) + }() + + tsf1, err := action.NewCandidateRegister("cand1", a.String(), a.String(), a.String(), unit.ConvertIotxToRau(1200000).String(), 0, false, nil) + require.NoError(err) + elp1 := (&action.EnvelopeBuilder{}).SetNonce(1).SetGasLimit(20000).SetAction(tsf1).Build() + selp1, err := action.Sign(elp1, priKeyA) + require.NoError(err) + + ctx = protocol.WithBlockCtx( + ctx, + protocol.BlockCtx{ + BlockHeight: 1, + Producer: identityset.Address(27), + GasLimit: testutil.TestGasLimit, + }, + ) + ctx = protocol.WithFeatureCtx(ctx) + mockActPool := mock_actpool.NewMockActPool(gomock.NewController(t)) + mockActPool.EXPECT().PendingActionMap().Return(map[string][]*action.SealedEnvelope{ + a.String(): {selp1}, + }).Times(1) + + blk1, err := sdb.Mint( + protocol.WithBlockchainCtx( + ctx, + protocol.BlockchainCtx{ + ChainID: 1, + Tip: protocol.TipInfo{ + Height: 0, + Hash: hash.ZeroHash256, + }, + }, + ), + mockActPool, + identityset.PrivateKey(27)) + require.NoError(err) + require.NotNil(blk1) + + ws1, exist, err := sdb.(*stateDB).getFromWorkingSets(ctx, blk1.HashBlock()) + require.NoError(err) + require.True(exist) + require.NotNil(ws1) + + tsf2, err := action.NewCandidateRegister("cand2", b.String(), b.String(), b.String(), unit.ConvertIotxToRau(1200000).String(), 0, false, nil) + require.NoError(err) + elp2 := (&action.EnvelopeBuilder{}).SetNonce(1).SetGasLimit(20000).SetAction(tsf2).Build() + selp2, err := action.Sign(elp2, priKeyB) + require.NoError(err) + + ctx = protocol.WithBlockCtx( + ctx, + protocol.BlockCtx{ + BlockHeight: 1, + Producer: identityset.Address(26), + GasLimit: testutil.TestGasLimit, + }, + ) + ctx = protocol.WithFeatureCtx(ctx) + mockActPool.EXPECT().PendingActionMap().Return(map[string][]*action.SealedEnvelope{ + a.String(): {selp2}, + }).Times(1) + + blk2, err := sdb.Mint( + protocol.WithBlockchainCtx( + ctx, + protocol.BlockchainCtx{ + ChainID: 1, + Tip: protocol.TipInfo{ + Height: 0, + Hash: hash.ZeroHash256, + }, + }, + ), + mockActPool, + identityset.PrivateKey(26)) + require.NoError(err) + require.NotNil(blk2) + + ws2, exist, err := sdb.(*stateDB).getFromWorkingSets(ctx, blk2.HashBlock()) + require.NoError(err) + require.True(exist) + require.NotNil(ws2) + + csr, err := staking.ConstructBaseView(sdb) + require.NoError(err) + csr1, err := staking.ConstructBaseView(ws1) + require.NoError(err) + csr2, err := staking.ConstructBaseView(ws2) + require.NoError(err) + require.NotNil(csr1.GetCandidateByName("cand1")) + require.Nil(csr.GetCandidateByName("cand1")) + require.Nil(csr2.GetCandidateByName("cand1")) + require.NotNil(csr2.GetCandidateByName("cand2")) + require.Nil(csr.GetCandidateByName("cand2")) + require.Nil(csr1.GetCandidateByName("cand2")) + + require.NoError(sdb.PutBlock(ctx, blk1)) + csr, err = staking.ConstructBaseView(sdb) + require.NoError(err) + require.NotNil(csr.GetCandidateByName("cand1")) + require.Nil(csr2.GetCandidateByName("cand1")) + require.NotNil(csr2.GetCandidateByName("cand2")) + require.Nil(csr.GetCandidateByName("cand2")) +} + func BenchmarkSDBInMemRunAction(b *testing.B) { cfg := DefaultConfig sdb, err := NewStateDB(cfg, db.NewMemKVStore(), SkipBlockValidationStateDBOption()) diff --git a/state/factory/statedb.go b/state/factory/statedb.go index 9a15440fa0..8f2709d3f6 100644 --- a/state/factory/statedb.go +++ b/state/factory/statedb.go @@ -51,7 +51,7 @@ type ( dao daoRetrofitter timerFactory *prometheustimer.TimerFactory workingsets cache.LRUCache // lru cache for workingsets - protocolView *protocol.Views + protocolViews *protocol.Views skipBlockValidationOnPut bool ps *patchStore } @@ -98,7 +98,7 @@ func NewStateDB(cfg Config, dao db.KVStore, opts ...StateDBOption) (Factory, err cfg: cfg, currentChainHeight: 0, registry: protocol.NewRegistry(), - protocolView: &protocol.Views{}, + protocolViews: &protocol.Views{}, workingsets: cache.NewThreadSafeLruCache(int(cfg.Chain.WorkingSetCacheSize)), } for _, opt := range opts { @@ -132,7 +132,7 @@ func (sdb *stateDB) Start(ctx context.Context) error { case nil: sdb.currentChainHeight = h // start all protocols - if sdb.protocolView, err = sdb.registry.StartAll(ctx, sdb); err != nil { + if sdb.protocolViews, err = sdb.registry.StartAll(ctx, sdb); err != nil { return err } case db.ErrNotExist: @@ -141,7 +141,7 @@ func (sdb *stateDB) Start(ctx context.Context) error { return errors.Wrap(err, "failed to init statedb's height") } // start all protocols - if sdb.protocolView, err = sdb.registry.StartAll(ctx, sdb); err != nil { + if sdb.protocolViews, err = sdb.registry.StartAll(ctx, sdb); err != nil { return err } ctx = protocol.WithBlockCtx( @@ -192,7 +192,7 @@ func (sdb *stateDB) newWorkingSetWithKVStore(ctx context.Context, height uint64, if err := store.Start(ctx); err != nil { return nil, err } - views := sdb.protocolView.Clone() + views := sdb.protocolViews.Clone() if err := views.Commit(ctx, sdb); err != nil { return nil, err } @@ -375,7 +375,7 @@ func (sdb *stateDB) PutBlock(ctx context.Context, blk *block.Block) error { if err := ws.Commit(ctx); err != nil { return err } - sdb.protocolView = ws.views + sdb.protocolViews = ws.views sdb.currentChainHeight = h return nil } @@ -419,7 +419,7 @@ func (sdb *stateDB) States(opts ...protocol.StateOption) (uint64, state.Iterator // ReadView reads the view func (sdb *stateDB) ReadView(name string) (protocol.View, error) { - return sdb.protocolView.Read(name) + return sdb.protocolViews.Read(name) } // StateReaderAt returns a state reader at a specific height @@ -491,7 +491,7 @@ func (sdb *stateDB) createGenesisStates(ctx context.Context) error { if err := ws.Commit(ctx); err != nil { return err } - sdb.protocolView = ws.views + sdb.protocolViews = ws.views return nil }