Skip to content

Commit

Permalink
addressed feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
silaslenihan committed Feb 13, 2025
1 parent 6d677d6 commit f275ca0
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 76 deletions.
77 changes: 77 additions & 0 deletions pkg/solana/chainwriter/ata_creation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package chainwriter

import (
"context"
"fmt"
"time"

"github.com/gagliardetto/solana-go"
"github.com/google/uuid"
"github.com/smartcontractkit/chainlink-common/pkg/types"
)

const (
maxAtas = 12
)

func (s *SolanaChainWriterService) handleATACreation(ctx context.Context, createATAinstructions []solana.Instruction, methodConfig MethodConfig, contractName, method string, feePayer solana.PublicKey) error {
blockhash, err := s.client.LatestBlockhash(ctx)
if err != nil {
return fmt.Errorf("error fetching latest blockhash: %w", err)
}

if len(createATAinstructions) > maxAtas {
return fmt.Errorf("too many ATAs to create: %d, max allowed: %d", len(createATAinstructions), maxAtas)
}
ataTx, ataErr := solana.NewTransaction(
createATAinstructions,
blockhash.Value.Blockhash,
solana.TransactionPayer(feePayer),
)
if ataErr != nil {
return fmt.Errorf("error constructing ATA transaction: %w", err)
}
ataUUID := fmt.Sprintf("ATA-%s", uuid.NewString())

s.lggr.Info("Sending create ATA transaction", "contract", contractName, "method", method)

// Enqueue ATA transaction
if err = s.txm.Enqueue(ctx, methodConfig.FromAddress, ataTx, &ataUUID, blockhash.Value.LastValidBlockHeight); err != nil {
return fmt.Errorf("error enqueuing transaction: %w", err)
}

err = s.waitForTxFinality(ctx, ataUUID)
if err != nil {
return fmt.Errorf("error waiting for ATA transaction finality: %w", err)
}
return nil
}

func (s *SolanaChainWriterService) waitForTxFinality(ctx context.Context, transactionID string) error {
waitCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
defer cancel()

ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()

for {
select {
case <-waitCtx.Done():
return fmt.Errorf("context ended while waiting for finality of transaction %s", transactionID)
case <-ticker.C:
status, err := s.txm.GetTransactionStatus(waitCtx, transactionID)
if err != nil {
return fmt.Errorf("error fetching transaction status: %w", err)
}
switch status {
case types.Finalized:
s.lggr.Debug("ATA transaction finalized", "transactionID", transactionID)
return nil
case types.Failed, types.Fatal:
return fmt.Errorf("transaction %s failed", transactionID)
default:
// Keep polling
}
}
}
}
88 changes: 12 additions & 76 deletions pkg/solana/chainwriter/chain_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@ import (
"fmt"
"math/big"
"strings"
"time"

"github.com/gagliardetto/solana-go"
addresslookuptable "github.com/gagliardetto/solana-go/programs/address-lookup-table"
"github.com/gagliardetto/solana-go/rpc"
"github.com/google/uuid"

"github.com/smartcontractkit/chainlink-ccip/chains/solana/utils/tokens"
commoncodec "github.com/smartcontractkit/chainlink-common/pkg/codec"
Expand Down Expand Up @@ -69,10 +67,6 @@ type MethodConfig struct {
ArgsTransform string `json:"argsTransform,omitempty"`
}

const (
maxAtas = 12
)

func NewSolanaChainWriterService(logger logger.Logger, client client.MultiClient, txm txm.TxManager, ge fees.Estimator, config ChainWriterConfig) (*SolanaChainWriterService, error) {
w := SolanaChainWriterService{
lggr: logger,
Expand Down Expand Up @@ -354,7 +348,7 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra
return errorWithDebugID(fmt.Errorf("error getting lookup tables: %w", err), debugID)
}

s.lggr.Info("Resolving account addresses", "contract", contractName, "method", method)
s.lggr.Debug("Resolving account addresses", "contract", contractName, "method", method)
// Resolve account metas
accounts, err := GetAddresses(ctx, args, methodConfig.Accounts, derivedTableMap, s.client)
if err != nil {
Expand All @@ -366,13 +360,13 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra
return errorWithDebugID(fmt.Errorf("error parsing fee payer address: %w", err), debugID)
}

s.lggr.Info("Creating ATAs", "contract", contractName, "method", method)
s.lggr.Debug("Creating ATAs", "contract", contractName, "method", method)
createATAinstructions, err := CreateATAs(ctx, args, methodConfig.ATAs, derivedTableMap, s.client, programConfig.IDL, feePayer, s.lggr)
if err != nil {
return errorWithDebugID(fmt.Errorf("error resolving account addresses: %w", err), debugID)
}

s.lggr.Info("Filtering lookup table addresses", "contract", contractName, "method", method)
s.lggr.Debug("Filtering lookup table addresses", "contract", contractName, "method", method)
// Filter the lookup table addresses based on which accounts are actually used
filteredLookupTableMap := s.FilterLookupTableAddresses(accounts, derivedTableMap, staticTableMap)

Expand All @@ -382,7 +376,7 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra
if tfErr != nil {
return errorWithDebugID(fmt.Errorf("error finding transform function: %w", tfErr), debugID)
}
s.lggr.Info("Applying args transformation", "contract", contractName, "method", method)
s.lggr.Debug("Applying args transformation", "contract", contractName, "method", method)
args, err = transformFunc(ctx, s, args, accounts, toAddress)
if err != nil {
return errorWithDebugID(fmt.Errorf("error transforming args: %w", err), debugID)
Expand All @@ -395,7 +389,7 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra
return errorWithDebugID(fmt.Errorf("error parsing program ID: %w", err), debugID)
}

s.lggr.Info("Encoding transaction payload", "contract", contractName, "method", method)
s.lggr.Debug("Encoding transaction payload", "contract", contractName, "method", method)
encodedPayload, err := s.encoder.Encode(ctx, args, codec.WrapItemType(true, contractName, method))

if err != nil {
Expand All @@ -405,45 +399,16 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra
discriminator := GetDiscriminator(methodConfig.ChainSpecificName)
encodedPayload = append(discriminator[:], encodedPayload...)

if len(createATAinstructions) > 0 {
s.handleATACreation(ctx, createATAinstructions, methodConfig, contractName, method, feePayer)
}

// Fetch latest blockhash
blockhash, err := s.client.LatestBlockhash(ctx)
if err != nil {
return errorWithDebugID(fmt.Errorf("error fetching latest blockhash: %w", err), debugID)
}

if len(createATAinstructions) > 0 {
if len(createATAinstructions) > maxAtas {
return errorWithDebugID(fmt.Errorf("too many ATAs to create: %d, max allowed: %d", len(createATAinstructions), maxAtas), debugID)
}
ataTx, ataErr := solana.NewTransaction(
createATAinstructions,
blockhash.Value.Blockhash,
solana.TransactionPayer(feePayer),
)
if ataErr != nil {
return errorWithDebugID(fmt.Errorf("error constructing ATA transaction: %w", err), debugID)
}
ataUUID := fmt.Sprintf("ATA-%s", uuid.NewString())

s.lggr.Info("Sending create ATA transaction", "contract", contractName, "method", method)

// Enqueue ATA transaction
if err = s.txm.Enqueue(ctx, methodConfig.FromAddress, ataTx, &ataUUID, blockhash.Value.LastValidBlockHeight); err != nil {
return errorWithDebugID(fmt.Errorf("error enqueuing transaction: %w", err), debugID)
}

err = s.waitForTxFinality(ctx, transactionID)
if err != nil {
return errorWithDebugID(fmt.Errorf("error waiting for ATA transaction finality: %w", err), debugID)
}

// refresh blockhash for next tx
blockhash, err = s.client.LatestBlockhash(ctx)
if err != nil {
return errorWithDebugID(fmt.Errorf("error fetching latest blockhash: %w", err), debugID)
}
}

tx, err := solana.NewTransaction(
[]solana.Instruction{solana.NewInstruction(programID, accounts, encodedPayload)},
blockhash.Value.Blockhash,
Expand All @@ -454,7 +419,7 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra
return errorWithDebugID(fmt.Errorf("error constructing transaction: %w", err), debugID)
}

s.lggr.Info("Sending main transaction", "contract", contractName, "method", method)
s.lggr.Debug("Sending main transaction", "contract", contractName, "method", method)
// Enqueue transaction
if err = s.txm.Enqueue(ctx, methodConfig.FromAddress, tx, &transactionID, blockhash.Value.LastValidBlockHeight); err != nil {
return errorWithDebugID(fmt.Errorf("error enqueuing transaction: %w", err), debugID)
Expand All @@ -463,38 +428,9 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra
return nil
}

func (s *SolanaChainWriterService) waitForTxFinality(ctx context.Context, transactionID string) error {
waitCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
defer cancel()

ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()

for {
select {
case <-waitCtx.Done():
return fmt.Errorf("context ended while waiting for finality of transaction %s", transactionID)
case <-ticker.C:
status, err := s.txm.GetTransactionStatus(waitCtx, transactionID)
if err != nil {
return fmt.Errorf("error fetching transaction status: %w", err)
}
switch status {
case types.Finalized:
s.lggr.Info("ATA transaction finalized", "transactionID", transactionID)
return nil
case types.Failed, types.Fatal:
return fmt.Errorf("transaction %s failed", transactionID)
default:
// Keep polling
}
}
}
}

// GetTransactionStatus returns the current status of a transaction in the underlying chain's TXM.
func (s *SolanaChainWriterService) GetTransactionStatus(ctx context.Context, transactionID string) (types.TransactionStatus, error) {
s.lggr.Info("Fetching transaction status", "transactionID", transactionID)
s.lggr.Debug("Fetching transaction status", "transactionID", transactionID)
return s.txm.GetTransactionStatus(ctx, transactionID)
}

Expand All @@ -504,7 +440,7 @@ func (s *SolanaChainWriterService) GetFeeComponents(ctx context.Context) (*types
return nil, fmt.Errorf("gas estimator not available")
}

s.lggr.Info("Fetching fee components")
s.lggr.Debug("Fetching fee components")
fee := s.ge.BaseComputeUnitPrice()
return &types.ChainFeeComponents{
ExecutionFee: new(big.Int).SetUint64(fee),
Expand Down

0 comments on commit f275ca0

Please sign in to comment.