Skip to content

Commit

Permalink
Add Solana codec validations (#16567)
Browse files Browse the repository at this point in the history
* Add Solana Address length validation

* Add missing merkleroot validation in commit codec decode()

* Update core/capabilities/ccip/oraclecreator/plugin.go

Co-authored-by: Joe Huang <[email protected]>

* Update core/capabilities/ccip/ccipsolana/commitcodec.go

Co-authored-by: amit-momin <[email protected]>

---------

Co-authored-by: Joe Huang <[email protected]>
Co-authored-by: amit-momin <[email protected]>
  • Loading branch information
3 people authored Feb 26, 2025
1 parent 87b3664 commit af0dce9
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 12 deletions.
23 changes: 13 additions & 10 deletions core/capabilities/ccip/ccipsolana/commitcodec.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,19 @@ func (c *CommitPluginCodecV1) Decode(ctx context.Context, bytes []byte) (cciptyp
return cciptypes.CommitPluginReport{}, err
}

merkleRoots := []cciptypes.MerkleRootChain{
{
ChainSel: cciptypes.ChainSelector(commitReport.MerkleRoot.SourceChainSelector),
OnRampAddress: commitReport.MerkleRoot.OnRampAddress,
SeqNumsRange: cciptypes.NewSeqNumRange(
cciptypes.SeqNum(commitReport.MerkleRoot.MinSeqNr),
cciptypes.SeqNum(commitReport.MerkleRoot.MaxSeqNr),
),
MerkleRoot: commitReport.MerkleRoot.MerkleRoot,
},
var merkleRoots []cciptypes.MerkleRootChain
if commitReport.MerkleRoot != nil {
merkleRoots = []cciptypes.MerkleRootChain{
{
ChainSel: cciptypes.ChainSelector(commitReport.MerkleRoot.SourceChainSelector),
OnRampAddress: commitReport.MerkleRoot.OnRampAddress,
SeqNumsRange: cciptypes.NewSeqNumRange(
cciptypes.SeqNum(commitReport.MerkleRoot.MinSeqNr),
cciptypes.SeqNum(commitReport.MerkleRoot.MaxSeqNr),
),
MerkleRoot: commitReport.MerkleRoot.MerkleRoot,
},
}
}

tokenPriceUpdates := make([]cciptypes.TokenPrice, 0, len(commitReport.PriceUpdates.TokenPriceUpdates))
Expand Down
50 changes: 50 additions & 0 deletions core/capabilities/ccip/ccipsolana/commitcodec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,56 @@ func Test_DecodingCommitReport(t *testing.T) {
require.Equal(t, chainSel, gu.ChainSel)
})

t.Run("decode on-chain commit report with no MerkleRoot", func(t *testing.T) {
chainSel := cciptypes.ChainSelector(rand.Uint64())

tokenSource := solanago.MustPublicKeyFromBase58("C8WSPj3yyus1YN3yNB6YA5zStYtbjQWtpmKadmvyUXq8")
tokenPrice := encodeBigIntToFixedLengthLE(big.NewInt(rand.Int63()), 28)
gasPrice := encodeBigIntToFixedLengthLE(big.NewInt(rand.Int63()), 28)

tpu := []ccip_offramp.TokenPriceUpdate{
{
SourceToken: tokenSource,
UsdPerToken: [28]uint8(tokenPrice),
},
}

gpu := []ccip_offramp.GasPriceUpdate{
{UsdPerUnitGas: [28]uint8(gasPrice), DestChainSelector: uint64(chainSel)},
{UsdPerUnitGas: [28]uint8(gasPrice), DestChainSelector: uint64(chainSel)},
{UsdPerUnitGas: [28]uint8(gasPrice), DestChainSelector: uint64(chainSel)},
}

onChainReport := ccip_offramp.CommitInput{
MerkleRoot: nil,
PriceUpdates: ccip_offramp.PriceUpdates{
TokenPriceUpdates: tpu,
GasPriceUpdates: gpu,
},
}

var buf bytes.Buffer
encoder := agbinary.NewBorshEncoder(&buf)
err := onChainReport.MarshalWithEncoder(encoder)
require.NoError(t, err)

commitCodec := NewCommitPluginCodecV1()
decode, err := commitCodec.Decode(testutils.Context(t), buf.Bytes())
require.NoError(t, err)
require.Nilf(t, decode.UnblessedMerkleRoots, "UnblessedMerkleRoots should be nil")
require.Nilf(t, decode.BlessedMerkleRoots, "BlessedMerkleRoots should be nil")

// check decoded ocr report token price update matches with on-chain report
pu := decode.PriceUpdates.TokenPriceUpdates[0]
require.Equal(t, decodeLEToBigInt(tokenPrice), pu.Price)
require.Equal(t, cciptypes.UnknownEncodedAddress(tokenSource.String()), pu.TokenID)

// check decoded ocr report gas price update matches with on-chain report
gu := decode.PriceUpdates.GasPriceUpdates[0]
require.Equal(t, decodeLEToBigInt(gasPrice), gu.GasPrice)
require.Equal(t, chainSel, gu.ChainSel)
})

t.Run("decode Borsh encoded commit report", func(t *testing.T) {
rep := randomBlessedCommitReport()
commitCodec := NewCommitPluginCodecV1()
Expand Down
4 changes: 4 additions & 0 deletions core/capabilities/ccip/ccipsolana/executecodec.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

agbinary "github.com/gagliardetto/binary"
"github.com/gagliardetto/solana-go"

"github.com/smartcontractkit/chainlink/v2/core/capabilities/ccip/common"

"github.com/smartcontractkit/chainlink-ccip/chains/solana/gobindings/ccip_offramp"
Expand Down Expand Up @@ -64,6 +65,9 @@ func (e *ExecutePluginCodecV1) Encode(ctx context.Context, report cciptypes.Exec
return nil, err
}

if solana.PublicKeyLength != len(tokenAmount.DestTokenAddress) {
return nil, fmt.Errorf("invalid DestTokenAddress length: %d", len(tokenAmount.DestTokenAddress))
}
tokenAmounts = append(tokenAmounts, ccip_offramp.Any2SVMTokenTransfer{
SourcePoolAddress: tokenAmount.SourcePoolAddress,
DestTokenAddress: solana.PublicKeyFromBytes(tokenAmount.DestTokenAddress),
Expand Down
19 changes: 19 additions & 0 deletions core/capabilities/ccip/ccipsolana/executecodec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (

cciptypes "github.com/smartcontractkit/chainlink-ccip/pkg/types/ccipocr3"
"github.com/smartcontractkit/chainlink-integrations/evm/utils"

"github.com/smartcontractkit/chainlink/v2/core/internal/testutils"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -137,6 +138,24 @@ func TestExecutePluginCodecV1(t *testing.T) {
expErr: false,
chainSelector: 124615329519749607, // Solana mainnet chain selector
},
{
name: "reports have invalid DestTokenAddress",
report: func(report cciptypes.ExecutePluginReport) cciptypes.ExecutePluginReport {
report.ChainReports[0].Messages[0].TokenAmounts[0].DestTokenAddress = []byte{0, 0}
return report
},
expErr: true,
chainSelector: 124615329519749607, // Solana mainnet chain selector
},
{
name: "reports have invalid receiver",
report: func(report cciptypes.ExecutePluginReport) cciptypes.ExecutePluginReport {
report.ChainReports[0].Messages[0].Receiver = []byte{0, 0}
return report
},
expErr: true,
chainSelector: 124615329519749607, // Solana mainnet chain selector
},
}

ctx := testutils.Context(t)
Expand Down
7 changes: 7 additions & 0 deletions core/capabilities/ccip/ccipsolana/msghasher.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strings"

"github.com/gagliardetto/solana-go"

"github.com/smartcontractkit/chainlink/v2/core/capabilities/ccip/common"

"github.com/smartcontractkit/chainlink-ccip/chains/solana/gobindings/ccip_offramp"
Expand Down Expand Up @@ -44,6 +45,9 @@ func (h *MessageHasherV1) Hash(_ context.Context, msg cciptypes.Message) (ccipty
MessageId: msg.Header.MessageID,
Nonce: msg.Header.Nonce,
}
if solana.PublicKeyLength != len(msg.Receiver) {
return [32]byte{}, fmt.Errorf("invalid receiver length: %d", len(msg.Receiver))
}
anyToSolanaMessage.TokenReceiver = solana.PublicKeyFromBytes(msg.Receiver)
anyToSolanaMessage.Sender = msg.Sender
anyToSolanaMessage.Data = msg.Data
Expand All @@ -58,6 +62,9 @@ func (h *MessageHasherV1) Hash(_ context.Context, msg cciptypes.Message) (ccipty
return [32]byte{}, err
}

if solana.PublicKeyLength != len(ta.DestTokenAddress) {
return [32]byte{}, fmt.Errorf("invalid DestTokenAddress length: %d", len(ta.DestTokenAddress))
}
anyToSolanaMessage.TokenAmounts = append(anyToSolanaMessage.TokenAmounts, ccip_offramp.Any2SVMTokenTransfer{
SourcePoolAddress: ta.SourcePoolAddress,
DestTokenAddress: solana.PublicKeyFromBytes(ta.DestTokenAddress),
Expand Down
49 changes: 48 additions & 1 deletion core/capabilities/ccip/ccipsolana/msghasher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ import (
"github.com/smartcontractkit/chainlink-common/pkg/logger"

"github.com/smartcontractkit/chainlink-integrations/evm/utils"

"github.com/smartcontractkit/chainlink/v2/core/internal/testutils"

cciptypes "github.com/smartcontractkit/chainlink-ccip/pkg/types/ccipocr3"
)

func TestMessageHasher_Any2Solana(t *testing.T) {
func TestMessageHasher_Any2SVM(t *testing.T) {
any2AnyMsg, any2SolanaMsg, msgAccounts := createAny2SolanaMessages(t)
mockExtraDataCodec := &mocks.ExtraDataCodec{}
mockExtraDataCodec.On("DecodeTokenAmountDestExecData", mock.Anything, mock.Anything).Return(map[string]any{
Expand All @@ -49,6 +50,52 @@ func TestMessageHasher_Any2Solana(t *testing.T) {
require.Equal(t, expectedHash, actualHash[:32])
}

func TestMessageHasher_InvalidReceiver(t *testing.T) {
any2AnyMsg, _, _ := createAny2SolanaMessages(t)

// Set receiver to a []byte of 2 length
any2AnyMsg.Receiver = []byte{0, 0}
mockExtraDataCodec := &mocks.ExtraDataCodec{}
mockExtraDataCodec.On("DecodeTokenAmountDestExecData", mock.Anything, mock.Anything).Return(map[string]any{
"destGasAmount": uint32(10),
}, nil)
mockExtraDataCodec.On("DecodeExtraArgs", mock.Anything, mock.Anything).Return(map[string]any{
"ComputeUnits": uint32(1000),
"AccountIsWritableBitmap": uint64(10),
"Accounts": [][32]byte{
[32]byte(config.CcipLogicReceiver.Bytes()),
[32]byte(config.ReceiverTargetAccountPDA.Bytes()),
[32]byte(solana.SystemProgramID.Bytes()),
},
}, nil)
msgHasher := NewMessageHasherV1(logger.Test(t), mockExtraDataCodec)
_, err := msgHasher.Hash(testutils.Context(t), any2AnyMsg)
require.Error(t, err)
}

func TestMessageHasher_InvalidDestinationTokenAddress(t *testing.T) {
any2AnyMsg, _, _ := createAny2SolanaMessages(t)

// Set DestTokenAddress to a []byte of 2 length
any2AnyMsg.TokenAmounts[0].DestTokenAddress = []byte{0, 0}
mockExtraDataCodec := &mocks.ExtraDataCodec{}
mockExtraDataCodec.On("DecodeTokenAmountDestExecData", mock.Anything, mock.Anything).Return(map[string]any{
"destGasAmount": uint32(10),
}, nil)
mockExtraDataCodec.On("DecodeExtraArgs", mock.Anything, mock.Anything).Return(map[string]any{
"ComputeUnits": uint32(1000),
"AccountIsWritableBitmap": uint64(10),
"Accounts": [][32]byte{
[32]byte(config.CcipLogicReceiver.Bytes()),
[32]byte(config.ReceiverTargetAccountPDA.Bytes()),
[32]byte(solana.SystemProgramID.Bytes()),
},
}, nil)
msgHasher := NewMessageHasherV1(logger.Test(t), mockExtraDataCodec)
_, err := msgHasher.Hash(testutils.Context(t), any2AnyMsg)
require.Error(t, err)
}

func createAny2SolanaMessages(t *testing.T) (cciptypes.Message, ccip_offramp.Any2SVMRampMessage, []solana.PublicKey) {
messageID := utils.RandomBytes32()

Expand Down
6 changes: 5 additions & 1 deletion core/capabilities/ccip/oraclecreator/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"github.com/smartcontractkit/chainlink-common/pkg/types"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/chainwriter"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/config"

ccipcommon "github.com/smartcontractkit/chainlink/v2/core/capabilities/ccip/common"
evmconfig "github.com/smartcontractkit/chainlink/v2/core/capabilities/ccip/configs/evm"
"github.com/smartcontractkit/chainlink/v2/core/capabilities/ccip/ocrimpls"
Expand Down Expand Up @@ -599,8 +600,11 @@ func createChainWriter(
switch chainFamily {
case relay.NetworkSolana:
var solConfig chainwriter.ChainWriterConfig
if solana.PublicKeyLength != len(offrampProgramAddress) {
return nil, fmt.Errorf("invalid offrampProgramAddress length: %d", len(offrampProgramAddress))
}
offrampAddress := solana.PublicKeyFromBytes(offrampProgramAddress)
if solConfig, err = solanaconfig.GetSolanaChainWriterConfig(offrampAddress.String(), transmitter[0], destChainSelector); err == nil {
if solConfig, err = solanaconfig.GetSolanaChainWriterConfig(offrampAddress.String(), transmitter[0], destChainSelector); err != nil {
return nil, fmt.Errorf("failed to get Solana chain writer config: %w", err)
}
if chainWriterConfig, err = json.Marshal(solConfig); err != nil {
Expand Down

0 comments on commit af0dce9

Please sign in to comment.