From 2955ccf4504a516a216e4ce8f0b3c346eeb840dc Mon Sep 17 00:00:00 2001 From: Scott Fairclough Date: Thu, 9 Jan 2025 10:13:22 +0000 Subject: [PATCH] include all used info tree indexes in the witness --- zk/witness/witness.go | 65 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 62 insertions(+), 3 deletions(-) diff --git a/zk/witness/witness.go b/zk/witness/witness.go index 5ae7ac04bcf..1a14646205f 100644 --- a/zk/witness/witness.go +++ b/zk/witness/witness.go @@ -8,11 +8,13 @@ import ( "math/big" "time" + "github.com/iden3/go-iden3-crypto/keccak256" "github.com/ledgerwatch/erigon-lib/chain" libcommon "github.com/ledgerwatch/erigon-lib/common" "github.com/ledgerwatch/erigon-lib/common/datadir" "github.com/ledgerwatch/erigon-lib/kv" libstate "github.com/ledgerwatch/erigon-lib/state" + "github.com/ledgerwatch/erigon/common" "github.com/ledgerwatch/erigon/consensus" "github.com/ledgerwatch/erigon/core" "github.com/ledgerwatch/erigon/core/rawdb" @@ -34,13 +36,14 @@ import ( zkUtils "github.com/ledgerwatch/erigon/zk/utils" "github.com/ledgerwatch/log/v3" - "github.com/ledgerwatch/erigon-lib/kv/membatchwithdb" - "github.com/holiman/uint256" "math" + + "github.com/holiman/uint256" + "github.com/ledgerwatch/erigon-lib/kv/membatchwithdb" ) var ( - maxGetProofRewindBlockCount uint64 = 500_000 + maxGetProofRewindBlockCount uint64 = 500_000_000 ErrEndBeforeStart = errors.New("end block must be higher than start block") ) @@ -270,6 +273,10 @@ func (g *Generator) generateWitness(tx kv.Tx, ctx context.Context, batchNum uint reader := state.NewPlainState(tx, blocks[0].NumberU64(), systemcontracts.SystemContractCodeLookup[g.chainCfg.ChainName]) defer reader.Close() + // used to ensure that any info tree updates for this batch are included in the witness - re-use of an index for example + // won't write to storage so will be missing from the witness but the prover needs it + forcedInfoTreeUpdates := make([]libcommon.Hash, 0) + for _, block := range blocks { blockNum := block.NumberU64() reader.SetBlockNr(blockNum) @@ -339,6 +346,14 @@ func (g *Generator) generateWitness(tx kv.Tx, ctx context.Context, batchNum uint return nil, err } + forcedInfoTreeUpdate, err := CheckForForcedInfoTreeUpdate(hermezDb, blockNum) + if err != nil { + return nil, fmt.Errorf("CheckForForcedInfoTreeUpdate: %w", err) + } + if forcedInfoTreeUpdate != nil { + forcedInfoTreeUpdates = append(forcedInfoTreeUpdates, *forcedInfoTreeUpdate) + } + prevStateRoot = block.Root() } @@ -353,6 +368,27 @@ func (g *Generator) generateWitness(tx kv.Tx, ctx context.Context, batchNum uint } } + // ensure that the ger manager is in the inclusion list if there are forced info tree updates + if len(forcedInfoTreeUpdates) > 0 { + if _, ok := inclusion[state.GER_MANAGER_ADDRESS]; !ok { + inclusion[state.GER_MANAGER_ADDRESS] = []libcommon.Hash{} + } + } + + // add any forced info tree updates to the inclusion list that aren't already there + for _, forced := range forcedInfoTreeUpdates { + skip := false + for _, hash := range inclusion[state.GER_MANAGER_ADDRESS] { + if hash == forced { + skip = true + break + } + } + if !skip { + inclusion[state.GER_MANAGER_ADDRESS] = append(inclusion[state.GER_MANAGER_ADDRESS], forced) + } + } + var rl trie.RetainDecider // if full is true, we will send all the nodes to the witness rl = &trie.AlwaysTrueRetainDecider{} @@ -401,3 +437,26 @@ func (g *Generator) generateMockWitness(batchNum uint64, blocks []*eritypes.Bloc return mockWitness, nil } + +func CheckForForcedInfoTreeUpdate(reader *hermez_db.HermezDbReader, blockNum uint64) (*libcommon.Hash, error) { + // check if there were any info tree index updates for this block number + index, err := reader.GetBlockL1InfoTreeIndex(blockNum) + if err != nil { + return nil, fmt.Errorf("failed to check for block info tree index: %w", err) + } + var result *libcommon.Hash + if index != 0 { + // we need to load this info tree index to get the storage slot address to force witness inclusion + infoTreeIndex, err := reader.GetL1InfoTreeUpdate(index) + if err != nil { + return nil, fmt.Errorf("failed to get info tree index: %w", err) + } + d1 := common.LeftPadBytes(infoTreeIndex.GER.Bytes(), 32) + d2 := common.LeftPadBytes(state.GLOBAL_EXIT_ROOT_STORAGE_POS.Bytes(), 32) + mapKey := keccak256.Hash(d1, d2) + mkh := libcommon.BytesToHash(mapKey) + result = &mkh + } + + return result, nil +}