diff --git a/pkg/solana/chainwriter/chain_writer.go b/pkg/solana/chainwriter/chain_writer.go index c423fad34..3c43c880f 100644 --- a/pkg/solana/chainwriter/chain_writer.go +++ b/pkg/solana/chainwriter/chain_writer.go @@ -384,7 +384,7 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra // Filter the lookup table addresses based on which accounts are actually used filteredLookupTableMap := s.FilterLookupTableAddresses(accounts, derivedTableMap, staticTableMap) - computeLimit := uint32(0) + options := []txmutils.SetTxConfig{} // Transform args if necessary if methodConfig.ArgsTransform != "" { transformFunc, tfErr := FindTransform(methodConfig.ArgsTransform) @@ -392,7 +392,7 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra return errorWithDebugID(fmt.Errorf("error finding transform function: %w", tfErr), debugID) } s.lggr.Debugw("Applying args transformation", "contract", contractName, "method", method) - args, accounts, computeLimit, err = transformFunc(ctx, args, accounts, derivedTableMap) + args, accounts, options, err = transformFunc(ctx, args, accounts, derivedTableMap) if err != nil { return errorWithDebugID(fmt.Errorf("error transforming args: %w", err), debugID) } @@ -439,10 +439,6 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra s.lggr.Debugw("Sending main transaction", "contract", contractName, "method", method) - options := []txmutils.SetTxConfig{} - if computeLimit != 0 { - options = append(options, txmutils.SetComputeUnitLimit(computeLimit)) - } // Enqueue transaction if err = s.txm.Enqueue(ctx, methodConfig.FromAddress, tx, &transactionID, blockhash.Value.LastValidBlockHeight, options...); err != nil { return errorWithDebugID(fmt.Errorf("error enqueuing transaction: %w", err), debugID) diff --git a/pkg/solana/chainwriter/chain_writer_test.go b/pkg/solana/chainwriter/chain_writer_test.go index 84d9b17c8..34c25636b 100644 --- a/pkg/solana/chainwriter/chain_writer_test.go +++ b/pkg/solana/chainwriter/chain_writer_test.go @@ -928,13 +928,19 @@ func TestChainWriter_CCIPOfframp(t *testing.T) { require.Len(t, tokenIndexes, 1) require.Equal(t, uint8(3), tokenIndexes[0]) return true - }), &txID, mock.Anything, mock.AnythingOfType("utils.SetTxConfig")).Return(nil).Run(func(args mock.Arguments) { - setComputeLimit, ok := args[5].(txmutils.SetTxConfig) + }), &txID, mock.Anything, mock.AnythingOfType("utils.SetTxConfig"), mock.AnythingOfType("utils.SetTxConfig")).Return(nil).Run(func(args mock.Arguments) { + opt1, ok := args[5].(txmutils.SetTxConfig) + require.True(t, ok) + + opt2, ok := args[6].(txmutils.SetTxConfig) require.True(t, ok) txConfig := &txmutils.TxConfig{} - setComputeLimit(txConfig) - require.Equal(t, uint32(500), txConfig.ComputeUnitLimit) + opt1(txConfig) + opt2(txConfig) + + require.Equal(t, false, txConfig.EstimateComputeUnitLimit) + require.Equal(t, uint32(1700), txConfig.ComputeUnitLimit) }).Once() // stripped back report just for purposes of example @@ -944,6 +950,9 @@ func TestChainWriter_CCIPOfframp(t *testing.T) { TokenAmounts: []ccipocr3.RampTokenAmount{ { DestTokenAddress: destTokenAddr.Bytes(), + DestExecDataDecoded: map[string]any{ + "destGasAmount": uint32(200), + }, }, }, ExtraArgsDecoded: map[string]any{ @@ -1014,7 +1023,14 @@ func TestChainWriter_CCIPOfframp(t *testing.T) { // The CCIPCommit ArgsTransform should remove the last account since no price updates were provided in the report require.Len(t, tx.Message.Instructions[0].Accounts, 2) return true - }), &txID, mock.Anything).Return(nil).Once() + }), &txID, mock.Anything, mock.AnythingOfType("utils.SetTxConfig")).Return(nil).Run(func(args mock.Arguments) { + opt, ok := args[5].(txmutils.SetTxConfig) + require.True(t, ok) + txConfig := &txmutils.TxConfig{} + opt(txConfig) + + require.Equal(t, true, txConfig.EstimateComputeUnitLimit) + }).Once() submitErr := cw.SubmitTransaction(ctx, ccipconsts.ContractNameOffRamp, ccipconsts.MethodCommit, args, txID, offrampAddr.String(), nil, nil) require.NoError(t, submitErr) diff --git a/pkg/solana/chainwriter/transform_registry.go b/pkg/solana/chainwriter/transform_registry.go index 98071ad69..3b13e37a7 100644 --- a/pkg/solana/chainwriter/transform_registry.go +++ b/pkg/solana/chainwriter/transform_registry.go @@ -8,6 +8,7 @@ import ( "github.com/gagliardetto/solana-go" "github.com/mitchellh/mapstructure" "github.com/smartcontractkit/chainlink-ccip/pkg/types/ccipocr3" + txmutils "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/utils" ) type ReportPostTransform struct { @@ -17,7 +18,10 @@ type ReportPostTransform struct { TokenIndexes []byte } -func FindTransform(id string) (func(context.Context, any, solana.AccountMetaSlice, map[string]map[string][]*solana.AccountMeta) (any, solana.AccountMetaSlice, uint32, error), error) { +// TODO: replace with actual value from CCIP on-chain +const staticCuOverhead = 1000 + +func FindTransform(id string) (func(context.Context, any, solana.AccountMetaSlice, map[string]map[string][]*solana.AccountMeta) (any, solana.AccountMetaSlice, []txmutils.SetTxConfig, error), error) { switch id { case "CCIPExecute": return CCIPExecuteArgsTransform, nil @@ -30,20 +34,35 @@ func FindTransform(id string) (func(context.Context, any, solana.AccountMetaSlic // This Transform function looks up the token pool addresses in the accounts slice and augments the args // with the indexes of the token pool addresses in the accounts slice. -func CCIPExecuteArgsTransform(ctx context.Context, args any, accounts solana.AccountMetaSlice, tableMap map[string]map[string][]*solana.AccountMeta) (any, solana.AccountMetaSlice, uint32, error) { +func CCIPExecuteArgsTransform(ctx context.Context, args any, accounts solana.AccountMetaSlice, tableMap map[string]map[string][]*solana.AccountMeta) (any, solana.AccountMetaSlice, []txmutils.SetTxConfig, error) { var argsTransformed ReportPostTransform err := mapstructure.Decode(args, &argsTransformed) if err != nil { - return nil, nil, 0, err + return nil, nil, []txmutils.SetTxConfig{}, err } if len(argsTransformed.Info.AbstractReports) != 1 || len(argsTransformed.Info.AbstractReports[0].Messages) != 1 { - return nil, nil, 0, fmt.Errorf("Expected 1 report with 1 message") + return nil, nil, []txmutils.SetTxConfig{}, fmt.Errorf("Expected 1 report with 1 message") + } + + // Compute Units: static cu overhead + svmExtraArgsV1.computeUnits + sum_per_token(tokenDestGasAmount) + cu := argsTransformed.Info.AbstractReports[0].Messages[0].ExtraArgsDecoded["ComputeUnits"] + if cu == nil { + return nil, nil, []txmutils.SetTxConfig{}, errors.New("missing ComputeUnits in ExtraArgs") + } + computeUnits := staticCuOverhead + cu.(uint32) + + for _, token := range argsTransformed.Info.AbstractReports[0].Messages[0].TokenAmounts { + destGasAmount := token.DestExecDataDecoded["destGasAmount"] + if destGasAmount == nil { + return nil, nil, []txmutils.SetTxConfig{}, fmt.Errorf("missing destGasAmount in DestExecData") + } + computeUnits += destGasAmount.(uint32) } - computeUnits := argsTransformed.Info.AbstractReports[0].Messages[0].ExtraArgsDecoded["ComputeUnits"] - if computeUnits == nil { - return nil, nil, 0, errors.New("missing ComputeUnits in ExtraArgs") + options := []txmutils.SetTxConfig{ + txmutils.SetEstimateComputeUnitLimit(false), + txmutils.SetComputeUnitLimit(computeUnits), } registryTables, exists := tableMap["PoolLookupTable"] @@ -51,7 +70,7 @@ func CCIPExecuteArgsTransform(ctx context.Context, args any, accounts solana.Acc // Return with empty TokenIndexes if !exists { argsTransformed.TokenIndexes = []byte{} - return argsTransformed, accounts, computeUnits.(uint32), nil + return argsTransformed, accounts, options, nil } tokenPoolAddresses := []solana.PublicKey{} @@ -64,7 +83,7 @@ func CCIPExecuteArgsTransform(ctx context.Context, args any, accounts solana.Acc for _, address := range tokenPoolAddresses { if account.PublicKey == address { if i > 255 { - return nil, nil, 0, fmt.Errorf("index %d out of range for uint8", i) + return nil, nil, []txmutils.SetTxConfig{}, fmt.Errorf("index %d out of range for uint8", i) } tokenIndexes = append(tokenIndexes, uint8(i)) //nolint:gosec } @@ -72,28 +91,28 @@ func CCIPExecuteArgsTransform(ctx context.Context, args any, accounts solana.Acc } if len(tokenIndexes) != len(tokenPoolAddresses) { - return nil, nil, 0, fmt.Errorf("missing token pools in accounts") + return nil, nil, []txmutils.SetTxConfig{}, fmt.Errorf("missing token pools in accounts") } argsTransformed.TokenIndexes = tokenIndexes - return argsTransformed, accounts, computeUnits.(uint32), nil + return argsTransformed, accounts, options, nil } // This Transform function trims off the GlobalState account from commit transactions if there are no token or gas price updates -func CCIPCommitAccountTransform(ctx context.Context, args any, accounts solana.AccountMetaSlice, _ map[string]map[string][]*solana.AccountMeta) (any, solana.AccountMetaSlice, uint32, error) { +func CCIPCommitAccountTransform(ctx context.Context, args any, accounts solana.AccountMetaSlice, _ map[string]map[string][]*solana.AccountMeta) (any, solana.AccountMetaSlice, []txmutils.SetTxConfig, error) { var tokenPriceVals, gasPriceVals [][]byte var err error tokenPriceVals, err = GetValuesAtLocation(args, "Info.TokenPrices.TokenID") if err != nil && !errors.Is(err, errFieldNotFound) { - return nil, nil, 0, fmt.Errorf("error getting values at location: %w", err) + return nil, nil, []txmutils.SetTxConfig{}, fmt.Errorf("error getting values at location: %w", err) } gasPriceVals, err = GetValuesAtLocation(args, "Info.GasPrices.ChainSel") if err != nil && !errors.Is(err, errFieldNotFound) { - return nil, nil, 0, fmt.Errorf("error getting values at location: %w", err) + return nil, nil, []txmutils.SetTxConfig{}, fmt.Errorf("error getting values at location: %w", err) } transformedAccounts := accounts if len(tokenPriceVals) == 0 && len(gasPriceVals) == 0 { transformedAccounts = accounts[:len(accounts)-1] } - return args, transformedAccounts, 0, nil + return args, transformedAccounts, []txmutils.SetTxConfig{txmutils.SetEstimateComputeUnitLimit(true)}, nil } diff --git a/pkg/solana/chainwriter/transform_registry_test.go b/pkg/solana/chainwriter/transform_registry_test.go index 97a6aae86..25e2c0d06 100644 --- a/pkg/solana/chainwriter/transform_registry_test.go +++ b/pkg/solana/chainwriter/transform_registry_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "github.com/smartcontractkit/chainlink-solana/pkg/solana/chainwriter" + txmutils "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/utils" ) type ReportPreTransform struct { @@ -30,6 +31,9 @@ func Test_CCIPExecuteArgsTransform(t *testing.T) { Messages: []ccipocr3.Message{{ TokenAmounts: []ccipocr3.RampTokenAmount{{ DestTokenAddress: ccipocr3.UnknownAddress(destTokenAddr.Bytes()), + DestExecDataDecoded: map[string]any{ + "destGasAmount": uint32(500), + }, }}, ExtraArgsDecoded: map[string]any{ "ComputeUnits": uint32(500), @@ -52,10 +56,11 @@ func Test_CCIPExecuteArgsTransform(t *testing.T) { } tableMap["PoolLookupTable"][lookupTablePubkey.String()] = poolKeysMeta - transformedArgs, newAccounts, computeUnits, err := chainwriter.CCIPExecuteArgsTransform(ctx, args, accounts, tableMap) + transformedArgs, newAccounts, options, err := chainwriter.CCIPExecuteArgsTransform(ctx, args, accounts, tableMap) require.NoError(t, err) - require.Equal(t, computeUnits, uint32(500)) + verifyTxOpts(t, options, true) + // Accounts should be unchanged require.Len(t, newAccounts, 2) typedArgs, ok := transformedArgs.(chainwriter.ReportPostTransform) @@ -65,9 +70,9 @@ func Test_CCIPExecuteArgsTransform(t *testing.T) { }) t.Run("CCIPExecute ArgsTransform includes empty token indexes if lookup table not found", func(t *testing.T) { - transformedArgs, newAccounts, computeUnits, err := chainwriter.CCIPExecuteArgsTransform(ctx, args, accounts, nil) + transformedArgs, newAccounts, options, err := chainwriter.CCIPExecuteArgsTransform(ctx, args, accounts, nil) require.NoError(t, err) - require.Equal(t, computeUnits, uint32(500)) + verifyTxOpts(t, options, true) // Accounts should be unchanged require.Len(t, newAccounts, 2) @@ -89,14 +94,21 @@ func Test_CCIPExecuteArgsTransform(t *testing.T) { ExtraArgsDecoded: map[string]any{ "ComputeUnits": uint32(500), }, + TokenAmounts: []ccipocr3.RampTokenAmount{ + { + DestExecDataDecoded: map[string]any{ + "destGasAmount": uint32(500), + }, + }, + }, }}, }}, }, } - transformedArgs, newAccounts, computeUnits, err := chainwriter.CCIPExecuteArgsTransform(ctx, args, accounts, nil) + transformedArgs, newAccounts, options, err := chainwriter.CCIPExecuteArgsTransform(ctx, args, accounts, nil) require.NoError(t, err) - require.Equal(t, computeUnits, uint32(500)) + verifyTxOpts(t, options, true) _, ok := transformedArgs.(chainwriter.ReportPostTransform) require.True(t, ok) require.Len(t, newAccounts, 2) @@ -146,7 +158,8 @@ func Test_CCIPCommitAccountTransform(t *testing.T) { }, } accounts := []*solana.AccountMeta{{PublicKey: key1}, {PublicKey: key2}} - _, newAccounts, _, err := chainwriter.CCIPCommitAccountTransform(ctx, args, accounts, nil) + _, newAccounts, options, err := chainwriter.CCIPCommitAccountTransform(ctx, args, accounts, nil) + verifyTxOpts(t, options, false) require.NoError(t, err) require.Len(t, newAccounts, 2) }) @@ -157,8 +170,27 @@ func Test_CCIPCommitAccountTransform(t *testing.T) { Info: ccipocr3.CommitReportInfo{}, } accounts := []*solana.AccountMeta{{PublicKey: key1}, {PublicKey: key2}} - _, newAccounts, _, err := chainwriter.CCIPCommitAccountTransform(ctx, args, accounts, nil) + _, newAccounts, options, err := chainwriter.CCIPCommitAccountTransform(ctx, args, accounts, nil) + verifyTxOpts(t, options, false) + require.NoError(t, err) require.Len(t, newAccounts, 1) }) } + +func verifyTxOpts(t *testing.T, options []txmutils.SetTxConfig, exec bool) { + expectedLen := 1 + if exec { + expectedLen = 2 + } + require.Len(t, options, expectedLen) + + txConfig := &txmutils.TxConfig{} + options[0](txConfig) + require.Equal(t, !exec, txConfig.EstimateComputeUnitLimit) + + if exec { + options[1](txConfig) + require.Equal(t, uint32(2000), txConfig.ComputeUnitLimit) + } +}