From 58e8754a381d18de80e443a6c0603743fdad98b0 Mon Sep 17 00:00:00 2001 From: weiihann Date: Wed, 9 Oct 2024 18:24:05 +0800 Subject: [PATCH 01/13] remove firstdeployed field --- blockchain/encoder_initializer.go | 1 + core/contract.go | 231 ++++++++--------- core/contract_test.go | 140 +++++----- core/state.go | 416 ++++++++++++------------------ db/buckets.go | 55 ++-- 5 files changed, 368 insertions(+), 475 deletions(-) diff --git a/blockchain/encoder_initializer.go b/blockchain/encoder_initializer.go index fbc82ab4d8..93efac7a82 100644 --- a/blockchain/encoder_initializer.go +++ b/blockchain/encoder_initializer.go @@ -20,6 +20,7 @@ func RegisterCoreTypesToEncoder() { reflect.TypeOf(core.DeployAccountTransaction{}), reflect.TypeOf(core.Cairo0Class{}), reflect.TypeOf(core.Cairo1Class{}), + reflect.TypeOf(core.StateContract{}), } for _, t := range types { diff --git a/core/contract.go b/core/contract.go index 2af1fd8c4c..178f5c9cb1 100644 --- a/core/contract.go +++ b/core/contract.go @@ -7,6 +7,7 @@ import ( "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/trie" "github.com/NethermindEth/juno/db" + "github.com/NethermindEth/juno/encoder" ) // contract storage has fixed height at 251 @@ -17,191 +18,161 @@ var ( ErrContractAlreadyDeployed = errors.New("contract already deployed") ) -// NewContractUpdater creates an updater for the contract instance at the given address. -// Deploy should be called for contracts that were just deployed to the network. -func NewContractUpdater(addr *felt.Felt, txn db.Transaction) (*ContractUpdater, error) { - contractDeployed, err := deployed(addr, txn) - if err != nil { - return nil, err - } - - if !contractDeployed { - return nil, ErrContractNotDeployed - } +type OnValueChanged = func(location, oldValue *felt.Felt) error - return &ContractUpdater{ - Address: addr, - txn: txn, - }, nil +type StateContract struct { + // ClassHash is the hash of the contract's class + ClassHash *felt.Felt + // Nonce is the contract's nonce + Nonce *felt.Felt + // DeployHeight is the height at which the contract is deployed + DeployHeight uint64 + // Address that this contract instance is deployed to + Address *felt.Felt `cbor:"-"` + // Storage is the contract's storage + Storage map[felt.Felt]*felt.Felt `cbor:"-"` } -// DeployContract sets up the database for a new contract. -func DeployContract(addr, classHash *felt.Felt, txn db.Transaction) (*ContractUpdater, error) { - contractDeployed, err := deployed(addr, txn) - if err != nil { - return nil, err +func NewStateContract( + addr *felt.Felt, + classHash *felt.Felt, + nonce *felt.Felt, + storage map[felt.Felt]*felt.Felt, + DeployHeight uint64, +) *StateContract { + sc := &StateContract{ + Address: addr, + ClassHash: classHash, + Nonce: nonce, + Storage: storage, + DeployHeight: DeployHeight, } - if contractDeployed { - return nil, ErrContractAlreadyDeployed + if storage == nil { + sc.Storage = make(map[felt.Felt]*felt.Felt) } - err = setClassHash(txn, addr, classHash) + return sc +} + +func (c *StateContract) StorageRoot(txn db.Transaction) (*felt.Felt, error) { + storageTrie, err := storage(c.Address, txn) if err != nil { return nil, err } - c, err := NewContractUpdater(addr, txn) - if err != nil { - return nil, err + return storageTrie.Root() +} + +func (c *StateContract) GetStorage(key *felt.Felt, txn db.Transaction) (*felt.Felt, error) { + if c.Storage != nil { + if val, ok := c.Storage[*key]; ok { + return val, nil + } } - err = c.UpdateNonce(&felt.Zero) + // get from db + storage, err := storage(c.Address, txn) if err != nil { return nil, err } - return c, nil + return storage.Get(key) } -// ContractAddress computes the address of a Starknet contract. -func ContractAddress(callerAddress, classHash, salt *felt.Felt, constructorCallData []*felt.Felt) *felt.Felt { - prefix := new(felt.Felt).SetBytes([]byte("STARKNET_CONTRACT_ADDRESS")) - callDataHash := crypto.PedersenArray(constructorCallData...) - - // https://docs.starknet.io/architecture-and-concepts/smart-contracts/contract-address/ - return crypto.PedersenArray( - prefix, - callerAddress, - salt, - classHash, - callDataHash, - ) -} - -func deployed(addr *felt.Felt, txn db.Transaction) (bool, error) { - _, err := ContractClassHash(addr, txn) - if errors.Is(err, db.ErrKeyNotFound) { - return false, nil - } +func (c *StateContract) Commit(txn db.Transaction, cb OnValueChanged) error { + storageTrie, err := storage(c.Address, txn) if err != nil { - return false, err + return err } - return true, nil -} - -// ContractUpdater is a helper to update an existing contract instance. -type ContractUpdater struct { - // Address that this contract instance is deployed to - Address *felt.Felt - // txn to access the database - txn db.Transaction -} - -// Purge eliminates the contract instance, deleting all associated data from storage -// assumes storage is cleared in revert process -func (c *ContractUpdater) Purge() error { - addrBytes := c.Address.Marshal() - buckets := []db.Bucket{db.ContractNonce, db.ContractClassHash} - for _, bucket := range buckets { - if err := c.txn.Delete(bucket.Key(addrBytes)); err != nil { + for key, value := range c.Storage { + oldVal, err := storageTrie.Put(&key, value) + if err != nil { return err } - } - - return nil -} -// ContractNonce returns the amount transactions sent from this contract. -// Only account contracts can have a non-zero nonce. -func ContractNonce(addr *felt.Felt, txn db.Transaction) (*felt.Felt, error) { - key := db.ContractNonce.Key(addr.Marshal()) - var nonce *felt.Felt - if err := txn.Get(key, func(val []byte) error { - nonce = new(felt.Felt) - nonce.SetBytes(val) - return nil - }); err != nil { - return nil, err + if oldVal != nil { + if err = cb(&key, oldVal); err != nil { + return err + } + } } - return nonce, nil -} -// UpdateNonce updates the nonce value in the database. -func (c *ContractUpdater) UpdateNonce(nonce *felt.Felt) error { - nonceKey := db.ContractNonce.Key(c.Address.Marshal()) - return c.txn.Set(nonceKey, nonce.Marshal()) -} + if err := storageTrie.Commit(); err != nil { + return err + } -// ContractRoot returns the root of the contract storage. -func ContractRoot(addr *felt.Felt, txn db.Transaction) (*felt.Felt, error) { - cStorage, err := storage(addr, txn) + contractBytes, err := encoder.Marshal(c) if err != nil { - return nil, err + return err } - return cStorage.Root() + + return txn.Set(db.Contract.Key(c.Address.Marshal()), contractBytes) } -type OnValueChanged = func(location, oldValue *felt.Felt) error +// Purge eliminates the contract instance, deleting all associated data from database +// assumes storage is cleared in revert process +func (c *StateContract) Purge(txn db.Transaction) error { + addrBytes := c.Address.Marshal() -// UpdateStorage applies a change-set to the contract storage. -func (c *ContractUpdater) UpdateStorage(diff map[felt.Felt]*felt.Felt, cb OnValueChanged) error { - cStorage, err := storage(c.Address, c.txn) - if err != nil { + if err := txn.Delete(db.Contract.Key(addrBytes)); err != nil { return err } - // apply the diff - for key, value := range diff { - oldValue, pErr := cStorage.Put(&key, value) - if pErr != nil { - return pErr - } - if oldValue != nil { - if err = cb(&key, oldValue); err != nil { - return err - } - } - } - - return cStorage.Commit() + return txn.Delete(db.ContractDeploymentHeight.Key(addrBytes)) } -func ContractStorage(addr, key *felt.Felt, txn db.Transaction) (*felt.Felt, error) { - cStorage, err := storage(addr, txn) +// GetContract is a wrapper around getContract which checks if a contract is deployed +func GetContract(addr *felt.Felt, txn db.Transaction) (*StateContract, error) { + contract, err := getContract(addr, txn) if err != nil { + if errors.Is(err, db.ErrKeyNotFound) { + return nil, ErrContractNotDeployed + } return nil, err } - return cStorage.Get(key) + + return contract, nil } -// ContractClassHash returns hash of the class that the contract at the given address instantiates. -func ContractClassHash(addr *felt.Felt, txn db.Transaction) (*felt.Felt, error) { - key := db.ContractClassHash.Key(addr.Marshal()) - var classHash *felt.Felt +// getContract gets a contract instance from the database. +func getContract(addr *felt.Felt, txn db.Transaction) (*StateContract, error) { + key := db.Contract.Key(addr.Marshal()) + var contract StateContract if err := txn.Get(key, func(val []byte) error { - classHash = new(felt.Felt) - classHash.SetBytes(val) + if err := encoder.Unmarshal(val, &contract); err != nil { + return err + } + + contract.Address = addr + contract.Storage = make(map[felt.Felt]*felt.Felt) + return nil }); err != nil { return nil, err } - return classHash, nil + return &contract, nil } -func setClassHash(txn db.Transaction, addr, classHash *felt.Felt) error { - classHashKey := db.ContractClassHash.Key(addr.Marshal()) - return txn.Set(classHashKey, classHash.Marshal()) -} +// ContractAddress computes the address of a Starknet contract. +func ContractAddress(callerAddress, classHash, salt *felt.Felt, constructorCallData []*felt.Felt) *felt.Felt { + prefix := new(felt.Felt).SetBytes([]byte("STARKNET_CONTRACT_ADDRESS")) + callDataHash := crypto.PedersenArray(constructorCallData...) -// Replace replaces the class that the contract instantiates -func (c *ContractUpdater) Replace(classHash *felt.Felt) error { - return setClassHash(c.txn, c.Address, classHash) + // https://docs.starknet.io/architecture-and-concepts/smart-contracts/contract-address/ + return crypto.PedersenArray( + prefix, + callerAddress, + salt, + classHash, + callDataHash, + ) } // storage returns the [core.Trie] that represents the // storage of the contract. +// TODO(weiihann): how to deal with the root key? func storage(addr *felt.Felt, txn db.Transaction) (*trie.Trie, error) { addrBytes := addr.Marshal() trieTxn := trie.NewStorage(txn, db.ContractStorage.Key(addrBytes)) diff --git a/core/contract_test.go b/core/contract_test.go index 8ace83ba7e..3a32f29915 100644 --- a/core/contract_test.go +++ b/core/contract_test.go @@ -59,124 +59,129 @@ func TestNewContract(t *testing.T) { t.Cleanup(func() { require.NoError(t, txn.Discard()) }) + + blockNumber := uint64(10) addr := new(felt.Felt).SetUint64(234) classHash := new(felt.Felt).SetBytes([]byte("class hash")) - t.Run("cannot create Contract instance if un-deployed", func(t *testing.T) { - _, err = core.NewContractUpdater(addr, txn) - require.EqualError(t, err, core.ErrContractNotDeployed.Error()) + t.Run("cannot get contract if un-deployed", func(t *testing.T) { + _, err = core.GetContract(addr, txn) + require.ErrorIs(t, err, core.ErrContractNotDeployed) }) - contract, err := core.DeployContract(addr, classHash, txn) - require.NoError(t, err) + var contract *core.StateContract + t.Run("commit contract", func(t *testing.T) { + contract = core.NewStateContract(addr, classHash, &felt.Zero, nil, blockNumber) + require.NoError(t, contract.Commit(txn, nil)) + }) - t.Run("redeploy should fail", func(t *testing.T) { - _, err := core.DeployContract(addr, classHash, txn) - require.EqualError(t, err, core.ErrContractAlreadyDeployed.Error()) + t.Run("get contract from db", func(t *testing.T) { + contract, err = core.GetContract(addr, txn) + require.NoError(t, err) }) - t.Run("a call to contract should fail with a committed txn", func(t *testing.T) { - assert.NoError(t, txn.Commit()) - t.Run("ClassHash()", func(t *testing.T) { - _, err := core.ContractClassHash(addr, txn) - assert.Error(t, err) - }) - t.Run("Root()", func(t *testing.T) { - _, err := core.ContractRoot(addr, txn) - assert.Error(t, err) - }) - t.Run("Nonce()", func(t *testing.T) { - _, err := core.ContractNonce(addr, txn) - assert.Error(t, err) - }) - t.Run("Storage()", func(t *testing.T) { - _, err := core.ContractStorage(addr, classHash, txn) - assert.Error(t, err) - }) - t.Run("UpdateNonce()", func(t *testing.T) { - assert.Error(t, contract.UpdateNonce(&felt.Zero)) - }) - t.Run("UpdateStorage()", func(t *testing.T) { - assert.Error(t, contract.UpdateStorage(nil, NoopOnValueChanged)) - }) + t.Run("check contract fields", func(t *testing.T) { + assert.Equal(t, addr, contract.Address) + assert.Equal(t, classHash, contract.ClassHash) + assert.Equal(t, &felt.Zero, contract.Nonce) + assert.Empty(t, contract.Storage) + assert.Equal(t, blockNumber, contract.DeployHeight) }) } -func TestNonceAndClassHash(t *testing.T) { +func TestUpdateContract(t *testing.T) { testDB := pebble.NewMemTest(t) txn, err := testDB.NewTransaction(true) require.NoError(t, err) + blockNumber := uint64(10) addr := new(felt.Felt).SetUint64(44) classHash := new(felt.Felt).SetUint64(37) - contract, err := core.DeployContract(addr, classHash, txn) + contract := core.NewStateContract(addr, classHash, &felt.Zero, nil, blockNumber) + require.NoError(t, contract.Commit(txn, nil)) + + contract, err = core.GetContract(addr, txn) require.NoError(t, err) - t.Run("initial nonce should be 0", func(t *testing.T) { - got, err := core.ContractNonce(addr, txn) - require.NoError(t, err) - assert.Equal(t, new(felt.Felt), got) + t.Run("verify initial nonce", func(t *testing.T) { + require.Equal(t, &felt.Zero, contract.Nonce) }) - t.Run("UpdateNonce()", func(t *testing.T) { - require.NoError(t, contract.UpdateNonce(classHash)) - got, err := core.ContractNonce(addr, txn) + t.Run("update contract nonce", func(t *testing.T) { + newNonce := new(felt.Felt).SetUint64(1) + contract.Nonce = newNonce + require.NoError(t, contract.Commit(txn, nil)) + + contract, err = core.GetContract(addr, txn) require.NoError(t, err) - assert.Equal(t, classHash, got) + + require.Equal(t, newNonce, contract.Nonce) }) - t.Run("ClassHash()", func(t *testing.T) { - got, err := core.ContractClassHash(addr, txn) - require.NoError(t, err) - assert.Equal(t, classHash, got) + t.Run("verify initial class hash", func(t *testing.T) { + require.Equal(t, classHash, contract.ClassHash) }) - t.Run("Replace()", func(t *testing.T) { - replaceWith := utils.HexToFelt(t, "0xDEADBEEF") - require.NoError(t, contract.Replace(replaceWith)) - got, err := core.ContractClassHash(addr, txn) + t.Run("update class hash", func(t *testing.T) { + newHash := new(felt.Felt).SetUint64(1) + contract.ClassHash = newHash + require.NoError(t, contract.Commit(txn, nil)) + + contract, err = core.GetContract(addr, txn) require.NoError(t, err) - assert.Equal(t, replaceWith, got) + + require.Equal(t, newHash, contract.ClassHash) }) } -func TestUpdateStorageAndStorage(t *testing.T) { +func TestContractStorage(t *testing.T) { testDB := pebble.NewMemTest(t) txn, err := testDB.NewTransaction(true) require.NoError(t, err) + blockNumber := uint64(10) addr := new(felt.Felt).SetUint64(44) classHash := new(felt.Felt).SetUint64(37) - contract, err := core.DeployContract(addr, classHash, txn) - require.NoError(t, err) + contract := core.NewStateContract(addr, classHash, &felt.Zero, nil, blockNumber) + require.NoError(t, contract.Commit(txn, nil)) + + t.Run("get initial storage", func(t *testing.T) { + gotValue, err := contract.GetStorage(addr, txn) + require.NoError(t, err) + assert.Equal(t, &felt.Zero, gotValue) + }) t.Run("apply storage diff", func(t *testing.T) { - oldRoot, err := core.ContractRoot(addr, txn) + oldRoot, err := contract.StorageRoot(txn) require.NoError(t, err) - require.NoError(t, contract.UpdateStorage(map[felt.Felt]*felt.Felt{*addr: classHash}, NoopOnValueChanged)) + contract.Storage = map[felt.Felt]*felt.Felt{*addr: classHash} + require.NoError(t, contract.Commit(txn, NoopOnValueChanged)) - gotValue, err := core.ContractStorage(addr, addr, txn) + contract, err = core.GetContract(addr, txn) + require.NoError(t, err) + + gotValue, err := contract.GetStorage(addr, txn) require.NoError(t, err) assert.Equal(t, classHash, gotValue) - newRoot, err := core.ContractRoot(addr, txn) + newRoot, err := contract.StorageRoot(txn) require.NoError(t, err) assert.NotEqual(t, oldRoot, newRoot) }) t.Run("delete key from storage with storage diff", func(t *testing.T) { - require.NoError(t, contract.UpdateStorage(map[felt.Felt]*felt.Felt{*addr: new(felt.Felt)}, NoopOnValueChanged)) + contract.Storage[*addr] = new(felt.Felt) + require.NoError(t, contract.Commit(txn, NoopOnValueChanged)) - val, err := core.ContractStorage(addr, addr, txn) + contract, err = core.GetContract(addr, txn) require.NoError(t, err) - require.Equal(t, &felt.Zero, val) - sRoot, err := core.ContractRoot(addr, txn) + gotValue, err := contract.GetStorage(addr, txn) require.NoError(t, err) - assert.Equal(t, new(felt.Felt), sRoot) + assert.Equal(t, &felt.Zero, gotValue) }) } @@ -185,13 +190,14 @@ func TestPurge(t *testing.T) { txn, err := testDB.NewTransaction(true) require.NoError(t, err) + blockNumber := uint64(10) addr := new(felt.Felt).SetUint64(44) classHash := new(felt.Felt).SetUint64(37) - contract, err := core.DeployContract(addr, classHash, txn) - require.NoError(t, err) + contract := core.NewStateContract(addr, classHash, &felt.Zero, nil, blockNumber) + require.NoError(t, contract.Commit(txn, nil)) - require.NoError(t, contract.Purge()) - _, err = core.NewContractUpdater(addr, txn) + require.NoError(t, contract.Purge(txn)) + _, err = core.GetContract(addr, txn) assert.ErrorIs(t, err, core.ErrContractNotDeployed) } diff --git a/core/state.go b/core/state.go index effde8b518..b8681257ad 100644 --- a/core/state.go +++ b/core/state.go @@ -2,20 +2,14 @@ package core import ( "bytes" - "encoding/binary" "errors" "fmt" - "maps" - "runtime" - "slices" - "sort" "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" "github.com/NethermindEth/juno/core/trie" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/encoder" - "github.com/sourcegraph/conc/pool" ) const globalTrieHeight = 251 @@ -58,33 +52,48 @@ func NewState(txn db.Transaction) *State { // putNewContract creates a contract storage instance in the state and stores the relation between contract address and class hash to be // queried later with [GetContractClass]. -func (s *State) putNewContract(stateTrie *trie.Trie, addr, classHash *felt.Felt, blockNumber uint64) error { - contract, err := DeployContract(addr, classHash, s.txn) - if err != nil { - return err - } +// func (s *State) putNewContract(stateTrie *trie.Trie, addr, classHash *felt.Felt, blockNumber uint64) error { +// contract, err := DeployContract(addr, classHash, s.txn) +// if err != nil { +// return err +// } - numBytes := MarshalBlockNumber(blockNumber) - if err = s.txn.Set(db.ContractDeploymentHeight.Key(addr.Marshal()), numBytes); err != nil { - return err - } +// numBytes := MarshalBlockNumber(blockNumber) +// if err = s.txn.Set(db.ContractDeploymentHeight.Key(addr.Marshal()), numBytes); err != nil { +// return err +// } - return s.updateContractCommitment(stateTrie, contract) -} +// return s.updateContractCommitment(stateTrie, contract) +// } // ContractClassHash returns class hash of a contract at a given address. func (s *State) ContractClassHash(addr *felt.Felt) (*felt.Felt, error) { - return ContractClassHash(addr, s.txn) + contract, err := GetContract(addr, s.txn) + if err != nil { + return nil, err + } + + return contract.ClassHash, nil } // ContractNonce returns nonce of a contract at a given address. func (s *State) ContractNonce(addr *felt.Felt) (*felt.Felt, error) { - return ContractNonce(addr, s.txn) + contract, err := GetContract(addr, s.txn) + if err != nil { + return nil, err + } + + return contract.Nonce, nil } // ContractStorage returns value of a key in the storage of the contract at the given address. func (s *State) ContractStorage(addr, key *felt.Felt) (*felt.Felt, error) { - return ContractStorage(addr, key, s.txn) + contract, err := GetContract(addr, s.txn) + if err != nil { + return nil, err + } + + return contract.GetStorage(key, s.txn) } // Root returns the state commitment. @@ -220,17 +229,42 @@ func (s *State) Update(blockNumber uint64, update *StateUpdate, declaredClasses return err } + contracts := make(map[felt.Felt]*StateContract) // register deployed contracts for addr, classHash := range update.StateDiff.DeployedContracts { - if err = s.putNewContract(stateTrie, &addr, classHash, blockNumber); err != nil { + // check if contract is already deployed + _, err := GetContract(&addr, s.txn) + if err == nil { + return ErrContractAlreadyDeployed + } + + if !errors.Is(err, ErrContractNotDeployed) { return err } + + contracts[addr] = NewStateContract(&addr, classHash, &felt.Zero, nil, blockNumber) } - if err = s.updateContracts(stateTrie, blockNumber, update.StateDiff, true); err != nil { + if err = s.updateContracts(blockNumber, update.StateDiff, true, contracts); err != nil { return err } + // TODO(weiihann): handle history + tempOnValChanged := func(location, oldValue *felt.Felt) error { + return nil + } + + // Commit all contract updates + for _, contract := range contracts { + if err = contract.Commit(s.txn, tempOnValChanged); err != nil { + return err + } + + if err := s.updateContractCommitment(stateTrie, contract); err != nil { + return err + } + } + if err = storageCloser(); err != nil { return err } @@ -246,16 +280,29 @@ var ( } ) -func (s *State) updateContracts(stateTrie *trie.Trie, blockNumber uint64, diff *StateDiff, logChanges bool) error { - // replace contract instances +func (s *State) updateContracts(blockNumber uint64, diff *StateDiff, logChanges bool, contracts map[felt.Felt]*StateContract) error { + if contracts == nil { + return fmt.Errorf("contracts is nil") + } + + var err error + + // update contract class hashes for addr, classHash := range diff.ReplacedClasses { - oldClassHash, err := s.replaceContract(stateTrie, &addr, classHash) - if err != nil { - return err + contract, ok := contracts[addr] + if !ok { + contract, err = GetContract(&addr, s.txn) + if err != nil { + return err + } + contracts[addr] = contract } + oldClassHash := contract.ClassHash + contract.ClassHash = classHash + if logChanges { - if err = s.LogContractClassHash(&addr, oldClassHash, blockNumber); err != nil { + if err := s.LogContractClassHash(&addr, oldClassHash, blockNumber); err != nil { return err } } @@ -263,43 +310,45 @@ func (s *State) updateContracts(stateTrie *trie.Trie, blockNumber uint64, diff * // update contract nonces for addr, nonce := range diff.Nonces { - oldNonce, err := s.updateContractNonce(stateTrie, &addr, nonce) - if err != nil { - return err + contract, ok := contracts[addr] + if !ok { + contract, err = GetContract(&addr, s.txn) + if err != nil { + return err + } + contracts[addr] = contract } + oldNonce := contract.Nonce + contract.Nonce = nonce + if logChanges { - if err = s.LogContractNonce(&addr, oldNonce, blockNumber); err != nil { + if err := s.LogContractNonce(&addr, oldNonce, blockNumber); err != nil { return err } } } // update contract storages - return s.updateContractStorages(stateTrie, diff.StorageDiffs, blockNumber, logChanges) -} - -// replaceContract replaces the class that a contract at a given address instantiates -func (s *State) replaceContract(stateTrie *trie.Trie, addr, classHash *felt.Felt) (*felt.Felt, error) { - contract, err := NewContractUpdater(addr, s.txn) - if err != nil { - return nil, err - } - - oldClassHash, err := ContractClassHash(addr, s.txn) - if err != nil { - return nil, err - } - - if err = contract.Replace(classHash); err != nil { - return nil, err - } + for addr, diff := range diff.StorageDiffs { + contract, ok := contracts[addr] + if !ok { + contract, err = GetContract(&addr, s.txn) + if err != nil { + // makes sure that all noClassContracts are deployed + if _, ok := noClassContracts[addr]; ok && errors.Is(err, ErrContractNotDeployed) { + contract = NewStateContract(&addr, noClassContractsClassHash, &felt.Zero, nil, blockNumber) + } else { + return err + } + } + contracts[addr] = contract + } - if err = s.updateContractCommitment(stateTrie, contract); err != nil { - return nil, err + contract.Storage = diff } - return oldClassHash, nil + return nil } type DeclaredClass struct { @@ -342,150 +391,18 @@ func (s *State) Class(classHash *felt.Felt) (*DeclaredClass, error) { return &class, nil } -func (s *State) updateStorageBuffered(contractAddr *felt.Felt, updateDiff map[felt.Felt]*felt.Felt, blockNumber uint64, logChanges bool) ( - *db.BufferedTransaction, error, -) { - // to avoid multiple transactions writing to s.txn, create a buffered transaction and use that in the worker goroutine - bufferedTxn := db.NewBufferedTransaction(s.txn) - bufferedState := NewState(bufferedTxn) - bufferedContract, err := NewContractUpdater(contractAddr, bufferedTxn) - if err != nil { - return nil, err - } - - onValueChanged := func(location, oldValue *felt.Felt) error { - if logChanges { - return bufferedState.LogContractStorage(contractAddr, location, oldValue, blockNumber) - } - return nil - } - - if err = bufferedContract.UpdateStorage(updateDiff, onValueChanged); err != nil { - return nil, err - } - - return bufferedTxn, nil -} - -// updateContractStorage applies the diff set to the Trie of the -// contract at the given address in the given Txn context. -func (s *State) updateContractStorages(stateTrie *trie.Trie, diffs map[felt.Felt]map[felt.Felt]*felt.Felt, - blockNumber uint64, logChanges bool, -) error { - type bufferedTransactionWithAddress struct { - txn *db.BufferedTransaction - addr *felt.Felt - } - - // make sure all noClassContracts are deployed - for addr := range diffs { - if _, ok := noClassContracts[addr]; !ok { - continue - } - - _, err := NewContractUpdater(&addr, s.txn) - if err != nil { - if !errors.Is(err, ErrContractNotDeployed) { - return err - } - // Deploy noClassContract - err = s.putNewContract(stateTrie, &addr, noClassContractsClassHash, blockNumber) - if err != nil { - return err - } - } - } - - // sort the contracts in decending diff size order - // so we start with the heaviest update first - keys := slices.SortedStableFunc(maps.Keys(diffs), func(a, b felt.Felt) int { return len(diffs[a]) - len(diffs[b]) }) - - // update per-contract storage Tries concurrently - contractUpdaters := pool.NewWithResults[*bufferedTransactionWithAddress]().WithErrors().WithMaxGoroutines(runtime.GOMAXPROCS(0)) - for _, key := range keys { - contractAddr := key - contractUpdaters.Go(func() (*bufferedTransactionWithAddress, error) { - bufferedTxn, err := s.updateStorageBuffered(&contractAddr, diffs[contractAddr], blockNumber, logChanges) - if err != nil { - return nil, err - } - return &bufferedTransactionWithAddress{txn: bufferedTxn, addr: &contractAddr}, nil - }) - } - - bufferedTxns, err := contractUpdaters.Wait() - if err != nil { - return err - } - - // we sort bufferedTxns in ascending contract address order to achieve an additional speedup - sort.Slice(bufferedTxns, func(i, j int) bool { - return bufferedTxns[i].addr.Cmp(bufferedTxns[j].addr) < 0 - }) - - // flush buffered txns - for _, txnWithAddress := range bufferedTxns { - if err := txnWithAddress.txn.Flush(); err != nil { - return err - } - } - - for addr := range diffs { - contract, err := NewContractUpdater(&addr, s.txn) - if err != nil { - return err - } - - if err = s.updateContractCommitment(stateTrie, contract); err != nil { - return err - } - } - - return nil -} - -// updateContractNonce updates nonce of the contract at the -// given address in the given Txn context. -func (s *State) updateContractNonce(stateTrie *trie.Trie, addr, nonce *felt.Felt) (*felt.Felt, error) { - contract, err := NewContractUpdater(addr, s.txn) - if err != nil { - return nil, err - } - - oldNonce, err := ContractNonce(addr, s.txn) - if err != nil { - return nil, err - } - - if err = contract.UpdateNonce(nonce); err != nil { - return nil, err - } - - if err = s.updateContractCommitment(stateTrie, contract); err != nil { - return nil, err - } - - return oldNonce, nil -} - // updateContractCommitment recalculates the contract commitment and updates its value in the global state Trie -func (s *State) updateContractCommitment(stateTrie *trie.Trie, contract *ContractUpdater) error { - root, err := ContractRoot(contract.Address, s.txn) - if err != nil { - return err - } - - cHash, err := ContractClassHash(contract.Address, s.txn) - if err != nil { - return err - } - - nonce, err := ContractNonce(contract.Address, s.txn) +func (s *State) updateContractCommitment(stateTrie *trie.Trie, contract *StateContract) error { + rootKey, err := contract.StorageRoot(s.txn) if err != nil { return err } - commitment := calculateContractCommitment(root, cHash, nonce) + commitment := calculateContractCommitment( + rootKey, + contract.ClassHash, + contract.Nonce, + ) _, err = stateTrie.Put(contract.Address, commitment) return err @@ -517,17 +434,15 @@ func (s *State) updateDeclaredClassesTrie(declaredClasses map[felt.Felt]*felt.Fe // ContractIsAlreadyDeployedAt returns if contract at given addr was deployed at blockNumber func (s *State) ContractIsAlreadyDeployedAt(addr *felt.Felt, blockNumber uint64) (bool, error) { - var deployedAt uint64 - if err := s.txn.Get(db.ContractDeploymentHeight.Key(addr.Marshal()), func(bytes []byte) error { - deployedAt = binary.BigEndian.Uint64(bytes) - return nil - }); err != nil { - if errors.Is(err, db.ErrKeyNotFound) { + contract, err := GetContract(addr, s.txn) + if err != nil { + if errors.Is(err, ErrContractNotDeployed) { return false, nil } return false, err } - return deployedAt <= blockNumber, nil + + return contract.DeployHeight <= blockNumber, nil } func (s *State) Revert(blockNumber uint64, update *StateUpdate) error { @@ -540,14 +455,10 @@ func (s *State) Revert(blockNumber uint64, update *StateUpdate) error { return fmt.Errorf("remove declared classes: %v", err) } - reversedDiff, err := s.GetReverseStateDiff(blockNumber, update.StateDiff) + // update contracts + reversedDiff, err := s.buildReverseDiff(blockNumber, update.StateDiff) if err != nil { - return fmt.Errorf("error getting reverse state diff: %v", err) - } - - err = s.performStateDeletions(blockNumber, update.StateDiff) - if err != nil { - return fmt.Errorf("error performing state deletions: %v", err) + return fmt.Errorf("build reverse diff: %v", err) } stateTrie, storageCloser, err := s.storage() @@ -555,7 +466,8 @@ func (s *State) Revert(blockNumber uint64, update *StateUpdate) error { return err } - if err = s.updateContracts(stateTrie, blockNumber, reversedDiff, false); err != nil { + contracts := make(map[felt.Felt]*StateContract) + if err = s.updateContracts(blockNumber, reversedDiff, false, contracts); err != nil { return fmt.Errorf("update contracts: %v", err) } @@ -570,19 +482,13 @@ func (s *State) Revert(blockNumber uint64, update *StateUpdate) error { } } - if err = s.purgeNoClassContracts(); err != nil { - return err - } - - return s.verifyStateUpdateRoot(update.OldRoot) -} - -func (s *State) purgeNoClassContracts() error { + // purge noClassContracts + // // As noClassContracts are not in StateDiff.DeployedContracts we can only purge them if their storage no longer exists. // Updating contracts with reverse diff will eventually lead to the deletion of noClassContract's storage key from db. Thus, // we can use the lack of key's existence as reason for purging noClassContracts. for addr := range noClassContracts { - noClassC, err := NewContractUpdater(&addr, s.txn) + contract, err := GetContract(&addr, s.txn) if err != nil { if !errors.Is(err, ErrContractNotDeployed) { return err @@ -590,18 +496,36 @@ func (s *State) purgeNoClassContracts() error { continue } - r, err := ContractRoot(noClassC.Address, s.txn) + rootKey, err := contract.StorageRoot(s.txn) if err != nil { - return fmt.Errorf("contract root: %v", err) + return fmt.Errorf("get root key: %v", err) } - if r.Equal(&felt.Zero) { + if rootKey.Equal(&felt.Zero) { if err = s.purgeContract(&addr); err != nil { return fmt.Errorf("purge contract: %v", err) } } } - return nil + + // TODO(weiihann): handle this + tempOnValChanged := func(location, oldValue *felt.Felt) error { + return nil + } + + // TODO(weiihann): make concurrent + // Commit the changes to the contracts and update their commitments + for _, contract := range contracts { + if err = contract.Commit(s.txn, tempOnValChanged); err != nil { + return fmt.Errorf("commit contract: %v", err) + } + + if err = s.updateContractCommitment(stateTrie, contract); err != nil { + return fmt.Errorf("update contract commitment: %v", err) + } + } + + return s.verifyStateUpdateRoot(update.OldRoot) } func (s *State) removeDeclaredClasses(blockNumber uint64, v0Classes []*felt.Felt, v1Classes map[felt.Felt]*felt.Felt) error { @@ -640,32 +564,28 @@ func (s *State) removeDeclaredClasses(blockNumber uint64, v0Classes []*felt.Felt } func (s *State) purgeContract(addr *felt.Felt) error { - contract, err := NewContractUpdater(addr, s.txn) + contract, err := GetContract(addr, s.txn) if err != nil { return err } - state, storageCloser, err := s.storage() + stateTrie, storageCloser, err := s.storage() if err != nil { return err } - if err = s.txn.Delete(db.ContractDeploymentHeight.Key(addr.Marshal())); err != nil { + if _, err = stateTrie.Put(contract.Address, &felt.Zero); err != nil { return err } - if _, err = state.Put(contract.Address, &felt.Zero); err != nil { - return err - } - - if err = contract.Purge(); err != nil { + if err = contract.Purge(s.txn); err != nil { return err } return storageCloser() } -func (s *State) GetReverseStateDiff(blockNumber uint64, diff *StateDiff) (*StateDiff, error) { +func (s *State) buildReverseDiff(blockNumber uint64, diff *StateDiff) (*StateDiff, error) { reversed := *diff // storage diffs @@ -681,6 +601,10 @@ func (s *State) GetReverseStateDiff(blockNumber uint64, diff *StateDiff) (*State } value = oldValue } + + if err := s.DeleteContractStorageLog(&addr, &key, blockNumber); err != nil { + return nil, err + } reversedDiffs[key] = value } reversed.StorageDiffs[addr] = reversedDiffs @@ -690,6 +614,7 @@ func (s *State) GetReverseStateDiff(blockNumber uint64, diff *StateDiff) (*State reversed.Nonces = make(map[felt.Felt]*felt.Felt, len(diff.Nonces)) for addr := range diff.Nonces { oldNonce := &felt.Zero + if blockNumber > 0 { var err error oldNonce, err = s.ContractNonceAt(&addr, blockNumber-1) @@ -697,6 +622,10 @@ func (s *State) GetReverseStateDiff(blockNumber uint64, diff *StateDiff) (*State return nil, err } } + + if err := s.DeleteContractNonceLog(&addr, blockNumber); err != nil { + return nil, err + } reversed.Nonces[addr] = oldNonce } @@ -711,35 +640,12 @@ func (s *State) GetReverseStateDiff(blockNumber uint64, diff *StateDiff) (*State return nil, err } } - reversed.ReplacedClasses[addr] = classHash - } - - return &reversed, nil -} - -func (s *State) performStateDeletions(blockNumber uint64, diff *StateDiff) error { - // storage diffs - for addr, storageDiffs := range diff.StorageDiffs { - for key := range storageDiffs { - if err := s.DeleteContractStorageLog(&addr, &key, blockNumber); err != nil { - return err - } - } - } - // nonces - for addr := range diff.Nonces { - if err := s.DeleteContractNonceLog(&addr, blockNumber); err != nil { - return err - } - } - - // replaced classes - for addr := range diff.ReplacedClasses { if err := s.DeleteContractClassHashLog(&addr, blockNumber); err != nil { - return err + return nil, err } + reversed.ReplacedClasses[addr] = classHash } - return nil + return &reversed, nil } diff --git a/db/buckets.go b/db/buckets.go index 3918eb5f29..b36419bf36 100644 --- a/db/buckets.go +++ b/db/buckets.go @@ -9,30 +9,39 @@ type Bucket byte // keys like Bolt or MDBX does. We use a global prefix list as a poor // man's bucket alternative. const ( - StateTrie Bucket = iota // state metadata (e.g., the state root) - Peer // maps peer ID to peer multiaddresses - ContractClassHash // maps contract addresses and class hashes - ContractStorage // contract storages - Class // maps class hashes to classes - ContractNonce // contract nonce - ChainHeight // Latest height of the blockchain - BlockHeaderNumbersByHash - BlockHeadersByNumber - TransactionBlockNumbersAndIndicesByHash // maps transaction hashes to block number and index - TransactionsByBlockNumberAndIndex // maps block number and index to transaction - ReceiptsByBlockNumberAndIndex // maps block number and index to transaction receipt - StateUpdatesByBlockNumber + // StateTrie -> Latest state trie's root key + // StateTrie + ContractAddr -> Contract's commitment value + // StateTrie + ContractAddr + Trie node path -> Trie node value + StateTrie Bucket = iota + Peer // Peer + PeerID bytes -> Encoded peer multiaddresses + ContractClassHash // (Legacy) ContractClassHash + ContractAddr -> Contract's class hash value + // ContractStorage + ContractAddr -> Latest contract storage trie's root key + // ContractStorage + ContractAddr + Trie node path -> Trie node value + ContractStorage + Class // Class + Class hash -> Class object + ContractNonce // (Legacy) ContractNonce + ContractAddr -> Contract's nonce value + ChainHeight // ChainHeight -> Latest height of the blockchain + BlockHeaderNumbersByHash // BlockHeaderNumbersByHash + BlockHash -> Block number + BlockHeadersByNumber // BlockHeadersByNumber + BlockNumber -> Block header object + TransactionBlockNumbersAndIndicesByHash // TransactionBlockNumbersAndIndicesByHash + TransactionHash -> Encoded(BlockNumber, Index) + TransactionsByBlockNumberAndIndex // TransactionsByBlockNumberAndIndex + Encoded(BlockNumber, Index) -> Encoded(Transaction) + ReceiptsByBlockNumberAndIndex // ReceiptsByBlockNumberAndIndex + Encoded(BlockNumber, Index) -> Encoded(Receipt) + StateUpdatesByBlockNumber // StateUpdatesByBlockNumber + BlockNumber -> Encoded(StateUpdate) + // ClassesTrie -> Latest classes trie's root key + // ClassesTrie + ClassHash -> PoseidonHash(leafVersion, compiledClassHash) ClassesTrie - ContractStorageHistory - ContractNonceHistory - ContractClassHashHistory - ContractDeploymentHeight - L1Height - SchemaVersion - Pending - BlockCommitments - Temporary // used temporarily for migrations - SchemaIntermediateState + ContractStorageHistory // (Legacy) ContractStorageHistory + ContractAddr + BlockHeight + StorageLocation -> StorageValue + ContractNonceHistory // (Legacy) ContractNonceHistory + ContractAddr + BlockHeight -> Contract's nonce value + ContractClassHashHistory // (Legacy) ContractClassHashHistory + ContractAddr + BlockHeight -> Contract's class hash value + ContractDeploymentHeight // ContractDeploymentHeight + ContractAddr -> BlockHeight + L1Height // L1Height -> Latest height of the L1 chain + SchemaVersion // SchemaVersion -> DB schema version + Pending // Pending -> Pending block + BlockCommitments // BlockCommitments + BlockNumber -> Block commitments + Temporary // used temporarily for migrations + SchemaIntermediateState // used for db schema metadata + Contract // Contract + ContractAddr -> Encoded(Contract) + ContractHistory // ContractHistory + ContractAddr + BlockHeight -> Encoded(Contract) ) // Key flattens a prefix and series of byte arrays into a single []byte. From 14123ea83be6c2f189b31af874c93714d5239493 Mon Sep 17 00:00:00 2001 From: weiihann Date: Wed, 16 Oct 2024 19:36:13 +0800 Subject: [PATCH 02/13] pass state tests --- core/contract.go | 72 ++++++++++++++------ core/contract_test.go | 33 ++++------ core/state.go | 148 +++++++++++++++++++++++++----------------- core/state_test.go | 17 ++++- db/buckets.go | 8 +-- 5 files changed, 172 insertions(+), 106 deletions(-) diff --git a/core/contract.go b/core/contract.go index 178f5c9cb1..d9dcca58ce 100644 --- a/core/contract.go +++ b/core/contract.go @@ -29,27 +29,22 @@ type StateContract struct { DeployHeight uint64 // Address that this contract instance is deployed to Address *felt.Felt `cbor:"-"` - // Storage is the contract's storage - Storage map[felt.Felt]*felt.Felt `cbor:"-"` + // dirtyStorage is a map of storage locations that have been updated + dirtyStorage map[felt.Felt]*felt.Felt `cbor:"-"` } func NewStateContract( addr *felt.Felt, classHash *felt.Felt, nonce *felt.Felt, - storage map[felt.Felt]*felt.Felt, - DeployHeight uint64, + deployHeight uint64, ) *StateContract { sc := &StateContract{ Address: addr, ClassHash: classHash, Nonce: nonce, - Storage: storage, - DeployHeight: DeployHeight, - } - - if storage == nil { - sc.Storage = make(map[felt.Felt]*felt.Felt) + DeployHeight: deployHeight, + dirtyStorage: make(map[felt.Felt]*felt.Felt), } return sc @@ -64,9 +59,17 @@ func (c *StateContract) StorageRoot(txn db.Transaction) (*felt.Felt, error) { return storageTrie.Root() } +func (c *StateContract) UpdateStorage(key *felt.Felt, value *felt.Felt) { + if c.dirtyStorage == nil { + c.dirtyStorage = make(map[felt.Felt]*felt.Felt) + } + + c.dirtyStorage[*key] = value +} + func (c *StateContract) GetStorage(key *felt.Felt, txn db.Transaction) (*felt.Felt, error) { - if c.Storage != nil { - if val, ok := c.Storage[*key]; ok { + if c.dirtyStorage != nil { + if val, ok := c.dirtyStorage[*key]; ok { return val, nil } } @@ -80,20 +83,39 @@ func (c *StateContract) GetStorage(key *felt.Felt, txn db.Transaction) (*felt.Fe return storage.Get(key) } -func (c *StateContract) Commit(txn db.Transaction, cb OnValueChanged) error { +func (c *StateContract) logOldValue(key []byte, oldValue *felt.Felt, height uint64, txn db.Transaction) error { + return txn.Set(logDBKey(key, height), oldValue.Marshal()) +} + +func (c *StateContract) logStorage(location, oldVal *felt.Felt, height uint64, txn db.Transaction) error { + key := storageLogKey(c.Address, location) + return c.logOldValue(key, oldVal, height, txn) +} + +func (c *StateContract) logNonce(height uint64, txn db.Transaction) error { + key := nonceLogKey(c.Address) + return c.logOldValue(key, c.Nonce, height, txn) +} + +func (c *StateContract) logClassHash(height uint64, txn db.Transaction) error { + key := classHashLogKey(c.Address) + return c.logOldValue(key, c.ClassHash, height, txn) +} + +func (c *StateContract) Commit(txn db.Transaction, logChanges bool, blockNum uint64) error { storageTrie, err := storage(c.Address, txn) if err != nil { return err } - for key, value := range c.Storage { + for key, value := range c.dirtyStorage { oldVal, err := storageTrie.Put(&key, value) if err != nil { return err } - if oldVal != nil { - if err = cb(&key, oldVal); err != nil { + if oldVal != nil && logChanges { + if err = c.logStorage(&key, oldVal, blockNum, txn); err != nil { return err } } @@ -116,11 +138,19 @@ func (c *StateContract) Commit(txn db.Transaction, cb OnValueChanged) error { func (c *StateContract) Purge(txn db.Transaction) error { addrBytes := c.Address.Marshal() - if err := txn.Delete(db.Contract.Key(addrBytes)); err != nil { - return err - } + return txn.Delete(db.Contract.Key(addrBytes)) +} + +func storageLogKey(contractAddress, storageLocation *felt.Felt) []byte { + return db.ContractStorageHistory.Key(contractAddress.Marshal(), storageLocation.Marshal()) +} + +func nonceLogKey(contractAddress *felt.Felt) []byte { + return db.ContractNonceHistory.Key(contractAddress.Marshal()) +} - return txn.Delete(db.ContractDeploymentHeight.Key(addrBytes)) +func classHashLogKey(contractAddress *felt.Felt) []byte { + return db.ContractClassHashHistory.Key(contractAddress.Marshal()) } // GetContract is a wrapper around getContract which checks if a contract is deployed @@ -146,7 +176,7 @@ func getContract(addr *felt.Felt, txn db.Transaction) (*StateContract, error) { } contract.Address = addr - contract.Storage = make(map[felt.Felt]*felt.Felt) + contract.dirtyStorage = make(map[felt.Felt]*felt.Felt) return nil }); err != nil { diff --git a/core/contract_test.go b/core/contract_test.go index 3a32f29915..70a4608554 100644 --- a/core/contract_test.go +++ b/core/contract_test.go @@ -11,10 +11,6 @@ import ( "github.com/stretchr/testify/require" ) -var NoopOnValueChanged = func(location, oldValue *felt.Felt) error { - return nil -} - func TestContractAddress(t *testing.T) { tests := []struct { callerAddress *felt.Felt @@ -71,8 +67,8 @@ func TestNewContract(t *testing.T) { var contract *core.StateContract t.Run("commit contract", func(t *testing.T) { - contract = core.NewStateContract(addr, classHash, &felt.Zero, nil, blockNumber) - require.NoError(t, contract.Commit(txn, nil)) + contract = core.NewStateContract(addr, classHash, &felt.Zero, blockNumber) + require.NoError(t, contract.Commit(txn, true, blockNumber)) }) t.Run("get contract from db", func(t *testing.T) { @@ -84,7 +80,6 @@ func TestNewContract(t *testing.T) { assert.Equal(t, addr, contract.Address) assert.Equal(t, classHash, contract.ClassHash) assert.Equal(t, &felt.Zero, contract.Nonce) - assert.Empty(t, contract.Storage) assert.Equal(t, blockNumber, contract.DeployHeight) }) } @@ -98,8 +93,8 @@ func TestUpdateContract(t *testing.T) { addr := new(felt.Felt).SetUint64(44) classHash := new(felt.Felt).SetUint64(37) - contract := core.NewStateContract(addr, classHash, &felt.Zero, nil, blockNumber) - require.NoError(t, contract.Commit(txn, nil)) + contract := core.NewStateContract(addr, classHash, &felt.Zero, blockNumber) + require.NoError(t, contract.Commit(txn, true, blockNumber)) contract, err = core.GetContract(addr, txn) require.NoError(t, err) @@ -111,7 +106,7 @@ func TestUpdateContract(t *testing.T) { t.Run("update contract nonce", func(t *testing.T) { newNonce := new(felt.Felt).SetUint64(1) contract.Nonce = newNonce - require.NoError(t, contract.Commit(txn, nil)) + require.NoError(t, contract.Commit(txn, true, blockNumber)) contract, err = core.GetContract(addr, txn) require.NoError(t, err) @@ -126,7 +121,7 @@ func TestUpdateContract(t *testing.T) { t.Run("update class hash", func(t *testing.T) { newHash := new(felt.Felt).SetUint64(1) contract.ClassHash = newHash - require.NoError(t, contract.Commit(txn, nil)) + require.NoError(t, contract.Commit(txn, true, blockNumber)) contract, err = core.GetContract(addr, txn) require.NoError(t, err) @@ -144,8 +139,8 @@ func TestContractStorage(t *testing.T) { addr := new(felt.Felt).SetUint64(44) classHash := new(felt.Felt).SetUint64(37) - contract := core.NewStateContract(addr, classHash, &felt.Zero, nil, blockNumber) - require.NoError(t, contract.Commit(txn, nil)) + contract := core.NewStateContract(addr, classHash, &felt.Zero, blockNumber) + require.NoError(t, contract.Commit(txn, true, blockNumber)) t.Run("get initial storage", func(t *testing.T) { gotValue, err := contract.GetStorage(addr, txn) @@ -157,8 +152,8 @@ func TestContractStorage(t *testing.T) { oldRoot, err := contract.StorageRoot(txn) require.NoError(t, err) - contract.Storage = map[felt.Felt]*felt.Felt{*addr: classHash} - require.NoError(t, contract.Commit(txn, NoopOnValueChanged)) + contract.UpdateStorage(addr, classHash) + require.NoError(t, contract.Commit(txn, false, blockNumber)) contract, err = core.GetContract(addr, txn) require.NoError(t, err) @@ -173,8 +168,8 @@ func TestContractStorage(t *testing.T) { }) t.Run("delete key from storage with storage diff", func(t *testing.T) { - contract.Storage[*addr] = new(felt.Felt) - require.NoError(t, contract.Commit(txn, NoopOnValueChanged)) + contract.UpdateStorage(addr, new(felt.Felt)) + require.NoError(t, contract.Commit(txn, false, blockNumber)) contract, err = core.GetContract(addr, txn) require.NoError(t, err) @@ -194,8 +189,8 @@ func TestPurge(t *testing.T) { addr := new(felt.Felt).SetUint64(44) classHash := new(felt.Felt).SetUint64(37) - contract := core.NewStateContract(addr, classHash, &felt.Zero, nil, blockNumber) - require.NoError(t, contract.Commit(txn, nil)) + contract := core.NewStateContract(addr, classHash, &felt.Zero, blockNumber) + require.NoError(t, contract.Commit(txn, false, blockNumber)) require.NoError(t, contract.Purge(txn)) _, err = core.GetContract(addr, txn) diff --git a/core/state.go b/core/state.go index b8681257ad..822f11aec1 100644 --- a/core/state.go +++ b/core/state.go @@ -2,6 +2,7 @@ package core import ( "bytes" + "encoding/binary" "errors" "fmt" @@ -10,6 +11,7 @@ import ( "github.com/NethermindEth/juno/core/trie" "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/encoder" + "github.com/NethermindEth/juno/utils" ) const globalTrieHeight = 251 @@ -50,22 +52,6 @@ func NewState(txn db.Transaction) *State { } } -// putNewContract creates a contract storage instance in the state and stores the relation between contract address and class hash to be -// queried later with [GetContractClass]. -// func (s *State) putNewContract(stateTrie *trie.Trie, addr, classHash *felt.Felt, blockNumber uint64) error { -// contract, err := DeployContract(addr, classHash, s.txn) -// if err != nil { -// return err -// } - -// numBytes := MarshalBlockNumber(blockNumber) -// if err = s.txn.Set(db.ContractDeploymentHeight.Key(addr.Marshal()), numBytes); err != nil { -// return err -// } - -// return s.updateContractCommitment(stateTrie, contract) -// } - // ContractClassHash returns class hash of a contract at a given address. func (s *State) ContractClassHash(addr *felt.Felt) (*felt.Felt, error) { contract, err := GetContract(addr, s.txn) @@ -96,6 +82,63 @@ func (s *State) ContractStorage(addr, key *felt.Felt) (*felt.Felt, error) { return contract.GetStorage(key, s.txn) } +func (s *State) ContractClassHashAt(addr *felt.Felt, blockNumber uint64) (*felt.Felt, error) { + return s.contractValueAt(classHashLogKey, addr, blockNumber) +} + +func (s *State) ContractStorageAt(addr, loc *felt.Felt, blockNumber uint64) (*felt.Felt, error) { + return s.contractValueAt(func(a *felt.Felt) []byte { return storageLogKey(a, loc) }, addr, blockNumber) +} + +func (s *State) ContractNonceAt(addr *felt.Felt, blockNumber uint64) (*felt.Felt, error) { + return s.contractValueAt(nonceLogKey, addr, blockNumber) +} + +func (s *State) contractValueAt(keyFunc func(*felt.Felt) []byte, addr *felt.Felt, blockNumber uint64) (*felt.Felt, error) { + key := keyFunc(addr) + value, err := s.valueAt(key, blockNumber) + if err != nil { + return nil, err + } + + return new(felt.Felt).SetBytes(value), nil +} + +func (s *State) valueAt(key []byte, height uint64) ([]byte, error) { + it, err := s.txn.NewIterator() + if err != nil { + return nil, err + } + + for it.Seek(logDBKey(key, height)); it.Valid(); it.Next() { + seekedKey := it.Key() + // seekedKey size should be `len(key) + sizeof(uint64)` and seekedKey should match key prefix + if len(seekedKey) != len(key)+8 || !bytes.HasPrefix(seekedKey, key) { + break + } + + seekedHeight := binary.BigEndian.Uint64(seekedKey[len(key):]) + if seekedHeight < height { + // last change happened before the height we are looking for + // check head state + break + } else if seekedHeight == height { + // a log exists for the height we are looking for, so the old value in this log entry is not useful. + // advance the iterator and see we can use the next entry. If not, ErrCheckHeadState will be returned + continue + } + + val, itErr := it.Value() + if err = utils.RunAndWrapOnError(it.Close, itErr); err != nil { + return nil, err + } + // seekedHeight > height + return val, nil + } + + return nil, utils.RunAndWrapOnError(it.Close, ErrCheckHeadState) +} + // Root returns the state commitment. func (s *State) Root() (*felt.Felt, error) { var storageRoot, classesRoot *felt.Felt @@ -242,21 +285,16 @@ func (s *State) Update(blockNumber uint64, update *StateUpdate, declaredClasses return err } - contracts[addr] = NewStateContract(&addr, classHash, &felt.Zero, nil, blockNumber) + contracts[addr] = NewStateContract(&addr, classHash, &felt.Zero, blockNumber) } if err = s.updateContracts(blockNumber, update.StateDiff, true, contracts); err != nil { return err } - // TODO(weiihann): handle history - tempOnValChanged := func(location, oldValue *felt.Felt) error { - return nil - } - // Commit all contract updates for _, contract := range contracts { - if err = contract.Commit(s.txn, tempOnValChanged); err != nil { + if err = contract.Commit(s.txn, true, blockNumber); err != nil { return err } @@ -298,14 +336,13 @@ func (s *State) updateContracts(blockNumber uint64, diff *StateDiff, logChanges contracts[addr] = contract } - oldClassHash := contract.ClassHash - contract.ClassHash = classHash - if logChanges { - if err := s.LogContractClassHash(&addr, oldClassHash, blockNumber); err != nil { + if err := contract.logClassHash(blockNumber, s.txn); err != nil { return err } } + + contract.ClassHash = classHash } // update contract nonces @@ -319,14 +356,13 @@ func (s *State) updateContracts(blockNumber uint64, diff *StateDiff, logChanges contracts[addr] = contract } - oldNonce := contract.Nonce - contract.Nonce = nonce - if logChanges { - if err := s.LogContractNonce(&addr, oldNonce, blockNumber); err != nil { + if err := contract.logNonce(blockNumber, s.txn); err != nil { return err } } + + contract.Nonce = nonce } // update contract storages @@ -337,7 +373,7 @@ func (s *State) updateContracts(blockNumber uint64, diff *StateDiff, logChanges if err != nil { // makes sure that all noClassContracts are deployed if _, ok := noClassContracts[addr]; ok && errors.Is(err, ErrContractNotDeployed) { - contract = NewStateContract(&addr, noClassContractsClassHash, &felt.Zero, nil, blockNumber) + contract = NewStateContract(&addr, noClassContractsClassHash, &felt.Zero, blockNumber) } else { return err } @@ -345,7 +381,7 @@ func (s *State) updateContracts(blockNumber uint64, diff *StateDiff, logChanges contracts[addr] = contract } - contract.Storage = diff + contract.dirtyStorage = diff } return nil @@ -471,13 +507,21 @@ func (s *State) Revert(blockNumber uint64, update *StateUpdate) error { return fmt.Errorf("update contracts: %v", err) } - if err = storageCloser(); err != nil { - return err + // TODO(weiihann): make concurrent + // Commit the changes to the contracts and update their commitments + for _, contract := range contracts { + if err = contract.Commit(s.txn, false, blockNumber); err != nil { + return fmt.Errorf("commit contract: %v", err) + } + + if err = s.updateContractCommitment(stateTrie, contract); err != nil { + return fmt.Errorf("update contract commitment: %v", err) + } } // purge deployed contracts for addr := range update.StateDiff.DeployedContracts { - if err = s.purgeContract(&addr); err != nil { + if err = s.purgeContract(stateTrie, &addr); err != nil { return fmt.Errorf("purge contract: %v", err) } } @@ -502,27 +546,14 @@ func (s *State) Revert(blockNumber uint64, update *StateUpdate) error { } if rootKey.Equal(&felt.Zero) { - if err = s.purgeContract(&addr); err != nil { + if err = s.purgeContract(stateTrie, &addr); err != nil { return fmt.Errorf("purge contract: %v", err) } } } - // TODO(weiihann): handle this - tempOnValChanged := func(location, oldValue *felt.Felt) error { - return nil - } - - // TODO(weiihann): make concurrent - // Commit the changes to the contracts and update their commitments - for _, contract := range contracts { - if err = contract.Commit(s.txn, tempOnValChanged); err != nil { - return fmt.Errorf("commit contract: %v", err) - } - - if err = s.updateContractCommitment(stateTrie, contract); err != nil { - return fmt.Errorf("update contract commitment: %v", err) - } + if err = storageCloser(); err != nil { + return err } return s.verifyStateUpdateRoot(update.OldRoot) @@ -563,17 +594,12 @@ func (s *State) removeDeclaredClasses(blockNumber uint64, v0Classes []*felt.Felt return classesCloser() } -func (s *State) purgeContract(addr *felt.Felt) error { +func (s *State) purgeContract(stateTrie *trie.Trie, addr *felt.Felt) error { contract, err := GetContract(addr, s.txn) if err != nil { return err } - stateTrie, storageCloser, err := s.storage() - if err != nil { - return err - } - if _, err = stateTrie.Put(contract.Address, &felt.Zero); err != nil { return err } @@ -582,7 +608,7 @@ func (s *State) purgeContract(addr *felt.Felt) error { return err } - return storageCloser() + return nil } func (s *State) buildReverseDiff(blockNumber uint64, diff *StateDiff) (*StateDiff, error) { @@ -649,3 +675,7 @@ func (s *State) buildReverseDiff(blockNumber uint64, diff *StateDiff) (*StateDif return &reversed, nil } + +func logDBKey(key []byte, height uint64) []byte { + return binary.BigEndian.AppendUint64(key, height) +} diff --git a/core/state_test.go b/core/state_test.go index 6b96d64b3b..454bf01dd5 100644 --- a/core/state_test.go +++ b/core/state_test.go @@ -41,6 +41,7 @@ func TestMain(m *testing.M) { _ = encoder.RegisterType(reflect.TypeOf(core.Cairo0Class{})) _ = encoder.RegisterType(reflect.TypeOf(core.Cairo1Class{})) + _ = encoder.RegisterType(reflect.TypeOf(core.StateContract{})) code := m.Run() @@ -440,17 +441,22 @@ func TestRevert(t *testing.T) { require.NoError(t, state.Update(1, su1, nil)) t.Run("revert a replaced class", func(t *testing.T) { + replacedVal := utils.HexToFelt(t, "0xDEADBEEF") replaceStateUpdate := &core.StateUpdate{ NewRoot: utils.HexToFelt(t, "0x30b1741b28893b892ac30350e6372eac3a6f32edee12f9cdca7fbe7540a5ee"), OldRoot: su1.NewRoot, StateDiff: &core.StateDiff{ ReplacedClasses: map[felt.Felt]*felt.Felt{ - su1FirstDeployedAddress: utils.HexToFelt(t, "0xDEADBEEF"), + su1FirstDeployedAddress: replacedVal, }, }, } require.NoError(t, state.Update(2, replaceStateUpdate, nil)) + classHash, err := state.ContractClassHash(new(felt.Felt).Set(&su1FirstDeployedAddress)) + require.NoError(t, err) + assert.Equal(t, replacedVal, classHash) + require.NoError(t, state.Revert(2, replaceStateUpdate)) classHash, sErr := state.ContractClassHash(new(felt.Felt).Set(&su1FirstDeployedAddress)) require.NoError(t, sErr) @@ -458,20 +464,25 @@ func TestRevert(t *testing.T) { }) t.Run("revert a nonce update", func(t *testing.T) { + replacedVal := utils.HexToFelt(t, "0xDEADBEEF") nonceStateUpdate := &core.StateUpdate{ NewRoot: utils.HexToFelt(t, "0x6683657d2b6797d95f318e7c6091dc2255de86b72023c15b620af12543eb62c"), OldRoot: su1.NewRoot, StateDiff: &core.StateDiff{ Nonces: map[felt.Felt]*felt.Felt{ - su1FirstDeployedAddress: utils.HexToFelt(t, "0xDEADBEEF"), + su1FirstDeployedAddress: replacedVal, }, }, } require.NoError(t, state.Update(2, nonceStateUpdate, nil)) - require.NoError(t, state.Revert(2, nonceStateUpdate)) nonce, sErr := state.ContractNonce(new(felt.Felt).Set(&su1FirstDeployedAddress)) require.NoError(t, sErr) + assert.Equal(t, replacedVal, nonce) + + require.NoError(t, state.Revert(2, nonceStateUpdate)) + nonce, sErr = state.ContractNonce(new(felt.Felt).Set(&su1FirstDeployedAddress)) + require.NoError(t, sErr) assert.Equal(t, &felt.Zero, nonce) }) diff --git a/db/buckets.go b/db/buckets.go index b36419bf36..211edc65ea 100644 --- a/db/buckets.go +++ b/db/buckets.go @@ -30,10 +30,10 @@ const ( // ClassesTrie -> Latest classes trie's root key // ClassesTrie + ClassHash -> PoseidonHash(leafVersion, compiledClassHash) ClassesTrie - ContractStorageHistory // (Legacy) ContractStorageHistory + ContractAddr + BlockHeight + StorageLocation -> StorageValue - ContractNonceHistory // (Legacy) ContractNonceHistory + ContractAddr + BlockHeight -> Contract's nonce value - ContractClassHashHistory // (Legacy) ContractClassHashHistory + ContractAddr + BlockHeight -> Contract's class hash value - ContractDeploymentHeight // ContractDeploymentHeight + ContractAddr -> BlockHeight + ContractStorageHistory // ContractStorageHistory + ContractAddr + BlockHeight + StorageLocation -> StorageValue + ContractNonceHistory // ContractNonceHistory + ContractAddr + BlockHeight -> Contract's nonce value + ContractClassHashHistory // ContractClassHashHistory + ContractAddr + BlockHeight -> Contract's class hash value + ContractDeploymentHeight // (Legacy) ContractDeploymentHeight + ContractAddr -> BlockHeight L1Height // L1Height -> Latest height of the L1 chain SchemaVersion // SchemaVersion -> DB schema version Pending // Pending -> Pending block From f7677a1d7e461903a9719cf6f2081339ac97cf19 Mon Sep 17 00:00:00 2001 From: weiihann Date: Wed, 16 Oct 2024 20:31:09 +0800 Subject: [PATCH 03/13] remove history --- core/contract.go | 8 +-- core/history.go | 134 --------------------------------------- core/history_pkg_test.go | 104 ------------------------------ core/state.go | 29 +++++++-- core/state_test.go | 127 +++++++++++++++++++++++++++++++++++++ 5 files changed, 153 insertions(+), 249 deletions(-) delete mode 100644 core/history.go delete mode 100644 core/history_pkg_test.go diff --git a/core/contract.go b/core/contract.go index d9dcca58ce..7a02f2d364 100644 --- a/core/contract.go +++ b/core/contract.go @@ -87,17 +87,17 @@ func (c *StateContract) logOldValue(key []byte, oldValue *felt.Felt, height uint return txn.Set(logDBKey(key, height), oldValue.Marshal()) } -func (c *StateContract) logStorage(location, oldVal *felt.Felt, height uint64, txn db.Transaction) error { +func (c *StateContract) LogStorage(location, oldVal *felt.Felt, height uint64, txn db.Transaction) error { key := storageLogKey(c.Address, location) return c.logOldValue(key, oldVal, height, txn) } -func (c *StateContract) logNonce(height uint64, txn db.Transaction) error { +func (c *StateContract) LogNonce(height uint64, txn db.Transaction) error { key := nonceLogKey(c.Address) return c.logOldValue(key, c.Nonce, height, txn) } -func (c *StateContract) logClassHash(height uint64, txn db.Transaction) error { +func (c *StateContract) LogClassHash(height uint64, txn db.Transaction) error { key := classHashLogKey(c.Address) return c.logOldValue(key, c.ClassHash, height, txn) } @@ -115,7 +115,7 @@ func (c *StateContract) Commit(txn db.Transaction, logChanges bool, blockNum uin } if oldVal != nil && logChanges { - if err = c.logStorage(&key, oldVal, blockNum, txn); err != nil { + if err = c.LogStorage(&key, oldVal, blockNum, txn); err != nil { return err } } diff --git a/core/history.go b/core/history.go deleted file mode 100644 index f14db1702e..0000000000 --- a/core/history.go +++ /dev/null @@ -1,134 +0,0 @@ -package core - -import ( - "bytes" - "encoding/binary" - "errors" - - "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/db" - "github.com/NethermindEth/juno/utils" -) - -var ErrCheckHeadState = errors.New("check head state") - -type history struct { - txn db.Transaction -} - -func logDBKey(key []byte, height uint64) []byte { - return binary.BigEndian.AppendUint64(key, height) -} - -func (h *history) logOldValue(key, value []byte, height uint64) error { - return h.txn.Set(logDBKey(key, height), value) -} - -func (h *history) deleteLog(key []byte, height uint64) error { - return h.txn.Delete(logDBKey(key, height)) -} - -func (h *history) valueAt(key []byte, height uint64) ([]byte, error) { - it, err := h.txn.NewIterator() - if err != nil { - return nil, err - } - - for it.Seek(logDBKey(key, height)); it.Valid(); it.Next() { - seekedKey := it.Key() - // seekedKey size should be `len(key) + sizeof(uint64)` and seekedKey should match key prefix - if len(seekedKey) != len(key)+8 || !bytes.HasPrefix(seekedKey, key) { - break - } - - seekedHeight := binary.BigEndian.Uint64(seekedKey[len(key):]) - if seekedHeight < height { - // last change happened before the height we are looking for - // check head state - break - } else if seekedHeight == height { - // a log exists for the height we are looking for, so the old value in this log entry is not useful. - // advance the iterator and see we can use the next entry. If not, ErrCheckHeadState will be returned - continue - } - - val, itErr := it.Value() - if err = utils.RunAndWrapOnError(it.Close, itErr); err != nil { - return nil, err - } - // seekedHeight > height - return val, nil - } - - return nil, utils.RunAndWrapOnError(it.Close, ErrCheckHeadState) -} - -func storageLogKey(contractAddress, storageLocation *felt.Felt) []byte { - return db.ContractStorageHistory.Key(contractAddress.Marshal(), storageLocation.Marshal()) -} - -// LogContractStorage logs the old value of a storage location for the given contract which changed on height `height` -func (h *history) LogContractStorage(contractAddress, storageLocation, oldValue *felt.Felt, height uint64) error { - key := storageLogKey(contractAddress, storageLocation) - return h.logOldValue(key, oldValue.Marshal(), height) -} - -// DeleteContractStorageLog deletes the log at the given height -func (h *history) DeleteContractStorageLog(contractAddress, storageLocation *felt.Felt, height uint64) error { - return h.deleteLog(storageLogKey(contractAddress, storageLocation), height) -} - -// ContractStorageAt returns the value of a storage location of the given contract at the height `height` -func (h *history) ContractStorageAt(contractAddress, storageLocation *felt.Felt, height uint64) (*felt.Felt, error) { - key := storageLogKey(contractAddress, storageLocation) - value, err := h.valueAt(key, height) - if err != nil { - return nil, err - } - - return new(felt.Felt).SetBytes(value), nil -} - -func nonceLogKey(contractAddress *felt.Felt) []byte { - return db.ContractNonceHistory.Key(contractAddress.Marshal()) -} - -func (h *history) LogContractNonce(contractAddress, oldValue *felt.Felt, height uint64) error { - return h.logOldValue(nonceLogKey(contractAddress), oldValue.Marshal(), height) -} - -func (h *history) DeleteContractNonceLog(contractAddress *felt.Felt, height uint64) error { - return h.deleteLog(nonceLogKey(contractAddress), height) -} - -func (h *history) ContractNonceAt(contractAddress *felt.Felt, height uint64) (*felt.Felt, error) { - key := nonceLogKey(contractAddress) - value, err := h.valueAt(key, height) - if err != nil { - return nil, err - } - - return new(felt.Felt).SetBytes(value), nil -} - -func classHashLogKey(contractAddress *felt.Felt) []byte { - return db.ContractClassHashHistory.Key(contractAddress.Marshal()) -} - -func (h *history) LogContractClassHash(contractAddress, oldValue *felt.Felt, height uint64) error { - return h.logOldValue(classHashLogKey(contractAddress), oldValue.Marshal(), height) -} - -func (h *history) DeleteContractClassHashLog(contractAddress *felt.Felt, height uint64) error { - return h.deleteLog(classHashLogKey(contractAddress), height) -} - -func (h *history) ContractClassHashAt(contractAddress *felt.Felt, height uint64) (*felt.Felt, error) { - key := classHashLogKey(contractAddress) - value, err := h.valueAt(key, height) - if err != nil { - return nil, err - } - - return new(felt.Felt).SetBytes(value), nil -} diff --git a/core/history_pkg_test.go b/core/history_pkg_test.go deleted file mode 100644 index b883f4c0ed..0000000000 --- a/core/history_pkg_test.go +++ /dev/null @@ -1,104 +0,0 @@ -package core - -import ( - "testing" - - "github.com/NethermindEth/juno/core/felt" - "github.com/NethermindEth/juno/db/pebble" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestHistory(t *testing.T) { - testDB := pebble.NewMemTest(t) - txn, err := testDB.NewTransaction(true) - require.NoError(t, err) - t.Cleanup(func() { - require.NoError(t, txn.Discard()) - }) - - history := &history{txn: txn} - contractAddress := new(felt.Felt).SetUint64(123) - - for desc, test := range map[string]struct { - logger func(location, oldValue *felt.Felt, height uint64) error - getter func(location *felt.Felt, height uint64) (*felt.Felt, error) - deleter func(location *felt.Felt, height uint64) error - }{ - "contract storage": { - logger: func(location, oldValue *felt.Felt, height uint64) error { - return history.LogContractStorage(contractAddress, location, oldValue, height) - }, - getter: func(location *felt.Felt, height uint64) (*felt.Felt, error) { - return history.ContractStorageAt(contractAddress, location, height) - }, - deleter: func(location *felt.Felt, height uint64) error { - return history.DeleteContractStorageLog(contractAddress, location, height) - }, - }, - "contract nonce": { - logger: history.LogContractNonce, - getter: history.ContractNonceAt, - deleter: history.DeleteContractNonceLog, - }, - "contract class hash": { - logger: history.LogContractClassHash, - getter: history.ContractClassHashAt, - deleter: history.DeleteContractClassHashLog, - }, - } { - location := new(felt.Felt).SetUint64(456) - - t.Run(desc, func(t *testing.T) { - t.Run("no history", func(t *testing.T) { - _, err := test.getter(location, 1) - assert.ErrorIs(t, err, ErrCheckHeadState) - }) - - value := new(felt.Felt).SetUint64(789) - - t.Run("log value changed at height 5 and 10", func(t *testing.T) { - assert.NoError(t, test.logger(location, &felt.Zero, 5)) - assert.NoError(t, test.logger(location, value, 10)) - }) - - t.Run("get value before height 5", func(t *testing.T) { - oldValue, err := test.getter(location, 1) - require.NoError(t, err) - assert.Equal(t, &felt.Zero, oldValue) - }) - - t.Run("get value between height 5-10 ", func(t *testing.T) { - oldValue, err := test.getter(location, 7) - require.NoError(t, err) - assert.Equal(t, value, oldValue) - }) - - t.Run("get value on height that change happened ", func(t *testing.T) { - oldValue, err := test.getter(location, 5) - require.NoError(t, err) - assert.Equal(t, value, oldValue) - - _, err = test.getter(location, 10) - assert.ErrorIs(t, err, ErrCheckHeadState) - }) - - t.Run("get value after height 10 ", func(t *testing.T) { - _, err := test.getter(location, 13) - assert.ErrorIs(t, err, ErrCheckHeadState) - }) - - t.Run("get a random location ", func(t *testing.T) { - _, err := test.getter(new(felt.Felt).SetUint64(37), 13) - assert.ErrorIs(t, err, ErrCheckHeadState) - }) - - require.NoError(t, test.deleter(location, 10)) - - t.Run("get after delete", func(t *testing.T) { - _, err := test.getter(location, 7) - assert.ErrorIs(t, err, ErrCheckHeadState) - }) - }) - } -} diff --git a/core/state.go b/core/state.go index 822f11aec1..d647fe6639 100644 --- a/core/state.go +++ b/core/state.go @@ -17,8 +17,9 @@ import ( const globalTrieHeight = 251 var ( - stateVersion = new(felt.Felt).SetBytes([]byte(`STARKNET_STATE_V0`)) - leafVersion = new(felt.Felt).SetBytes([]byte(`CONTRACT_CLASS_LEAF_V0`)) + stateVersion = new(felt.Felt).SetBytes([]byte(`STARKNET_STATE_V0`)) + leafVersion = new(felt.Felt).SetBytes([]byte(`CONTRACT_CLASS_LEAF_V0`)) + ErrCheckHeadState = errors.New("check head state") ) var _ StateHistoryReader = (*State)(nil) @@ -41,14 +42,12 @@ type StateReader interface { } type State struct { - *history txn db.Transaction } func NewState(txn db.Transaction) *State { return &State{ - history: &history{txn: txn}, - txn: txn, + txn: txn, } } @@ -94,6 +93,22 @@ func (s *State) ContractNonceAt(addr *felt.Felt, blockNumber uint64) (*felt.Felt return s.contractValueAt(nonceLogKey, addr, blockNumber) } +func (s *State) deleteLog(key []byte, height uint64) error { + return s.txn.Delete(logDBKey(key, height)) +} + +func (s *State) DeleteContractStorageLog(contractAddress, storageLocation *felt.Felt, height uint64) error { + return s.deleteLog(storageLogKey(contractAddress, storageLocation), height) +} + +func (s *State) DeleteContractNonceLog(contractAddress *felt.Felt, height uint64) error { + return s.deleteLog(nonceLogKey(contractAddress), height) +} + +func (s *State) DeleteContractClassHashLog(contractAddress *felt.Felt, height uint64) error { + return s.deleteLog(classHashLogKey(contractAddress), height) +} + func (s *State) contractValueAt(keyFunc func(*felt.Felt) []byte, addr *felt.Felt, blockNumber uint64) (*felt.Felt, error) { key := keyFunc(addr) value, err := s.valueAt(key, blockNumber) @@ -337,7 +352,7 @@ func (s *State) updateContracts(blockNumber uint64, diff *StateDiff, logChanges } if logChanges { - if err := contract.logClassHash(blockNumber, s.txn); err != nil { + if err := contract.LogClassHash(blockNumber, s.txn); err != nil { return err } } @@ -357,7 +372,7 @@ func (s *State) updateContracts(blockNumber uint64, diff *StateDiff, logChanges } if logChanges { - if err := contract.logNonce(blockNumber, s.txn); err != nil { + if err := contract.LogNonce(blockNumber, s.txn); err != nil { return err } } diff --git a/core/state_test.go b/core/state_test.go index 454bf01dd5..1119260672 100644 --- a/core/state_test.go +++ b/core/state_test.go @@ -713,3 +713,130 @@ func TestRevertDeclaredClasses(t *testing.T) { _, err = state.Class(sierraHash) require.ErrorIs(t, err, db.ErrKeyNotFound) } + +func TestHistory(t *testing.T) { + testDB := pebble.NewMemTest(t) + txn, err := testDB.NewTransaction(true) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, txn.Discard()) + }) + + state := core.NewState(txn) + addr := &felt.Zero + location := new(felt.Felt).SetUint64(456) + value := new(felt.Felt).SetUint64(789) + + t.Run("no history", func(t *testing.T) { + _, err := state.ContractNonceAt(new(felt.Felt).SetUint64(1), 1) + require.ErrorIs(t, err, core.ErrCheckHeadState) + + _, err = state.ContractClassHashAt(new(felt.Felt).SetUint64(1), 1) + require.ErrorIs(t, err, core.ErrCheckHeadState) + + _, err = state.ContractStorageAt(new(felt.Felt).SetUint64(1), new(felt.Felt).SetUint64(1), 1) + require.ErrorIs(t, err, core.ErrCheckHeadState) + }) + + contract := core.NewStateContract(&felt.Zero, &felt.Zero, &felt.Zero, 0) + t.Run("log value changed at height 5 and 10", func(t *testing.T) { + assert.NoError(t, contract.LogNonce(5, txn)) + assert.NoError(t, contract.LogClassHash(5, txn)) + assert.NoError(t, contract.LogStorage(location, &felt.Zero, 5, txn)) + + contract.Nonce = value + contract.ClassHash = value + + assert.NoError(t, contract.LogNonce(10, txn)) + assert.NoError(t, contract.LogClassHash(10, txn)) + assert.NoError(t, contract.LogStorage(location, value, 10, txn)) + }) + + t.Run("get value before height 5", func(t *testing.T) { + oldValue, err := state.ContractStorageAt(addr, location, 1) + require.NoError(t, err) + assert.Equal(t, &felt.Zero, oldValue) + + oldValue, err = state.ContractNonceAt(addr, 1) + require.NoError(t, err) + assert.Equal(t, &felt.Zero, oldValue) + + oldValue, err = state.ContractClassHashAt(addr, 1) + require.NoError(t, err) + assert.Equal(t, &felt.Zero, oldValue) + }) + + t.Run("get value between height 5-10", func(t *testing.T) { + oldValue, err := state.ContractStorageAt(addr, location, 7) + require.NoError(t, err) + assert.Equal(t, value, oldValue) + + oldValue, err = state.ContractNonceAt(addr, 7) + require.NoError(t, err) + assert.Equal(t, value, oldValue) + + oldValue, err = state.ContractClassHashAt(addr, 7) + require.NoError(t, err) + assert.Equal(t, value, oldValue) + }) + + t.Run("get value on height that change happened", func(t *testing.T) { + oldValue, err := state.ContractStorageAt(addr, location, 5) + require.NoError(t, err) + assert.Equal(t, value, oldValue) + + _, err = state.ContractStorageAt(addr, location, 10) + assert.ErrorIs(t, err, core.ErrCheckHeadState) + + oldValue, err = state.ContractNonceAt(addr, 5) + require.NoError(t, err) + assert.Equal(t, value, oldValue) + + _, err = state.ContractNonceAt(addr, 10) + assert.ErrorIs(t, err, core.ErrCheckHeadState) + + oldValue, err = state.ContractClassHashAt(addr, 5) + require.NoError(t, err) + assert.Equal(t, value, oldValue) + + _, err = state.ContractClassHashAt(addr, 10) + assert.ErrorIs(t, err, core.ErrCheckHeadState) + }) + + t.Run("get value after height 10 ", func(t *testing.T) { + _, err = state.ContractStorageAt(addr, location, 13) + assert.ErrorIs(t, err, core.ErrCheckHeadState) + + _, err = state.ContractNonceAt(addr, 13) + assert.ErrorIs(t, err, core.ErrCheckHeadState) + + _, err = state.ContractClassHashAt(addr, 13) + assert.ErrorIs(t, err, core.ErrCheckHeadState) + }) + + t.Run("get a random location ", func(t *testing.T) { + _, err = state.ContractStorageAt(new(felt.Felt).SetUint64(37), new(felt.Felt).SetUint64(37), 13) + assert.ErrorIs(t, err, core.ErrCheckHeadState) + + _, err = state.ContractNonceAt(new(felt.Felt).SetUint64(37), 13) + assert.ErrorIs(t, err, core.ErrCheckHeadState) + + _, err = state.ContractClassHashAt(new(felt.Felt).SetUint64(37), 13) + assert.ErrorIs(t, err, core.ErrCheckHeadState) + }) + + t.Run("delete storage and get value after delete", func(t *testing.T) { + assert.NoError(t, state.DeleteContractClassHashLog(addr, 10)) + assert.NoError(t, state.DeleteContractNonceLog(addr, 10)) + assert.NoError(t, state.DeleteContractStorageLog(addr, location, 10)) + + _, err = state.ContractStorageAt(addr, location, 10) + assert.ErrorIs(t, err, core.ErrCheckHeadState) + + _, err = state.ContractNonceAt(addr, 10) + assert.ErrorIs(t, err, core.ErrCheckHeadState) + + _, err = state.ContractClassHashAt(addr, 10) + assert.ErrorIs(t, err, core.ErrCheckHeadState) + }) +} From bb9adeb0b6624054e5965336ebf4f2cd364c66da Mon Sep 17 00:00:00 2001 From: weiihann Date: Wed, 16 Oct 2024 20:32:24 +0800 Subject: [PATCH 04/13] minor change --- core/contract.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/contract.go b/core/contract.go index 7a02f2d364..1eea36fa71 100644 --- a/core/contract.go +++ b/core/contract.go @@ -59,7 +59,7 @@ func (c *StateContract) StorageRoot(txn db.Transaction) (*felt.Felt, error) { return storageTrie.Root() } -func (c *StateContract) UpdateStorage(key *felt.Felt, value *felt.Felt) { +func (c *StateContract) UpdateStorage(key, value *felt.Felt) { if c.dirtyStorage == nil { c.dirtyStorage = make(map[felt.Felt]*felt.Felt) } From 238a67f5b613d05a30d004cc951ed03a220f0bc7 Mon Sep 17 00:00:00 2001 From: weiihann Date: Wed, 16 Oct 2024 22:16:25 +0800 Subject: [PATCH 05/13] make contract commit concurrently --- core/state.go | 85 +++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 66 insertions(+), 19 deletions(-) diff --git a/core/state.go b/core/state.go index d647fe6639..e00c518943 100644 --- a/core/state.go +++ b/core/state.go @@ -5,6 +5,10 @@ import ( "encoding/binary" "errors" "fmt" + "maps" + "runtime" + "slices" + "sort" "github.com/NethermindEth/juno/core/crypto" "github.com/NethermindEth/juno/core/felt" @@ -12,6 +16,7 @@ import ( "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/encoder" "github.com/NethermindEth/juno/utils" + "github.com/sourcegraph/conc/pool" ) const globalTrieHeight = 251 @@ -307,15 +312,8 @@ func (s *State) Update(blockNumber uint64, update *StateUpdate, declaredClasses return err } - // Commit all contract updates - for _, contract := range contracts { - if err = contract.Commit(s.txn, true, blockNumber); err != nil { - return err - } - - if err := s.updateContractCommitment(stateTrie, contract); err != nil { - return err - } + if err = s.Commit(stateTrie, contracts, true, blockNumber); err != nil { + return fmt.Errorf("state commit: %v", err) } if err = storageCloser(); err != nil { @@ -333,6 +331,63 @@ var ( } ) +func (s *State) Commit( + stateTrie *trie.Trie, + contracts map[felt.Felt]*StateContract, + logChanges bool, + blockNumber uint64, +) error { + type bufferedTransactionWithAddress struct { + txn *db.BufferedTransaction + addr *felt.Felt + } + + // // sort the contracts in descending storage diff order + keys := slices.SortedStableFunc(maps.Keys(contracts), func(a, b felt.Felt) int { + return len(contracts[a].dirtyStorage) - len(contracts[b].dirtyStorage) + }) + + contractPools := pool.NewWithResults[*bufferedTransactionWithAddress]().WithErrors().WithMaxGoroutines(runtime.GOMAXPROCS(0)) + for _, addr := range keys { + contract := contracts[addr] + contractPools.Go(func() (*bufferedTransactionWithAddress, error) { + txn, err := contract.BufferedCommit(s.txn, logChanges, blockNumber) + if err != nil { + return nil, err + } + + return &bufferedTransactionWithAddress{ + txn: txn, + addr: contract.Address, + }, nil + }) + } + + bufferedTxns, err := contractPools.Wait() + if err != nil { + return err + } + + // we sort bufferedTxns in ascending contract address order to achieve an additional speedup + sort.Slice(bufferedTxns, func(i, j int) bool { + return bufferedTxns[i].addr.Cmp(bufferedTxns[j].addr) < 0 + }) + + for _, bufferedTxn := range bufferedTxns { + if err := bufferedTxn.txn.Flush(); err != nil { + return err + } + } + + for _, contract := range contracts { + if err := s.updateContractCommitment(stateTrie, contract); err != nil { + return err + } + } + + return nil +} + func (s *State) updateContracts(blockNumber uint64, diff *StateDiff, logChanges bool, contracts map[felt.Felt]*StateContract) error { if contracts == nil { return fmt.Errorf("contracts is nil") @@ -522,16 +577,8 @@ func (s *State) Revert(blockNumber uint64, update *StateUpdate) error { return fmt.Errorf("update contracts: %v", err) } - // TODO(weiihann): make concurrent - // Commit the changes to the contracts and update their commitments - for _, contract := range contracts { - if err = contract.Commit(s.txn, false, blockNumber); err != nil { - return fmt.Errorf("commit contract: %v", err) - } - - if err = s.updateContractCommitment(stateTrie, contract); err != nil { - return fmt.Errorf("update contract commitment: %v", err) - } + if err = s.Commit(stateTrie, contracts, false, blockNumber); err != nil { + return fmt.Errorf("state commit: %v", err) } // purge deployed contracts From 27f924d51b192f2d796726254c34786715703a2b Mon Sep 17 00:00:00 2001 From: weiihann Date: Wed, 16 Oct 2024 22:17:35 +0800 Subject: [PATCH 06/13] make contract commit concurrently --- core/contract.go | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/core/contract.go b/core/contract.go index 1eea36fa71..66c9058e0d 100644 --- a/core/contract.go +++ b/core/contract.go @@ -102,6 +102,16 @@ func (c *StateContract) LogClassHash(height uint64, txn db.Transaction) error { return c.logOldValue(key, c.ClassHash, height, txn) } +func (c *StateContract) BufferedCommit(txn db.Transaction, logChanges bool, blockNum uint64) (*db.BufferedTransaction, error) { + bufferedTxn := db.NewBufferedTransaction(txn) + + if err := c.Commit(bufferedTxn, logChanges, blockNum); err != nil { + return nil, err + } + + return bufferedTxn, nil +} + func (c *StateContract) Commit(txn db.Transaction, logChanges bool, blockNum uint64) error { storageTrie, err := storage(c.Address, txn) if err != nil { @@ -202,7 +212,6 @@ func ContractAddress(callerAddress, classHash, salt *felt.Felt, constructorCallD // storage returns the [core.Trie] that represents the // storage of the contract. -// TODO(weiihann): how to deal with the root key? func storage(addr *felt.Felt, txn db.Transaction) (*trie.Trie, error) { addrBytes := addr.Marshal() trieTxn := trie.NewStorage(txn, db.ContractStorage.Key(addrBytes)) From b5ced5a085e1a2ebd50838e78430bdafd9485fe5 Mon Sep 17 00:00:00 2001 From: weiihann Date: Wed, 16 Oct 2024 22:26:51 +0800 Subject: [PATCH 07/13] add state update benchmark --- clients/feeder/feeder.go | 4 ++-- core/state_test.go | 38 ++++++++++++++++++++++++++++++++++++++ db/pebble/db.go | 2 +- 3 files changed, 41 insertions(+), 3 deletions(-) diff --git a/clients/feeder/feeder.go b/clients/feeder/feeder.go index 82e99390ad..010adee3bc 100644 --- a/clients/feeder/feeder.go +++ b/clients/feeder/feeder.go @@ -92,7 +92,7 @@ func NopBackoff(d time.Duration) time.Duration { } // NewTestClient returns a client and a function to close a test server. -func NewTestClient(t *testing.T, network *utils.Network) *Client { +func NewTestClient(t testing.TB, network *utils.Network) *Client { srv := newTestServer(t, network) t.Cleanup(srv.Close) ua := "Juno/v0.0.1-test Starknet Implementation" @@ -117,7 +117,7 @@ func NewTestClient(t *testing.T, network *utils.Network) *Client { return c } -func newTestServer(t *testing.T, network *utils.Network) *httptest.Server { +func newTestServer(t testing.TB, network *utils.Network) *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { queryMap, err := url.ParseQuery(r.URL.RawQuery) if err != nil { diff --git a/core/state_test.go b/core/state_test.go index 1119260672..e2d04ee2ea 100644 --- a/core/state_test.go +++ b/core/state_test.go @@ -840,3 +840,41 @@ func TestHistory(t *testing.T) { assert.ErrorIs(t, err, core.ErrCheckHeadState) }) } + +func BenchmarkStateUpdate(b *testing.B) { + client := feeder.NewTestClient(b, &utils.Mainnet) + gw := adaptfeeder.New(client) + + su0, err := gw.StateUpdate(context.Background(), 0) + require.NoError(b, err) + + su1, err := gw.StateUpdate(context.Background(), 1) + require.NoError(b, err) + + su2, err := gw.StateUpdate(context.Background(), 2) + require.NoError(b, err) + + stateUpdates := []*core.StateUpdate{su0, su1, su2} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StopTimer() + // Create a new test database for each iteration + testDB := pebble.NewMemTest(b) + txn, err := testDB.NewTransaction(true) + require.NoError(b, err) + + state := core.NewState(txn) + b.StartTimer() + + for i, su := range stateUpdates { + err = state.Update(uint64(i), su, nil) + if err != nil { + b.Fatalf("Error updating state: %v", err) + } + } + + b.StopTimer() + require.NoError(b, txn.Discard()) + } +} diff --git a/db/pebble/db.go b/db/pebble/db.go index 5974edf720..77aed603d7 100644 --- a/db/pebble/db.go +++ b/db/pebble/db.go @@ -60,7 +60,7 @@ func NewMem() (db.DB, error) { } // NewMemTest opens a new in-memory database, panics on error -func NewMemTest(t *testing.T) db.DB { +func NewMemTest(t testing.TB) db.DB { memDB, err := NewMem() if err != nil { t.Fatalf("create in-memory db: %v", err) From 20599d5d138eddadc49bee6854eb56a8322bcbd3 Mon Sep 17 00:00:00 2001 From: weiihann Date: Thu, 17 Oct 2024 11:35:10 +0800 Subject: [PATCH 08/13] add migration --- migration/migration.go | 131 ++++++++++++++++++++++++++++++++++++ migration/migration_test.go | 48 +++++++++++++ 2 files changed, 179 insertions(+) diff --git a/migration/migration.go b/migration/migration.go index f3a8dd2a72..1943badc38 100644 --- a/migration/migration.go +++ b/migration/migration.go @@ -66,6 +66,7 @@ var defaultMigrations = []Migration{ NewBucketMover(db.Temporary, db.ContractStorage), NewBucketMigrator(db.StateUpdatesByBlockNumber, changeStateDiffStruct).WithBatchSize(100), //nolint:mnd NewBucketMigrator(db.Class, migrateCairo1CompiledClass).WithBatchSize(1_000), //nolint:mnd + MigrationFunc(MigrateContractFields), } var ErrCallWithNewTransaction = errors.New("call with new transaction") @@ -714,3 +715,133 @@ func migrateCairo1CompiledClass(txn db.Transaction, key, value []byte, _ *utils. return txn.Set(key, value) } + +func MigrateContractFields(txn db.Transaction, _ *utils.Network) error { + contracts := make(map[felt.Felt]*core.StateContract) + + it, err := txn.NewIterator() + if err != nil { + return err + } + + if err := collectContractNonces(it, contracts); err != nil { + return err + } + + if err := collectContractClassHashes(it, contracts); err != nil { + return err + } + + if err := collectContractDeploymentHeights(it, contracts); err != nil { + return err + } + + if err := storeUpdatedContracts(txn, contracts); err != nil { + return err + } + + return it.Close() +} + +func collectContractNonces(it db.Iterator, contracts map[felt.Felt]*core.StateContract) error { + noncePrefix := db.ContractNonce.Key() + for it.Seek(noncePrefix); it.Valid(); it.Next() { + key := it.Key() + if !bytes.Equal(key[:len(noncePrefix)], noncePrefix) { + break + } + + addr := key[len(noncePrefix):] + if len(addr) != felt.Bytes { + return fmt.Errorf("invalid address length: %d", len(addr)) + } + + addrFelt := new(felt.Felt).SetBytes(addr) + + value, err := it.Value() + if err != nil { + return err + } + + contract := &core.StateContract{ + Nonce: new(felt.Felt).SetBytes(value), + } + contracts[*addrFelt] = contract + } + + return nil +} + +func collectContractClassHashes(it db.Iterator, contracts map[felt.Felt]*core.StateContract) error { + classHashPrefix := db.ContractClassHash.Key() + for it.Seek(classHashPrefix); it.Valid(); it.Next() { + key := it.Key() + if !bytes.Equal(key[:len(classHashPrefix)], classHashPrefix) { + break + } + + addr := key[len(classHashPrefix):] + if len(addr) != felt.Bytes { + return fmt.Errorf("invalid address length: %d", len(addr)) + } + addrFelt := new(felt.Felt).SetBytes(addr) + + // this should never happen because collectContractNonces should have collected all the contracts + if _, ok := contracts[*addrFelt]; !ok { + return fmt.Errorf("contract not found for address: %s", addrFelt) + } + + value, err := it.Value() + if err != nil { + return err + } + + contracts[*addrFelt].ClassHash = new(felt.Felt).SetBytes(value) + } + + return nil +} + +func collectContractDeploymentHeights(it db.Iterator, contracts map[felt.Felt]*core.StateContract) error { + deployHeightPrefix := db.ContractDeploymentHeight.Key() + for it.Seek(deployHeightPrefix); it.Valid(); it.Next() { + key := it.Key() + if !bytes.Equal(key[:len(deployHeightPrefix)], deployHeightPrefix) { + break + } + + addr := key[len(deployHeightPrefix):] + if len(addr) != felt.Bytes { + return fmt.Errorf("invalid address length: %d", len(addr)) + } + addrFelt := new(felt.Felt).SetBytes(addr) + + // this should never happen because collectContractNonces should have collected all the contracts + if _, ok := contracts[*addrFelt]; !ok { + return fmt.Errorf("contract not found for address: %s", addrFelt) + } + + value, err := it.Value() + if err != nil { + return err + } + + contracts[*addrFelt].DeployHeight = binary.BigEndian.Uint64(value) + } + + return nil +} + +func storeUpdatedContracts(txn db.Transaction, contracts map[felt.Felt]*core.StateContract) error { + for addr, contract := range contracts { + contractBytes, err := encoder.Marshal(contract) + if err != nil { + return err + } + + if err := txn.Set(db.Contract.Key(addr.Marshal()), contractBytes); err != nil { + return err + } + } + return nil +} diff --git a/migration/migration_test.go b/migration/migration_test.go index f037974fb1..85065462aa 100644 --- a/migration/migration_test.go +++ b/migration/migration_test.go @@ -2,9 +2,14 @@ package migration_test import ( "context" + "encoding/binary" "testing" + "github.com/NethermindEth/juno/core" + "github.com/NethermindEth/juno/core/felt" + "github.com/NethermindEth/juno/db" "github.com/NethermindEth/juno/db/pebble" + "github.com/NethermindEth/juno/encoder" "github.com/NethermindEth/juno/migration" "github.com/NethermindEth/juno/utils" "github.com/stretchr/testify/require" @@ -40,3 +45,46 @@ func TestMigrateIfNeeded(t *testing.T) { require.Equal(t, meta, postVersion) }) } + +func TestMigrateContractFields(t *testing.T) { + testDB := pebble.NewMemTest(t) + txn, err := testDB.NewTransaction(true) + require.NoError(t, err) + + // Test data + contracts := []struct { + addr *felt.Felt + nonce *felt.Felt + classHash *felt.Felt + deploymentHeight uint64 + }{ + {new(felt.Felt).SetUint64(1), new(felt.Felt).SetUint64(11), new(felt.Felt).SetUint64(111), 1111}, + {new(felt.Felt).SetUint64(2), new(felt.Felt).SetUint64(22), new(felt.Felt).SetUint64(222), 2222}, + {new(felt.Felt).SetUint64(3), new(felt.Felt).SetUint64(33), new(felt.Felt).SetUint64(333), 3333}, + } + + // Set up initial data + for _, c := range contracts { + addrBytes := c.addr.Marshal() + hBytes := make([]byte, 8) + binary.BigEndian.PutUint64(hBytes, c.deploymentHeight) + + require.NoError(t, txn.Set(db.ContractNonce.Key(addrBytes), c.nonce.Marshal())) + require.NoError(t, txn.Set(db.ContractClassHash.Key(addrBytes), c.classHash.Marshal())) + require.NoError(t, txn.Set(db.ContractDeploymentHeight.Key(addrBytes), hBytes)) + } + + // Run migration + require.NoError(t, migration.MigrateContractFields(txn, nil)) + + // Verify results + for _, c := range contracts { + var contract core.StateContract + require.NoError(t, txn.Get(db.Contract.Key(c.addr.Marshal()), func(value []byte) error { + return encoder.Unmarshal(value, &contract) + })) + require.Equal(t, c.nonce, contract.Nonce) + require.Equal(t, c.classHash, contract.ClassHash) + require.Equal(t, c.deploymentHeight, contract.DeployHeight) + } +} From 70d005d43b655bacf3a0e32fcc683ace6eaf470b Mon Sep 17 00:00:00 2001 From: weiihann Date: Thu, 17 Oct 2024 12:25:20 +0800 Subject: [PATCH 09/13] update bucket enumer, delete keys after migrate --- cmd/juno/dbcmd.go | 2 +- db/buckets.go | 1 - db/buckets_enumer.go | 12 +++++++---- migration/migration.go | 45 +++++++++++++++++++++++++++++++++--------- 4 files changed, 45 insertions(+), 15 deletions(-) diff --git a/cmd/juno/dbcmd.go b/cmd/juno/dbcmd.go index 4fe5cd3a81..5c79dea982 100644 --- a/cmd/juno/dbcmd.go +++ b/cmd/juno/dbcmd.go @@ -206,7 +206,7 @@ func dbSize(cmd *cobra.Command, args []string) error { totalSize += bucketItem.Size totalCount += bucketItem.Count - if utils.AnyOf(b, db.StateTrie, db.ContractStorage, db.Class, db.ContractNonce, db.ContractDeploymentHeight) { + if utils.AnyOf(b, db.StateTrie, db.ContractStorage, db.Class, db.Contract) { withoutHistorySize += bucketItem.Size withHistorySize += bucketItem.Size diff --git a/db/buckets.go b/db/buckets.go index 211edc65ea..ef8aa1b8a8 100644 --- a/db/buckets.go +++ b/db/buckets.go @@ -41,7 +41,6 @@ const ( Temporary // used temporarily for migrations SchemaIntermediateState // used for db schema metadata Contract // Contract + ContractAddr -> Encoded(Contract) - ContractHistory // ContractHistory + ContractAddr + BlockHeight -> Encoded(Contract) ) // Key flattens a prefix and series of byte arrays into a single []byte. diff --git a/db/buckets_enumer.go b/db/buckets_enumer.go index 0501198b61..0abf339cf5 100644 --- a/db/buckets_enumer.go +++ b/db/buckets_enumer.go @@ -7,11 +7,11 @@ import ( "strings" ) -const _BucketName = "StateTriePeerContractClassHashContractStorageClassContractNonceChainHeightBlockHeaderNumbersByHashBlockHeadersByNumberTransactionBlockNumbersAndIndicesByHashTransactionsByBlockNumberAndIndexReceiptsByBlockNumberAndIndexStateUpdatesByBlockNumberClassesTrieContractStorageHistoryContractNonceHistoryContractClassHashHistoryContractDeploymentHeightL1HeightSchemaVersionPendingBlockCommitmentsTemporarySchemaIntermediateState" +const _BucketName = "StateTriePeerContractClassHashContractStorageClassContractNonceChainHeightBlockHeaderNumbersByHashBlockHeadersByNumberTransactionBlockNumbersAndIndicesByHashTransactionsByBlockNumberAndIndexReceiptsByBlockNumberAndIndexStateUpdatesByBlockNumberClassesTrieContractStorageHistoryContractNonceHistoryContractClassHashHistoryContractDeploymentHeightL1HeightSchemaVersionPendingBlockCommitmentsTemporarySchemaIntermediateStateContract" -var _BucketIndex = [...]uint16{0, 9, 13, 30, 45, 50, 63, 74, 98, 118, 157, 190, 219, 244, 255, 277, 297, 321, 345, 353, 366, 373, 389, 398, 421} +var _BucketIndex = [...]uint16{0, 9, 13, 30, 45, 50, 63, 74, 98, 118, 157, 190, 219, 244, 255, 277, 297, 321, 345, 353, 366, 373, 389, 398, 421, 429} -const _BucketLowerName = "statetriepeercontractclasshashcontractstorageclasscontractnoncechainheightblockheadernumbersbyhashblockheadersbynumbertransactionblocknumbersandindicesbyhashtransactionsbyblocknumberandindexreceiptsbyblocknumberandindexstateupdatesbyblocknumberclassestriecontractstoragehistorycontractnoncehistorycontractclasshashhistorycontractdeploymentheightl1heightschemaversionpendingblockcommitmentstemporaryschemaintermediatestate" +const _BucketLowerName = "statetriepeercontractclasshashcontractstorageclasscontractnoncechainheightblockheadernumbersbyhashblockheadersbynumbertransactionblocknumbersandindicesbyhashtransactionsbyblocknumberandindexreceiptsbyblocknumberandindexstateupdatesbyblocknumberclassestriecontractstoragehistorycontractnoncehistorycontractclasshashhistorycontractdeploymentheightl1heightschemaversionpendingblockcommitmentstemporaryschemaintermediatestatecontract" func (i Bucket) String() string { if i >= Bucket(len(_BucketIndex)-1) { @@ -48,9 +48,10 @@ func _BucketNoOp() { _ = x[BlockCommitments-(21)] _ = x[Temporary-(22)] _ = x[SchemaIntermediateState-(23)] + _ = x[Contract-(24)] } -var _BucketValues = []Bucket{StateTrie, Peer, ContractClassHash, ContractStorage, Class, ContractNonce, ChainHeight, BlockHeaderNumbersByHash, BlockHeadersByNumber, TransactionBlockNumbersAndIndicesByHash, TransactionsByBlockNumberAndIndex, ReceiptsByBlockNumberAndIndex, StateUpdatesByBlockNumber, ClassesTrie, ContractStorageHistory, ContractNonceHistory, ContractClassHashHistory, ContractDeploymentHeight, L1Height, SchemaVersion, Pending, BlockCommitments, Temporary, SchemaIntermediateState} +var _BucketValues = []Bucket{StateTrie, Peer, ContractClassHash, ContractStorage, Class, ContractNonce, ChainHeight, BlockHeaderNumbersByHash, BlockHeadersByNumber, TransactionBlockNumbersAndIndicesByHash, TransactionsByBlockNumberAndIndex, ReceiptsByBlockNumberAndIndex, StateUpdatesByBlockNumber, ClassesTrie, ContractStorageHistory, ContractNonceHistory, ContractClassHashHistory, ContractDeploymentHeight, L1Height, SchemaVersion, Pending, BlockCommitments, Temporary, SchemaIntermediateState, Contract} var _BucketNameToValueMap = map[string]Bucket{ _BucketName[0:9]: StateTrie, @@ -101,6 +102,8 @@ var _BucketNameToValueMap = map[string]Bucket{ _BucketLowerName[389:398]: Temporary, _BucketName[398:421]: SchemaIntermediateState, _BucketLowerName[398:421]: SchemaIntermediateState, + _BucketName[421:429]: Contract, + _BucketLowerName[421:429]: Contract, } var _BucketNames = []string{ @@ -128,6 +131,7 @@ var _BucketNames = []string{ _BucketName[373:389], _BucketName[389:398], _BucketName[398:421], + _BucketName[421:429], } // BucketString retrieves an enum value from the enum constants string name. diff --git a/migration/migration.go b/migration/migration.go index 1943badc38..b0173dac37 100644 --- a/migration/migration.go +++ b/migration/migration.go @@ -724,15 +724,15 @@ func MigrateContractFields(txn db.Transaction, _ *utils.Network) error { return err } - if err := collectContractNonces(it, contracts); err != nil { + if err := collectContractNonces(txn, contracts); err != nil { return err } - if err := collectContractClassHashes(it, contracts); err != nil { + if err := collectContractClassHashes(txn, contracts); err != nil { return err } - if err := collectContractDeploymentHeights(it, contracts); err != nil { + if err := collectContractDeploymentHeights(txn, contracts); err != nil { return err } @@ -743,7 +743,12 @@ func MigrateContractFields(txn db.Transaction, _ *utils.Network) error { return it.Close() } -func collectContractNonces(it db.Iterator, contracts map[felt.Felt]*core.StateContract) error { +func collectContractNonces(txn db.Transaction, contracts map[felt.Felt]*core.StateContract) error { + it, err := txn.NewIterator() + if err != nil { + return err + } + noncePrefix := db.ContractNonce.Key() for it.Seek(noncePrefix); it.Valid(); it.Next() { key := it.Key() @@ -767,12 +772,21 @@ func collectContractNonces(it db.Iterator, contracts map[felt.Felt]*core.StateCo Nonce: new(felt.Felt).SetBytes(value), } contracts[*addrFelt] = contract + + if err := txn.Delete(key); err != nil { + return err + } } - return nil + return it.Close() } -func collectContractClassHashes(it db.Iterator, contracts map[felt.Felt]*core.StateContract) error { +func collectContractClassHashes(txn db.Transaction, contracts map[felt.Felt]*core.StateContract) error { + it, err := txn.NewIterator() + if err != nil { + return err + } + classHashPrefix := db.ContractClassHash.Key() for it.Seek(classHashPrefix); it.Valid(); it.Next() { key := it.Key() @@ -797,12 +811,21 @@ func collectContractClassHashes(it db.Iterator, contracts map[felt.Felt]*core.St } contracts[*addrFelt].ClassHash = new(felt.Felt).SetBytes(value) + + if err := txn.Delete(key); err != nil { + return err + } } - return nil + return it.Close() } -func collectContractDeploymentHeights(it db.Iterator, contracts map[felt.Felt]*core.StateContract) error { +func collectContractDeploymentHeights(txn db.Transaction, contracts map[felt.Felt]*core.StateContract) error { + it, err := txn.NewIterator() + if err != nil { + return err + } + deployHeightPrefix := db.ContractDeploymentHeight.Key() for it.Seek(deployHeightPrefix); it.Valid(); it.Next() { key := it.Key() @@ -827,9 +850,13 @@ func collectContractDeploymentHeights(it db.Iterator, contracts map[felt.Felt]*c } contracts[*addrFelt].DeployHeight = binary.BigEndian.Uint64(value) + + if err := txn.Delete(key); err != nil { + return err + } } - return nil + return it.Close() } func storeUpdatedContracts(txn db.Transaction, contracts map[felt.Felt]*core.StateContract) error { From 1097fb4fa0152065f852fff2beb6d5c63deefa77 Mon Sep 17 00:00:00 2001 From: weiihann Date: Thu, 17 Oct 2024 17:14:53 +0800 Subject: [PATCH 10/13] Squashed commit of the following: commit f5dc02caac7c59d36015845cf9d5b3355d312da4 Author: Kirill Date: Thu Oct 17 11:15:15 2024 +0400 Small restructure of plugin logic (#2221) commit b43c46f48189926713c58a9760168ef2496c3bff Author: Rian Hughes Date: Wed Oct 16 15:42:07 2024 +0300 Support plugins (#2051) Co-authored-by: rian Co-authored-by: LordGhostX commit ca006de578cfbc5a99dc13c2f751ed0cd03adcb7 Author: Kirill Date: Wed Oct 16 16:14:16 2024 +0400 Replace UnmarshalJSON() with UnmarshalText() for transaction statuses (#2220) * Replace UnmarshalJSON() with UnmarshalText() for transaction statuses UnmarshalText avoids issues with forgetting quotes in JSON, making it simpler for parsing plain text values. fix merge conflicts --- core/state.go | 252 ++++++++++++++++++++++++++++++-------------------- 1 file changed, 151 insertions(+), 101 deletions(-) diff --git a/core/state.go b/core/state.go index e00c518943..049914f4d8 100644 --- a/core/state.go +++ b/core/state.go @@ -323,6 +323,58 @@ func (s *State) Update(blockNumber uint64, update *StateUpdate, declaredClasses return s.verifyStateUpdateRoot(update.NewRoot) } +func (s *State) GetReverseStateDiff(blockNumber uint64, diff *StateDiff) (*StateDiff, error) { + reversed := *diff + + // storage diffs + reversed.StorageDiffs = make(map[felt.Felt]map[felt.Felt]*felt.Felt, len(diff.StorageDiffs)) + for addr, storageDiffs := range diff.StorageDiffs { + reversedDiffs := make(map[felt.Felt]*felt.Felt, len(storageDiffs)) + for key := range storageDiffs { + value := &felt.Zero + if blockNumber > 0 { + oldValue, err := s.ContractStorageAt(&addr, &key, blockNumber-1) + if err != nil { + return nil, err + } + value = oldValue + } + reversedDiffs[key] = value + } + reversed.StorageDiffs[addr] = reversedDiffs + } + + // nonces + reversed.Nonces = make(map[felt.Felt]*felt.Felt, len(diff.Nonces)) + for addr := range diff.Nonces { + oldNonce := &felt.Zero + if blockNumber > 0 { + var err error + oldNonce, err = s.ContractNonceAt(&addr, blockNumber-1) + if err != nil { + return nil, err + } + } + reversed.Nonces[addr] = oldNonce + } + + // replaced + reversed.ReplacedClasses = make(map[felt.Felt]*felt.Felt, len(diff.ReplacedClasses)) + for addr := range diff.ReplacedClasses { + classHash := &felt.Zero + if blockNumber > 0 { + var err error + classHash, err = s.ContractClassHashAt(&addr, blockNumber-1) + if err != nil { + return nil, err + } + } + reversed.ReplacedClasses[addr] = classHash + } + + return &reversed, nil +} + var ( noClassContractsClassHash = new(felt.Felt).SetUint64(0) @@ -393,17 +445,27 @@ func (s *State) updateContracts(blockNumber uint64, diff *StateDiff, logChanges return fmt.Errorf("contracts is nil") } - var err error + if err := s.updateContractClasses(blockNumber, diff.ReplacedClasses, logChanges, contracts); err != nil { + return err + } + + if err := s.updateContractNonces(blockNumber, diff.Nonces, logChanges, contracts); err != nil { + return err + } - // update contract class hashes - for addr, classHash := range diff.ReplacedClasses { - contract, ok := contracts[addr] - if !ok { - contract, err = GetContract(&addr, s.txn) - if err != nil { - return err - } - contracts[addr] = contract + return s.updateContractStorages(blockNumber, diff.StorageDiffs, contracts) +} + +func (s *State) updateContractClasses( + blockNumber uint64, + replacedClasses map[felt.Felt]*felt.Felt, + logChanges bool, + contracts map[felt.Felt]*StateContract, +) error { + for addr, classHash := range replacedClasses { + contract, err := s.getOrCreateContract(addr, contracts) + if err != nil { + return err } if logChanges { @@ -414,16 +476,19 @@ func (s *State) updateContracts(blockNumber uint64, diff *StateDiff, logChanges contract.ClassHash = classHash } + return nil +} - // update contract nonces - for addr, nonce := range diff.Nonces { - contract, ok := contracts[addr] - if !ok { - contract, err = GetContract(&addr, s.txn) - if err != nil { - return err - } - contracts[addr] = contract +func (s *State) updateContractNonces( + blockNumber uint64, + nonces map[felt.Felt]*felt.Felt, + logChanges bool, + contracts map[felt.Felt]*StateContract, +) error { + for addr, nonce := range nonces { + contract, err := s.getOrCreateContract(addr, contracts) + if err != nil { + return err } if logChanges { @@ -434,29 +499,43 @@ func (s *State) updateContracts(blockNumber uint64, diff *StateDiff, logChanges contract.Nonce = nonce } + return nil +} - // update contract storages - for addr, diff := range diff.StorageDiffs { - contract, ok := contracts[addr] - if !ok { - contract, err = GetContract(&addr, s.txn) - if err != nil { - // makes sure that all noClassContracts are deployed - if _, ok := noClassContracts[addr]; ok && errors.Is(err, ErrContractNotDeployed) { - contract = NewStateContract(&addr, noClassContractsClassHash, &felt.Zero, blockNumber) - } else { - return err - } +func (s *State) updateContractStorages( + blockNumber uint64, + storageDiffs map[felt.Felt]map[felt.Felt]*felt.Felt, + contracts map[felt.Felt]*StateContract, +) error { + for addr, diff := range storageDiffs { + contract, err := s.getOrCreateContract(addr, contracts) + if err != nil { + if _, ok := noClassContracts[addr]; ok && errors.Is(err, ErrContractNotDeployed) { + contract = NewStateContract(&addr, noClassContractsClassHash, &felt.Zero, blockNumber) + contracts[addr] = contract + } else { + return err } - contracts[addr] = contract } contract.dirtyStorage = diff } - return nil } +func (s *State) getOrCreateContract(addr felt.Felt, contracts map[felt.Felt]*StateContract) (*StateContract, error) { + contract, ok := contracts[addr] + if !ok { + var err error + contract, err = GetContract(&addr, s.txn) + if err != nil { + return nil, err + } + contracts[addr] = contract + } + return contract, nil +} + type DeclaredClass struct { At uint64 Class Class @@ -561,12 +640,15 @@ func (s *State) Revert(blockNumber uint64, update *StateUpdate) error { return fmt.Errorf("remove declared classes: %v", err) } - // update contracts - reversedDiff, err := s.buildReverseDiff(blockNumber, update.StateDiff) + reversedDiff, err := s.GetReverseStateDiff(blockNumber, update.StateDiff) if err != nil { return fmt.Errorf("build reverse diff: %v", err) } + if err = s.performStateDeletions(blockNumber, reversedDiff); err != nil { + return fmt.Errorf("perform state deletions: %v", err) + } + stateTrie, storageCloser, err := s.storage() if err != nil { return err @@ -588,30 +670,8 @@ func (s *State) Revert(blockNumber uint64, update *StateUpdate) error { } } - // purge noClassContracts - // - // As noClassContracts are not in StateDiff.DeployedContracts we can only purge them if their storage no longer exists. - // Updating contracts with reverse diff will eventually lead to the deletion of noClassContract's storage key from db. Thus, - // we can use the lack of key's existence as reason for purging noClassContracts. - for addr := range noClassContracts { - contract, err := GetContract(&addr, s.txn) - if err != nil { - if !errors.Is(err, ErrContractNotDeployed) { - return err - } - continue - } - - rootKey, err := contract.StorageRoot(s.txn) - if err != nil { - return fmt.Errorf("get root key: %v", err) - } - - if rootKey.Equal(&felt.Zero) { - if err = s.purgeContract(stateTrie, &addr); err != nil { - return fmt.Errorf("purge contract: %v", err) - } - } + if err = s.purgeNoClassContracts(stateTrie); err != nil { + return fmt.Errorf("purge no class contract: %v", err) } if err = storageCloser(); err != nil { @@ -673,69 +733,59 @@ func (s *State) purgeContract(stateTrie *trie.Trie, addr *felt.Felt) error { return nil } -func (s *State) buildReverseDiff(blockNumber uint64, diff *StateDiff) (*StateDiff, error) { - reversed := *diff - +func (s *State) performStateDeletions(blockNumber uint64, diff *StateDiff) error { // storage diffs - reversed.StorageDiffs = make(map[felt.Felt]map[felt.Felt]*felt.Felt, len(diff.StorageDiffs)) for addr, storageDiffs := range diff.StorageDiffs { - reversedDiffs := make(map[felt.Felt]*felt.Felt, len(storageDiffs)) for key := range storageDiffs { - value := &felt.Zero - if blockNumber > 0 { - oldValue, err := s.ContractStorageAt(&addr, &key, blockNumber-1) - if err != nil { - return nil, err - } - value = oldValue - } - if err := s.DeleteContractStorageLog(&addr, &key, blockNumber); err != nil { - return nil, err + return err } - reversedDiffs[key] = value } - reversed.StorageDiffs[addr] = reversedDiffs } // nonces - reversed.Nonces = make(map[felt.Felt]*felt.Felt, len(diff.Nonces)) for addr := range diff.Nonces { - oldNonce := &felt.Zero - - if blockNumber > 0 { - var err error - oldNonce, err = s.ContractNonceAt(&addr, blockNumber-1) - if err != nil { - return nil, err - } - } - if err := s.DeleteContractNonceLog(&addr, blockNumber); err != nil { - return nil, err + return err } - reversed.Nonces[addr] = oldNonce } - // replaced - reversed.ReplacedClasses = make(map[felt.Felt]*felt.Felt, len(diff.ReplacedClasses)) + // replaced classes for addr := range diff.ReplacedClasses { - classHash := &felt.Zero - if blockNumber > 0 { - var err error - classHash, err = s.ContractClassHashAt(&addr, blockNumber-1) - if err != nil { - return nil, err + if err := s.DeleteContractClassHashLog(&addr, blockNumber); err != nil { + return err + } + } + + return nil +} + +func (s *State) purgeNoClassContracts(stateTrie *trie.Trie) error { + // As noClassContracts are not in StateDiff.DeployedContracts we can only purge them if their storage no longer exists. + // Updating contracts with reverse diff will eventually lead to the deletion of noClassContract's storage key from db. Thus, + // we can use the lack of key's existence as reason for purging noClassContracts. + for addr := range noClassContracts { + contract, err := GetContract(&addr, s.txn) + if err != nil { + if !errors.Is(err, ErrContractNotDeployed) { + return err } + continue } - if err := s.DeleteContractClassHashLog(&addr, blockNumber); err != nil { - return nil, err + rootKey, err := contract.StorageRoot(s.txn) + if err != nil { + return fmt.Errorf("get root key: %v", err) + } + + if rootKey.Equal(&felt.Zero) { + if err = s.purgeContract(stateTrie, &addr); err != nil { + return fmt.Errorf("purge contract: %v", err) + } } - reversed.ReplacedClasses[addr] = classHash } - return &reversed, nil + return nil } func logDBKey(key []byte, height uint64) []byte { From 1b207ce76a900b26ff2d045b01b79dc3150788d5 Mon Sep 17 00:00:00 2001 From: weiihann Date: Thu, 17 Oct 2024 17:37:37 +0800 Subject: [PATCH 11/13] add comments --- core/contract.go | 15 +++++++++++++++ core/state.go | 1 + 2 files changed, 16 insertions(+) diff --git a/core/contract.go b/core/contract.go index 66c9058e0d..84378c6b51 100644 --- a/core/contract.go +++ b/core/contract.go @@ -20,6 +20,11 @@ var ( type OnValueChanged = func(location, oldValue *felt.Felt) error +// StateContract represents a contract instance. +// The usage of a `StateContract` is as follows: +// 1. Create or obtain `StateContract` instance from the database. +// 2. Update the contract fields +// 3. Commit the contract to the database type StateContract struct { // ClassHash is the hash of the contract's class ClassHash *felt.Felt @@ -33,6 +38,7 @@ type StateContract struct { dirtyStorage map[felt.Felt]*felt.Felt `cbor:"-"` } +// NewStateContract creates a new contract instance. func NewStateContract( addr *felt.Felt, classHash *felt.Felt, @@ -50,6 +56,7 @@ func NewStateContract( return sc } +// StorageRoot returns the root of the contract's storage trie. func (c *StateContract) StorageRoot(txn db.Transaction) (*felt.Felt, error) { storageTrie, err := storage(c.Address, txn) if err != nil { @@ -59,6 +66,8 @@ func (c *StateContract) StorageRoot(txn db.Transaction) (*felt.Felt, error) { return storageTrie.Root() } +// UpdateStorage updates the storage of a contract. +// Note that this does not modify the storage trie, which must be committed separately. func (c *StateContract) UpdateStorage(key, value *felt.Felt) { if c.dirtyStorage == nil { c.dirtyStorage = make(map[felt.Felt]*felt.Felt) @@ -67,6 +76,7 @@ func (c *StateContract) UpdateStorage(key, value *felt.Felt) { c.dirtyStorage[*key] = value } +// GetStorage retrieves the value of a storage location from the contract's storage func (c *StateContract) GetStorage(key *felt.Felt, txn db.Transaction) (*felt.Felt, error) { if c.dirtyStorage != nil { if val, ok := c.dirtyStorage[*key]; ok { @@ -83,25 +93,30 @@ func (c *StateContract) GetStorage(key *felt.Felt, txn db.Transaction) (*felt.Fe return storage.Get(key) } +// logOldValue is a helper function to record the history of a contract's value func (c *StateContract) logOldValue(key []byte, oldValue *felt.Felt, height uint64, txn db.Transaction) error { return txn.Set(logDBKey(key, height), oldValue.Marshal()) } +// LogStorage records the history of the contract's storage func (c *StateContract) LogStorage(location, oldVal *felt.Felt, height uint64, txn db.Transaction) error { key := storageLogKey(c.Address, location) return c.logOldValue(key, oldVal, height, txn) } +// LogNonce records the history of the contract's nonce func (c *StateContract) LogNonce(height uint64, txn db.Transaction) error { key := nonceLogKey(c.Address) return c.logOldValue(key, c.Nonce, height, txn) } +// LogClassHash records the history of the contract's class hash func (c *StateContract) LogClassHash(height uint64, txn db.Transaction) error { key := classHashLogKey(c.Address) return c.logOldValue(key, c.ClassHash, height, txn) } +// BufferedCommit creates a buffered transaction and commits the contract to the database func (c *StateContract) BufferedCommit(txn db.Transaction, logChanges bool, blockNum uint64) (*db.BufferedTransaction, error) { bufferedTxn := db.NewBufferedTransaction(txn) diff --git a/core/state.go b/core/state.go index 049914f4d8..e63e051c1f 100644 --- a/core/state.go +++ b/core/state.go @@ -383,6 +383,7 @@ var ( } ) +// Commit updates the state by committing the dirty contracts to the database. func (s *State) Commit( stateTrie *trie.Trie, contracts map[felt.Felt]*StateContract, From 8a5ed9f2ace293b44027e545d9f646a9224dde3d Mon Sep 17 00:00:00 2001 From: weiihann Date: Fri, 18 Oct 2024 18:46:04 +0800 Subject: [PATCH 12/13] chore --- core/state.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/state.go b/core/state.go index e63e051c1f..6a62e9d865 100644 --- a/core/state.go +++ b/core/state.go @@ -464,7 +464,7 @@ func (s *State) updateContractClasses( contracts map[felt.Felt]*StateContract, ) error { for addr, classHash := range replacedClasses { - contract, err := s.getOrCreateContract(addr, contracts) + contract, err := s.getContract(addr, contracts) if err != nil { return err } @@ -487,7 +487,7 @@ func (s *State) updateContractNonces( contracts map[felt.Felt]*StateContract, ) error { for addr, nonce := range nonces { - contract, err := s.getOrCreateContract(addr, contracts) + contract, err := s.getContract(addr, contracts) if err != nil { return err } @@ -509,7 +509,7 @@ func (s *State) updateContractStorages( contracts map[felt.Felt]*StateContract, ) error { for addr, diff := range storageDiffs { - contract, err := s.getOrCreateContract(addr, contracts) + contract, err := s.getContract(addr, contracts) if err != nil { if _, ok := noClassContracts[addr]; ok && errors.Is(err, ErrContractNotDeployed) { contract = NewStateContract(&addr, noClassContractsClassHash, &felt.Zero, blockNumber) @@ -524,7 +524,7 @@ func (s *State) updateContractStorages( return nil } -func (s *State) getOrCreateContract(addr felt.Felt, contracts map[felt.Felt]*StateContract) (*StateContract, error) { +func (s *State) getContract(addr felt.Felt, contracts map[felt.Felt]*StateContract) (*StateContract, error) { contract, ok := contracts[addr] if !ok { var err error From c466ccac39f4c8c9080d820d79d660fe3af53d0a Mon Sep 17 00:00:00 2001 From: weiihann Date: Sat, 19 Oct 2024 12:18:33 +0800 Subject: [PATCH 13/13] refactor dirty contracts --- core/state.go | 62 ++++++++++++++++++++++++--------------------------- 1 file changed, 29 insertions(+), 33 deletions(-) diff --git a/core/state.go b/core/state.go index 6a62e9d865..5cbdcc5c45 100644 --- a/core/state.go +++ b/core/state.go @@ -48,11 +48,15 @@ type StateReader interface { type State struct { txn db.Transaction + + // This map holds the contract objects which are being updated in the current state update. + contracts map[felt.Felt]*StateContract } func NewState(txn db.Transaction) *State { return &State{ - txn: txn, + txn: txn, + contracts: make(map[felt.Felt]*StateContract), } } @@ -292,7 +296,6 @@ func (s *State) Update(blockNumber uint64, update *StateUpdate, declaredClasses return err } - contracts := make(map[felt.Felt]*StateContract) // register deployed contracts for addr, classHash := range update.StateDiff.DeployedContracts { // check if contract is already deployed @@ -305,14 +308,14 @@ func (s *State) Update(blockNumber uint64, update *StateUpdate, declaredClasses return err } - contracts[addr] = NewStateContract(&addr, classHash, &felt.Zero, blockNumber) + s.contracts[addr] = NewStateContract(&addr, classHash, &felt.Zero, blockNumber) } - if err = s.updateContracts(blockNumber, update.StateDiff, true, contracts); err != nil { + if err = s.updateContracts(blockNumber, update.StateDiff, true); err != nil { return err } - if err = s.Commit(stateTrie, contracts, true, blockNumber); err != nil { + if err = s.Commit(stateTrie, true, blockNumber); err != nil { return fmt.Errorf("state commit: %v", err) } @@ -386,7 +389,6 @@ var ( // Commit updates the state by committing the dirty contracts to the database. func (s *State) Commit( stateTrie *trie.Trie, - contracts map[felt.Felt]*StateContract, logChanges bool, blockNumber uint64, ) error { @@ -396,22 +398,21 @@ func (s *State) Commit( } // // sort the contracts in descending storage diff order - keys := slices.SortedStableFunc(maps.Keys(contracts), func(a, b felt.Felt) int { - return len(contracts[a].dirtyStorage) - len(contracts[b].dirtyStorage) + keys := slices.SortedStableFunc(maps.Keys(s.contracts), func(a, b felt.Felt) int { + return len(s.contracts[a].dirtyStorage) - len(s.contracts[b].dirtyStorage) }) contractPools := pool.NewWithResults[*bufferedTransactionWithAddress]().WithErrors().WithMaxGoroutines(runtime.GOMAXPROCS(0)) for _, addr := range keys { - contract := contracts[addr] contractPools.Go(func() (*bufferedTransactionWithAddress, error) { - txn, err := contract.BufferedCommit(s.txn, logChanges, blockNumber) + txn, err := s.contracts[addr].BufferedCommit(s.txn, logChanges, blockNumber) if err != nil { return nil, err } return &bufferedTransactionWithAddress{ txn: txn, - addr: contract.Address, + addr: &addr, }, nil }) } @@ -432,39 +433,37 @@ func (s *State) Commit( } } - for _, contract := range contracts { + for _, contract := range s.contracts { if err := s.updateContractCommitment(stateTrie, contract); err != nil { return err } } + // finally, clear the contracts map + s.contracts = make(map[felt.Felt]*StateContract) + return nil } -func (s *State) updateContracts(blockNumber uint64, diff *StateDiff, logChanges bool, contracts map[felt.Felt]*StateContract) error { - if contracts == nil { - return fmt.Errorf("contracts is nil") - } - - if err := s.updateContractClasses(blockNumber, diff.ReplacedClasses, logChanges, contracts); err != nil { +func (s *State) updateContracts(blockNumber uint64, diff *StateDiff, logChanges bool) error { + if err := s.updateContractClasses(blockNumber, diff.ReplacedClasses, logChanges); err != nil { return err } - if err := s.updateContractNonces(blockNumber, diff.Nonces, logChanges, contracts); err != nil { + if err := s.updateContractNonces(blockNumber, diff.Nonces, logChanges); err != nil { return err } - return s.updateContractStorages(blockNumber, diff.StorageDiffs, contracts) + return s.updateContractStorages(blockNumber, diff.StorageDiffs) } func (s *State) updateContractClasses( blockNumber uint64, replacedClasses map[felt.Felt]*felt.Felt, logChanges bool, - contracts map[felt.Felt]*StateContract, ) error { for addr, classHash := range replacedClasses { - contract, err := s.getContract(addr, contracts) + contract, err := s.getContract(addr) if err != nil { return err } @@ -484,10 +483,9 @@ func (s *State) updateContractNonces( blockNumber uint64, nonces map[felt.Felt]*felt.Felt, logChanges bool, - contracts map[felt.Felt]*StateContract, ) error { for addr, nonce := range nonces { - contract, err := s.getContract(addr, contracts) + contract, err := s.getContract(addr) if err != nil { return err } @@ -506,14 +504,13 @@ func (s *State) updateContractNonces( func (s *State) updateContractStorages( blockNumber uint64, storageDiffs map[felt.Felt]map[felt.Felt]*felt.Felt, - contracts map[felt.Felt]*StateContract, ) error { for addr, diff := range storageDiffs { - contract, err := s.getContract(addr, contracts) + contract, err := s.getContract(addr) if err != nil { if _, ok := noClassContracts[addr]; ok && errors.Is(err, ErrContractNotDeployed) { contract = NewStateContract(&addr, noClassContractsClassHash, &felt.Zero, blockNumber) - contracts[addr] = contract + s.contracts[addr] = contract } else { return err } @@ -524,15 +521,15 @@ func (s *State) updateContractStorages( return nil } -func (s *State) getContract(addr felt.Felt, contracts map[felt.Felt]*StateContract) (*StateContract, error) { - contract, ok := contracts[addr] +func (s *State) getContract(addr felt.Felt) (*StateContract, error) { + contract, ok := s.contracts[addr] if !ok { var err error contract, err = GetContract(&addr, s.txn) if err != nil { return nil, err } - contracts[addr] = contract + s.contracts[addr] = contract } return contract, nil } @@ -655,12 +652,11 @@ func (s *State) Revert(blockNumber uint64, update *StateUpdate) error { return err } - contracts := make(map[felt.Felt]*StateContract) - if err = s.updateContracts(blockNumber, reversedDiff, false, contracts); err != nil { + if err = s.updateContracts(blockNumber, reversedDiff, false); err != nil { return fmt.Errorf("update contracts: %v", err) } - if err = s.Commit(stateTrie, contracts, false, blockNumber); err != nil { + if err = s.Commit(stateTrie, false, blockNumber); err != nil { return fmt.Errorf("state commit: %v", err) }