Skip to content

Commit

Permalink
Enabled automatic ATA creation in CW
Browse files Browse the repository at this point in the history
  • Loading branch information
silaslenihan committed Jan 30, 2025
1 parent aa71d84 commit 78dc10f
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 23 deletions.
72 changes: 69 additions & 3 deletions pkg/solana/chainwriter/chain_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ package chainwriter
import (
"context"
"encoding/json"
"errors"
"fmt"
"math/big"

"github.com/gagliardetto/solana-go"
addresslookuptable "github.com/gagliardetto/solana-go/programs/address-lookup-table"
"github.com/gagliardetto/solana-go/rpc"

"github.com/smartcontractkit/chainlink-ccip/chains/solana/utils/tokens"
commoncodec "github.com/smartcontractkit/chainlink-common/pkg/codec"
"github.com/smartcontractkit/chainlink-common/pkg/logger"
"github.com/smartcontractkit/chainlink-common/pkg/services"
Expand Down Expand Up @@ -55,6 +57,7 @@ type MethodConfig struct {
FromAddress string
InputModifications commoncodec.ModifiersConfig
ChainSpecificName string
ATAs []ATALookup
LookupTables LookupTables
Accounts []Lookup
// Location in the args where the debug ID is stored
Expand Down Expand Up @@ -214,6 +217,61 @@ func (s *SolanaChainWriterService) FilterLookupTableAddresses(
return filteredLookupTables
}

// CreateATAs first checks if a specified location exists, then checks if the accounts derived from the
// ATALookups in the ChainWriter's configuration exist on-chain and creates them if they do not.
func CreateATAs(ctx context.Context, args any, lookups []ATALookup, derivedTableMap map[string]map[string][]*solana.AccountMeta, reader client.Reader, idl string, feePayer solana.PublicKey) ([]solana.Instruction, error) {
createATAInstructions := []solana.Instruction{}
for _, lookup := range lookups {
// Check if location exists
if lookup.Location != "" {
// TODO refactor GetValuesAtLocation to not return an error if the field doesn't exist
_, err := GetValuesAtLocation(args, lookup.Location)
if err != nil {
// field doesn't exist, so ignore ATA creation
if errors.Is(err, errFieldNotFound) {
continue
}
return nil, fmt.Errorf("error getting values at location: %w", err)
}
}
walletAddresses, err := GetAddresses(ctx, args, []Lookup{lookup.WalletAddress}, derivedTableMap, reader, idl)
if err != nil {
return nil, fmt.Errorf("error resolving wallet address: %w", err)
}
if len(walletAddresses) != 1 {
return nil, fmt.Errorf("expected exactly one wallet address, got %d", len(walletAddresses))
}
wallet := walletAddresses[0].PublicKey

tokenPrograms, err := GetAddresses(ctx, args, []Lookup{lookup.TokenProgram}, derivedTableMap, reader, idl)
if err != nil {
return nil, fmt.Errorf("error resolving token program address: %w", err)
}

mints, err := GetAddresses(ctx, args, []Lookup{lookup.MintAddress}, derivedTableMap, reader, idl)
if err != nil {
return nil, fmt.Errorf("error resolving mint address: %w", err)
}

if len(tokenPrograms) != len(mints) {
return nil, fmt.Errorf("expected equal number of token programs and mints, got %d tokenPrograms and %d mints", len(tokenPrograms), len(mints))
}

for i := range tokenPrograms {
tokenProgram := tokenPrograms[i].PublicKey
mint := mints[i].PublicKey

ins, _, err := tokens.CreateAssociatedTokenAccount(tokenProgram, mint, wallet, feePayer)
if err != nil {
return nil, fmt.Errorf("error creating associated token account: %w", err)
}
createATAInstructions = append(createATAInstructions, ins)
}
}

return createATAInstructions, nil
}

// SubmitTransaction builds, encodes, and enqueues a transaction using the provided program
// configuration and method details. It relies on the configured IDL, account lookups, and
// lookup tables to gather the necessary accounts and data. The function retrieves the latest
Expand Down Expand Up @@ -274,6 +332,11 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra
return errorWithDebugID(fmt.Errorf("error parsing fee payer address: %w", err), debugID)
}

createATAinstructions, err := CreateATAs(ctx, args, methodConfig.ATAs, derivedTableMap, s.reader, programConfig.IDL, feePayer)
if err != nil {
return errorWithDebugID(fmt.Errorf("error resolving account addresses: %w", err), debugID)
}

// Filter the lookup table addresses based on which accounts are actually used
filteredLookupTableMap := s.FilterLookupTableAddresses(accounts, derivedTableMap, staticTableMap)

Expand Down Expand Up @@ -310,10 +373,13 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra
discriminator := GetDiscriminator(methodConfig.ChainSpecificName)
encodedPayload = append(discriminator[:], encodedPayload...)

// Combine the two sets of instructions into one slice
var instructions []solana.Instruction
instructions = append(instructions, createATAinstructions...)
instructions = append(instructions, solana.NewInstruction(programID, accounts, encodedPayload))

tx, err := solana.NewTransaction(
[]solana.Instruction{
solana.NewInstruction(programID, accounts, encodedPayload),
},
instructions,
blockhash.Value.Blockhash,
solana.TransactionPayer(feePayer),
solana.TransactionAddressTables(filteredLookupTableMap),
Expand Down
49 changes: 29 additions & 20 deletions pkg/solana/chainwriter/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,39 +43,48 @@ func FetchTestContractIDL() string {
return testContractIDL
}

var (
errFieldNotFound = errors.New("key not found")
)

// GetValuesAtLocation parses through nested types and arrays to find all locations of values
func GetValuesAtLocation(args any, location string) ([][]byte, error) {
var vals [][]byte
// If the user specified no location, just return empty (no-op).
if location == "" {
return nil, nil
}

path := strings.Split(location, ".")

addressList, err := traversePath(args, path)
items, err := traversePath(args, path)
if err != nil {
return nil, err
}
for _, value := range addressList {
// Dereference if it's a pointer
rv := reflect.ValueOf(value)

for _, item := range items {
rv := reflect.ValueOf(item)
if rv.Kind() == reflect.Ptr && !rv.IsNil() {
value = rv.Elem().Interface()
item = rv.Elem().Interface()
}

if byteArray, ok := value.([]byte); ok {
vals = append(vals, byteArray)
} else if address, ok := value.(solana.PublicKey); ok {
vals = append(vals, address.Bytes())
} else if num, ok := value.(uint64); ok {
switch value := item.(type) {
case []byte:
vals = append(vals, value)
case solana.PublicKey:
vals = append(vals, value.Bytes())
case ccipocr3.UnknownAddress:
vals = append(vals, value)
case uint64:
buf := make([]byte, 8)
binary.LittleEndian.PutUint64(buf, num)
binary.LittleEndian.PutUint64(buf, value)
vals = append(vals, buf)
} else if addr, ok := value.(ccipocr3.UnknownAddress); ok {
vals = append(vals, addr)
} else if arr, ok := value.([32]uint8); ok {
vals = append(vals, arr[:])
} else {
case [32]uint8:
vals = append(vals, value[:])
default:
return nil, fmt.Errorf("invalid value format at path: %s, type: %s", location, reflect.TypeOf(value).String())
}
}

return vals, nil
}

Expand Down Expand Up @@ -135,7 +144,7 @@ func traversePath(data any, path []string) ([]any, error) {
case reflect.Struct:
field := val.FieldByName(path[0])
if !field.IsValid() {
return nil, errors.New("field not found: " + path[0])
return []any{}, errFieldNotFound
}
return traversePath(field.Interface(), path[1:])

Expand All @@ -150,13 +159,13 @@ func traversePath(data any, path []string) ([]any, error) {
if len(result) > 0 {
return result, nil
}
return nil, errors.New("no matching field found in array")
return []any{}, errFieldNotFound

case reflect.Map:
key := reflect.ValueOf(path[0])
value := val.MapIndex(key)
if !value.IsValid() {
return nil, errors.New("key not found: " + path[0])
return []any{}, errFieldNotFound
}
return traversePath(value.Interface(), path[1:])
default:
Expand Down
10 changes: 10 additions & 0 deletions pkg/solana/chainwriter/lookups.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@ type AccountsFromLookupTable struct {
IncludeIndexes []int
}

type ATALookup struct {
// Field that determines whether the ATA lookup is necessary. Basically
// just need to check this field exists. Dot separated location.
Location string
// If the field exists, initialize a ATA account using the Wallet, Token Program, and Mint addresses below
WalletAddress Lookup
TokenProgram Lookup
MintAddress Lookup
}

func (ac AccountConstant) Resolve(_ context.Context, _ any, _ map[string]map[string][]*solana.AccountMeta, _ client.Reader, _ string) ([]*solana.AccountMeta, error) {
address, err := solana.PublicKeyFromBase58(ac.Address)
if err != nil {
Expand Down

0 comments on commit 78dc10f

Please sign in to comment.