diff --git a/action/protocol/context.go b/action/protocol/context.go index 268ba2ea3d..030c21cb9b 100644 --- a/action/protocol/context.go +++ b/action/protocol/context.go @@ -155,6 +155,7 @@ type ( CheckStakingDurationUpperLimit bool FixRevertSnapshot bool TimestampedStakingContract bool + PreStateSystemAction bool MakeUpBlockReward bool } @@ -316,6 +317,7 @@ func WithFeatureCtx(ctx context.Context) context.Context { CheckStakingDurationUpperLimit: g.IsVanuatu(height), FixRevertSnapshot: g.IsVanuatu(height), TimestampedStakingContract: g.IsWake(height), + PreStateSystemAction: !g.IsWake(height), MakeUpBlockReward: g.IsWake(height), }, ) diff --git a/state/factory/workingset.go b/state/factory/workingset.go index 9da01128f3..51b921f535 100644 --- a/state/factory/workingset.go +++ b/state/factory/workingset.go @@ -469,6 +469,12 @@ func (ws *workingSet) process(ctx context.Context, actions []*action.SealedEnvel if err := ws.validate(ctx); err != nil { return err } + userActions, systemActions := ws.splitActions(actions) + if protocol.MustGetFeatureCtx(ctx).PreStateSystemAction { + if err := ws.validatePostSystemActions(ctx, systemActions); err != nil { + return err + } + } reg := protocol.MustGetRegistry(ctx) for _, p := range reg.All() { if pp, ok := p.(protocol.PreStatesCreator); ok { @@ -478,11 +484,10 @@ func (ws *workingSet) process(ctx context.Context, actions []*action.SealedEnvel } } var ( - receipts = make([]*action.Receipt, 0) - ctxWithBlockContext = ctx - blkCtx = protocol.MustGetBlockCtx(ctx) - fCtx = protocol.MustGetFeatureCtx(ctx) - userActions, systemActions = ws.splitActions(actions) + receipts = make([]*action.Receipt, 0) + ctxWithBlockContext = ctx + blkCtx = protocol.MustGetBlockCtx(ctx) + fCtx = protocol.MustGetFeatureCtx(ctx) ) for _, act := range userActions { if err := ws.txValidator.ValidateWithState(ctxWithBlockContext, act); err != nil { @@ -513,8 +518,10 @@ func (ws *workingSet) process(ctx context.Context, actions []*action.SealedEnvel } } // Handle post system actions - if err := ws.validatePostSystemActions(ctxWithBlockContext, systemActions); err != nil { - return err + if !protocol.MustGetFeatureCtx(ctx).PreStateSystemAction { + if err := ws.validatePostSystemActions(ctxWithBlockContext, systemActions); err != nil { + return err + } } for _, act := range systemActions { actionCtx, err := withActionCtx(ctxWithBlockContext, act) @@ -682,6 +689,16 @@ func (ws *workingSet) pickAndRunActions( executedActions := make([]*action.SealedEnvelope, 0) reg := protocol.MustGetRegistry(ctx) + var ( + systemActions []*action.SealedEnvelope + ) + if protocol.MustGetFeatureCtx(ctx).PreStateSystemAction { + systemActions, err = ws.generateSignedSystemActions(ctx, sign) + if err != nil { + return nil, err + } + } + for _, p := range reg.All() { if pp, ok := p.(protocol.PreStatesCreator); ok { if err := pp.CreatePreStates(ctx, ws); err != nil { @@ -800,19 +817,14 @@ func (ws *workingSet) pickAndRunActions( } } - unsignedSystemActions, err := ws.generateSystemActions(ctxWithBlockContext) - if err != nil { - return nil, err - } - postSystemActions := make([]*action.SealedEnvelope, len(unsignedSystemActions)) - for i, elp := range unsignedSystemActions { - selp, err := sign(elp) + if !fCtx.PreStateSystemAction { + systemActions, err = ws.generateSignedSystemActions(ctx, sign) if err != nil { - return nil, errors.Wrapf(err, "failed to sign %+v", elp.Action()) + return nil, err } - postSystemActions[i] = selp } - for _, selp := range postSystemActions { + + for _, selp := range systemActions { actionCtx, err := withActionCtx(ctxWithBlockContext, selp) if err != nil { return nil, err @@ -832,6 +844,22 @@ func (ws *workingSet) pickAndRunActions( return executedActions, ws.finalize() } +func (ws *workingSet) generateSignedSystemActions(ctx context.Context, sign func(elp action.Envelope) (*action.SealedEnvelope, error)) ([]*action.SealedEnvelope, error) { + unsignedSystemActions, err := ws.generateSystemActions(ctx) + if err != nil { + return nil, err + } + postSystemActions := make([]*action.SealedEnvelope, len(unsignedSystemActions)) + for i, elp := range unsignedSystemActions { + selp, err := sign(elp) + if err != nil { + return nil, errors.Wrapf(err, "failed to sign %+v", elp.Action()) + } + postSystemActions[i] = selp + } + return postSystemActions, nil +} + func updateReceiptIndex(receipts []*action.Receipt) { var txIndex, logIndex uint32 for _, r := range receipts {