From 35d579c3f69483025a22ad5444818bdb334f3dfb Mon Sep 17 00:00:00 2001 From: Silas Lenihan Date: Tue, 11 Feb 2025 10:04:48 -0500 Subject: [PATCH] Separated ataTx and main chainwriter txs --- pkg/solana/chainwriter/chain_writer.go | 79 +++++++++++++++++++++----- 1 file changed, 65 insertions(+), 14 deletions(-) diff --git a/pkg/solana/chainwriter/chain_writer.go b/pkg/solana/chainwriter/chain_writer.go index 6c0c236b6..32cbd8e8f 100644 --- a/pkg/solana/chainwriter/chain_writer.go +++ b/pkg/solana/chainwriter/chain_writer.go @@ -7,6 +7,7 @@ import ( "fmt" "math/big" "strings" + "time" "github.com/gagliardetto/solana-go" addresslookuptable "github.com/gagliardetto/solana-go/programs/address-lookup-table" @@ -241,7 +242,7 @@ func CreateATAs(ctx context.Context, args any, lookups []ATALookup, derivedTable return nil, fmt.Errorf("error getting values at location: %w", err) } } - walletAddresses, err := GetAddresses(ctx, args, []Lookup{lookup.WalletAddress}, derivedTableMap, reader, idl) + walletAddresses, err := GetAddresses(ctx, args, []Lookup{lookup.WalletAddress}, derivedTableMap, reader) if err != nil { return nil, fmt.Errorf("error resolving wallet address: %w", err) } @@ -250,12 +251,12 @@ func CreateATAs(ctx context.Context, args any, lookups []ATALookup, derivedTable } wallet := walletAddresses[0].PublicKey - tokenPrograms, err := GetAddresses(ctx, args, []Lookup{lookup.TokenProgram}, derivedTableMap, reader, idl) + tokenPrograms, err := GetAddresses(ctx, args, []Lookup{lookup.TokenProgram}, derivedTableMap, reader) if err != nil { return nil, fmt.Errorf("error resolving token program address: %w", err) } - mints, err := GetAddresses(ctx, args, []Lookup{lookup.MintAddress}, derivedTableMap, reader, idl) + mints, err := GetAddresses(ctx, args, []Lookup{lookup.MintAddress}, derivedTableMap, reader) if err != nil { return nil, fmt.Errorf("error resolving mint address: %w", err) } @@ -374,12 +375,6 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra } } - // Fetch latest blockhash - blockhash, err := s.reader.LatestBlockhash(ctx) - if err != nil { - return errorWithDebugID(fmt.Errorf("error fetching latest blockhash: %w", err), debugID) - } - // Prepare transaction programID, err := solana.PublicKeyFromBase58(toAddress) if err != nil { @@ -395,13 +390,41 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra discriminator := GetDiscriminator(methodConfig.ChainSpecificName) encodedPayload = append(discriminator[:], encodedPayload...) - // Combine the two sets of instructions into one slice - var instructions []solana.Instruction - instructions = append(instructions, createATAinstructions...) - instructions = append(instructions, solana.NewInstruction(programID, accounts, encodedPayload)) + // Fetch latest blockhash + blockhash, err := s.reader.LatestBlockhash(ctx) + if err != nil { + return errorWithDebugID(fmt.Errorf("error fetching latest blockhash: %w", err), debugID) + } + + if len(createATAinstructions) > 0 { + ataTx, ataErr := solana.NewTransaction( + createATAinstructions, + blockhash.Value.Blockhash, + solana.TransactionPayer(feePayer), + ) + if ataErr != nil { + return errorWithDebugID(fmt.Errorf("error constructing ATA transaction: %w", err), debugID) + } + + // Enqueue ATA transaction + if err = s.txm.Enqueue(ctx, methodConfig.FromAddress, ataTx, &transactionID, blockhash.Value.LastValidBlockHeight); err != nil { + return errorWithDebugID(fmt.Errorf("error enqueuing transaction: %w", err), debugID) + } + + err = s.waitForTxFinality(ctx, transactionID) + if err != nil { + return errorWithDebugID(fmt.Errorf("error waiting for ATA transaction finality: %w", err), debugID) + } + + // refresh blockhash for next tx + blockhash, err = s.reader.LatestBlockhash(ctx) + if err != nil { + return errorWithDebugID(fmt.Errorf("error fetching latest blockhash: %w", err), debugID) + } + } tx, err := solana.NewTransaction( - instructions, + []solana.Instruction{solana.NewInstruction(programID, accounts, encodedPayload)}, blockhash.Value.Blockhash, solana.TransactionPayer(feePayer), solana.TransactionAddressTables(filteredLookupTableMap), @@ -418,6 +441,34 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra return nil } +func (s *SolanaChainWriterService) waitForTxFinality(ctx context.Context, transactionID string) error { + waitCtx, cancel := context.WithTimeout(ctx, 5*time.Minute) + defer cancel() + + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for { + select { + case <-waitCtx.Done(): + return fmt.Errorf("context ended while waiting for finality of transaction %s", transactionID) + case <-ticker.C: + status, err := s.txm.GetTransactionStatus(waitCtx, transactionID) + if err != nil { + return fmt.Errorf("error fetching transaction status: %w", err) + } + switch status { + case types.Finalized: + return nil + case types.Failed, types.Fatal: + return fmt.Errorf("transaction %s failed", transactionID) + default: + // Keep polling + } + } + } +} + // GetTransactionStatus returns the current status of a transaction in the underlying chain's TXM. func (s *SolanaChainWriterService) GetTransactionStatus(ctx context.Context, transactionID string) (types.TransactionStatus, error) { return s.txm.GetTransactionStatus(ctx, transactionID)