From 66c94484b85014178c76a03c41e5f4afdd5191e4 Mon Sep 17 00:00:00 2001 From: Silas Lenihan Date: Wed, 29 Jan 2025 17:24:55 -0500 Subject: [PATCH] Enabled automatic ATA creation in CW --- pkg/solana/chainwriter/chain_writer.go | 72 ++++++++++++++++++++++++-- pkg/solana/chainwriter/helpers.go | 49 +++++++++++------- pkg/solana/chainwriter/lookups.go | 10 ++++ 3 files changed, 108 insertions(+), 23 deletions(-) diff --git a/pkg/solana/chainwriter/chain_writer.go b/pkg/solana/chainwriter/chain_writer.go index 1bf8a3a8b..5466f7e81 100644 --- a/pkg/solana/chainwriter/chain_writer.go +++ b/pkg/solana/chainwriter/chain_writer.go @@ -3,6 +3,7 @@ package chainwriter import ( "context" "encoding/json" + "errors" "fmt" "math/big" @@ -10,6 +11,7 @@ import ( addresslookuptable "github.com/gagliardetto/solana-go/programs/address-lookup-table" "github.com/gagliardetto/solana-go/rpc" + "github.com/smartcontractkit/chainlink-ccip/chains/solana/utils/tokens" commoncodec "github.com/smartcontractkit/chainlink-common/pkg/codec" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/services" @@ -55,6 +57,7 @@ type MethodConfig struct { FromAddress string InputModifications commoncodec.ModifiersConfig ChainSpecificName string + ATAs []ATALookup LookupTables LookupTables Accounts []Lookup // Location in the args where the debug ID is stored @@ -214,6 +217,61 @@ func (s *SolanaChainWriterService) FilterLookupTableAddresses( return filteredLookupTables } +// CreateATAs first checks if a specified location exists, then checks if the accounts derived from the +// ATALookups in the ChainWriter's configuration exist on-chain and creates them if they do not. +func CreateATAs(ctx context.Context, args any, lookups []ATALookup, derivedTableMap map[string]map[string][]*solana.AccountMeta, reader client.Reader, idl string, feePayer solana.PublicKey) ([]solana.Instruction, error) { + createATAInstructions := []solana.Instruction{} + for _, lookup := range lookups { + // Check if location exists + if lookup.Location != "" { + // TODO refactor GetValuesAtLocation to not return an error if the field doesn't exist + _, err := GetValuesAtLocation(args, lookup.Location) + if err != nil { + // field doesn't exist, so ignore ATA creation + if errors.Is(err, errFieldNotFound) { + continue + } + return nil, fmt.Errorf("error getting values at location: %w", err) + } + } + walletAddresses, err := GetAddresses(ctx, args, []Lookup{lookup.WalletAddress}, derivedTableMap, reader, idl) + if err != nil { + return nil, fmt.Errorf("error resolving wallet address: %w", err) + } + if len(walletAddresses) != 1 { + return nil, fmt.Errorf("expected exactly one wallet address, got %d", len(walletAddresses)) + } + wallet := walletAddresses[0].PublicKey + + tokenPrograms, err := GetAddresses(ctx, args, []Lookup{lookup.TokenProgram}, derivedTableMap, reader, idl) + if err != nil { + return nil, fmt.Errorf("error resolving token program address: %w", err) + } + + mints, err := GetAddresses(ctx, args, []Lookup{lookup.MintAddress}, derivedTableMap, reader, idl) + if err != nil { + return nil, fmt.Errorf("error resolving mint address: %w", err) + } + + if len(tokenPrograms) != len(mints) { + return nil, fmt.Errorf("expected equal number of token programs and mints, got %d tokenPrograms and %d mints", len(tokenPrograms), len(mints)) + } + + for i := range tokenPrograms { + tokenProgram := tokenPrograms[i].PublicKey + mint := mints[i].PublicKey + + ins, _, err := tokens.CreateAssociatedTokenAccount(tokenProgram, mint, wallet, feePayer) + if err != nil { + return nil, fmt.Errorf("error creating associated token account: %w", err) + } + createATAInstructions = append(createATAInstructions, ins) + } + } + + return createATAInstructions, nil +} + // SubmitTransaction builds, encodes, and enqueues a transaction using the provided program // configuration and method details. It relies on the configured IDL, account lookups, and // lookup tables to gather the necessary accounts and data. The function retrieves the latest @@ -274,6 +332,11 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra return errorWithDebugID(fmt.Errorf("error parsing fee payer address: %w", err), debugID) } + createATAinstructions, err := CreateATAs(ctx, args, methodConfig.ATAs, derivedTableMap, s.reader, programConfig.IDL, feePayer) + if err != nil { + return errorWithDebugID(fmt.Errorf("error resolving account addresses: %w", err), debugID) + } + // Filter the lookup table addresses based on which accounts are actually used filteredLookupTableMap := s.FilterLookupTableAddresses(accounts, derivedTableMap, staticTableMap) @@ -310,10 +373,13 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra discriminator := GetDiscriminator(methodConfig.ChainSpecificName) encodedPayload = append(discriminator[:], encodedPayload...) + // Combine the two sets of instructions into one slice + var instructions []solana.Instruction + instructions = append(instructions, createATAinstructions...) + instructions = append(instructions, solana.NewInstruction(programID, accounts, encodedPayload)) + tx, err := solana.NewTransaction( - []solana.Instruction{ - solana.NewInstruction(programID, accounts, encodedPayload), - }, + instructions, blockhash.Value.Blockhash, solana.TransactionPayer(feePayer), solana.TransactionAddressTables(filteredLookupTableMap), diff --git a/pkg/solana/chainwriter/helpers.go b/pkg/solana/chainwriter/helpers.go index 6e2a3e5be..1c135fb69 100644 --- a/pkg/solana/chainwriter/helpers.go +++ b/pkg/solana/chainwriter/helpers.go @@ -43,39 +43,48 @@ func FetchTestContractIDL() string { return testContractIDL } +var ( + errFieldNotFound = errors.New("key not found") +) + // GetValuesAtLocation parses through nested types and arrays to find all locations of values func GetValuesAtLocation(args any, location string) ([][]byte, error) { var vals [][]byte + // If the user specified no location, just return empty (no-op). + if location == "" { + return nil, nil + } + path := strings.Split(location, ".") - addressList, err := traversePath(args, path) + items, err := traversePath(args, path) if err != nil { return nil, err } - for _, value := range addressList { - // Dereference if it's a pointer - rv := reflect.ValueOf(value) + + for _, item := range items { + rv := reflect.ValueOf(item) if rv.Kind() == reflect.Ptr && !rv.IsNil() { - value = rv.Elem().Interface() + item = rv.Elem().Interface() } - if byteArray, ok := value.([]byte); ok { - vals = append(vals, byteArray) - } else if address, ok := value.(solana.PublicKey); ok { - vals = append(vals, address.Bytes()) - } else if num, ok := value.(uint64); ok { + switch value := item.(type) { + case []byte: + vals = append(vals, value) + case solana.PublicKey: + vals = append(vals, value.Bytes()) + case ccipocr3.UnknownAddress: + vals = append(vals, value) + case uint64: buf := make([]byte, 8) - binary.LittleEndian.PutUint64(buf, num) + binary.LittleEndian.PutUint64(buf, value) vals = append(vals, buf) - } else if addr, ok := value.(ccipocr3.UnknownAddress); ok { - vals = append(vals, addr) - } else if arr, ok := value.([32]uint8); ok { - vals = append(vals, arr[:]) - } else { + case [32]uint8: + vals = append(vals, value[:]) + default: return nil, fmt.Errorf("invalid value format at path: %s, type: %s", location, reflect.TypeOf(value).String()) } } - return vals, nil } @@ -135,7 +144,7 @@ func traversePath(data any, path []string) ([]any, error) { case reflect.Struct: field := val.FieldByName(path[0]) if !field.IsValid() { - return nil, errors.New("field not found: " + path[0]) + return []any{}, errFieldNotFound } return traversePath(field.Interface(), path[1:]) @@ -150,13 +159,13 @@ func traversePath(data any, path []string) ([]any, error) { if len(result) > 0 { return result, nil } - return nil, errors.New("no matching field found in array") + return []any{}, errFieldNotFound case reflect.Map: key := reflect.ValueOf(path[0]) value := val.MapIndex(key) if !value.IsValid() { - return nil, errors.New("key not found: " + path[0]) + return []any{}, errFieldNotFound } return traversePath(value.Interface(), path[1:]) default: diff --git a/pkg/solana/chainwriter/lookups.go b/pkg/solana/chainwriter/lookups.go index 36719538a..982201b23 100644 --- a/pkg/solana/chainwriter/lookups.go +++ b/pkg/solana/chainwriter/lookups.go @@ -83,6 +83,16 @@ type AccountsFromLookupTable struct { IncludeIndexes []int } +type ATALookup struct { + // Field that determines whether the ATA lookup is necessary. Basically + // just need to check this field exists. Dot separated location. + Location string + // If the field exists, initialize a ATA account using the Wallet, Token Program, and Mint addresses below + WalletAddress Lookup + TokenProgram Lookup + MintAddress Lookup +} + func (ac AccountConstant) Resolve(_ context.Context, _ any, _ map[string]map[string][]*solana.AccountMeta, _ client.Reader, _ string) ([]*solana.AccountMeta, error) { address, err := solana.PublicKeyFromBase58(ac.Address) if err != nil {