diff --git a/commit/merkleroot/observation.go b/commit/merkleroot/observation.go index 7d48cc9ce..0727a7137 100644 --- a/commit/merkleroot/observation.go +++ b/commit/merkleroot/observation.go @@ -555,14 +555,16 @@ func (o observerImpl) ObserveLatestOnRampSeqNums(ctx context.Context) []pluginty mu := &sync.Mutex{} latestOnRampSeqNums := make([]plugintypes.SeqNumChain, 0, len(sourceChains)) - eg := &errgroup.Group{} + wg := &sync.WaitGroup{} + wg.Add(len(sourceChains)) for _, sourceChain := range sourceChains { - eg.Go(func() error { + go func() { + defer wg.Done() latestOnRampSeqNum, err := o.ccipReader.LatestMsgSeqNum(ctx, sourceChain) if err != nil { lggr.Errorf("failed to get latest msg seq num for source chain %d: %s", sourceChain, err) - return nil + return } mu.Lock() @@ -571,15 +573,9 @@ func (o observerImpl) ObserveLatestOnRampSeqNums(ctx context.Context) []pluginty plugintypes.NewSeqNumChain(sourceChain, latestOnRampSeqNum), ) mu.Unlock() - - return nil - }) - } - - if err := eg.Wait(); err != nil { - lggr.Warnw("call to GetExpectedNextSequenceNumber failed", "err", err) - return nil + }() } + wg.Wait() sort.Slice(latestOnRampSeqNums, func(i, j int) bool { return latestOnRampSeqNums[i].ChainSel < latestOnRampSeqNums[j].ChainSel diff --git a/commit/merkleroot/outcome.go b/commit/merkleroot/outcome.go index 1017f0561..cca74675a 100644 --- a/commit/merkleroot/outcome.go +++ b/commit/merkleroot/outcome.go @@ -152,6 +152,11 @@ func reportRangesOutcome( rmnRemoteConfig = observedRMNRemoteConfig[dstChain] } + if len(rangesToReport) == 0 { + lggr.Info("No ranges to report, outcomeType is ReportEmpty") + return Outcome{OutcomeType: ReportEmpty} + } + outcome := Outcome{ OutcomeType: ReportIntervalsSelected, RangesSelectedForReport: rangesToReport, diff --git a/commit/merkleroot/outcome_test.go b/commit/merkleroot/outcome_test.go index 01fd2535d..bb4cc8769 100644 --- a/commit/merkleroot/outcome_test.go +++ b/commit/merkleroot/outcome_test.go @@ -647,13 +647,8 @@ func Test_reportRangesOutcome(t *testing.T) { expectedOutcome Outcome }{ { - name: "base empty outcome", - expectedOutcome: Outcome{ - OutcomeType: ReportIntervalsSelected, - RangesSelectedForReport: []plugintypes.ChainRange{}, - OffRampNextSeqNums: []plugintypes.SeqNumChain{}, - RMNRemoteCfg: rmntypes.RemoteConfig{}, - }, + name: "base empty outcome", + expectedOutcome: Outcome{OutcomeType: ReportEmpty}, }, { name: "simple scenario with one chain", diff --git a/commit/merkleroot/rmn/controller.go b/commit/merkleroot/rmn/controller.go index e5ebd14fc..6c788e576 100644 --- a/commit/merkleroot/rmn/controller.go +++ b/commit/merkleroot/rmn/controller.go @@ -824,7 +824,7 @@ func transformAndSortObservations( return attrSigObservations } -// selectsRoots selects the roots from the signed observations. +// selectRoots selects the roots from the signed observations. // If there are more than one valid roots based on the provided F it returns an error. func selectRoots( observations []rmnSignedObservationWithMeta, diff --git a/commit/merkleroot/validate_observation.go b/commit/merkleroot/validate_observation.go index baff4ffd8..9a454940a 100644 --- a/commit/merkleroot/validate_observation.go +++ b/commit/merkleroot/validate_observation.go @@ -76,7 +76,7 @@ func validateObservedMerkleRoots( return fmt.Errorf("%s invalid: chain already appears in another observed root", root) } - if len(root.OnRampAddress) == 0 { + if root.OnRampAddress.IsZeroOrEmpty() { return fmt.Errorf("%s invalid: empty OnRampAddress", root) } @@ -123,10 +123,6 @@ func validateObservedOnRampMaxSeqNums( return fmt.Errorf("duplicate onRampMaxSeqNum for chain %d", seqNumChain.ChainSel) } - if seqNumChain.ChainSel == 0 { - return fmt.Errorf("onRampMaxSeqNum for chain %d has chain selector 0", seqNumChain.ChainSel) - } - seenChains.Add(seqNumChain.ChainSel) } @@ -192,8 +188,8 @@ func validateRMNRemoteConfig( return fmt.Errorf("not enough signers to cover F+1 threshold") } - if len(rmnRemoteConfig.ContractAddress) == 0 { - return fmt.Errorf("empty ContractAddress") + if rmnRemoteConfig.ContractAddress.IsZeroOrEmpty() { + return fmt.Errorf("empty ContractAddress: %s", rmnRemoteConfig.ContractAddress) } seenNodeIndexes := mapset.NewSet[uint64]() diff --git a/commit/plugin_e2e_test.go b/commit/plugin_e2e_test.go index 53b3b27b5..39b9e9dce 100644 --- a/commit/plugin_e2e_test.go +++ b/commit/plugin_e2e_test.go @@ -266,7 +266,7 @@ func TestPlugin_E2E_AllNodesAgree_TokenPrices(t *testing.T) { nodes := make([]ocr3types.ReportingPlugin[[]byte], len(oracleIDs)) - merkleOutcome := baseMerkleOutcome(params.rmnReportCfg) + merkleOutcome := reportEmptyMerkleRootOutcome() testCases := []struct { name string @@ -321,6 +321,7 @@ func TestPlugin_E2E_AllNodesAgree_TokenPrices(t *testing.T) { }, BlessedMerkleRoots: make([]ccipocr3.MerkleRootChain, 0), UnblessedMerkleRoots: make([]ccipocr3.MerkleRootChain, 0), + RMNSignatures: make([]ccipocr3.RMNECDSASignature, 0), }, }, }, @@ -422,6 +423,7 @@ func TestPlugin_E2E_AllNodesAgree_TokenPrices(t *testing.T) { }, BlessedMerkleRoots: make([]ccipocr3.MerkleRootChain, 0), UnblessedMerkleRoots: make([]ccipocr3.MerkleRootChain, 0), + RMNSignatures: make([]ccipocr3.RMNECDSASignature, 0), }, }, }, @@ -465,7 +467,7 @@ func TestPlugin_E2E_AllNodesAgree_TokenPrices(t *testing.T) { func TestPlugin_E2E_AllNodesAgree_ChainFee(t *testing.T) { params := defaultNodeParams(t) - merkleOutcome := baseMerkleOutcome(params.rmnReportCfg) + merkleOutcome := reportEmptyMerkleRootOutcome() nodes := make([]ocr3types.ReportingPlugin[[]byte], len(oracleIDs)) newFeeComponents, newNativePrice, packedGasPrice := newRandomFees() @@ -572,7 +574,7 @@ func TestPlugin_E2E_AllNodesAgree_ChainFee(t *testing.T) { ChainFeeOutcome: expectedChain1FeeOutcome, }, expOutcome: committypes.Outcome{ - MerkleRootOutcome: noReportMerkleOutcome(params.rmnReportCfg), + MerkleRootOutcome: merkleOutcome, ChainFeeOutcome: expectedChain1FeeOutcome, MainOutcome: committypes.MainOutcome{InflightPriceOcrSequenceNumber: 1, RemainingPriceChecks: 10}, }, @@ -599,7 +601,7 @@ func TestPlugin_E2E_AllNodesAgree_ChainFee(t *testing.T) { ChainFeeOutcome: expectedChain1FeeOutcome, }, expOutcome: committypes.Outcome{ - MerkleRootOutcome: noReportMerkleOutcome(params.rmnReportCfg), + MerkleRootOutcome: reportEmptyMerkleRootOutcome(), ChainFeeOutcome: chainfee.Outcome{ GasPrices: []ccipocr3.GasPriceChain{ { @@ -633,7 +635,7 @@ func TestPlugin_E2E_AllNodesAgree_ChainFee(t *testing.T) { ChainFeeOutcome: expectedChain1FeeOutcome, }, expOutcome: committypes.Outcome{ - MerkleRootOutcome: noReportMerkleOutcome(params.rmnReportCfg), + MerkleRootOutcome: reportEmptyMerkleRootOutcome(), ChainFeeOutcome: chainfee.Outcome{ GasPrices: []ccipocr3.GasPriceChain{ { @@ -1005,18 +1007,8 @@ var merkleRoot1 = ccipocr3.Bytes32{0x4a, 0x44, 0xdc, 0x15, 0x36, 0x42, 0x4, 0xa8 0x90, 0x39, 0x45, 0x5c, 0xc1, 0x60, 0x82, 0x81, 0x82, 0xf, 0xe2, 0xb2, 0x4f, 0x1e, 0x52, 0x33, 0xad, 0xe6, 0xaf, 0x1d, 0xd5} -func baseMerkleOutcome(r rmntypes.RemoteConfig) merkleroot.Outcome { - return merkleroot.Outcome{ - OutcomeType: merkleroot.ReportIntervalsSelected, - RMNRemoteCfg: r, - } -} - -func noReportMerkleOutcome(r rmntypes.RemoteConfig) merkleroot.Outcome { - return merkleroot.Outcome{ - OutcomeType: merkleroot.ReportEmpty, - RMNRemoteCfg: r, - } +func reportEmptyMerkleRootOutcome() merkleroot.Outcome { + return merkleroot.Outcome{OutcomeType: merkleroot.ReportEmpty} } func newRandomFees() (components types.ChainFeeComponents, nativePrice ccipocr3.BigInt, usdPrices ccipocr3.BigInt) { diff --git a/execute/observation.go b/execute/observation.go index 28e13aa6f..35b1565fe 100644 --- a/execute/observation.go +++ b/execute/observation.go @@ -198,7 +198,7 @@ func (p *Plugin) getCommitReportsObservation( } // Get pending exec reports. - groupedCommits, fullyExecutedFinalized, fullyExecutedUnfinalized, err := getPendingExecutedReports( + groupedCommits, fullyExecutedFinalized, fullyExecutedUnfinalized, err := getPendingReportsForExecution( ctx, p.ccipReader, p.commitRootsCache.CanExecute, diff --git a/execute/plugin.go b/execute/plugin.go index 0e2d9e9e2..66cda66ee 100644 --- a/execute/plugin.go +++ b/execute/plugin.go @@ -149,7 +149,7 @@ func (p *Plugin) Query(ctx context.Context, outctx ocr3types.OutcomeContext) (ty type CanExecuteHandle = func(sel cciptypes.ChainSelector, merkleRoot cciptypes.Bytes32) bool -// getPendingExecutedReports is used to find commit reports which need to be executed. +// getPendingReportsForExecution is used to find commit reports which need to be executed. // // The function checks execution status at two levels: // 1. Gets all executed messages (both finalized and unfinalized) via primitives.Unconfirmed @@ -159,7 +159,7 @@ type CanExecuteHandle = func(sel cciptypes.ChainSelector, merkleRoot cciptypes.B // - fullyExecutedFinalized: All messages executed with finality (mark as executed) // - fullyExecutedUnfinalized: All messages executed but not finalized (snooze) // - groupedCommits: Reports with unexecuted messages (available for execution) -func getPendingExecutedReports( +func getPendingReportsForExecution( ctx context.Context, ccipReader readerpkg.CCIPReader, canExecute CanExecuteHandle, @@ -286,8 +286,8 @@ func (p *Plugin) ValidateObservation( return fmt.Errorf("error finding supported chains by node: %w", err) } - state := previousOutcome.State.Next() - if state == exectypes.Initialized || state == exectypes.GetCommitReports { + nextState := previousOutcome.State.Next() + if nextState == exectypes.GetCommitReports { err = validateNoMessageRelatedObservations( decodedObservation.Messages, decodedObservation.TokenData, @@ -303,7 +303,7 @@ func (p *Plugin) ValidateObservation( } // check message related validations when states can contain messages - if state == exectypes.GetMessages || state == exectypes.Filter { + if nextState == exectypes.GetMessages || nextState == exectypes.Filter { if err = validateMsgsReadingEligibility(supportedChains, decodedObservation.Messages); err != nil { return fmt.Errorf("validate observer reading eligibility: %w", err) } @@ -337,7 +337,7 @@ func validateCommonStateObservations( return fmt.Errorf("validate commit reports reading eligibility: %w", err) } - if err := validateObservedSequenceNumbers(decodedObservation.CommitReports); err != nil { + if err := validateObservedSequenceNumbers(supportedChains, decodedObservation.CommitReports); err != nil { return fmt.Errorf("validate observed sequence numbers: %w", err) } diff --git a/execute/plugin_functions.go b/execute/plugin_functions.go index 99a04b3e6..4de60dc47 100644 --- a/execute/plugin_functions.go +++ b/execute/plugin_functions.go @@ -34,7 +34,8 @@ func validateCommitReportsReadingEligibility( } for _, data := range observedData[chainSel] { if data.SourceChain != chainSel { - return fmt.Errorf("observer not allowed to read from chain %d", data.SourceChain) + return fmt.Errorf("invalid observed data, key=%d but data chain=%d", + chainSel, data.SourceChain) } } } @@ -101,15 +102,24 @@ func validateHashesExist( } for chain, msgs := range observedMsgs { - _, ok := hashes[chain] + hashesForChain, ok := hashes[chain] if !ok { return fmt.Errorf("hash not found for chain %d", chain) } + if len(msgs) != len(hashesForChain) { + return fmt.Errorf("unexpected number of message hashes for chain %d: expected %d, got %d", + chain, len(msgs), len(hashesForChain)) + } + for seq, msg := range msgs { - if _, ok := hashes[chain][seq]; !ok { + h, exists := hashes[chain][seq] + if !exists { return fmt.Errorf("hash not found for message %s", msg) } + if h.IsEmpty() { + return fmt.Errorf("hash is empty for message %s", msg) + } } } @@ -123,6 +133,11 @@ func validateMessagesConformToCommitReports( observedData exectypes.CommitObservations, observedMsgs exectypes.MessageObservations, ) error { + if len(observedData) != len(observedMsgs) { + return fmt.Errorf("count of observed data=%d and observed msgs=%d do not match", + len(observedData), len(observedMsgs)) + } + msgsCount := 0 for chain, report := range observedData { for _, data := range report { @@ -154,9 +169,14 @@ func validateMessagesConformToCommitReports( // validateObservedSequenceNumbers checks if the sequence numbers of the provided messages are unique for each chain // and that they match the observed max sequence numbers. func validateObservedSequenceNumbers( + supportedChains mapset.Set[cciptypes.ChainSelector], observedData map[cciptypes.ChainSelector][]exectypes.CommitData, ) error { - for _, commitData := range observedData { + for chainSel, commitData := range observedData { + if !supportedChains.Contains(chainSel) { + return fmt.Errorf("observed a non-supported chain %d", chainSel) + } + // observed commitData must not contain duplicates observedMerkleRoots := mapset.NewSet[string]() diff --git a/execute/plugin_functions_test.go b/execute/plugin_functions_test.go index 057be0ad0..4143805b1 100644 --- a/execute/plugin_functions_test.go +++ b/execute/plugin_functions_test.go @@ -79,9 +79,10 @@ func Test_validateObserverReadingEligibility(t *testing.T) { func Test_validateObservedSequenceNumbers(t *testing.T) { testCases := []struct { - name string - observedData map[cciptypes.ChainSelector][]exectypes.CommitData - expErr bool + name string + observedData map[cciptypes.ChainSelector][]exectypes.CommitData + supportedChains mapset.Set[cciptypes.ChainSelector] + expErr bool }{ { name: "ValidData", @@ -101,6 +102,28 @@ func Test_validateObservedSequenceNumbers(t *testing.T) { }, }, }, + supportedChains: mapset.NewSet(cciptypes.ChainSelector(1), cciptypes.ChainSelector(2)), + }, + { + name: "UnsupportedChain", + observedData: map[cciptypes.ChainSelector][]exectypes.CommitData{ + 1: { + { + MerkleRoot: cciptypes.Bytes32{1}, + SequenceNumberRange: cciptypes.SeqNumRange{1, 3}, + ExecutedMessages: []cciptypes.SeqNum{1, 2, 3}, + }, + }, + 2: { + { + MerkleRoot: cciptypes.Bytes32{2}, + SequenceNumberRange: cciptypes.SeqNumRange{11, 15}, + ExecutedMessages: []cciptypes.SeqNum{11, 12, 13}, + }, + }, + }, + supportedChains: mapset.NewSet(cciptypes.ChainSelector(1)), // <-- 2 is missing + expErr: true, }, { name: "DuplicateMerkleRoot", @@ -118,7 +141,8 @@ func Test_validateObservedSequenceNumbers(t *testing.T) { }, }, }, - expErr: true, + supportedChains: mapset.NewSet(cciptypes.ChainSelector(1), cciptypes.ChainSelector(2)), + expErr: true, }, { name: "OverlappingSequenceNumberRange", @@ -136,7 +160,8 @@ func Test_validateObservedSequenceNumbers(t *testing.T) { }, }, }, - expErr: true, + supportedChains: mapset.NewSet(cciptypes.ChainSelector(1), cciptypes.ChainSelector(2)), + expErr: true, }, { name: "ExecutedMessageOutsideObservedRange", @@ -149,23 +174,26 @@ func Test_validateObservedSequenceNumbers(t *testing.T) { }, }, }, - expErr: true, + supportedChains: mapset.NewSet(cciptypes.ChainSelector(1), cciptypes.ChainSelector(2)), + expErr: true, }, { name: "NoCommitData", observedData: map[cciptypes.ChainSelector][]exectypes.CommitData{ 1: {}, }, + supportedChains: mapset.NewSet(cciptypes.ChainSelector(1), cciptypes.ChainSelector(2)), }, { - name: "EmptyObservedData", - observedData: map[cciptypes.ChainSelector][]exectypes.CommitData{}, + name: "EmptyObservedData", + observedData: map[cciptypes.ChainSelector][]exectypes.CommitData{}, + supportedChains: mapset.NewSet(cciptypes.ChainSelector(1), cciptypes.ChainSelector(2)), }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - err := validateObservedSequenceNumbers(tc.observedData) + err := validateObservedSequenceNumbers(tc.supportedChains, tc.observedData) if tc.expErr { assert.Error(t, err) return @@ -187,6 +215,7 @@ func Test_validateMessagesConformToCommitReports(t *testing.T) { observedData: map[cciptypes.ChainSelector][]exectypes.CommitData{ 1: {}, }, + expErr: true, }, { name: "EmptyObservedData", @@ -1457,7 +1486,7 @@ func Test_validateCommitReportsReadingEligibility(t *testing.T) { {SourceChain: 2}, }, }, - expErr: "observer not allowed to read from chain 2", + expErr: "invalid observed data, key=1 but data chain=2", }, } diff --git a/execute/plugin_test.go b/execute/plugin_test.go index fc11bd9cf..71947a7df 100644 --- a/execute/plugin_test.go +++ b/execute/plugin_test.go @@ -242,7 +242,7 @@ func Test_checkAlreadyExecuted(t *testing.T) { } } -func Test_getPendingExecutedReports(t *testing.T) { +func Test_getPendingReportsForExecution(t *testing.T) { canExecute := func(ret bool) CanExecuteHandle { return func(cciptypes.ChainSelector, cciptypes.Bytes32) bool { return ret } } @@ -593,19 +593,19 @@ func Test_getPendingExecutedReports(t *testing.T) { mockReader.On("ExecutedMessages", mock.Anything, k, mock.Anything, primitives.Unconfirmed).Return(v, nil) } - got, gotFinalized, gotUnfinalized, err := getPendingExecutedReports( + got, gotFinalized, gotUnfinalized, err := getPendingReportsForExecution( tests.Context(t), mockReader, tt.canExec, time.Now(), logger.Test(t), ) - if !tt.wantErr(t, err, "getPendingExecutedReports(...)") { + if !tt.wantErr(t, err, "getPendingReportsForExecution(...)") { return } - assert.Equalf(t, tt.wantObs, got, "getPendingExecutedReports(...)") - assert.Equalf(t, tt.wantExecutedFinalized, gotFinalized, "getPendingExecutedReports(...)") - assert.Equalf(t, tt.wantExecutedUnfinalized, gotUnfinalized, "getPendingExecutedReports(...)") + assert.Equalf(t, tt.wantObs, got, "getPendingReportsForExecution(...)") + assert.Equalf(t, tt.wantExecutedFinalized, gotFinalized, "getPendingReportsForExecution(...)") + assert.Equalf(t, tt.wantExecutedUnfinalized, gotUnfinalized, "getPendingReportsForExecution(...)") }) } } diff --git a/internal/reader/home_chain.go b/internal/reader/home_chain.go index 4feb0d52e..9eef7b2e4 100644 --- a/internal/reader/home_chain.go +++ b/internal/reader/home_chain.go @@ -103,6 +103,9 @@ func (r *homeChainPoller) poll() { if err := r.fetchAndSetConfigs(ctx); err != nil { // Just log, don't return error as we want to keep polling r.lggr.Errorw("Initial fetch of on-chain configs failed", "err", err) + r.failedPolls.Add(1) + } else { + r.failedPolls.Store(0) } ticker := time.NewTicker(r.pollingDuration) diff --git a/pkg/reader/ccip.go b/pkg/reader/ccip.go index aae8cb348..c47fc021b 100644 --- a/pkg/reader/ccip.go +++ b/pkg/reader/ccip.go @@ -1,6 +1,7 @@ package reader import ( + "bytes" "context" "encoding/binary" "encoding/hex" @@ -190,7 +191,7 @@ func (r *ccipChainReader) CommitReportsGTETimestamp( reports := make([]plugintypes2.CommitPluginReportWithMeta, 0) for _, item := range iter { - ev, err := validateCommitReportAcceptedEvent(item, ts) + ev, err := validateCommitReportAcceptedEvent(item, ts, r.destChain) if err != nil { lggr.Errorw("validate commit report accepted event", "err", err, "ev", ev) continue @@ -206,18 +207,9 @@ func (r *ccipChainReader) CommitReportsGTETimestamp( blessedMerkleRoots := make([]cciptypes.MerkleRootChain, 0, len(ev.BlessedMerkleRoots)) unblessedMerkleRoots := make([]cciptypes.MerkleRootChain, 0, len(ev.UnblessedMerkleRoots)) for _, mr := range allMerkleRoots { - onRampAddress, err := r.GetContractAddress( - consts.ContractNameOnRamp, - cciptypes.ChainSelector(mr.SourceChainSelector), - ) - if err != nil { - r.lggr.Errorw("get onRamp address for selector", "sourceChain", mr.SourceChainSelector, "err", err) - continue - } - mrc := cciptypes.MerkleRootChain{ ChainSel: cciptypes.ChainSelector(mr.SourceChainSelector), - OnRampAddress: onRampAddress, + OnRampAddress: mr.OnRampAddress, SeqNumsRange: cciptypes.NewSeqNumRange( cciptypes.SeqNum(mr.MinSeqNr), cciptypes.SeqNum(mr.MaxSeqNr), @@ -411,6 +403,14 @@ func (r *ccipChainReader) MsgsBetweenSeqNums( return nil, fmt.Errorf("failed to query onRamp: %w", err) } + onRampAddressAfterQuery, err := r.GetContractAddress(consts.ContractNameOnRamp, sourceChainSelector) + if err != nil { + return nil, fmt.Errorf("get onRamp address after query: %w", err) + } + if !bytes.Equal(onRampAddress, onRampAddressAfterQuery) { + return nil, fmt.Errorf("onRamp address has changed from %s to %s", onRampAddress, onRampAddressAfterQuery) + } + lggr.Infow("queried messages between sequence numbers", "numMsgs", len(seq), "sourceChainSelector", sourceChainSelector, @@ -624,15 +624,20 @@ func (r *ccipChainReader) Nonces( for i, readResult := range results { address := getAddressByIndex(addressToIndex, i) + if address == "" { + lggr.Errorw("critical error, address not found for index", "index", i) + continue + } + returnVal, err := readResult.GetResult() if err != nil { - r.lggr.Errorw("failed to get nonce for address", "address", address, "err", err) + lggr.Errorw("failed to get nonce for address", "address", address, "err", err) continue } val, ok := returnVal.(*uint64) if !ok || val == nil { - r.lggr.Errorw("invalid nonce value returned", "address", address) + lggr.Errorw("invalid nonce value returned", "address", address) continue } @@ -1023,7 +1028,7 @@ func (r *ccipChainReader) DiscoverContracts(ctx context.Context, defer mu.Unlock() // Add FeeQuoter from dynamic config - if len(config.OnRamp.DynamicConfig.DynamicConfig.FeeQuoter) > 0 { + if !cciptypes.UnknownAddress(config.OnRamp.DynamicConfig.DynamicConfig.FeeQuoter).IsZeroOrEmpty() { resp = resp.Append( consts.ContractNameFeeQuoter, chainSel, @@ -1031,7 +1036,7 @@ func (r *ccipChainReader) DiscoverContracts(ctx context.Context, } // Add Router from dest chain config - if len(config.OnRamp.DestChainConfig.Router) > 0 { + if !cciptypes.UnknownAddress(config.OnRamp.DestChainConfig.Router).IsZeroOrEmpty() { resp = resp.Append( consts.ContractNameRouter, chainSel, @@ -1844,12 +1849,18 @@ func (r *ccipChainReader) processFeeQuoterResults(results []types.BatchReadResul return FeeQuoterConfig{}, fmt.Errorf("invalid type for fee quoter static config: %T", val) } -func validateCommitReportAcceptedEvent(seq types.Sequence, gteTimestamp time.Time) (*CommitReportAcceptedEvent, error) { +func validateCommitReportAcceptedEvent( + seq types.Sequence, gteTimestamp time.Time, destChain cciptypes.ChainSelector, +) (*CommitReportAcceptedEvent, error) { ev, is := (seq.Data).(*CommitReportAcceptedEvent) if !is { return nil, fmt.Errorf("unexpected type %T while expecting a commit report", seq) } + if ev == nil { + return nil, fmt.Errorf("commit report accepted event is nil") + } + if seq.Timestamp < uint64(gteTimestamp.Unix()) { return nil, fmt.Errorf("commit report accepted event timestamp is less than the minimum timestamp %v<%v", seq.Timestamp, gteTimestamp.Unix()) @@ -1860,8 +1871,8 @@ func validateCommitReportAcceptedEvent(seq types.Sequence, gteTimestamp time.Tim } for _, tpus := range ev.PriceUpdates.TokenPriceUpdates { - if len(tpus.SourceToken) == 0 { - return nil, fmt.Errorf("empty source token") + if tpus.SourceToken.IsZeroOrEmpty() { + return nil, fmt.Errorf("invalid source token address: %s", tpus.SourceToken.String()) } if tpus.UsdPerToken == nil || tpus.UsdPerToken.Cmp(big.NewInt(0)) <= 0 { return nil, fmt.Errorf("nil or non-positive usd per token") @@ -1869,8 +1880,9 @@ func validateCommitReportAcceptedEvent(seq types.Sequence, gteTimestamp time.Tim } for _, gpus := range ev.PriceUpdates.GasPriceUpdates { - if gpus.DestChainSelector == 0 { - return nil, fmt.Errorf("dest chain is zero") + if gpus.DestChainSelector != uint64(destChain) { + return nil, fmt.Errorf("dest chain does not match the reader's one %d != %d", + gpus.DestChainSelector, destChain) } if gpus.UsdPerUnitGas == nil || gpus.UsdPerUnitGas.Cmp(big.NewInt(0)) <= 0 { return nil, fmt.Errorf("nil or non-positive usd per unit gas") @@ -1904,8 +1916,8 @@ func validateMerkleRoots(merkleRoots []MerkleRoot) error { if mr.MerkleRoot.IsEmpty() { return fmt.Errorf("empty merkle root") } - if len(mr.OnRampAddress) == 0 { - return fmt.Errorf("empty onramp address") + if mr.OnRampAddress.IsZeroOrEmpty() { + return fmt.Errorf("invalid onramp address: %s", mr.OnRampAddress.String()) } } @@ -1947,6 +1959,9 @@ func validateSendRequestedEvent( return fmt.Errorf("send requested event is nil") } + if ev.Message.Header.DestChainSelector != dest { + return fmt.Errorf("msg dest chain is not the expected queried one") + } if ev.DestChainSelector != dest { return fmt.Errorf("dest chain is not the expected queried one") } @@ -1955,6 +1970,11 @@ func validateSendRequestedEvent( return fmt.Errorf("source chain is not the expected queried one") } + if ev.SequenceNumber != ev.Message.Header.SequenceNumber { + return fmt.Errorf("event sequence number does not match the message sequence number %d != %d", + ev.SequenceNumber, ev.Message.Header.SequenceNumber) + } + if ev.SequenceNumber < seqNumRange.Start() || ev.SequenceNumber > seqNumRange.End() { return fmt.Errorf("send requested event sequence number is not in the expected range") } @@ -1964,19 +1984,19 @@ func validateSendRequestedEvent( } if len(ev.Message.Receiver) == 0 { - return fmt.Errorf("empty receiver address") + return fmt.Errorf("empty receiver address: %s", ev.Message.Receiver.String()) } - if len(ev.Message.Sender) == 0 { - return fmt.Errorf("empty sender address") + if ev.Message.Sender.IsZeroOrEmpty() { + return fmt.Errorf("invalid sender address: %s", ev.Message.Sender.String()) } if ev.Message.FeeTokenAmount.IsEmpty() { return fmt.Errorf("fee token amount is zero") } - if len(ev.Message.FeeToken) == 0 { - return fmt.Errorf("empty fee token") + if ev.Message.FeeToken.IsZeroOrEmpty() { + return fmt.Errorf("invalid fee token: %s", ev.Message.FeeToken.String()) } return nil diff --git a/pkg/types/ccipocr3/common_types.go b/pkg/types/ccipocr3/common_types.go index f0b5d8eee..744582466 100644 --- a/pkg/types/ccipocr3/common_types.go +++ b/pkg/types/ccipocr3/common_types.go @@ -32,6 +32,20 @@ func (a *UnknownAddress) UnmarshalJSON(data []byte) error { return (*Bytes)(a).UnmarshalJSON(data) } +// IsZeroOrEmpty returns true if the address contains 0 bytes or if all the bytes are 0. +func (a UnknownAddress) IsZeroOrEmpty() bool { + if len(a) == 0 { + return true // empty + } + + for _, b := range a { + if b != 0 { + return false // zero + } + } + return true +} + // UnknownEncodedAddress represents an encoded address with an unknown encoding. type UnknownEncodedAddress string @@ -65,7 +79,7 @@ func (b Bytes) MarshalJSON() ([]byte, error) { func (b *Bytes) UnmarshalJSON(data []byte) error { v := string(data) if len(v) < 4 { - return fmt.Errorf("bytes must be of at least length 2 (i.e, '\"0x\"'): %s", v) + return fmt.Errorf("bytes must be of at least length 4 (i.e, '\"0x\"'): %s", v) } // trim the start and end double quotes @@ -88,8 +102,8 @@ func (b *Bytes) UnmarshalJSON(data []byte) error { type Bytes32 [32]byte func NewBytes32FromString(s string) (Bytes32, error) { - if len(s) < 2 { - return Bytes32{}, fmt.Errorf("Bytes32 must be of at least length 2 (i.e, '0x' prefix): %s", s) + if len(s) > 66 { // "0x" + 64 hex chars + return Bytes32{}, fmt.Errorf("Bytes32 must be at most 32 bytes (64 hex chars) long: %s", s) } if !strings.HasPrefix(s, "0x") { @@ -121,10 +135,16 @@ func (b Bytes32) MarshalJSON() ([]byte, error) { func (b *Bytes32) UnmarshalJSON(data []byte) error { v := string(data) if len(v) < 4 { - return fmt.Errorf("invalid MerkleRoot: %s", v) + return fmt.Errorf("invalid Bytes32: %s", v) + } + v = v[1 : len(v)-1] // trim quotes + + if !strings.HasPrefix(v, "0x") { + return fmt.Errorf("bytes must start with '0x' prefix: %s", v) } + v = v[2:] // trim 0x prefix - bCp, err := hex.DecodeString(v[1 : len(v)-1][2:]) + bCp, err := hex.DecodeString(v) if err != nil { return err } diff --git a/pkg/types/ccipocr3/common_types_test.go b/pkg/types/ccipocr3/common_types_test.go index ddac03493..a8ebb2beb 100644 --- a/pkg/types/ccipocr3/common_types_test.go +++ b/pkg/types/ccipocr3/common_types_test.go @@ -1,6 +1,7 @@ package ccipocr3 import ( + "strings" "testing" "github.com/stretchr/testify/assert" @@ -33,11 +34,34 @@ func TestNewBytes32FromString(t *testing.T) { expErr: true, }, { - name: "invalid input, not enough hex chars", + name: "invalid input, odd len", input: "0x2", expected: Bytes32{}, expErr: true, }, + { + name: "valid input, not enough hex chars", + input: "0x22", + expected: Bytes32{0x22}, + expErr: false, + }, + { + name: "valid input exact length", + input: "0x" + strings.Repeat("12", 32), + expected: Bytes32{ + 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, + 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, + 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, + 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, + }, + expErr: false, + }, + { + name: "invalid input, tou much hex chars", + input: "0x" + strings.Repeat("12", 33), + expected: Bytes32{}, + expErr: true, + }, } for _, tc := range testCases {