Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor state and contract modules #2224

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
Squashed commit of the following:
commit f5dc02c
Author: Kirill <[email protected]>
Date:   Thu Oct 17 11:15:15 2024 +0400

    Small restructure of plugin logic (#2221)

commit b43c46f
Author: Rian Hughes <[email protected]>
Date:   Wed Oct 16 15:42:07 2024 +0300

    Support plugins (#2051)

    Co-authored-by: rian <[email protected]>
    Co-authored-by: LordGhostX <[email protected]>

commit ca006de
Author: Kirill <[email protected]>
Date:   Wed Oct 16 16:14:16 2024 +0400

    Replace UnmarshalJSON() with UnmarshalText() for transaction statuses (#2220)

    * Replace UnmarshalJSON() with UnmarshalText() for transaction statuses

    UnmarshalText avoids issues with forgetting quotes in JSON, making it simpler for parsing plain text values.

fix merge conflicts
weiihann committed Oct 28, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit 1097fb4fa0152065f852fff2beb6d5c63deefa77
252 changes: 151 additions & 101 deletions core/state.go
Original file line number Diff line number Diff line change
@@ -323,6 +323,58 @@ func (s *State) Update(blockNumber uint64, update *StateUpdate, declaredClasses
return s.verifyStateUpdateRoot(update.NewRoot)
}

func (s *State) GetReverseStateDiff(blockNumber uint64, diff *StateDiff) (*StateDiff, error) {
reversed := *diff

// storage diffs
reversed.StorageDiffs = make(map[felt.Felt]map[felt.Felt]*felt.Felt, len(diff.StorageDiffs))
for addr, storageDiffs := range diff.StorageDiffs {
reversedDiffs := make(map[felt.Felt]*felt.Felt, len(storageDiffs))
for key := range storageDiffs {
value := &felt.Zero
if blockNumber > 0 {
oldValue, err := s.ContractStorageAt(&addr, &key, blockNumber-1)
if err != nil {
return nil, err
}
value = oldValue
}
reversedDiffs[key] = value
}
reversed.StorageDiffs[addr] = reversedDiffs
}

// nonces
reversed.Nonces = make(map[felt.Felt]*felt.Felt, len(diff.Nonces))
for addr := range diff.Nonces {
oldNonce := &felt.Zero
if blockNumber > 0 {
var err error
oldNonce, err = s.ContractNonceAt(&addr, blockNumber-1)
if err != nil {
return nil, err
}
}
reversed.Nonces[addr] = oldNonce
}

// replaced
reversed.ReplacedClasses = make(map[felt.Felt]*felt.Felt, len(diff.ReplacedClasses))
for addr := range diff.ReplacedClasses {
classHash := &felt.Zero
if blockNumber > 0 {
var err error
classHash, err = s.ContractClassHashAt(&addr, blockNumber-1)
if err != nil {
return nil, err
}
}
reversed.ReplacedClasses[addr] = classHash
}

return &reversed, nil
}

var (
noClassContractsClassHash = new(felt.Felt).SetUint64(0)

@@ -393,17 +445,27 @@ func (s *State) updateContracts(blockNumber uint64, diff *StateDiff, logChanges
return fmt.Errorf("contracts is nil")
}

var err error
if err := s.updateContractClasses(blockNumber, diff.ReplacedClasses, logChanges, contracts); err != nil {
return err
}

if err := s.updateContractNonces(blockNumber, diff.Nonces, logChanges, contracts); err != nil {
return err
}

// update contract class hashes
for addr, classHash := range diff.ReplacedClasses {
contract, ok := contracts[addr]
if !ok {
contract, err = GetContract(&addr, s.txn)
if err != nil {
return err
}
contracts[addr] = contract
return s.updateContractStorages(blockNumber, diff.StorageDiffs, contracts)
}

func (s *State) updateContractClasses(
blockNumber uint64,
replacedClasses map[felt.Felt]*felt.Felt,
logChanges bool,
contracts map[felt.Felt]*StateContract,
) error {
for addr, classHash := range replacedClasses {
contract, err := s.getOrCreateContract(addr, contracts)
if err != nil {
return err
}

if logChanges {
@@ -414,16 +476,19 @@ func (s *State) updateContracts(blockNumber uint64, diff *StateDiff, logChanges

contract.ClassHash = classHash
}
return nil
}

// update contract nonces
for addr, nonce := range diff.Nonces {
contract, ok := contracts[addr]
if !ok {
contract, err = GetContract(&addr, s.txn)
if err != nil {
return err
}
contracts[addr] = contract
func (s *State) updateContractNonces(
blockNumber uint64,
nonces map[felt.Felt]*felt.Felt,
logChanges bool,
contracts map[felt.Felt]*StateContract,
) error {
for addr, nonce := range nonces {
contract, err := s.getOrCreateContract(addr, contracts)
if err != nil {
return err
}

if logChanges {
@@ -434,29 +499,43 @@ func (s *State) updateContracts(blockNumber uint64, diff *StateDiff, logChanges

contract.Nonce = nonce
}
return nil
}

// update contract storages
for addr, diff := range diff.StorageDiffs {
contract, ok := contracts[addr]
if !ok {
contract, err = GetContract(&addr, s.txn)
if err != nil {
// makes sure that all noClassContracts are deployed
if _, ok := noClassContracts[addr]; ok && errors.Is(err, ErrContractNotDeployed) {
contract = NewStateContract(&addr, noClassContractsClassHash, &felt.Zero, blockNumber)
} else {
return err
}
func (s *State) updateContractStorages(
blockNumber uint64,
storageDiffs map[felt.Felt]map[felt.Felt]*felt.Felt,
contracts map[felt.Felt]*StateContract,
) error {
for addr, diff := range storageDiffs {
contract, err := s.getOrCreateContract(addr, contracts)
if err != nil {
if _, ok := noClassContracts[addr]; ok && errors.Is(err, ErrContractNotDeployed) {
contract = NewStateContract(&addr, noClassContractsClassHash, &felt.Zero, blockNumber)
contracts[addr] = contract
} else {
return err
}
contracts[addr] = contract
}

contract.dirtyStorage = diff
}

return nil
}

func (s *State) getOrCreateContract(addr felt.Felt, contracts map[felt.Felt]*StateContract) (*StateContract, error) {
contract, ok := contracts[addr]
if !ok {
var err error
contract, err = GetContract(&addr, s.txn)
if err != nil {
return nil, err
}
contracts[addr] = contract
}
return contract, nil
}

type DeclaredClass struct {
At uint64
Class Class
@@ -561,12 +640,15 @@ func (s *State) Revert(blockNumber uint64, update *StateUpdate) error {
return fmt.Errorf("remove declared classes: %v", err)
}

// update contracts
reversedDiff, err := s.buildReverseDiff(blockNumber, update.StateDiff)
reversedDiff, err := s.GetReverseStateDiff(blockNumber, update.StateDiff)
if err != nil {
return fmt.Errorf("build reverse diff: %v", err)
}

if err = s.performStateDeletions(blockNumber, reversedDiff); err != nil {
return fmt.Errorf("perform state deletions: %v", err)
}

stateTrie, storageCloser, err := s.storage()
if err != nil {
return err
@@ -588,30 +670,8 @@ func (s *State) Revert(blockNumber uint64, update *StateUpdate) error {
}
}

// purge noClassContracts
//
// As noClassContracts are not in StateDiff.DeployedContracts we can only purge them if their storage no longer exists.
// Updating contracts with reverse diff will eventually lead to the deletion of noClassContract's storage key from db. Thus,
// we can use the lack of key's existence as reason for purging noClassContracts.
for addr := range noClassContracts {
contract, err := GetContract(&addr, s.txn)
if err != nil {
if !errors.Is(err, ErrContractNotDeployed) {
return err
}
continue
}

rootKey, err := contract.StorageRoot(s.txn)
if err != nil {
return fmt.Errorf("get root key: %v", err)
}

if rootKey.Equal(&felt.Zero) {
if err = s.purgeContract(stateTrie, &addr); err != nil {
return fmt.Errorf("purge contract: %v", err)
}
}
if err = s.purgeNoClassContracts(stateTrie); err != nil {
return fmt.Errorf("purge no class contract: %v", err)
}

if err = storageCloser(); err != nil {
@@ -673,69 +733,59 @@ func (s *State) purgeContract(stateTrie *trie.Trie, addr *felt.Felt) error {
return nil
}

func (s *State) buildReverseDiff(blockNumber uint64, diff *StateDiff) (*StateDiff, error) {
reversed := *diff

func (s *State) performStateDeletions(blockNumber uint64, diff *StateDiff) error {
// storage diffs
reversed.StorageDiffs = make(map[felt.Felt]map[felt.Felt]*felt.Felt, len(diff.StorageDiffs))
for addr, storageDiffs := range diff.StorageDiffs {
reversedDiffs := make(map[felt.Felt]*felt.Felt, len(storageDiffs))
for key := range storageDiffs {
value := &felt.Zero
if blockNumber > 0 {
oldValue, err := s.ContractStorageAt(&addr, &key, blockNumber-1)
if err != nil {
return nil, err
}
value = oldValue
}

if err := s.DeleteContractStorageLog(&addr, &key, blockNumber); err != nil {
return nil, err
return err
}
reversedDiffs[key] = value
}
reversed.StorageDiffs[addr] = reversedDiffs
}

// nonces
reversed.Nonces = make(map[felt.Felt]*felt.Felt, len(diff.Nonces))
for addr := range diff.Nonces {
oldNonce := &felt.Zero

if blockNumber > 0 {
var err error
oldNonce, err = s.ContractNonceAt(&addr, blockNumber-1)
if err != nil {
return nil, err
}
}

if err := s.DeleteContractNonceLog(&addr, blockNumber); err != nil {
return nil, err
return err
}
reversed.Nonces[addr] = oldNonce
}

// replaced
reversed.ReplacedClasses = make(map[felt.Felt]*felt.Felt, len(diff.ReplacedClasses))
// replaced classes
for addr := range diff.ReplacedClasses {
classHash := &felt.Zero
if blockNumber > 0 {
var err error
classHash, err = s.ContractClassHashAt(&addr, blockNumber-1)
if err != nil {
return nil, err
if err := s.DeleteContractClassHashLog(&addr, blockNumber); err != nil {
return err
}
}

return nil
}

func (s *State) purgeNoClassContracts(stateTrie *trie.Trie) error {
// As noClassContracts are not in StateDiff.DeployedContracts we can only purge them if their storage no longer exists.
// Updating contracts with reverse diff will eventually lead to the deletion of noClassContract's storage key from db. Thus,
// we can use the lack of key's existence as reason for purging noClassContracts.
for addr := range noClassContracts {
contract, err := GetContract(&addr, s.txn)
if err != nil {
if !errors.Is(err, ErrContractNotDeployed) {
return err
}
continue
}

if err := s.DeleteContractClassHashLog(&addr, blockNumber); err != nil {
return nil, err
rootKey, err := contract.StorageRoot(s.txn)
if err != nil {
return fmt.Errorf("get root key: %v", err)
}

if rootKey.Equal(&felt.Zero) {
if err = s.purgeContract(stateTrie, &addr); err != nil {
return fmt.Errorf("purge contract: %v", err)
}
}
reversed.ReplacedClasses[addr] = classHash
}

return &reversed, nil
return nil
}

func logDBKey(key []byte, height uint64) []byte {