diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index e69de29b..00000000 diff --git a/chain/config/config.go b/chain/config/config.go index 6a3b0c4e..3b40fe2e 100644 --- a/chain/config/config.go +++ b/chain/config/config.go @@ -19,6 +19,17 @@ type TestChainConfig struct { // ContractAddressOverrides describes contracts that are going to be deployed at deterministic addresses ContractAddressOverrides map[common.Hash]common.Address `json:"contractAddressOverrides,omitempty"` + + // ForkConfig indicates the RPC configuration if fuzzing using a network fork. + ForkConfig ForkConfig `json:"forkConfig,omitempty"` +} + +// ForkConfig describes configuration for fuzzing using a network fork +type ForkConfig struct { + ForkModeEnabled bool `json:"forkModeEnabled"` + RpcUrl string `json:"rpcUrl"` + RpcBlock uint64 `json:"rpcBlock"` + PoolSize uint `json:"poolSize"` } // CheatCodeConfig describes any configuration options related to the use of vm extensions (a.k.a. cheat codes) diff --git a/chain/config/config_defaults.go b/chain/config/config_defaults.go index 5a611c1c..8699173c 100644 --- a/chain/config/config_defaults.go +++ b/chain/config/config_defaults.go @@ -11,6 +11,12 @@ func DefaultTestChainConfig() (*TestChainConfig, error) { EnableFFI: false, }, SkipAccountChecks: true, + ForkConfig: ForkConfig{ + ForkModeEnabled: false, + RpcUrl: "", + RpcBlock: 1, + PoolSize: 20, + }, } // Return the generated configuration. diff --git a/chain/state/cache/caches_test.go b/chain/state/cache/caches_test.go new file mode 100644 index 00000000..e34a8e10 --- /dev/null +++ b/chain/state/cache/caches_test.go @@ -0,0 +1,177 @@ +package cache + +import ( + "context" + "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/assert" + "math/rand" + "os" + "sync" + "testing" +) + +// TestNonPersistentStateObjectCacheRace tests for race conditions +func TestNonPersistentStateObjectCacheRace(t *testing.T) { + cache := newNonPersistentStateCache() + numObjects := 5 + writers := 10 + numWrites := 10_000 + readers := 10 + numReads := 10_000 + + var wg sync.WaitGroup + wg.Add(writers + readers) + + write := func(r *rand.Rand, writesRem int) { + for writesRem > 0 { + objId := r.Uint32() % uint32(numObjects) + addr := common.BytesToAddress([]byte{byte(objId)}) + stateObject := StateObject{ + Nonce: r.Uint64(), + } + err := cache.WriteStateObject(addr, stateObject) + assert.NoError(t, err) + writesRem-- + } + wg.Add(-1) + } + + read := func(r *rand.Rand, readsRem int) { + for readsRem > 0 { + objId := r.Uint32() % uint32(numObjects) + addr := common.BytesToAddress([]byte{byte(objId)}) + _, _ = cache.GetStateObject(addr) + readsRem-- + } + wg.Add(-1) + } + + for i := 0; i < readers; i++ { + go read(rand.New(rand.NewSource(int64(i))), numReads) + } + + for i := 0; i < writers; i++ { + go write(rand.New(rand.NewSource(int64(i))), numWrites) + } + wg.Wait() +} + +// TestNonPersistentSlotCacheRace tests for race conditions +func TestNonPersistentSlotCacheRace(t *testing.T) { + cache := newNonPersistentStateCache() + numContracts := 3 + numObjects := 5 + writers := 10 + numWrites := 10_000 + readers := 10 + numReads := 10_000 + + var wg sync.WaitGroup + wg.Add(writers + readers) + + write := func(r *rand.Rand, writesRem int) { + for writesRem > 0 { + addrId := r.Uint32() % uint32(numContracts) + addr := common.BytesToAddress([]byte{byte(addrId)}) + + objId := r.Uint32() % uint32(numObjects) + objHash := common.BytesToHash([]byte{byte(objId)}) + + data := r.Uint32() % 255 + dataHash := common.BytesToHash([]byte{byte(data)}) + + err := cache.WriteSlotData(addr, objHash, dataHash) + assert.NoError(t, err) + writesRem-- + } + wg.Add(-1) + } + + read := func(r *rand.Rand, readsRem int) { + for readsRem > 0 { + addrId := r.Uint32() % uint32(numContracts) + addr := common.BytesToAddress([]byte{byte(addrId)}) + + objId := r.Uint32() % uint32(numObjects) + objHash := common.BytesToHash([]byte{byte(objId)}) + _, _ = cache.GetSlotData(addr, objHash) + readsRem-- + } + wg.Add(-1) + } + + for i := 0; i < readers; i++ { + go read(rand.New(rand.NewSource(int64(i))), numReads) + } + + for i := 0; i < writers; i++ { + go write(rand.New(rand.NewSource(int64(i))), numWrites) + } + wg.Wait() +} + +// TestPersistentCache tests read/write capability of the persistent cache, along with persistence itself. +func TestPersistentCache(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + rpcAddr := "www.rpc.net/ethereum/etc" + blockHeight := uint64(55555) + tmpDir, err := os.MkdirTemp("", "test-*") + assert.NoError(t, err) + defer os.RemoveAll(tmpDir) + + pc, err := newPersistentCache(ctx, tmpDir, rpcAddr, blockHeight) + assert.NoError(t, err) + + stateObjectAddr := common.Address{0x55} + stateObjectData := &StateObject{ + Nonce: rand.Uint64(), + } + // try reading from a state cache that doesnt exist + _, err = pc.GetStateObject(stateObjectAddr) + assert.Error(t, err) + assert.Equal(t, err, ErrCacheMiss) + + // write the state cache, then make sure we can read it + err = pc.WriteStateObject(stateObjectAddr, *stateObjectData) + assert.NoError(t, err) + + so, err := pc.GetStateObject(stateObjectAddr) + assert.NoError(t, err) + assert.Equal(t, *stateObjectData, *so) + + // repeat the above for slots + stateSlotAddress := common.Hash{0x66, 0x01} + stateSlotData := common.Hash{0x81} + + // try reading from a slot that doesnt exist + _, err = pc.GetSlotData(stateObjectAddr, stateSlotAddress) + assert.Error(t, err) + assert.Equal(t, err, ErrCacheMiss) + + // write the slot, then make sure we can read it + err = pc.WriteSlotData(stateObjectAddr, stateSlotAddress, stateSlotData) + assert.NoError(t, err) + + data, err := pc.GetSlotData(stateObjectAddr, stateSlotAddress) + assert.NoError(t, err) + assert.Equal(t, stateSlotData, data) + + // now terminate our cache to test persistence + cancel() + + ctx, cancel = context.WithCancel(context.Background()) + defer cancel() + pc, err = newPersistentCache(ctx, tmpDir, rpcAddr, blockHeight) + assert.NoError(t, err) + + // state cache matches + so, err = pc.GetStateObject(stateObjectAddr) + assert.NoError(t, err) + assert.Equal(t, *stateObjectData, *so) + + // slot matches + data, err = pc.GetSlotData(stateObjectAddr, stateSlotAddress) + assert.NoError(t, err) + assert.Equal(t, stateSlotData, data) +} diff --git a/chain/state/cache/factory.go b/chain/state/cache/factory.go new file mode 100644 index 00000000..ec5a115f --- /dev/null +++ b/chain/state/cache/factory.go @@ -0,0 +1,26 @@ +package cache + +import ( + "context" + "errors" + "os" +) + +var _ StateCache = (*nonPersistentStateCache)(nil) +var _ StateCache = (*persistentCache)(nil) + +var ErrCacheMiss = errors.New("not found in cache") + +// NewPersistentCache creates a new set of persistent caches that will persist cache content to disk. +// Each cache is indexed by the RPC address (to separate network caches) and blockNum +func NewPersistentCache(ctx context.Context, rpcAddr string, height uint64) (StateCache, error) { + workingDir, err := os.Getwd() + if err != nil { + return nil, err + } + return newPersistentCache(ctx, workingDir, rpcAddr, height) +} + +func NewNonPersistentCache() (StateCache, error) { + return newNonPersistentStateCache(), nil +} diff --git a/chain/state/cache/non_persistent_cache.go b/chain/state/cache/non_persistent_cache.go new file mode 100644 index 00000000..02ffc57e --- /dev/null +++ b/chain/state/cache/non_persistent_cache.go @@ -0,0 +1,67 @@ +package cache + +import ( + "github.com/ethereum/go-ethereum/common" + "sync" +) + +// nonPersistentStateCache provides a thread-safe cache for storing state objects and slots without persisting to disk. +type nonPersistentStateCache struct { + stateObjectLock sync.RWMutex + stateObjectCache map[common.Address]*StateObject + + slotLock sync.RWMutex + slotCache map[common.Address]map[common.Hash]common.Hash +} + +func newNonPersistentStateCache() *nonPersistentStateCache { + return &nonPersistentStateCache{ + stateObjectLock: sync.RWMutex{}, + slotLock: sync.RWMutex{}, + stateObjectCache: make(map[common.Address]*StateObject), + slotCache: make(map[common.Address]map[common.Hash]common.Hash), + } +} + +// GetStateObject checks if the addr is present in the cache, and if not, returns an error +func (s *nonPersistentStateCache) GetStateObject(addr common.Address) (*StateObject, error) { + s.stateObjectLock.RLock() + defer s.stateObjectLock.RUnlock() + + if obj, ok := s.stateObjectCache[addr]; !ok { + return nil, ErrCacheMiss + } else { + return obj, nil + } +} + +func (s *nonPersistentStateCache) WriteStateObject(addr common.Address, data StateObject) error { + s.stateObjectLock.Lock() + defer s.stateObjectLock.Unlock() + s.stateObjectCache[addr] = &data + return nil +} + +// GetSlotData checks if the specified data is stored in the cache, and if not, returns an error. +func (s *nonPersistentStateCache) GetSlotData(addr common.Address, slot common.Hash) (common.Hash, error) { + s.slotLock.RLock() + defer s.slotLock.RUnlock() + if slotLookup, ok := s.slotCache[addr]; ok { + if data, ok := slotLookup[slot]; ok { + return data, nil + } + } + return common.Hash{}, ErrCacheMiss +} + +func (s *nonPersistentStateCache) WriteSlotData(addr common.Address, slot common.Hash, data common.Hash) error { + s.slotLock.Lock() + defer s.slotLock.Unlock() + + if _, ok := s.slotCache[addr]; !ok { + s.slotCache[addr] = make(map[common.Hash]common.Hash) + } + + s.slotCache[addr][slot] = data + return nil +} diff --git a/chain/state/cache/persistent_cache.go b/chain/state/cache/persistent_cache.go new file mode 100644 index 00000000..c4c90a78 --- /dev/null +++ b/chain/state/cache/persistent_cache.go @@ -0,0 +1,238 @@ +package cache + +import ( + "context" + "crypto/sha256" + "encoding/json" + "errors" + "fmt" + "github.com/ethereum/go-ethereum/common" + "log" + "os" + "path/filepath" + "sync" + "time" +) +import "go.etcd.io/bbolt" + +// persistentCache provides a thread-safe cache for storing objects/slots that persists the cache to disk. +type persistentCache struct { + memCache *nonPersistentStateCache + db *bbolt.DB + + pendingWriteMutex sync.Mutex + pendingWrites []pendingWrite + flushThreshold int +} + +type pendingWrite struct { + key []byte + value []byte +} + +func newPersistentCache(ctx context.Context, workingDir string, rpcAddr string, height uint64) (*persistentCache, error) { + cacheDir, err := createCacheDirectory(workingDir) + if err != nil { + return nil, fmt.Errorf("failed to create cache directory: %w", err) + } + cacheFile := filepath.Join(cacheDir, getCacheFilename(rpcAddr, height)) + db, err := bbolt.Open(cacheFile, 0600, &bbolt.Options{Timeout: 1 * time.Second}) + if err != nil { + return nil, fmt.Errorf("could not open db: %v", err) + } + + // create default bucket if it doesnt exist + err = db.Update(func(tx *bbolt.Tx) error { + _, err := tx.CreateBucketIfNotExists([]byte("cache")) + return err + }) + if err != nil { + return nil, err + } + + memCache := newNonPersistentStateCache() + p := &persistentCache{ + memCache: memCache, + db: db, + flushThreshold: 25, + pendingWrites: []pendingWrite{}, + pendingWriteMutex: sync.Mutex{}, + } + + // close db if context cancelled + go func() { + <-ctx.Done() + err := p.Close() + if err != nil { + log.Printf("error closing database: %v", err) + } + }() + + return p, nil +} + +func (p *persistentCache) getFromPersist(key []byte, value interface{}) (bool, error) { + found := false + err := p.db.View(func(tx *bbolt.Tx) error { + bucket := tx.Bucket([]byte("cache")) + data := bucket.Get(key) + if data == nil { + return nil + } + found = true + return json.Unmarshal(data, &value) + }) + if err != nil { + return false, fmt.Errorf("could not get value: %v", err) + } + + if !found { + return false, nil + } else { + return true, nil + } +} + +func (p *persistentCache) writeToPersist(key []byte, value []byte) error { + item := pendingWrite{ + key: key, + value: value, + } + p.pendingWriteMutex.Lock() + defer p.pendingWriteMutex.Unlock() + + p.pendingWrites = append(p.pendingWrites, item) + if len(p.pendingWrites) >= p.flushThreshold { + return p.flushWrites() + } else { + return nil + } +} + +func (p *persistentCache) flushWrites() error { + err := p.db.Update(func(tx *bbolt.Tx) error { + for _, pw := range p.pendingWrites { + bucket := tx.Bucket([]byte("cache")) + err := bucket.Put(pw.key, pw.value) + if err != nil { + return err + } + } + p.pendingWrites = p.pendingWrites[:0] + return nil + }) + return err +} + +func (p *persistentCache) GetStateObject(addr common.Address) (*StateObject, error) { + so, err := p.memCache.GetStateObject(addr) + if err == nil { + return so, err + } + + if errors.Is(err, ErrCacheMiss) { + // check persistent cache + s := StateObject{} + exists, err := p.getFromPersist(addr[:], &s) + if err != nil { + return nil, err + } + if exists { + err = p.memCache.WriteStateObject(addr, s) + return &s, err + } else { + return nil, ErrCacheMiss + } + } else { + return nil, err + } +} + +func (p *persistentCache) GetSlotData(addr common.Address, slot common.Hash) (common.Hash, error) { + data, err := p.memCache.GetSlotData(addr, slot) + if err == nil { + return data, err + } + + if errors.Is(err, ErrCacheMiss) { + // check persistent cache + data := common.Hash{} + + key := append(addr[:], slot[:]...) + exists, err := p.getFromPersist(key, &data) + if err != nil { + return common.Hash{}, err + } + if exists { + err = p.memCache.WriteSlotData(addr, slot, data) + return data, err + } else { + return common.Hash{}, ErrCacheMiss + } + } else { + return common.Hash{}, err + } +} + +func (p *persistentCache) WriteStateObject(addr common.Address, data StateObject) error { + err := p.memCache.WriteStateObject(addr, data) + if err != nil { + return err + } + + serialized, err := json.Marshal(data) + if err != nil { + return err + } + + err = p.writeToPersist(addr[:], serialized) + return err +} + +func (p *persistentCache) WriteSlotData(addr common.Address, slot common.Hash, data common.Hash) error { + err := p.memCache.WriteSlotData(addr, slot, data) + if err != nil { + return err + } + + serialized, err := json.Marshal(data) + if err != nil { + return err + } + + key := append(addr[:], slot[:]...) + err = p.writeToPersist(key, serialized) + return err +} + +func (p *persistentCache) Close() error { + err := p.flushWrites() + if err != nil { + return err + } + err = p.db.Close() + return err +} + +func createCacheDirectory(workingDir string) (string, error) { + cachePath := filepath.Join(workingDir, ".medusacache") + _, err := os.Stat(cachePath) + if os.IsNotExist(err) { + // Create directory with 0755 permissions if it doesn't exist + err = os.Mkdir(cachePath, 0755) + if err != nil { + return "", fmt.Errorf("failed to create cache directory: %w", err) + } + } else if err != nil { + return "", fmt.Errorf("failed to check cache directory: %w", err) + } + return cachePath, nil +} + +func getCacheFilename(rpcAddr string, height uint64) string { + h := sha256.New() + h.Write([]byte(rpcAddr)) + bs := h.Sum(nil) + + return fmt.Sprintf("%d-%x.dat", height, bs[0:10]) +} diff --git a/chain/state/cache/types.go b/chain/state/cache/types.go new file mode 100644 index 00000000..e8aad832 --- /dev/null +++ b/chain/state/cache/types.go @@ -0,0 +1,21 @@ +package cache + +import ( + "github.com/ethereum/go-ethereum/common" + "github.com/holiman/uint256" +) + +// StateObject gives us a way to store state objects without the overhead of using geth's stateObject +type StateObject struct { + Balance *uint256.Int + Nonce uint64 + Code []byte +} + +type StateCache interface { + GetStateObject(addr common.Address) (*StateObject, error) + WriteStateObject(addr common.Address, data StateObject) error + + GetSlotData(addr common.Address, slot common.Hash) (common.Hash, error) + WriteSlotData(addr common.Address, slot common.Hash, data common.Hash) error +} diff --git a/chain/state/empty_backend.go b/chain/state/empty_backend.go new file mode 100644 index 00000000..2efed6d9 --- /dev/null +++ b/chain/state/empty_backend.go @@ -0,0 +1,19 @@ +package state + +import ( + "github.com/ethereum/go-ethereum/common" + "github.com/holiman/uint256" +) + +/* +EmptyBackend defines a backend containing no data. Intended to be used for local-only state databases. +*/ +type EmptyBackend struct{} + +func (d EmptyBackend) GetStorageAt(address common.Address, hash common.Hash) (common.Hash, error) { + return common.Hash{}, nil +} + +func (d EmptyBackend) GetStateObject(address common.Address) (*uint256.Int, uint64, []byte, error) { + return uint256.NewInt(0), 0, nil, nil +} diff --git a/chain/state/factories.go b/chain/state/factories.go new file mode 100644 index 00000000..16518098 --- /dev/null +++ b/chain/state/factories.go @@ -0,0 +1,46 @@ +package state + +import ( + "github.com/crytic/medusa/chain/types" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/state" +) + +/* +MedusaStateFactory defines a thread-safe interface for creating new state databases. This abstraction allows globally +shared data like RPC caches to be shared across all TestChain instances. +*/ +type MedusaStateFactory interface { + // New initializes a new state + New(root common.Hash, db state.Database) (types.MedusaStateDB, error) +} + +var _ MedusaStateFactory = (*UnbackedStateFactory)(nil) +var _ MedusaStateFactory = (*ForkedStateFactory)(nil) + +// ForkedStateFactory is used to build StateDBs that are backed by a remote RPC +type ForkedStateFactory struct { + globalRemoteStateQuery stateBackend +} + +func NewForkedStateFactory(globalCache stateBackend) *ForkedStateFactory { + return &ForkedStateFactory{globalCache} +} + +func (f *ForkedStateFactory) New(root common.Hash, db state.Database) (types.MedusaStateDB, error) { + remoteStateProvider := newRemoteStateProvider(f.globalRemoteStateQuery) + return state.NewForkedStateDb(root, db, remoteStateProvider) +} + +// UnbackedStateFactory is used to build StateDBs that are not backed by any remote state, but still use the custom +// forked stateDB logic around state object existence checks. +type UnbackedStateFactory struct{} + +func NewUnbackedStateFactory() *UnbackedStateFactory { + return &UnbackedStateFactory{} +} + +func (f *UnbackedStateFactory) New(root common.Hash, db state.Database) (types.MedusaStateDB, error) { + remoteStateProvider := newRemoteStateProvider(EmptyBackend{}) + return state.NewForkedStateDb(root, db, remoteStateProvider) +} diff --git a/chain/state/factories_test.go b/chain/state/factories_test.go new file mode 100644 index 00000000..8c0897b3 --- /dev/null +++ b/chain/state/factories_test.go @@ -0,0 +1,246 @@ +package state + +import ( + "github.com/crytic/medusa/chain/state/cache" + types2 "github.com/crytic/medusa/chain/types" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/rawdb" + gethstate "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/core/tracing" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/triedb" + "github.com/holiman/uint256" + "github.com/stretchr/testify/assert" + "testing" +) + +/* TestForkedStateDB provides unit testing for medusa-geth's ForkedStateDb */ +func TestForkedStateDB(t *testing.T) { + fixture := newPrePopulatedBackendFixture() + factory := NewForkedStateFactory(fixture.Backend) + + db := rawdb.NewMemoryDatabase() + tdb := triedb.NewDatabase(db, nil) + cachingDb := gethstate.NewDatabaseWithNodeDB(db, tdb) + + stateDb1, err := factory.New(types.EmptyRootHash, cachingDb) + assert.NoError(t, err) + genesisSnap := stateDb1.Snapshot() + + /* ensure the statedb is hitting the backend */ + assert.True(t, stateDb1.Exist(fixture.StateObjectContractAddress)) + assert.True(t, stateDb1.Exist(fixture.StateObjectEOAAddress)) + assert.False(t, stateDb1.Exist(fixture.StateObjectEmptyAddress)) + + fixture.verifyAgainstState(t, stateDb1) + + /* write some new data and make sure it's readable */ + newAccount := common.BytesToAddress([]byte{1, 2, 3, 4, 5, 6}) + newAccountData := cache.StateObject{ + Balance: uint256.NewInt(5), + Nonce: 99, + Code: []byte{1, 2, 3}, + } + + stateDb1.SetNonce(newAccount, newAccountData.Nonce) + assert.True(t, stateDb1.Exist(newAccount)) + stateDb1.SetCode(newAccount, newAccountData.Code) + stateDb1.SetBalance(newAccount, newAccountData.Balance, tracing.BalanceChangeUnspecified) + checkAccountAgainstFixture(t, stateDb1, newAccount, newAccountData) + + /* roll back to snapshot, ensure fork data still queryable and newly added data was purged */ + stateDb1.Snapshot() + stateDb1.RevertToSnapshot(genesisSnap) + fixture.verifyAgainstState(t, stateDb1) + assert.False(t, stateDb1.Exist(newAccount)) + + /* now we want to test to verify our fork-populated data is being persisted */ + root, err := stateDb1.Commit(1, true) + assert.NoError(t, err) + stateDb2, err := factory.New(root, cachingDb) + assert.NoError(t, err) + + fixture.verifyAgainstState(t, stateDb2) +} + +/* +TestForkedStateFactory verifies the various independence/shared properties of each forkedStateDB created by the +factory. This is because the underlying RPC/caching layer is shared between all TestChain instances globally, +but this sharing relationship should not cause state to leak from one forkedStateDb to another. +*/ +func TestForkedStateFactory(t *testing.T) { + fixture := newPrePopulatedBackendFixture() + factory := NewForkedStateFactory(fixture.Backend) + + stateDb1, err := createEmptyStateDb(factory) + assert.NoError(t, err) + stateDb1.Snapshot() + + stateDb2, err := createEmptyStateDb(factory) + assert.NoError(t, err) + stateDb2.Snapshot() + + /* naive check to ensure they're both pulling from the same remote */ + fixture.verifyAgainstState(t, stateDb1) + fixture.verifyAgainstState(t, stateDb2) + + // snapshot and roll em back + stateDb1.Snapshot() + stateDb2.Snapshot() + stateDb1.RevertToSnapshot(0) + stateDb2.RevertToSnapshot(0) + + /* now we'll mutate a cold account in one stateDB and ensure the mutation doesn't propagate */ + valueAdded := uint256.NewInt(100) + expectedSum := uint256.NewInt(0).Add(fixture.StateObjectEOA.Balance, valueAdded) + stateDb1.AddBalance(fixture.StateObjectEOAAddress, valueAdded, tracing.BalanceChangeUnspecified) + bal := stateDb1.GetBalance(fixture.StateObjectEOAAddress) + assert.Equal(t, expectedSum, bal) + + // check the other statedb + bal = stateDb2.GetBalance(fixture.StateObjectEOAAddress) + assert.Equal(t, bal, fixture.StateObjectEOA.Balance) + + // just in case there's some weird pointer issue that was introduced, create a new stateDB and check it as well + stateDb3, err := createEmptyStateDb(factory) + assert.NoError(t, err) + bal = stateDb3.GetBalance(fixture.StateObjectEOAAddress) + assert.Equal(t, bal, fixture.StateObjectEOA.Balance) + + /* + now we'll emulate one stateDB obtaining a new piece of data from RPC and ensuring the other stateDB loads + the same data + */ + newAccount := common.BytesToAddress([]byte{1, 2, 3, 4, 5, 6}) + slotKey := common.BytesToHash([]byte{5, 5, 5, 5, 5, 5, 5}) + slotData := common.BytesToHash([]byte{6, 6, 6, 6, 6, 6, 6}) + + fixture.Backend.SetStorageAt(newAccount, slotKey, slotData) + data := stateDb1.GetState(newAccount, slotKey) + assert.EqualValues(t, slotData, data) + + // do it again with a fresh stateDB + stateDb4, err := createEmptyStateDb(factory) + assert.NoError(t, err) + data = stateDb4.GetState(newAccount, slotKey) + assert.EqualValues(t, slotData, data) +} + +/* +TestEmptyBackendFactoryDifferential tests the differential properties between a stateDB using an empty forked backend +versus directly using geth's statedb. +*/ +func TestEmptyBackendFactoryDifferential(t *testing.T) { + gethFactory := &gethStateFactory{} + unbackedFactory := NewUnbackedStateFactory() + + gethStateDb, err := createEmptyStateDb(gethFactory) + assert.NoError(t, err) + + unbackedStateDb, err := createEmptyStateDb(unbackedFactory) + assert.NoError(t, err) + + /* start with existence/empty of an existing object. should be identical. */ + addr := common.BytesToAddress([]byte{1}) + gethStateDb.SetNonce(addr, 5) + unbackedStateDb.SetNonce(addr, 5) + assert.EqualValues(t, gethStateDb.Exist(addr), unbackedStateDb.Exist(addr)) + assert.EqualValues(t, gethStateDb.Empty(addr), unbackedStateDb.Empty(addr)) + + /* existence/empty of a non-existing object, should be identical. */ + nonExistentStateObjAddr := common.BytesToAddress([]byte{5, 5, 5, 5, 5}) + assert.EqualValues(t, gethStateDb.Exist(nonExistentStateObjAddr), unbackedStateDb.Exist(nonExistentStateObjAddr)) + assert.EqualValues(t, gethStateDb.Empty(nonExistentStateObjAddr), unbackedStateDb.Empty(nonExistentStateObjAddr)) + + emptyStateObjectAddr := common.BytesToAddress([]byte{6, 7, 8, 9, 10}) + value := uint256.NewInt(5000) + gethStateDb.SetBalance(emptyStateObjectAddr, value, tracing.BalanceChangeUnspecified) + unbackedStateDb.SetBalance(emptyStateObjectAddr, value, tracing.BalanceChangeUnspecified) + + /* existence/empty of an empty object, should be identical. */ + gethStateDb.SubBalance(emptyStateObjectAddr, value, tracing.BalanceChangeUnspecified) + unbackedStateDb.SubBalance(emptyStateObjectAddr, value, tracing.BalanceChangeUnspecified) + assert.EqualValues(t, gethStateDb.Exist(emptyStateObjectAddr), unbackedStateDb.Exist(emptyStateObjectAddr)) + assert.EqualValues(t, gethStateDb.Empty(emptyStateObjectAddr), unbackedStateDb.Empty(emptyStateObjectAddr)) +} + +/* +TestForkedBackendDifferential tests the differential properties between a stateDB using a forked backend +versus directly using geth's statedb. Consider this test a canonical definition of how our forked stateDB acts +differently from geth's. +Good place for future fuzz testing if we run into issues. +*/ +func TestForkedBackendDifferential(t *testing.T) { + fixture := newPrePopulatedBackendFixture() + factory := NewForkedStateFactory(fixture.Backend) + forkedStateDb, err := createEmptyStateDb(factory) + assert.NoError(t, err) + + gethFactory := &gethStateFactory{} + gethStateDb, err := createEmptyStateDb(gethFactory) + assert.NoError(t, err) + + // modify the geth statedb to reflect the fixture's different accounts + // contract + gethStateDb.SetBalance( + fixture.StateObjectContractAddress, + fixture.StateObjectContract.Balance, + tracing.BalanceChangeUnspecified) + gethStateDb.SetNonce(fixture.StateObjectContractAddress, fixture.StateObjectContract.Nonce) + gethStateDb.SetCode(fixture.StateObjectContractAddress, fixture.StateObjectContract.Code) + // eoa + gethStateDb.SetBalance( + fixture.StateObjectEOAAddress, + fixture.StateObjectEOA.Balance, + tracing.BalanceChangeUnspecified) + gethStateDb.SetNonce(fixture.StateObjectEOAAddress, fixture.StateObjectEOA.Nonce) + // do not set the empty account. On a live geth node, the empty account will be pruned. + + // check exist/empty equivalence for the contract account + assert.EqualValues( + t, + gethStateDb.Exist(fixture.StateObjectContractAddress), + forkedStateDb.Exist(fixture.StateObjectContractAddress)) + assert.EqualValues( + t, + gethStateDb.Empty(fixture.StateObjectContractAddress), + forkedStateDb.Empty(fixture.StateObjectContractAddress)) + + // check exist/empty equivalence for the eoa account + assert.EqualValues( + t, + gethStateDb.Exist(fixture.StateObjectEOAAddress), + forkedStateDb.Exist(fixture.StateObjectEOAAddress)) + assert.EqualValues( + t, + gethStateDb.Empty(fixture.StateObjectEOAAddress), + forkedStateDb.Empty(fixture.StateObjectEOAAddress)) + + // check exist/empty equivalence for the empty account + assert.EqualValues( + t, + gethStateDb.Empty(fixture.StateObjectEmptyAddress), + forkedStateDb.Empty(fixture.StateObjectEmptyAddress)) + // note how this is _not_ EqualValues. As far as we know, this is the only place where the forked state provider + // diverges from geth's behavior. + assert.NotEqualValues( + t, + gethStateDb.Exist(fixture.StateObjectEmptyAddress), + forkedStateDb.Exist(fixture.StateObjectEmptyAddress)) +} + +// createEmptyStateDb creates an empty stateDB using the provided factory. Intended for tests only. +func createEmptyStateDb(factory MedusaStateFactory) (types2.MedusaStateDB, error) { + db := rawdb.NewMemoryDatabase() + tdb := triedb.NewDatabase(db, nil) + cachingDb := gethstate.NewDatabaseWithNodeDB(db, tdb) + return factory.New(types.EmptyRootHash, cachingDb) +} + +// GethStateFactory is used to build vanilla StateDBs that perfectly reproduce geth's statedb behavior. Only intended +// to be used for differential testing against the unbacked state factory. +type gethStateFactory struct{} + +func (f *gethStateFactory) New(root common.Hash, db gethstate.Database) (types2.MedusaStateDB, error) { + return gethstate.New(root, db, nil) +} diff --git a/chain/state/fixtures_test.go b/chain/state/fixtures_test.go new file mode 100644 index 00000000..91c72835 --- /dev/null +++ b/chain/state/fixtures_test.go @@ -0,0 +1,149 @@ +package state + +import ( + "github.com/crytic/medusa/chain/state/cache" + types2 "github.com/crytic/medusa/chain/types" + "github.com/ethereum/go-ethereum/common" + "github.com/holiman/uint256" + "github.com/stretchr/testify/assert" + "testing" +) + +/* This file is exclusively for test fixtures. */ + +var _ stateBackend = (*prePopulatedBackend)(nil) + +// prePopulatedBackend is an offline-only backend used for testing. +type prePopulatedBackend struct { + storageSlots map[common.Address]map[common.Hash]common.Hash + stateObjects map[common.Address]cache.StateObject +} + +func newPrepopulatedBackend( + storageSlots map[common.Address]map[common.Hash]common.Hash, + stateObjects map[common.Address]cache.StateObject, +) *prePopulatedBackend { + return &prePopulatedBackend{ + storageSlots: storageSlots, + stateObjects: stateObjects, + } +} + +func (p *prePopulatedBackend) GetStorageAt(address common.Address, hash common.Hash) (common.Hash, error) { + if c, exists := p.storageSlots[address]; exists { + if data, exists := c[hash]; exists { + return data, nil + } + } + return common.Hash{}, nil +} + +func (p *prePopulatedBackend) GetStateObject(address common.Address) (*uint256.Int, uint64, []byte, error) { + if s, exists := p.stateObjects[address]; exists { + return s.Balance, s.Nonce, s.Code, nil + } + return uint256.NewInt(0), uint64(0), []byte{}, nil +} + +func (p *prePopulatedBackend) SetStorageAt(address common.Address, slotKey common.Hash, value common.Hash) { + if _, exists := p.storageSlots[address]; !exists { + p.storageSlots[address] = make(map[common.Hash]common.Hash) + } + p.storageSlots[address][slotKey] = value +} + +// prepopulatedBackendFixture is a test fixture for a pre-populated backend +type prepopulatedBackendFixture struct { + Backend *prePopulatedBackend + + StateObjectContractAddress common.Address + StateObjectContract cache.StateObject + + StorageSlotPopulatedKey common.Hash + StorageSlotPopulatedData common.Hash + + StorageSlotEmptyKey common.Hash + StorageSlotEmpty common.Hash + + StateObjectEOAAddress common.Address + StateObjectEOA cache.StateObject + + StateObjectEmptyAddress common.Address + StateObjectEmpty cache.StateObject +} + +func newPrePopulatedBackendFixture() *prepopulatedBackendFixture { + stateObjectContract := cache.StateObject{ + Balance: uint256.NewInt(1000), + Nonce: 5, + Code: []byte{1, 2, 3}, + } + stateObjectEOA := cache.StateObject{ + Balance: uint256.NewInt(5000), + Nonce: 1, + Code: nil, + } + + stateObjectEmpty := cache.StateObject{ + Balance: uint256.NewInt(0), + Nonce: 0, + Code: nil, + } + + contractAddress := common.BytesToAddress([]byte{5, 5, 5, 5}) + eoaAddress := common.BytesToAddress([]byte{6, 6, 6, 6}) + emptyAddress := common.BytesToAddress([]byte{0, 0, 0, 1}) + + storageSlotPopulated := common.HexToHash("0xdeadbeef") + storageSlotPopulatedAddress := common.HexToHash("0xaaaaaaaa") + + storageSlotEmpty := common.Hash{} + storageSlotEmptyAddress := common.HexToHash("0xbbbbbbbbb") + + stateObjects := make(map[common.Address]cache.StateObject) + stateObjects[contractAddress] = stateObjectContract + stateObjects[eoaAddress] = stateObjectEOA + stateObjects[emptyAddress] = stateObjectEmpty + + storageObjects := make(map[common.Address]map[common.Hash]common.Hash) + storageObjects[contractAddress] = make(map[common.Hash]common.Hash) + storageObjects[contractAddress][storageSlotPopulatedAddress] = storageSlotPopulated + storageObjects[contractAddress][storageSlotEmptyAddress] = storageSlotEmpty + + prepopulatedBackend := newPrepopulatedBackend(storageObjects, stateObjects) + + return &prepopulatedBackendFixture{ + Backend: prepopulatedBackend, + StateObjectContractAddress: contractAddress, + StateObjectContract: stateObjectContract, + StorageSlotPopulatedKey: storageSlotPopulatedAddress, + StorageSlotPopulatedData: storageSlotPopulated, + StorageSlotEmptyKey: storageSlotEmptyAddress, + StorageSlotEmpty: storageSlotEmpty, + StateObjectEOAAddress: eoaAddress, + StateObjectEOA: stateObjectEOA, + StateObjectEmpty: stateObjectEmpty, + StateObjectEmptyAddress: emptyAddress, + } +} + +// verifyAgainstState is used by the test suite to verify the statedb is pulling fields from the +// prepopulated fixture +func (p *prepopulatedBackendFixture) verifyAgainstState(t *testing.T, stateDb types2.MedusaStateDB) { + checkAccountAgainstFixture(t, stateDb, p.StateObjectContractAddress, p.StateObjectContract) + checkAccountAgainstFixture(t, stateDb, p.StateObjectEOAAddress, p.StateObjectEOA) + checkAccountAgainstFixture(t, stateDb, p.StateObjectEmptyAddress, p.StateObjectEmpty) +} + +// checkAccountAgainstFixture is used by the test suite to verify an account in the stateDB matches the provided fixture +func checkAccountAgainstFixture(t *testing.T, stateDb types2.MedusaStateDB, addr common.Address, fixture cache.StateObject) { + bal := stateDb.GetBalance(addr) + assert.NoError(t, stateDb.Error()) + assert.EqualValues(t, bal, fixture.Balance) + nonce := stateDb.GetNonce(addr) + assert.NoError(t, stateDb.Error()) + assert.EqualValues(t, nonce, fixture.Nonce) + code := stateDb.GetCode(addr) + assert.NoError(t, stateDb.Error()) + assert.EqualValues(t, code, fixture.Code) +} diff --git a/chain/state/remote_state_provider.go b/chain/state/remote_state_provider.go new file mode 100644 index 00000000..8df34f38 --- /dev/null +++ b/chain/state/remote_state_provider.go @@ -0,0 +1,244 @@ +package state + +import ( + "fmt" + "github.com/ethereum/go-ethereum/common" + gethState "github.com/ethereum/go-ethereum/core/state" + "github.com/holiman/uint256" +) + +var _ gethState.RemoteStateProvider = (*RemoteStateProvider)(nil) + +/* +RemoteStateProvider implements an import mechanism for state that was not written by a locally executed transaction. +This allows us to use the state of a remote RPC server for fork mode, or the state of some other serialized database. +It is consumed by medusa-geth's ForkStateDb. +This provider is snapshot-aware and will refuse to fetch certain data if it has reason to believe the local statedb +has newer data. +*/ +type RemoteStateProvider struct { + // stateBackend is used to fetch state when RemoteStateProvider believes the remote source to be canonical + stateBackend stateBackend + + // stateObjBySnapshot keeps track of imported state objects by snapshot, thus allowing state objects to be + // re-imported when their snapshot is reverted + stateObjBySnapshot map[int][]common.Address + stateSlotBySnapshot map[int]map[common.Address][]common.Hash + + // stateObjsImported keeps track of all the state objects RemoteStateProvider has imported + stateObjsImported map[common.Address]struct{} + // stateSlotsImported keeps track of all the storage slots RemoteStateProvider has imported + stateSlotsImported map[common.Address]map[common.Hash]struct{} + + // contractsDeployed keeps track of contracts that were deployed locally. + contractsDeployed map[common.Address]struct{} + contractsDeployedBySnapshot map[int][]common.Address +} + +func newRemoteStateProvider(stateBackend stateBackend) *RemoteStateProvider { + return &RemoteStateProvider{ + stateBackend: stateBackend, + stateObjBySnapshot: make(map[int][]common.Address), + stateSlotBySnapshot: make(map[int]map[common.Address][]common.Hash), + stateObjsImported: make(map[common.Address]struct{}), + stateSlotsImported: make(map[common.Address]map[common.Hash]struct{}), + contractsDeployed: make(map[common.Address]struct{}), + contractsDeployedBySnapshot: make(map[int][]common.Address), + } +} + +/* +ImportStateObject attempts to import a state object from the backend. If the state object has already been imported and +its snapshot has not been reverted, this function will return an error with CannotQueryDirtyAccount set to true. +*/ +func (s *RemoteStateProvider) ImportStateObject( + addr common.Address, + snapId int, +) (bal *uint256.Int, nonce uint64, code []byte, e *gethState.RemoteStateError) { + if _, ok := s.stateObjsImported[addr]; ok { + return nil, 0, nil, &gethState.RemoteStateError{ + CannotQueryDirtyAccount: true, + Error: fmt.Errorf("state object %s was already imported", + addr.Hex()), + } + } + + bal, nonce, code, err := s.stateBackend.GetStateObject(addr) + if err == nil { + s.recordDirtyStateObject(addr, snapId) + return bal, nonce, code, nil + } else { + return uint256.NewInt(0), 0, nil, &gethState.RemoteStateError{ + CannotQueryDirtyAccount: false, + Error: err, + } + } +} + +/* +ImportStorageAt attempts to import a storage slot from the backend. If the slot has already been imported and its +snapshot has not been reverted, this function will return an error with CannotQueryDirtySlot set to true. +If the storage slot is associated with a contract that was deployed locally, this function will return an error with +CannotQueryDirtySlot set to true, since the remote database will never contain canonical slot data for a locally +deployed contract. +*/ +func (s *RemoteStateProvider) ImportStorageAt( + addr common.Address, + slot common.Hash, + snapId int, +) (common.Hash, *gethState.RemoteStorageError) { + // if the contract was deployed locally, the RPC will not have data for its slots + if _, exists := s.contractsDeployed[addr]; exists { + return common.Hash{}, &gethState.RemoteStorageError{ + CannotQueryDirtySlot: true, + Error: fmt.Errorf( + "state slot %s of address %s cannot be remote-queried because the contract was deployed locally", + slot.Hex(), + addr.Hex(), + ), + } + } + + imported := s.isStateSlotImported(addr, slot) + if imported { + return common.Hash{}, &gethState.RemoteStorageError{ + CannotQueryDirtySlot: true, + Error: fmt.Errorf( + "state slot %s of address %s was already imported in snapshot %d", + slot.Hex(), + addr.Hex(), + snapId, + ), + } + } + data, err := s.stateBackend.GetStorageAt(addr, slot) + if err == nil { + s.recordDirtyStateSlot(addr, slot, snapId) + return data, nil + } else { + return common.Hash{}, &gethState.RemoteStorageError{ + CannotQueryDirtySlot: false, + Error: err, + } + } +} + +/* +MarkSlotWritten is used to notify the provider that a local transaction has written a value to the specified slot. +As long as the snapshot indicated by snapId is not reverted, the provider will now return "dirty" if ImportStorageAt is +called for the slot in the future. +*/ +func (s *RemoteStateProvider) MarkSlotWritten(addr common.Address, slot common.Hash, snapId int) { + s.recordDirtyStateSlot(addr, slot, snapId) +} + +/* +MarkContractDeployed is used to notify the provider that a contract was locally deployed to the specified address. +As long as the snapshot indicated by snapId is not reverted, the provider will not return "dirty" if ImportStorageAt is +called for any slots associated with the contract. +*/ +func (s *RemoteStateProvider) MarkContractDeployed(addr common.Address, snapId int) { + s.recordContractDeployed(addr, snapId) +} + +/* +NotifyRevertedToSnapshot is used to notify the provider that the state has been reverted back to snapId. The provider +uses this information to clear its import history up to and not including the provided snapId. +*/ +func (s *RemoteStateProvider) NotifyRevertedToSnapshot(snapId int) { + // purge all records down to and not including the provided snapId + + /* accounts */ + accountsToClear := make([]common.Address, 0) + for sId, accounts := range s.stateObjBySnapshot { + if sId > snapId { + accountsToClear = append(accountsToClear, accounts...) + delete(s.stateObjBySnapshot, sId) + } + } + for _, addr := range accountsToClear { + delete(s.stateObjsImported, addr) + } + + /* state slots */ + accountSlotsToClear := make(map[common.Address][]common.Hash) + for sId, accounts := range s.stateSlotBySnapshot { + if sId > snapId { + for addr, slots := range accounts { + if _, ok := accountSlotsToClear[addr]; !ok { + accountSlotsToClear[addr] = make([]common.Hash, 0, len(slots)) + } + accountSlotsToClear[addr] = append(accountSlotsToClear[addr], slots...) + } + delete(s.stateSlotBySnapshot, sId) + } + } + + for addr, slots := range accountSlotsToClear { + for _, slot := range slots { + delete(s.stateSlotsImported[addr], slot) + } + } + + /* contract deploys */ + contractsToClear := make([]common.Address, 0) + for sId, contracts := range s.contractsDeployedBySnapshot { + if sId > snapId { + contractsToClear = append(contractsToClear, contracts...) + delete(s.contractsDeployedBySnapshot, sId) + } + } + for _, contract := range contractsToClear { + delete(s.contractsDeployed, contract) + } +} + +func (s *RemoteStateProvider) isStateSlotImported(addr common.Address, slot common.Hash) bool { + if _, ok := s.stateSlotsImported[addr]; !ok { + return false + } else { + if _, ok := s.stateSlotsImported[addr][slot]; !ok { + return false + } else { + return true + } + } +} + +func (s *RemoteStateProvider) recordDirtyStateObject(addr common.Address, snapId int) { + s.stateObjsImported[addr] = struct{}{} + if _, ok := s.stateObjBySnapshot[snapId]; !ok { + s.stateObjBySnapshot[snapId] = make([]common.Address, 0) + } + s.stateObjBySnapshot[snapId] = append(s.stateObjBySnapshot[snapId], addr) +} + +func (s *RemoteStateProvider) recordDirtyStateSlot(addr common.Address, slot common.Hash, snapId int) { + if _, ok := s.stateSlotsImported[addr]; !ok { + s.stateSlotsImported[addr] = make(map[common.Hash]struct{}) + } + // If this slot has already been marked dirty, we don't want to do it again or overwrite the old snapId. + // We only want to track the oldest snapId that made a slot dirty, since that is the only snapId that should + // change the RemoteStateProvider's import behavior if reverted. + if _, exists := s.stateSlotsImported[addr][slot]; exists { + return + } + + s.stateSlotsImported[addr][slot] = struct{}{} + if _, ok := s.stateSlotBySnapshot[snapId]; !ok { + s.stateSlotBySnapshot[snapId] = make(map[common.Address][]common.Hash) + } + if _, ok := s.stateSlotBySnapshot[snapId][addr]; !ok { + s.stateSlotBySnapshot[snapId][addr] = make([]common.Hash, 0) + } + + s.stateSlotBySnapshot[snapId][addr] = append(s.stateSlotBySnapshot[snapId][addr], slot) +} + +func (s *RemoteStateProvider) recordContractDeployed(addr common.Address, snapId int) { + s.contractsDeployed[addr] = struct{}{} + if _, ok := s.contractsDeployedBySnapshot[snapId]; !ok { + s.contractsDeployedBySnapshot[snapId] = make([]common.Address, 0) + } + s.contractsDeployedBySnapshot[snapId] = append(s.contractsDeployedBySnapshot[snapId], addr) +} diff --git a/chain/state/remote_state_provider_test.go b/chain/state/remote_state_provider_test.go new file mode 100644 index 00000000..996637f7 --- /dev/null +++ b/chain/state/remote_state_provider_test.go @@ -0,0 +1,154 @@ +package state + +import ( + "github.com/crytic/medusa/chain/state/cache" + "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestRemoteStateProvider_ImportStateObject(t *testing.T) { + fixture := newPrePopulatedBackendFixture() + stateProvider := newRemoteStateProvider(fixture.Backend) + + snapId := 5 + importTest := func(objectAddr common.Address, expectedObjectData cache.StateObject) { + /* test a basic state cache read */ + bal, nonce, code, err := stateProvider.ImportStateObject(objectAddr, snapId) + assert.Nil(t, err) + assert.EqualValues(t, bal, expectedObjectData.Balance) + assert.EqualValues(t, nonce, expectedObjectData.Nonce) + assert.EqualValues(t, code, expectedObjectData.Code) + + /* reading the same state cache twice should return dirty error */ + _, _, _, err = stateProvider.ImportStateObject(objectAddr, snapId) + assert.True(t, err.CannotQueryDirtyAccount) + assert.NotNil(t, err) + assert.Error(t, err.Error) + + /* reverting to a snapshot equal to that of the imported value should still result in a dirty error */ + stateProvider.NotifyRevertedToSnapshot(snapId) + _, _, _, err = stateProvider.ImportStateObject(objectAddr, snapId) + assert.True(t, err.CannotQueryDirtyAccount) + assert.NotNil(t, err.Error) + + /* reverting to a snapshot before that of the imported value should result in the value being returned */ + stateProvider.NotifyRevertedToSnapshot(snapId - 1) + bal, nonce, code, err = stateProvider.ImportStateObject(objectAddr, snapId) + assert.Nil(t, err) + assert.EqualValues(t, bal, expectedObjectData.Balance) + assert.EqualValues(t, nonce, expectedObjectData.Nonce) + assert.EqualValues(t, code, expectedObjectData.Code) + } + + // run importTest for a contract + importTest(fixture.StateObjectContractAddress, fixture.StateObjectContract) + // run importTest for an EOA + importTest(fixture.StateObjectEOAAddress, fixture.StateObjectEOA) + // run importTest for an empty/non-existent account + importTest(fixture.StateObjectEmptyAddress, fixture.StateObjectEmpty) +} + +func TestRemoteStateProvider_ImportStorageAt(t *testing.T) { + fixture := newPrePopulatedBackendFixture() + stateProvider := newRemoteStateProvider(fixture.Backend) + + snapId := 5 + importTest := func(contractAddr common.Address, slotKey common.Hash, expectedData common.Hash) { + /* test a basic state slot read */ + data, err := stateProvider.ImportStorageAt(contractAddr, slotKey, snapId) + assert.Nil(t, err) + assert.EqualValues(t, expectedData, data) + + /* reading the same slot twice should result in an error */ + _, err = stateProvider.ImportStorageAt(contractAddr, slotKey, snapId) + assert.NotNil(t, err) + assert.True(t, err.CannotQueryDirtySlot) + + /* reverting to a snapshot equal to that of the imported value should still result in a dirty error */ + stateProvider.NotifyRevertedToSnapshot(snapId) + _, err = stateProvider.ImportStorageAt(contractAddr, slotKey, snapId) + assert.NotNil(t, err) + assert.True(t, err.CannotQueryDirtySlot) + + /* reverting to a snapshot before that of the imported value should result in the value being returned */ + stateProvider.NotifyRevertedToSnapshot(snapId - 1) + data, err = stateProvider.ImportStorageAt(contractAddr, slotKey, snapId) + assert.Nil(t, err) + assert.EqualValues(t, expectedData, data) + } + + /* test for populated slot */ + importTest(fixture.StateObjectContractAddress, fixture.StorageSlotPopulatedKey, fixture.StorageSlotPopulatedData) + /* test for empty slot */ + importTest(fixture.StateObjectContractAddress, fixture.StorageSlotEmptyKey, fixture.StorageSlotEmpty) +} + +func TestRemoteStateProvider_MarkSlotWritten(t *testing.T) { + fixture := newPrePopulatedBackendFixture() + stateProvider := newRemoteStateProvider(fixture.Backend) + snapId := 5 + + /* marking a slot as dirty should result in a dirty error when it's read */ + stateProvider.MarkSlotWritten(fixture.StateObjectContractAddress, fixture.StorageSlotPopulatedKey, snapId) + _, err := stateProvider.ImportStorageAt(fixture.StateObjectContractAddress, fixture.StorageSlotPopulatedKey, snapId) + assert.NotNil(t, err) + assert.Error(t, err.Error) + assert.True(t, err.CannotQueryDirtySlot) + + /* reverting to a snapshot before the mark should allow the value to be read */ + stateProvider.NotifyRevertedToSnapshot(snapId - 1) + _, err = stateProvider.ImportStorageAt(fixture.StateObjectContractAddress, fixture.StorageSlotPopulatedKey, snapId) + assert.Nil(t, err) + + /* + If a slot is written to twice in two successive snapshots, reverting one of the snapshots should not make the + value readable from the remoteStateProvider. + */ + initSnapId := snapId - 2 + stateProvider.NotifyRevertedToSnapshot(initSnapId) + _, err = stateProvider.ImportStorageAt(fixture.StateObjectContractAddress, fixture.StorageSlotPopulatedKey, initSnapId) + assert.Nil(t, err) + + // first write + stateProvider.MarkSlotWritten(fixture.StateObjectContractAddress, fixture.StorageSlotPopulatedKey, initSnapId) + _, err = stateProvider.ImportStorageAt(fixture.StateObjectContractAddress, fixture.StorageSlotPopulatedKey, initSnapId) + assert.NotNil(t, err) + assert.Error(t, err.Error) + assert.True(t, err.CannotQueryDirtySlot) + + // second write + secondSnapId := initSnapId + 1 + stateProvider.MarkSlotWritten(fixture.StateObjectContractAddress, fixture.StorageSlotPopulatedKey, secondSnapId) + _, err = stateProvider.ImportStorageAt(fixture.StateObjectContractAddress, fixture.StorageSlotPopulatedKey, secondSnapId) + assert.NotNil(t, err) + assert.Error(t, err.Error) + assert.True(t, err.CannotQueryDirtySlot) + + // revert to initSnapId + stateProvider.NotifyRevertedToSnapshot(initSnapId) + + // ensure it's still dirty + _, err = stateProvider.ImportStorageAt(fixture.StateObjectContractAddress, fixture.StorageSlotPopulatedKey, initSnapId) + assert.NotNil(t, err) + assert.Error(t, err.Error) + assert.True(t, err.CannotQueryDirtySlot) +} + +func TestRemoteStateProvider_MarkContractDeployed(t *testing.T) { + fixture := newPrePopulatedBackendFixture() + stateProvider := newRemoteStateProvider(fixture.Backend) + snapId := 5 + + /* marking a contract as deployed should result in a dirty error when we try to read its slots */ + stateProvider.MarkContractDeployed(fixture.StateObjectContractAddress, snapId) + _, err := stateProvider.ImportStorageAt(fixture.StateObjectContractAddress, fixture.StorageSlotPopulatedKey, snapId) + assert.NotNil(t, err) + assert.Error(t, err.Error) + assert.True(t, err.CannotQueryDirtySlot) + + /* reverting to a snapshot before the mark should allow the value to be read */ + stateProvider.NotifyRevertedToSnapshot(snapId - 1) + _, err = stateProvider.ImportStorageAt(fixture.StateObjectContractAddress, fixture.StorageSlotPopulatedKey, snapId) + assert.Nil(t, err) +} diff --git a/chain/state/rpc/client_pool.go b/chain/state/rpc/client_pool.go new file mode 100644 index 00000000..6c5802c3 --- /dev/null +++ b/chain/state/rpc/client_pool.go @@ -0,0 +1,124 @@ +package rpc + +import ( + "github.com/ethereum/go-ethereum/rpc" + "golang.org/x/net/context" + "sync" + "time" +) + +const maxRetries = 3 + +/* +ClientPool is an Ethereum JSON-RPC provider that provides automatic connection pooling and request deduplication. +*/ +type ClientPool struct { + rpcClients []*rpc.Client + currentClientIdx int + clientLock sync.Mutex + + inflightRequests map[requestKey]*inflightRequest + inflightLock sync.Mutex + + endpoint string + maxRetries int +} + +func NewClientPool(endpoint string, poolSize uint) (*ClientPool, error) { + pool := &ClientPool{ + rpcClients: make([]*rpc.Client, poolSize), + clientLock: sync.Mutex{}, + inflightRequests: make(map[requestKey]*inflightRequest), + inflightLock: sync.Mutex{}, + endpoint: endpoint, + maxRetries: maxRetries, + } + + // dial out + for i := uint(0); i < poolSize; i++ { + client, err := rpc.Dial(endpoint) + if err != nil { + return nil, err + } + pool.rpcClients[i] = client + } + + return pool, nil +} + +/* +ExecuteRequestBlocking makes a blocking RPC request and stores the result in the result interface pointer. +If there is an existing request on the wire with the same method/args, the calling thread will be blocked until it has +completed. +*/ +func (c *ClientPool) ExecuteRequestBlocking(ctx context.Context, result interface{}, method string, args ...interface{}) error { + pending, err := c.ExecuteRequestAsync(ctx, method, args...) + if err != nil { + return err + } else { + return pending.GetResultBlocking(result) + } +} + +/* +ExecuteRequestAsync makes a non-blocking RPC request whose result may be obtained from interacting with *PendingResult. +If there is an existing request on the wire with the same method/args, this function will return a PendingResult linked +to that request. +*/ +func (c *ClientPool) ExecuteRequestAsync(ctx context.Context, method string, args ...interface{}) (*PendingResult, error) { + key, err := makeRequestKey(method, args...) + if err != nil { + return nil, err + } + + // check for in-flight requests + c.inflightLock.Lock() + if inflight, exists := c.inflightRequests[key]; exists { + c.inflightLock.Unlock() + return newPendingResult(inflight), nil + } else { + // no inflight requests + inflight = &inflightRequest{ + Done: make(chan struct{}), + Context: ctx, + } + c.inflightRequests[key] = inflight + c.inflightLock.Unlock() + client := c.getClient() + + go c.launchRequest(client, inflight, method, args...) + return newPendingResult(inflight), nil + } +} + +// getClient obtains the next client in the round-robin of clients in the pool +func (c *ClientPool) getClient() *rpc.Client { + c.clientLock.Lock() + defer c.clientLock.Unlock() + + client := c.rpcClients[c.currentClientIdx] + c.currentClientIdx = (c.currentClientIdx + 1) % len(c.rpcClients) + + return client +} + +// launchRequest performs the actual RPC request, storing the results of the request in the inflightRequest +func (c *ClientPool) launchRequest( + client *rpc.Client, + request *inflightRequest, + method string, + args ...interface{}) { + defer close(request.Done) + + var err error + var result string + for attempt := 0; attempt < c.maxRetries; attempt++ { + err = client.CallContext(request.Context, &result, method, args...) + if err == nil { + request.Result = []byte("\"" + result + "\"") + return + } + time.Sleep(time.Duration(attempt+1) * 100 * time.Millisecond) + } + request.Error = err +} diff --git a/chain/state/rpc/structs.go b/chain/state/rpc/structs.go new file mode 100644 index 00000000..2a854b4d --- /dev/null +++ b/chain/state/rpc/structs.go @@ -0,0 +1,64 @@ +package rpc + +import ( + "context" + "encoding/json" +) + +/* +PendingResult defines an object that can be returned when calling the RPC asynchronously. It's kinda like a promise as +seen in other languages. +*/ +type PendingResult struct { + request *inflightRequest +} + +func newPendingResult(request *inflightRequest) *PendingResult { + return &PendingResult{ + request: request, + } +} + +/* +GetResultBlocking obtains the result from the client, blocking until the result or an error is available. Callers must +pass a pointer to their data through result. Note that if the fuzzer is shutting down, an error may be returned to +signify the context has been cancelled. +*/ +func (p *PendingResult) GetResultBlocking(result interface{}) error { + select { + case <-p.request.Done: + if p.request.Error != nil { + return p.request.Error + } else { + err := json.Unmarshal(p.request.Result, result) + return err + } + case <-p.request.Context.Done(): + return p.request.Context.Err() + } +} + +// requestKey defines a struct that can uniquely identify an Ethereum RPC request for request deduplication purposes. +type requestKey struct { + Method string + Args string +} + +func makeRequestKey(method string, args ...interface{}) (requestKey, error) { + serialized, err := json.Marshal(args) + if err != nil { + return requestKey{}, err + } else { + return requestKey{Method: method, Args: string(serialized)}, nil + } + +} + +// inflightRequest represents an HTTP-JSON request that is currently traversing the network. +type inflightRequest struct { + // Done is used to signal to each interested worker that the request is completed (possibly with error). + Done chan struct{} + Error error + Result []byte + Context context.Context +} diff --git a/chain/state/rpc_backend.go b/chain/state/rpc_backend.go new file mode 100644 index 00000000..8b628287 --- /dev/null +++ b/chain/state/rpc_backend.go @@ -0,0 +1,173 @@ +package state + +import ( + "context" + "github.com/crytic/medusa/chain/state/cache" + "github.com/crytic/medusa/chain/state/rpc" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/holiman/uint256" +) + +/* +stateBackend defines an interface for fetching arbitrary state from a different source such as a remote RPC server or +K/V store. +*/ +type stateBackend interface { + GetStorageAt(common.Address, common.Hash) (common.Hash, error) + GetStateObject(common.Address) (*uint256.Int, uint64, []byte, error) +} + +var _ stateBackend = (*EmptyBackend)(nil) +var _ stateBackend = (*RPCBackend)(nil) + +/* +RPCBackend defines a stateBackend for fetching state from a remote RPC server. It is locked to a single block height, +and caches data in-memory with no expiry. +*/ +type RPCBackend struct { + context context.Context + clientPool *rpc.ClientPool + height string + + cache cache.StateCache +} + +func NewRPCBackend( + ctx context.Context, + url string, + height uint64, + poolSize uint) (*RPCBackend, error) { + clientPool, err := rpc.NewClientPool(url, poolSize) + if err != nil { + return nil, err + } + + cache, err := cache.NewPersistentCache(ctx, url, height) + if err != nil { + return nil, err + } + + return &RPCBackend{ + context: ctx, + clientPool: clientPool, + height: hexutil.Uint64(height).String(), + cache: cache, + }, nil +} + +// newRPCBackendNoPersistence creates a new RPC backend that will not persist its cache to disk. used for tests. +// nolint:unused +func newRPCBackendNoPersistence( + ctx context.Context, + url string, + height uint64, + poolSize uint) (*RPCBackend, error) { + clientPool, err := rpc.NewClientPool(url, poolSize) + if err != nil { + return nil, err + } + + cache, err := cache.NewNonPersistentCache() + if err != nil { + return nil, err + } + + return &RPCBackend{ + context: ctx, + clientPool: clientPool, + height: hexutil.Uint64(height).String(), + cache: cache, + }, nil +} + +/* +GetStorageAt returns data stored in the remote RPC for the given address/slot. +Note that Ethereum RPC will return zero for slots that have never been written to or are associated with undeployed +contracts. +Errors may be network errors or a context cancelled error when the fuzzer is shutting down. +*/ +func (q *RPCBackend) GetStorageAt(addr common.Address, slot common.Hash) (common.Hash, error) { + data, err := q.cache.GetSlotData(addr, slot) + if err == nil { + return data, nil + } else { + method := "eth_getStorageAt" + var result hexutil.Bytes + err = q.clientPool.ExecuteRequestBlocking(q.context, &result, method, addr, slot, q.height) + if err != nil { + return common.Hash{}, err + } else { + resultCast := common.HexToHash(common.Bytes2Hex(result)) + err = q.cache.WriteSlotData(addr, slot, resultCast) + return resultCast, err + } + } +} + +/* +GetStateObject returns the data stored in the remote RPC for the specified state object +Note that the Ethereum RPC will return zero for accounts that do not exist. +Errors may be network errors or a context cancelled error when the fuzzer is shutting down. +*/ +func (q *RPCBackend) GetStateObject(addr common.Address) (*uint256.Int, uint64, []byte, error) { + obj, err := q.cache.GetStateObject(addr) + if err == nil { + return obj.Balance, obj.Nonce, obj.Code, nil + } else { + balance := hexutil.Big{} + nonce := hexutil.Uint(0) + code := hexutil.Bytes{} + + pendingBalance, err := q.clientPool.ExecuteRequestAsync( + q.context, + "eth_getBalance", + addr, + q.height) + if err != nil { + return nil, 0, nil, err + } + pendingNonce, err := q.clientPool.ExecuteRequestAsync( + q.context, + "eth_getTransactionCount", + addr, + q.height) + if err != nil { + return nil, 0, nil, err + } + + pendingCode, err := q.clientPool.ExecuteRequestAsync( + q.context, + "eth_getCode", + addr, + q.height) + if err != nil { + return nil, 0, nil, err + } + + err = pendingBalance.GetResultBlocking(&balance) + if err != nil { + return nil, 0, nil, err + } + balanceTyped := &uint256.Int{} + balanceTyped.SetFromBig(balance.ToInt()) + + err = pendingNonce.GetResultBlocking(&nonce) + if err != nil { + return nil, 0, nil, err + } + + err = pendingCode.GetResultBlocking(&code) + if err != nil { + return nil, 0, nil, err + } + err = q.cache.WriteStateObject( + addr, + cache.StateObject{ + Balance: balanceTyped, + Nonce: uint64(nonce), + Code: code, + }) + return balanceTyped, uint64(nonce), code, err + } +} diff --git a/chain/test_chain.go b/chain/test_chain.go index 0137c6c3..c84b8030 100644 --- a/chain/test_chain.go +++ b/chain/test_chain.go @@ -3,6 +3,8 @@ package chain import ( "errors" "fmt" + "github.com/crytic/medusa/chain/state" + "golang.org/x/net/context" "math/big" "sort" @@ -14,14 +16,14 @@ import ( "github.com/holiman/uint256" "golang.org/x/exp/maps" - chainTypes "github.com/crytic/medusa/chain/types" + "github.com/crytic/medusa/chain/types" "github.com/crytic/medusa/chain/vendored" "github.com/crytic/medusa/utils" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/math" "github.com/ethereum/go-ethereum/core" - "github.com/ethereum/go-ethereum/core/state" - "github.com/ethereum/go-ethereum/core/types" + gethState "github.com/ethereum/go-ethereum/core/state" + gethTypes "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/params" @@ -33,10 +35,10 @@ type TestChain struct { // blocks represents the blocks created on the current chain. If blocks are sent to the chain which skip some // block numbers, any block in that gap will not be committed here and its block hash and other parameters // will be spoofed when requested through the API, for efficiency. - blocks []*chainTypes.Block + blocks []*types.Block // pendingBlock is a block currently under construction by the chain which has not yet been committed. - pendingBlock *chainTypes.Block + pendingBlock *types.Block // pendingBlockContext is the vm.BlockContext for the current pending block. This is used by cheatcodes to override the EVM // interpreter's behavior. This should be set when a new EVM is created by the test chain e.g. using vm.NewEVM. @@ -63,13 +65,13 @@ type TestChain struct { // genesisDefinition represents the Genesis information used to generate the chain's initial state. genesisDefinition *core.Genesis - // state represents the current Ethereum world state.StateDB. It tracks all state across the chain and dummyChain - // and is the subject of state changes when executing new transactions. This does not track the current block - // head or anything of that nature and simply tracks accounts, balances, code, storage, etc. - state *state.StateDB + // state represents the current Ethereum world (interface implementing state.StateDB). It tracks all state across + // the chain and dummyChain and is the subject of state changes when executing new transactions. This does not + // track the current block head or anything of that nature and simply tracks accounts, balances, code, storage, etc. + state types.MedusaStateDB // stateDatabase refers to the database object which state uses to store data. It is constructed over db. - stateDatabase state.Database + stateDatabase gethState.Database // db represents the in-memory database used by the TestChain and its underlying chain to store state changes. // This is constructed over the kvstore. @@ -85,12 +87,55 @@ type TestChain struct { // Events defines the event system for the TestChain. Events TestChainEvents + + // stateFactory used to construct state databases from db/root. Abstracts away the backing RPC when running in + // fork mode. + stateFactory state.MedusaStateFactory } // NewTestChain creates a simulated Ethereum backend used for testing, or returns an error if one occurred. // This creates a test chain with a test chain configuration and the provided genesis allocation and config. // If a nil config is provided, a default one is used. -func NewTestChain(genesisAlloc types.GenesisAlloc, testChainConfig *config.TestChainConfig) (*TestChain, error) { +// Additional TestChain objects should be obtained via calling Clone on the original, as this allows cloned chains to +// benefit from shared RPC caching and certain kinds of state memoization that may be implemented in the future. +func NewTestChain( + fuzzerContext context.Context, + genesisAlloc gethTypes.GenesisAlloc, + testChainConfig *config.TestChainConfig) (*TestChain, error) { + + // Use a default config if we were not provided one + var err error + if testChainConfig == nil { + testChainConfig, err = config.DefaultTestChainConfig() + if err != nil { + return nil, err + } + } + var stateFactory state.MedusaStateFactory + if testChainConfig.ForkConfig.ForkModeEnabled { + provider, err := state.NewRPCBackend( + fuzzerContext, + testChainConfig.ForkConfig.RpcUrl, + testChainConfig.ForkConfig.RpcBlock, + testChainConfig.ForkConfig.PoolSize) + if err != nil { + return nil, err + } + stateFactory = state.NewForkedStateFactory(provider) + } else { + stateFactory = state.NewUnbackedStateFactory() + } + + return newTestChainWithStateFactory(genesisAlloc, testChainConfig, stateFactory) +} + +// newTestChainWithStateFactory creates a simulated backend, using the provided stateFactory for optionally fetching +// remote state if RPC mode is configured. +func newTestChainWithStateFactory( + genesisAlloc gethTypes.GenesisAlloc, + testChainConfig *config.TestChainConfig, + stateFactory state.MedusaStateFactory) (*TestChain, error) { + // Copy our chain config, so it is not shared across chains. chainConfig, err := utils.CopyChainConfig(params.TestChainConfig) if err != nil { @@ -125,14 +170,6 @@ func NewTestChain(genesisAlloc types.GenesisAlloc, testChainConfig *config.TestC BaseFee: big.NewInt(0), } - // Use a default config if we were not provided one - if testChainConfig == nil { - testChainConfig, err = config.DefaultTestChainConfig() - if err != nil { - return nil, err - } - } - // Obtain our VM extensions from our config vmConfigExtensions := testChainConfig.GetVMConfigExtensions() @@ -149,7 +186,7 @@ func NewTestChain(genesisAlloc types.GenesisAlloc, testChainConfig *config.TestC return nil, err } for _, cheatContract := range cheatContracts { - genesisDefinition.Alloc[cheatContract.address] = types.Account{ + genesisDefinition.Alloc[cheatContract.address] = gethTypes.Account{ Balance: big.NewInt(0), Code: []byte{0xFF}, } @@ -170,10 +207,10 @@ func NewTestChain(genesisAlloc types.GenesisAlloc, testChainConfig *config.TestC genesisBlock := genesisDefinition.MustCommit(db, trieDB) // Convert our genesis block (go-ethereum type) to a test chain block. - testChainGenesisBlock := chainTypes.NewBlock(genesisBlock.Header()) + testChainGenesisBlock := types.NewBlock(genesisBlock.Header()) // Create our state database over-top our database. - stateDatabase := state.NewDatabaseWithConfig(db, dbConfig) + stateDatabase := gethState.NewDatabaseWithConfig(db, dbConfig) // Create a tracer forwarder to support the addition of multiple tracers for transaction and call execution. transactionTracerRouter := NewTestChainTracerRouter() @@ -183,7 +220,7 @@ func NewTestChain(genesisAlloc types.GenesisAlloc, testChainConfig *config.TestC chain := &TestChain{ genesisDefinition: genesisDefinition, BlockGasLimit: genesisBlock.Header().GasLimit, - blocks: []*chainTypes.Block{testChainGenesisBlock}, + blocks: []*types.Block{testChainGenesisBlock}, pendingBlock: nil, db: db, state: nil, @@ -193,6 +230,7 @@ func NewTestChain(genesisAlloc types.GenesisAlloc, testChainConfig *config.TestC testChainConfig: testChainConfig, chainConfig: genesisDefinition.Config, vmConfigExtensions: vmConfigExtensions, + stateFactory: stateFactory, } // Add our internal tracers to this chain. @@ -229,7 +267,7 @@ func (t *TestChain) Close() { // Returns the new chain, or an error if one occurred. func (t *TestChain) Clone(onCreateFunc func(chain *TestChain) error) (*TestChain, error) { // Create a new chain with the same genesis definition and config - targetChain, err := NewTestChain(t.genesisDefinition.Alloc, t.testChainConfig) + targetChain, err := newTestChainWithStateFactory(t.genesisDefinition.Alloc, t.testChainConfig, t.stateFactory) if err != nil { return nil, err } @@ -297,7 +335,7 @@ func (t *TestChain) GenesisDefinition() *core.Genesis { } // State returns the current state.StateDB of the chain. -func (t *TestChain) State() *state.StateDB { +func (t *TestChain) State() types.MedusaStateDB { return t.state } @@ -319,12 +357,12 @@ func (t *TestChain) CheatCodeContracts() map[common.Address]*CheatCodeContract { // CommittedBlocks returns the real blocks which were committed to the chain, where methods such as BlockFromNumber // return the simulated chain state with intermediate blocks injected for block number jumps, etc. -func (t *TestChain) CommittedBlocks() []*chainTypes.Block { +func (t *TestChain) CommittedBlocks() []*types.Block { return t.blocks } // Head returns the head of the chain (the latest block). -func (t *TestChain) Head() *chainTypes.Block { +func (t *TestChain) Head() *types.Block { return t.blocks[len(t.blocks)-1] } @@ -337,7 +375,7 @@ func (t *TestChain) HeadBlockNumber() uint64 { // When the TestChain creates a new block that jumps the block number forward, the existence of any intermediate // block will be spoofed based off of the closest preceding internally committed block. // Returns the index of the closest preceding block in blocks and the Block itself. -func (t *TestChain) fetchClosestInternalBlock(blockNumber uint64) (int, *chainTypes.Block) { +func (t *TestChain) fetchClosestInternalBlock(blockNumber uint64) (int, *types.Block) { // Perform a binary search for this exact block number, or the closest preceding block we committed. k := sort.Search(len(t.blocks), func(n int) bool { return t.blocks[n].Header.Number.Uint64() >= blockNumber @@ -365,7 +403,7 @@ func (t *TestChain) fetchClosestInternalBlock(blockNumber uint64) (int, *chainTy // the TestChain skip block numbers, this method will simulate the existence of well-formed intermediate blocks to // ensure chain validity throughout. Thus, this is a "simulated" chain API method. // Returns the block, or an error if one occurs. -func (t *TestChain) BlockFromNumber(blockNumber uint64) (*chainTypes.Block, error) { +func (t *TestChain) BlockFromNumber(blockNumber uint64) (*types.Block, error) { // If the block number is past our current head, return an error. if blockNumber > t.HeadBlockNumber() { return nil, fmt.Errorf("could not obtain block for block number %d because it exceeds the current head block number %d", blockNumber, t.HeadBlockNumber()) @@ -403,14 +441,14 @@ func (t *TestChain) BlockFromNumber(blockNumber uint64) (*chainTypes.Block, erro // - Reuses gas limit from last committed block. // - We reuse the previous timestamp and add 1 for every block generated (blocks must have different timestamps) // - Note: This means that we must check that our timestamp jump >= block number jump when committing a new block. - blockHeader := &types.Header{ + blockHeader := &gethTypes.Header{ ParentHash: previousBlockHash, - UncleHash: types.EmptyUncleHash, + UncleHash: gethTypes.EmptyUncleHash, Coinbase: common.Address{}, Root: closestBlock.Header.Root, - TxHash: types.EmptyRootHash, - ReceiptHash: types.EmptyRootHash, - Bloom: types.Bloom{}, + TxHash: gethTypes.EmptyRootHash, + ReceiptHash: gethTypes.EmptyRootHash, + Bloom: gethTypes.Bloom{}, Difficulty: common.Big0, Number: big.NewInt(int64(blockNumber)), GasLimit: closestBlock.Header.GasLimit, @@ -418,12 +456,12 @@ func (t *TestChain) BlockFromNumber(blockNumber uint64) (*chainTypes.Block, erro Time: closestBlock.Header.Time + (blockNumber - closestBlockNumber), Extra: []byte{}, MixDigest: previousBlockHash, - Nonce: types.BlockNonce{}, + Nonce: gethTypes.BlockNonce{}, BaseFee: closestBlock.Header.BaseFee, } // Create our new empty block with our provided header and return it. - block := chainTypes.NewBlock(blockHeader) + block := types.NewBlock(blockHeader) block.Hash = blockHash // we patch our block hash with our spoofed one immediately return block, nil } @@ -460,9 +498,9 @@ func (t *TestChain) BlockHashFromNumber(blockNumber uint64) (common.Hash, error) // StateFromRoot obtains a state from a given state root hash. // Returns the state, or an error if one occurred. -func (t *TestChain) StateFromRoot(root common.Hash) (*state.StateDB, error) { +func (t *TestChain) StateFromRoot(root common.Hash) (types.MedusaStateDB, error) { // Load our state from the database - stateDB, err := state.New(root, t.stateDatabase, nil) + stateDB, err := t.stateFactory.New(root, t.stateDatabase) if err != nil { return nil, err } @@ -486,7 +524,7 @@ func (t *TestChain) StateRootAfterBlockNumber(blockNumber uint64) (common.Hash, // StateAfterBlockNumber obtains the Ethereum world state after processing all transactions in the provided block // number. Returns the state, or an error if one occurs. -func (t *TestChain) StateAfterBlockNumber(blockNumber uint64) (*state.StateDB, error) { +func (t *TestChain) StateAfterBlockNumber(blockNumber uint64) (types.MedusaStateDB, error) { // Obtain our block's post-execution state root hash root, err := t.StateRootAfterBlockNumber(blockNumber) if err != nil { @@ -558,7 +596,7 @@ func (t *TestChain) RevertToBlockNumber(blockNumber uint64) error { // It takes an optional state argument, which is the state to execute the message over. If not provided, the // current pending state (or committed state if none is pending) will be used instead. // The state executed over may be a pending block state. -func (t *TestChain) CallContract(msg *core.Message, state *state.StateDB, additionalTracers ...*TestChainTracer) (*core.ExecutionResult, error) { +func (t *TestChain) CallContract(msg *core.Message, state types.MedusaStateDB, additionalTracers ...*TestChainTracer) (*core.ExecutionResult, error) { // If our provided state is nil, use our current chain state. if state == nil { state = t.state @@ -607,11 +645,11 @@ func (t *TestChain) CallContract(msg *core.Message, state *state.StateDB, additi state.RevertToSnapshot(snapshot) // Gather receipt for OnTxEnd - receipt := &types.Receipt{Type: tx.Type()} + receipt := &gethTypes.Receipt{Type: tx.Type()} if msgResult.Failed() { - receipt.Status = types.ReceiptStatusFailed + receipt.Status = gethTypes.ReceiptStatusFailed } else { - receipt.Status = types.ReceiptStatusSuccessful + receipt.Status = gethTypes.ReceiptStatusSuccessful } receipt.TxHash = tx.Hash() receipt.GasUsed = msgResult.UsedGas @@ -627,14 +665,14 @@ func (t *TestChain) CallContract(msg *core.Message, state *state.StateDB, additi // PendingBlock describes the current pending block which is being constructed and awaiting commitment to the chain. // This may be nil if no pending block was created. -func (t *TestChain) PendingBlock() *chainTypes.Block { +func (t *TestChain) PendingBlock() *types.Block { return t.pendingBlock } // PendingBlockCreate constructs an empty block which is pending addition to the chain. The block produces by this // method will have a block number and timestamp that is greater by the current chain head by 1. // Returns the constructed block, or an error if one occurred. -func (t *TestChain) PendingBlockCreate() (*chainTypes.Block, error) { +func (t *TestChain) PendingBlockCreate() (*types.Block, error) { // Create a block with default parameters blockNumber := t.HeadBlockNumber() + 1 timestamp := t.Head().Header.Time + 1 @@ -646,7 +684,7 @@ func (t *TestChain) PendingBlockCreate() (*chainTypes.Block, error) { // previous block). Providing a block number that is greater than the previous block number plus one will simulate empty // blocks between. // Returns the constructed block, or an error if one occurred. -func (t *TestChain) PendingBlockCreateWithParameters(blockNumber uint64, blockTime uint64, blockGasLimit *uint64) (*chainTypes.Block, error) { +func (t *TestChain) PendingBlockCreateWithParameters(blockNumber uint64, blockTime uint64, blockGasLimit *uint64) (*types.Block, error) { // If we already have a pending block, return an error. if t.pendingBlock != nil { return nil, fmt.Errorf("could not create a new pending block for chain, as a block is already pending") @@ -688,14 +726,14 @@ func (t *TestChain) PendingBlockCreateWithParameters(blockNumber uint64, blockTi // - GasUsed is aggregated for each transaction in the block (for now zero). // - Mix digest is only useful for randomness, so we just fake randomness by using the previous block hash. // - TODO: BaseFee should be revisited/checked. - header := &types.Header{ + header := &gethTypes.Header{ ParentHash: parentBlockHash, - UncleHash: types.EmptyUncleHash, + UncleHash: gethTypes.EmptyUncleHash, Coinbase: t.Head().Header.Coinbase, Root: t.Head().Header.Root, - TxHash: types.EmptyRootHash, - ReceiptHash: types.EmptyRootHash, - Bloom: types.Bloom{}, + TxHash: gethTypes.EmptyRootHash, + ReceiptHash: gethTypes.EmptyRootHash, + Bloom: gethTypes.Bloom{}, Difficulty: common.Big0, Number: big.NewInt(int64(blockNumber)), GasLimit: *blockGasLimit, @@ -703,12 +741,12 @@ func (t *TestChain) PendingBlockCreateWithParameters(blockNumber uint64, blockTi Time: blockTime, Extra: []byte{}, MixDigest: parentBlockHash, - Nonce: types.BlockNonce{}, + Nonce: gethTypes.BlockNonce{}, BaseFee: big.NewInt(params.InitialBaseFee), } // Create a new block for our test node - t.pendingBlock = chainTypes.NewBlock(header) + t.pendingBlock = types.NewBlock(header) t.pendingBlock.Hash = t.pendingBlock.Header.Hash() // Emit our event for the pending block being created @@ -780,7 +818,7 @@ func (t *TestChain) PendingBlockAddTx(message *core.Message, additionalTracers . } // Create our message result - messageResult := &chainTypes.MessageResults{ + messageResult := &types.MessageResults{ PostStateRoot: common.BytesToHash(receipt.PostState), ExecutionResult: executionResult, Receipt: receipt, @@ -835,7 +873,7 @@ func (t *TestChain) PendingBlockCommit() error { // Committing the state invalidates the cached tries and we need to reload the state. // Otherwise, methods such as FillFromTestChainProperties will not work correctly. - t.state, err = state.New(root, t.stateDatabase, nil) + t.state, err = t.stateFactory.New(root, t.stateDatabase) if err != nil { return err } @@ -906,7 +944,7 @@ func (t *TestChain) PendingBlockDiscard() error { // emitContractChangeEvents emits events for contract deployments being added or removed by playing through a list // of provided message results. If reverting, the inverse events are emitted. -func (t *TestChain) emitContractChangeEvents(reverting bool, messageResults ...*chainTypes.MessageResults) error { +func (t *TestChain) emitContractChangeEvents(reverting bool, messageResults ...*types.MessageResults) error { // If we're not reverting, we simply play events for our contract deployment changes in order. If we are, inverse // all the events. var err error diff --git a/chain/test_chain_test.go b/chain/test_chain_test.go index 048822a1..0d50e320 100644 --- a/chain/test_chain_test.go +++ b/chain/test_chain_test.go @@ -1,6 +1,7 @@ package chain import ( + "context" "math/big" "math/rand" "testing" @@ -75,7 +76,7 @@ func createChain(t *testing.T) (*TestChain, []common.Address) { } // Create a test chain with a default test chain configuration - chain, err := NewTestChain(genesisAlloc, nil) + chain, err := NewTestChain(context.Background(), genesisAlloc, nil) assert.NoError(t, err) @@ -619,7 +620,7 @@ func TestChainCallSequenceReplayMatchSimple(t *testing.T) { } // Create another test chain which we will recreate our state from. - recreatedChain, err := NewTestChain(chain.genesisDefinition.Alloc, nil) + recreatedChain, err := NewTestChain(context.Background(), chain.genesisDefinition.Alloc, nil) assert.NoError(t, err) // Replay all messages after genesis diff --git a/chain/types/medusa_statedb.go b/chain/types/medusa_statedb.go new file mode 100644 index 00000000..6e61a176 --- /dev/null +++ b/chain/types/medusa_statedb.go @@ -0,0 +1,32 @@ +package types + +import ( + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/core/tracing" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/core/vm" + "github.com/holiman/uint256" +) + +var _ MedusaStateDB = (*state.StateDB)(nil) +var _ MedusaStateDB = (*state.ForkStateDb)(nil) + +/* +MedusaStateDB provides an interface that supersedes the stateDB interface exposed by geth. All of these functions are +implemented by the vanilla geth statedb. +This interface allows the TestChain to use a forked statedb and native geth statedb interoperably. +*/ +type MedusaStateDB interface { + vm.StateDB + IntermediateRoot(bool) common.Hash + Finalise(bool) + Logs() []*types.Log + GetLogs(common.Hash, uint64, common.Hash) []*types.Log + TxIndex() int + SetBalance(common.Address, *uint256.Int, tracing.BalanceChangeReason) + SetTxContext(common.Hash, int) + Commit(uint64, bool) (common.Hash, error) + SetLogger(*tracing.Hooks) + Error() error +} diff --git a/chain/vendored/apply_transaction.go b/chain/vendored/apply_transaction.go index 1472025b..674154e4 100644 --- a/chain/vendored/apply_transaction.go +++ b/chain/vendored/apply_transaction.go @@ -18,10 +18,10 @@ package vendored import ( "github.com/crytic/medusa/chain/config" + "github.com/crytic/medusa/chain/types" "github.com/ethereum/go-ethereum/common" . "github.com/ethereum/go-ethereum/core" - "github.com/ethereum/go-ethereum/core/state" - "github.com/ethereum/go-ethereum/core/types" + gethtypes "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/params" @@ -36,7 +36,7 @@ import ( // This executes on an underlying EVM and returns a transaction receipt, or an error if one occurs. // Additional changes: // - Exposed core.ExecutionResult as a return value. -func EVMApplyTransaction(msg *Message, config *params.ChainConfig, testChainConfig *config.TestChainConfig, author *common.Address, gp *GasPool, statedb *state.StateDB, blockNumber *big.Int, blockHash common.Hash, tx *types.Transaction, usedGas *uint64, evm *vm.EVM) (receipt *types.Receipt, result *ExecutionResult, err error) { +func EVMApplyTransaction(msg *Message, config *params.ChainConfig, testChainConfig *config.TestChainConfig, author *common.Address, gp *GasPool, statedb types.MedusaStateDB, blockNumber *big.Int, blockHash common.Hash, tx *gethtypes.Transaction, usedGas *uint64, evm *vm.EVM) (receipt *gethtypes.Receipt, result *ExecutionResult, err error) { // Apply the OnTxStart and OnTxEnd hooks if evm.Config.Tracer != nil && evm.Config.Tracer.OnTxStart != nil { evm.Config.Tracer.OnTxStart(evm.GetVMContext(), tx, msg.From) @@ -67,11 +67,11 @@ func EVMApplyTransaction(msg *Message, config *params.ChainConfig, testChainConf // Create a new receipt for the transaction, storing the intermediate root and gas used // by the tx. - receipt = &types.Receipt{Type: tx.Type(), PostState: root, CumulativeGasUsed: *usedGas} + receipt = &gethtypes.Receipt{Type: tx.Type(), PostState: root, CumulativeGasUsed: *usedGas} if result.Failed() { - receipt.Status = types.ReceiptStatusFailed + receipt.Status = gethtypes.ReceiptStatusFailed } else { - receipt.Status = types.ReceiptStatusSuccessful + receipt.Status = gethtypes.ReceiptStatusSuccessful } receipt.TxHash = tx.Hash() receipt.GasUsed = result.UsedGas @@ -95,7 +95,7 @@ func EVMApplyTransaction(msg *Message, config *params.ChainConfig, testChainConf // Set the receipt logs and create the bloom filter. receipt.Logs = statedb.GetLogs(tx.Hash(), blockNumber.Uint64(), blockHash) - receipt.Bloom = types.CreateBloom(types.Receipts{receipt}) + receipt.Bloom = gethtypes.CreateBloom(gethtypes.Receipts{receipt}) receipt.BlockHash = blockHash receipt.BlockNumber = blockNumber receipt.TransactionIndex = uint(statedb.TxIndex()) diff --git a/docs/src/project_configuration/chain_config.md b/docs/src/project_configuration/chain_config.md index 13bc0685..efccd71b 100644 --- a/docs/src/project_configuration/chain_config.md +++ b/docs/src/project_configuration/chain_config.md @@ -29,3 +29,29 @@ The chain configuration defines the parameters for setting up `medusa`'s underly - **Description**: Determines whether the `ffi` cheatcode is enabled. > 🚩 Enabling the `ffi` cheatcode may allow for arbitrary code execution on your machine. - **Default**: `false` + +## Fork Configuration + +### `forkModeEnabled` + +- **Type**: Boolean +- **Description**: Determines whether fork mode is enabled +- **Default**: `false` + +### `rpcUrl` + +- **Type**: String +- **Description**: Determines the RPC URL that will be queried during fork mode. +- **Default**: `n/a` + +### `rpcBlock` + +- **Type**: Integer +- **Description**: Determines the block height that fork state will be queried for. Block tags like `LATEST` are not supported. +- **Default**: `1` + +### `poolSize` + +- **Type**: Integer +- **Description**: Determines the size of the client pool used to query the RPC. It is recommended to use a pool size that is 2-3x the number of workers used, but smaller pools may be required to avoid exceeding external RPC query limits. +- **Default**: `20` diff --git a/fuzzing/executiontracer/execution_tracer.go b/fuzzing/executiontracer/execution_tracer.go index 0b876c0c..0a20e29e 100644 --- a/fuzzing/executiontracer/execution_tracer.go +++ b/fuzzing/executiontracer/execution_tracer.go @@ -4,11 +4,11 @@ import ( "math/big" "github.com/crytic/medusa/chain" + "github.com/crytic/medusa/chain/types" "github.com/crytic/medusa/fuzzing/contracts" "github.com/crytic/medusa/utils" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" - "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/tracing" coretypes "github.com/ethereum/go-ethereum/core/types" @@ -20,7 +20,7 @@ import ( // CallWithExecutionTrace obtains an execution trace for a given call, on the provided chain, using the state // provided. If a nil state is provided, the current chain state will be used. // Returns the ExecutionTrace for the call or an error if one occurs. -func CallWithExecutionTrace(testChain *chain.TestChain, contractDefinitions contracts.Contracts, msg *core.Message, state *state.StateDB) (*core.ExecutionResult, *ExecutionTrace, error) { +func CallWithExecutionTrace(testChain *chain.TestChain, contractDefinitions contracts.Contracts, msg *core.Message, state types.MedusaStateDB) (*core.ExecutionResult, *ExecutionTrace, error) { // Create an execution tracer executionTracer := NewExecutionTracer(contractDefinitions, testChain.CheatCodeContracts()) defer executionTracer.Close() @@ -302,7 +302,7 @@ func (t *ExecutionTracer) OnOpcode(pc uint64, op byte, gas, cost uint64, scope t // TODO: Move this to OnLog if op == byte(vm.LOG0) || op == byte(vm.LOG1) || op == byte(vm.LOG2) || op == byte(vm.LOG3) || op == byte(vm.LOG4) { t.onNextCaptureState = append(t.onNextCaptureState, func() { - logs := t.evmContext.StateDB.(*state.StateDB).Logs() + logs := t.evmContext.StateDB.(types.MedusaStateDB).Logs() if len(logs) > 0 { t.currentCallFrame.Operations = append(t.currentCallFrame.Operations, logs[len(logs)-1]) } diff --git a/fuzzing/fuzzer.go b/fuzzing/fuzzer.go index 4d862dc0..68d29373 100644 --- a/fuzzing/fuzzer.go +++ b/fuzzing/fuzzer.go @@ -386,7 +386,7 @@ func (f *Fuzzer) createTestChain() (*chain.TestChain, error) { f.config.Fuzzing.TestChainConfig.ContractAddressOverrides = contractAddressOverrides // Create our test chain with our basic allocations and passed medusa's chain configuration - testChain, err := chain.NewTestChain(genesisAlloc, &f.config.Fuzzing.TestChainConfig) + testChain, err := chain.NewTestChain(f.ctx, genesisAlloc, &f.config.Fuzzing.TestChainConfig) // Set our block gas limit testChain.BlockGasLimit = f.config.Fuzzing.BlockGasLimit diff --git a/go.mod b/go.mod index 3bae89ac..2988b69b 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/spf13/cobra v1.8.1 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.9.0 + go.etcd.io/bbolt v1.3.11 golang.org/x/crypto v0.25.0 golang.org/x/exp v0.0.0-20240707233637-46b078467d37 golang.org/x/net v0.27.0 @@ -88,4 +89,4 @@ require ( rsc.io/tmplfunc v0.0.3 // indirect ) -replace github.com/ethereum/go-ethereum => github.com/crytic/medusa-geth v0.0.0-20240919134035-0fd368c28419 +replace github.com/ethereum/go-ethereum => github.com/crytic/medusa-geth v0.0.0-20241203194135-80361903eb03 diff --git a/go.sum b/go.sum index 6c8f950d..161a37ca 100644 --- a/go.sum +++ b/go.sum @@ -46,8 +46,8 @@ github.com/crate-crypto/go-ipa v0.0.0-20240223125850-b1e8a79f509c/go.mod h1:geZJ github.com/crate-crypto/go-kzg-4844 v1.0.0 h1:TsSgHwrkTKecKJ4kadtHi4b3xHW5dCFUDFnUp1TsawI= github.com/crate-crypto/go-kzg-4844 v1.0.0/go.mod h1:1kMhvPgI0Ky3yIa+9lFySEBUBXkYxeOi8ZF1sYioxhc= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/crytic/medusa-geth v0.0.0-20240919134035-0fd368c28419 h1:MJXzWPObZtF0EMRqX64JkzJDj+GMLPxg3XK5xb12FFU= -github.com/crytic/medusa-geth v0.0.0-20240919134035-0fd368c28419/go.mod h1:ajGCVsk6ctffGwe9TSDQqj4HIUUQ1WdUit5tWFNl8Tw= +github.com/crytic/medusa-geth v0.0.0-20241203194135-80361903eb03 h1:hFoDTUHSCcyy8KQaIIF2gIrYmIZ7KjOa2lvXdXGC+gw= +github.com/crytic/medusa-geth v0.0.0-20241203194135-80361903eb03/go.mod h1:ajGCVsk6ctffGwe9TSDQqj4HIUUQ1WdUit5tWFNl8Tw= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -208,6 +208,8 @@ github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= +go.etcd.io/bbolt v1.3.11 h1:yGEzV1wPz2yVCLsD8ZAiGHhHVlczyC9d1rP43/VCRJ0= +go.etcd.io/bbolt v1.3.11/go.mod h1:dksAq7YMXoljX0xu6VF5DMZGbhYYoLUalEiSySYAS4I= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=