diff --git a/go.mod b/go.mod index 388bead..5dd4de5 100644 --- a/go.mod +++ b/go.mod @@ -64,6 +64,7 @@ require ( github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cast v1.6.0 // indirect github.com/spf13/pflag v1.0.5 // indirect + github.com/stretchr/objx v0.5.0 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/supranational/blst v0.3.11 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect diff --git a/go.sum b/go.sum index 72e0c81..63ad93c 100644 --- a/go.sum +++ b/go.sum @@ -226,6 +226,7 @@ github.com/status-im/keycard-go v0.2.0 h1:QDLFswOQu1r5jsycloeQh3bVU8n/NatHHaZobt github.com/status-im/keycard-go v0.2.0/go.mod h1:wlp8ZLbsmrF6g6WjugPAx+IzoLrkdf9+mHxBEeo3Hbg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= diff --git a/pkg/chain/chain.go b/pkg/chain/chain.go index fb430f0..1605983 100644 --- a/pkg/chain/chain.go +++ b/pkg/chain/chain.go @@ -5,6 +5,7 @@ import ( "context" "math/big" + "github.com/LiskHQ/op-fault-detector/pkg/encoding" "github.com/LiskHQ/op-fault-detector/pkg/log" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" @@ -60,6 +61,15 @@ func (c *ChainAPIClient) GetBlockByNumber(ctx context.Context, blockNumber *big. return c.eth.BlockByNumber(ctx, blockNumber) } +// GetLatestBlockHeader returns latest block header from a connected node. +func (c *ChainAPIClient) GetLatestBlockHeader(ctx context.Context) (*types.Header, error) { + blockNumber, err := c.eth.BlockNumber(ctx) + if err != nil { + return nil, err + } + return c.eth.HeaderByNumber(ctx, encoding.MustConvertUint64ToBigInt(blockNumber)) +} + // GetProof returns the account and storage values, including the Merkle proof, of the specified account/address. func (c *ChainAPIClient) GetProof(ctx context.Context, blockNumber *big.Int, address common.Address) (*ProofResponse, error) { var result ProofResponse diff --git a/pkg/faultdetector/faultdetector.go b/pkg/faultdetector/faultdetector.go index 6f41075..e338d0e 100644 --- a/pkg/faultdetector/faultdetector.go +++ b/pkg/faultdetector/faultdetector.go @@ -8,6 +8,7 @@ import ( "github.com/LiskHQ/op-fault-detector/pkg/chain" "github.com/LiskHQ/op-fault-detector/pkg/config" + "github.com/LiskHQ/op-fault-detector/pkg/encoding" "github.com/LiskHQ/op-fault-detector/pkg/log" "github.com/prometheus/client_golang/prometheus" ) @@ -80,20 +81,20 @@ func NewFaultDetector(ctx context.Context, logger log.Logger, errorChan chan err l2ChainID, err := l2RpcApi.GetChainID(ctx) if err != nil { - logger.Errorf("Failed to get L2 provider's chainID: %d, error: %w", l2ChainID.Int64(), err) + logger.Errorf("Failed to get L2 provider's chainID: %d, error: %w", encoding.MustConvertBigIntToUint64(l2ChainID), err) return nil, err } // Initialize Oracle contract accessor chainConfig := &chain.ConfigOptions{ L1RPCEndpoint: faultDetectorConfig.L1RPCEndpoint, - ChainID: l2ChainID.Uint64(), + ChainID: encoding.MustConvertBigIntToUint64(l2ChainID), L2OutputOracleContractAddress: faultDetectorConfig.L2OutputOracleContractAddress, } oracleContractAccessor, err := chain.NewOracleAccessor(ctx, chainConfig) if err != nil { - logger.Errorf("Failed to create Oracle contract accessor with chainID: %d, L1 endpoint: %s and L2OutputOracleContractAddress: %s, error: %w", l2ChainID.Int64(), faultDetectorConfig.L1RPCEndpoint, faultDetectorConfig.L2OutputOracleContractAddress, err) + logger.Errorf("Failed to create Oracle contract accessor with chainID: %d, L1 endpoint: %s and L2OutputOracleContractAddress: %s, error: %w", encoding.MustConvertBigIntToUint64(l2ChainID), faultDetectorConfig.L1RPCEndpoint, faultDetectorConfig.L2OutputOracleContractAddress, err) return nil, err } diff --git a/pkg/faultdetector/helper.go b/pkg/faultdetector/helper.go new file mode 100644 index 0000000..43887ce --- /dev/null +++ b/pkg/faultdetector/helper.go @@ -0,0 +1,63 @@ +package faultdetector + +import ( + "context" + "fmt" + "math/big" + + "github.com/LiskHQ/op-fault-detector/pkg/chain" + "github.com/LiskHQ/op-fault-detector/pkg/encoding" + "github.com/LiskHQ/op-fault-detector/pkg/log" + "github.com/ethereum/go-ethereum/core/types" +) + +type ChainAPIClient interface { + GetLatestBlockHeader(ctx context.Context) (*types.Header, error) +} + +type OracleAccessor interface { + GetNextOutputIndex() (*big.Int, error) + GetL2Output(index *big.Int) (chain.L2Output, error) +} + +// FindFirstUnfinalizedOutputIndex finds and returns the first L2 output index that has not yet passed the fault proof window. +func FindFirstUnfinalizedOutputIndex(ctx context.Context, logger log.Logger, fpw uint64, oracleAccessor OracleAccessor, l2RpcApi ChainAPIClient) (uint64, error) { + latestBlockHeader, err := l2RpcApi.GetLatestBlockHeader(ctx) + if err != nil { + logger.Errorf("Failed to get latest block header from L2 provider, error: %w", err) + return 0, err + } + totalOutputsBigInt, err := oracleAccessor.GetNextOutputIndex() + if err != nil { + logger.Errorf("Failed to get next output index, error: %w", err) + return 0, err + } + totalOutputs := encoding.MustConvertBigIntToUint64(totalOutputsBigInt) + + // Perform a binary search to find the next batch that will pass the challenge period. + var lo uint64 = 0 + hi := totalOutputs + for lo != hi { + mid := (lo + hi) / 2 + midBigInt := encoding.MustConvertUint64ToBigInt(mid) + outputData, err := oracleAccessor.GetL2Output(midBigInt) + if err != nil { + logger.Errorf("Failed to get L2 output for index: %d, error: %w", midBigInt, err) + return 0, err + } + if outputData.L1Timestamp+fpw < latestBlockHeader.Time { + lo = mid + 1 + } else { + hi = mid + } + } + + // Result will be zero if the chain is less than FPW seconds old. Only returns 0 with error Undefined in the + // case that no batches have been submitted for an entire challenge period. + if lo == totalOutputs { + logger.Errorf("No batches have been submitted for the entire challenge period and therefore first unfinalized output index is undefined") + return 0, fmt.Errorf("Undefined") + } else { + return lo, nil + } +} diff --git a/pkg/faultdetector/helper_test.go b/pkg/faultdetector/helper_test.go new file mode 100644 index 0000000..0f527f5 --- /dev/null +++ b/pkg/faultdetector/helper_test.go @@ -0,0 +1,322 @@ +package faultdetector + +import ( + "context" + crand "crypto/rand" + "fmt" + "math/big" + "testing" + + "github.com/LiskHQ/op-fault-detector/pkg/chain" + "github.com/LiskHQ/op-fault-detector/pkg/log" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +type mockChainAPIClient struct { + mock.Mock +} +type mockOracleAccessor struct { + mock.Mock +} + +func (m *mockChainAPIClient) GetLatestBlockHeader(ctx context.Context) (*types.Header, error) { + called := m.MethodCalled("GetLatestBlockHeader", ctx) + return called.Get(0).(*types.Header), called.Error(1) +} + +func (o *mockOracleAccessor) GetNextOutputIndex() (*big.Int, error) { + called := o.MethodCalled("GetNextOutputIndex") + return called.Get(0).(*big.Int), called.Error(1) +} + +func (o *mockOracleAccessor) GetL2Output(index *big.Int) (chain.L2Output, error) { + called := o.MethodCalled("GetL2Output", index) + return called.Get(0).(chain.L2Output), called.Error(1) +} + +func randHash() (out common.Hash) { + _, _ = crand.Read(out[:]) + return out +} + +func TestFindFirstUnfinalizedOutputIndex(t *testing.T) { + const defaultL1Timestamp uint64 = 123456 + const finalizationPeriodSeconds uint64 = 1000 + var hdr = &types.Header{ + ParentHash: randHash(), + UncleHash: randHash(), + Coinbase: common.Address{}, + Root: randHash(), + TxHash: randHash(), + ReceiptHash: randHash(), + Bloom: types.Bloom{}, + Difficulty: big.NewInt(42), + Number: big.NewInt(1234), + GasLimit: 0, + GasUsed: 0, + Time: defaultL1Timestamp + finalizationPeriodSeconds, + Extra: make([]byte, 0), + MixDigest: randHash(), + Nonce: types.BlockNonce{}, + BaseFee: big.NewInt(100), + } + type testSuite struct { + l2Client *mockChainAPIClient + oracle *mockOracleAccessor + logger log.Logger + ctx context.Context + } + + var tests = []struct { + name string + construction func() *testSuite + assertion func(uint64, error) + }{ + { + name: "when fails to get latest block header from L2 provider", + construction: func() *testSuite { + l2Client := new(mockChainAPIClient) + oracle := new(mockOracleAccessor) + logger, _ := log.NewDefaultProductionLogger() + ctx := context.Background() + + sampleL2Output1 := chain.L2Output{ + OutputRoot: randHash().String(), + L1Timestamp: defaultL1Timestamp - 1, + L2BlockNumber: 500, + L2OutputIndex: 2, + } + sampleL2Output2 := chain.L2Output{ + OutputRoot: randHash().String(), + L1Timestamp: defaultL1Timestamp + 1, + L2BlockNumber: 500, + L2OutputIndex: 2, + } + + l2Client.On("GetLatestBlockHeader", ctx).Return(hdr, fmt.Errorf("Failed to get latest block header")) + oracle.On("GetNextOutputIndex").Return(big.NewInt(2), nil) + oracle.On("GetL2Output", big.NewInt(0)).Return(sampleL2Output1, nil) + oracle.On("GetL2Output", big.NewInt(1)).Return(sampleL2Output2, nil) + + return &testSuite{ + l2Client: l2Client, + oracle: oracle, + logger: logger, + ctx: ctx, + } + }, + assertion: func(index uint64, err error) { + require.EqualError(t, err, "Failed to get latest block header") + var expected uint64 = 0 + require.Equal(t, index, expected) + }, + }, + { + name: "when fails to get next output index", + construction: func() *testSuite { + l2Client := new(mockChainAPIClient) + oracle := new(mockOracleAccessor) + logger, _ := log.NewDefaultProductionLogger() + ctx := context.Background() + + sampleL2Output1 := chain.L2Output{ + OutputRoot: randHash().String(), + L1Timestamp: defaultL1Timestamp - 1, + L2BlockNumber: 500, + L2OutputIndex: 2, + } + sampleL2Output2 := chain.L2Output{ + OutputRoot: randHash().String(), + L1Timestamp: defaultL1Timestamp + 1, + L2BlockNumber: 500, + L2OutputIndex: 2, + } + + l2Client.On("GetLatestBlockHeader", ctx).Return(hdr, nil) + oracle.On("GetNextOutputIndex").Return(big.NewInt(2), fmt.Errorf("Failed to get next output index")) + oracle.On("GetL2Output", big.NewInt(0)).Return(sampleL2Output1, nil) + oracle.On("GetL2Output", big.NewInt(1)).Return(sampleL2Output2, nil) + + return &testSuite{ + l2Client: l2Client, + oracle: oracle, + logger: logger, + ctx: ctx, + } + }, + assertion: func(index uint64, err error) { + require.EqualError(t, err, "Failed to get next output index") + var expected uint64 = 0 + require.Equal(t, index, expected) + }, + }, + { + name: "when fails to get L2 output", + construction: func() *testSuite { + l2Client := new(mockChainAPIClient) + oracle := new(mockOracleAccessor) + logger, _ := log.NewDefaultProductionLogger() + ctx := context.Background() + + sampleL2Output1 := chain.L2Output{ + OutputRoot: randHash().String(), + L1Timestamp: defaultL1Timestamp - 1, + L2BlockNumber: 500, + L2OutputIndex: 2, + } + sampleL2Output2 := chain.L2Output{ + OutputRoot: randHash().String(), + L1Timestamp: defaultL1Timestamp + 1, + L2BlockNumber: 500, + L2OutputIndex: 2, + } + + l2Client.On("GetLatestBlockHeader", ctx).Return(hdr, nil) + oracle.On("GetNextOutputIndex").Return(big.NewInt(2), nil) + oracle.On("GetL2Output", big.NewInt(0)).Return(sampleL2Output1, fmt.Errorf("Failed to get L2 output")) + oracle.On("GetL2Output", big.NewInt(1)).Return(sampleL2Output2, nil) + + return &testSuite{ + l2Client: l2Client, + oracle: oracle, + logger: logger, + ctx: ctx, + } + }, + assertion: func(index uint64, err error) { + require.EqualError(t, err, "Failed to get L2 output") + var expected uint64 = 0 + require.Equal(t, index, expected) + }, + }, + { + name: "when the chain is more then FPW seconds old", + construction: func() *testSuite { + l2Client := new(mockChainAPIClient) + oracle := new(mockOracleAccessor) + logger, _ := log.NewDefaultProductionLogger() + ctx := context.Background() + + sampleL2Output1 := chain.L2Output{ + OutputRoot: randHash().String(), + L1Timestamp: defaultL1Timestamp - 1, + L2BlockNumber: 500, + L2OutputIndex: 2, + } + sampleL2Output2 := chain.L2Output{ + OutputRoot: randHash().String(), + L1Timestamp: defaultL1Timestamp + 1, + L2BlockNumber: 500, + L2OutputIndex: 2, + } + + l2Client.On("GetLatestBlockHeader", ctx).Return(hdr, nil) + oracle.On("GetNextOutputIndex").Return(big.NewInt(2), nil) + oracle.On("GetL2Output", big.NewInt(0)).Return(sampleL2Output1, nil) + oracle.On("GetL2Output", big.NewInt(1)).Return(sampleL2Output2, nil) + + return &testSuite{ + l2Client: l2Client, + oracle: oracle, + logger: logger, + ctx: ctx, + } + }, + assertion: func(index uint64, err error) { + require.NoError(t, err) + var expected uint64 = 1 + require.Equal(t, index, expected) + }, + }, + { + name: "when the chain is less than FPW seconds old", + construction: func() *testSuite { + l2Client := new(mockChainAPIClient) + oracle := new(mockOracleAccessor) + logger, _ := log.NewDefaultProductionLogger() + ctx := context.Background() + + sampleL2Output1 := chain.L2Output{ + OutputRoot: randHash().String(), + L1Timestamp: defaultL1Timestamp + 1, + L2BlockNumber: 500, + L2OutputIndex: 2, + } + sampleL2Output2 := chain.L2Output{ + OutputRoot: randHash().String(), + L1Timestamp: defaultL1Timestamp + 2, + L2BlockNumber: 500, + L2OutputIndex: 2, + } + + l2Client.On("GetLatestBlockHeader", ctx).Return(hdr, nil) + oracle.On("GetNextOutputIndex").Return(big.NewInt(2), nil) + oracle.On("GetL2Output", big.NewInt(0)).Return(sampleL2Output1, nil) + oracle.On("GetL2Output", big.NewInt(1)).Return(sampleL2Output2, nil) + + return &testSuite{ + l2Client: l2Client, + oracle: oracle, + logger: logger, + ctx: ctx, + } + }, + assertion: func(index uint64, err error) { + require.NoError(t, err) + var expected uint64 = 0 + require.Equal(t, index, expected) + }, + }, + { + name: "when no batches submitted for the entire FPW", + construction: func() *testSuite { + l2Client := new(mockChainAPIClient) + oracle := new(mockOracleAccessor) + logger, _ := log.NewDefaultProductionLogger() + ctx := context.Background() + + sampleL2Output1 := chain.L2Output{ + OutputRoot: randHash().String(), + L1Timestamp: defaultL1Timestamp - 2, + L2BlockNumber: 500, + L2OutputIndex: 2, + } + sampleL2Output2 := chain.L2Output{ + OutputRoot: randHash().String(), + L1Timestamp: defaultL1Timestamp - 1, + L2BlockNumber: 500, + L2OutputIndex: 2, + } + + l2Client.On("GetLatestBlockHeader", ctx).Return(hdr, nil) + oracle.On("GetNextOutputIndex").Return(big.NewInt(2), nil) + oracle.On("GetL2Output", big.NewInt(0)).Return(sampleL2Output1, nil) + oracle.On("GetL2Output", big.NewInt(1)).Return(sampleL2Output2, nil) + + return &testSuite{ + l2Client: l2Client, + oracle: oracle, + logger: logger, + ctx: ctx, + } + }, + assertion: func(index uint64, err error) { + require.EqualError(t, err, "Undefined") + var expected uint64 = 0 + require.Equal(t, index, expected) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ts := test.construction() + + index, err := FindFirstUnfinalizedOutputIndex(ts.ctx, ts.logger, finalizationPeriodSeconds, ts.oracle, ts.l2Client) + test.assertion(index, err) + }) + } +}