diff --git a/go/node/client/client.go b/go/node/client/client.go index 007b5fa6..bc962cf2 100644 --- a/go/node/client/client.go +++ b/go/node/client/client.go @@ -4,11 +4,10 @@ import ( "context" "errors" - "github.com/spf13/pflag" - sdkclient "github.com/cosmos/cosmos-sdk/client" tmjclient "github.com/tendermint/tendermint/rpc/jsonrpc/client" + cltypes "github.com/akash-network/akash-api/go/node/client/types" "github.com/akash-network/akash-api/go/node/client/v1beta2" ) @@ -16,7 +15,9 @@ var ( ErrUnknownClientVersion = errors.New("akash-api: unknown client version") ) -func DiscoverClient(ctx context.Context, cctx sdkclient.Context, flags *pflag.FlagSet, setup func(interface{}) error) error { +type SetupFn func(interface{}) error + +func DiscoverClient(ctx context.Context, cctx sdkclient.Context, setup SetupFn, opts ...cltypes.ClientOption) error { rpc, err := tmjclient.New(cctx.NodeURI) if err != nil { return err @@ -39,7 +40,7 @@ func DiscoverClient(ctx context.Context, cctx sdkclient.Context, flags *pflag.Fl switch result.ClientInfo.ApiVersion { case "v1beta2": - cl, err = v1beta2.NewClient(ctx, cctx, flags) + cl, err = v1beta2.NewClient(ctx, cctx, opts...) default: err = ErrUnknownClientVersion } @@ -55,7 +56,7 @@ func DiscoverClient(ctx context.Context, cctx sdkclient.Context, flags *pflag.Fl return nil } -func DiscoverQueryClient(ctx context.Context, cctx sdkclient.Context, setup func(interface{}) error) error { +func DiscoverQueryClient(ctx context.Context, cctx sdkclient.Context, setup SetupFn) error { rpc, err := tmjclient.New(cctx.NodeURI) if err != nil { return err diff --git a/go/node/client/types/options.go b/go/node/client/types/options.go new file mode 100644 index 00000000..800ba265 --- /dev/null +++ b/go/node/client/types/options.go @@ -0,0 +1,181 @@ +package types + +import ( + "time" + + "github.com/spf13/pflag" + + "github.com/cosmos/cosmos-sdk/client" + "github.com/cosmos/cosmos-sdk/client/flags" + "github.com/cosmos/cosmos-sdk/client/tx" + "github.com/cosmos/cosmos-sdk/types/tx/signing" +) + +type ClientOptions struct { + AccountNumber uint64 + AccountSequence uint64 + GasAdjustment float64 + Gas flags.GasSetting + GasPrices string + Fees string + Note string + TimeoutHeight uint64 + BroadcastTimeout time.Duration +} + +type ClientOption func(options *ClientOptions) error + +// NewTxFactory creates a new Factory. +func NewTxFactory(cctx client.Context, opts ...ClientOption) (tx.Factory, error) { + clOpts := &ClientOptions{} + + for _, opt := range opts { + if err := opt(clOpts); err != nil { + return tx.Factory{}, err + } + } + + signMode := signing.SignMode_SIGN_MODE_UNSPECIFIED + switch cctx.SignModeStr { + case flags.SignModeDirect: + signMode = signing.SignMode_SIGN_MODE_DIRECT + case flags.SignModeLegacyAminoJSON: + signMode = signing.SignMode_SIGN_MODE_LEGACY_AMINO_JSON + case flags.SignModeEIP191: + signMode = signing.SignMode_SIGN_MODE_EIP_191 + } + + txf := tx.Factory{} + + txf = txf.WithTxConfig(cctx.TxConfig). + WithAccountRetriever(cctx.AccountRetriever). + WithAccountNumber(clOpts.AccountNumber). + WithSequence(clOpts.AccountSequence). + WithKeybase(cctx.Keyring). + WithChainID(cctx.ChainID). + WithGas(clOpts.Gas.Gas). + WithGasAdjustment(clOpts.GasAdjustment). + WithGasPrices(clOpts.GasPrices). + WithSimulateAndExecute(clOpts.Gas.Simulate). + WithTimeoutHeight(clOpts.TimeoutHeight). + WithMemo(clOpts.Note). + WithSignMode(signMode). + WithFees(clOpts.Fees) + + if !cctx.Offline { + from := cctx.GetFromAddress() + + if err := txf.AccountRetriever().EnsureExists(cctx, from); err != nil { + return txf, err + } + + if txf.AccountNumber() == 0 || txf.Sequence() == 0 { + num, seq, err := txf.AccountRetriever().GetAccountNumberSequence(cctx, from) + if err != nil { + return txf, err + } + + txf = txf.WithAccountNumber(num).WithSequence(seq) + } + } + + return txf, nil +} + +func WithAccountNumber(val uint64) ClientOption { + return func(options *ClientOptions) error { + options.AccountNumber = val + return nil + } +} + +func WithAccountSequence(val uint64) ClientOption { + return func(options *ClientOptions) error { + options.AccountSequence = val + return nil + } +} + +func WithGasAdjustment(val float64) ClientOption { + return func(options *ClientOptions) error { + options.GasAdjustment = val + return nil + } +} + +func WithNote(val string) ClientOption { + return func(options *ClientOptions) error { + options.Note = val + return nil + } +} + +func WithGas(val flags.GasSetting) ClientOption { + return func(options *ClientOptions) error { + options.Gas = val + return nil + } +} + +func WithGasPrices(val string) ClientOption { + return func(options *ClientOptions) error { + options.GasPrices = val + return nil + } +} + +func WithFees(val string) ClientOption { + return func(options *ClientOptions) error { + options.Fees = val + return nil + } +} + +func WithTimeoutHeight(val uint64) ClientOption { + return func(options *ClientOptions) error { + options.TimeoutHeight = val + return nil + } +} + +func ClientOptionsFromFlags(flagSet *pflag.FlagSet) ([]ClientOption, error) { + opts := make([]ClientOption, 0) + + if flagSet.Changed(flags.FlagAccountNumber) { + accNum, _ := flagSet.GetUint64(flags.FlagAccountNumber) + opts = append(opts, WithAccountNumber(accNum)) + } + + if flagSet.Changed(flags.FlagSequence) { + accSeq, _ := flagSet.GetUint64(flags.FlagSequence) + opts = append(opts, WithAccountSequence(accSeq)) + } + + // if flagSet.Changed(flags.FlagGasAdjustment) { + gasAdj, _ := flagSet.GetFloat64(flags.FlagGasAdjustment) + opts = append(opts, WithGasAdjustment(gasAdj)) + // } + + if flagSet.Changed(flags.FlagNote) { + memo, _ := flagSet.GetString(flags.FlagNote) + opts = append(opts, WithNote(memo)) + } + + if flagSet.Changed(flags.FlagTimeoutHeight) { + timeoutHeight, _ := flagSet.GetUint64(flags.FlagTimeoutHeight) + opts = append(opts, WithTimeoutHeight(timeoutHeight)) + } + + // if flagSet.Changed(flags.FlagGas) { + gasStr, _ := flagSet.GetString(flags.FlagGas) + gasSetting, _ := flags.ParseGasSetting(gasStr) + opts = append(opts, WithGas(gasSetting)) + // } + + // if flagSet.Changed(flags.FlagFees) { + feesStr, _ := flagSet.GetString(flags.FlagFees) + opts = append(opts, WithFees(feesStr)) + // } + + return opts, nil +} diff --git a/go/node/client/v1beta2/client.go b/go/node/client/v1beta2/client.go index 6db4d500..7eb4ee58 100644 --- a/go/node/client/v1beta2/client.go +++ b/go/node/client/v1beta2/client.go @@ -4,16 +4,14 @@ import ( "context" "fmt" - "github.com/gogo/protobuf/proto" - "github.com/spf13/pflag" - - tmrpc "github.com/tendermint/tendermint/rpc/core/types" - sdkclient "github.com/cosmos/cosmos-sdk/client" sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/gogo/protobuf/proto" + tmrpc "github.com/tendermint/tendermint/rpc/core/types" atypes "github.com/akash-network/akash-api/go/node/audit/v1beta3" ctypes "github.com/akash-network/akash-api/go/node/cert/v1beta3" + cltypes "github.com/akash-network/akash-api/go/node/client/types" dtypes "github.com/akash-network/akash-api/go/node/deployment/v1beta3" mtypes "github.com/akash-network/akash-api/go/node/market/v1beta4" ptypes "github.com/akash-network/akash-api/go/node/provider/v1beta3" @@ -56,7 +54,7 @@ type client struct { var _ Client = (*client)(nil) -func NewClient(ctx context.Context, cctx sdkclient.Context, flags *pflag.FlagSet) (Client, error) { +func NewClient(ctx context.Context, cctx sdkclient.Context, opts ...cltypes.ClientOption) (Client, error) { nd := newNode(cctx) cl := &client{ @@ -65,7 +63,7 @@ func NewClient(ctx context.Context, cctx sdkclient.Context, flags *pflag.FlagSet } var err error - cl.tx, err = newSerialTx(ctx, cctx, flags, nd, BroadcastDefaultTimeout) + cl.tx, err = newSerialTx(ctx, cctx, nd, opts...) if err != nil { return nil, err } diff --git a/go/node/client/v1beta2/tx.go b/go/node/client/v1beta2/tx.go index 2c6e97b2..897c1112 100644 --- a/go/node/client/v1beta2/tx.go +++ b/go/node/client/v1beta2/tx.go @@ -12,11 +12,10 @@ import ( "unsafe" "github.com/boz/go-lifecycle" + "github.com/cosmos/cosmos-sdk/client/flags" "github.com/cosmos/cosmos-sdk/client/input" "github.com/edwingeng/deque/v2" "github.com/gogo/protobuf/proto" - "github.com/spf13/pflag" - "github.com/tendermint/tendermint/libs/log" ttypes "github.com/tendermint/tendermint/types" @@ -27,6 +26,7 @@ import ( sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" authtx "github.com/cosmos/cosmos-sdk/x/auth/tx" + cltypes "github.com/akash-network/akash-api/go/node/client/types" "github.com/akash-network/akash-api/go/util/ctxlog" ) @@ -34,7 +34,6 @@ var ( ErrNotRunning = errors.New("tx client: not running") ErrSyncTimedOut = errors.New("tx client: timed-out waiting for sequence sync") ErrNodeCatchingUp = errors.New("tx client: cannot sync from catching up node") - ErrAdjustGas = errors.New("tx client: couldn't adjust gas") ErrSimulateOffline = errors.New("tx client: cannot simulate tx in offline mode") ErrBroadcastOffline = errors.New("tx client: cannot broadcast tx in offline mode") ErrTxCanceledByUser = errors.New("tx client: transaction declined by user input") @@ -57,16 +56,90 @@ const ( notFoundErrorMessageSuffix = ") not found" ) +type ConfirmFn func(string) (bool, error) + type BroadcastOptions struct { - resultAsError bool + timeoutHeight *uint64 + gasAdjustment *float64 + gas *flags.GasSetting + gasPrices *string + fees *string + note *string + broadcastTimeout *time.Duration + resultAsError bool + skipConfirm *bool + confirmFn ConfirmFn +} + +type BroadcastOption func(*BroadcastOptions) error + +func WithGasAdjustment(val float64) BroadcastOption { + return func(options *BroadcastOptions) error { + options.gasAdjustment = new(float64) + *options.gasAdjustment = val + return nil + } } -type BroadcastOption func(*BroadcastOptions) *BroadcastOptions +func WithNote(val string) BroadcastOption { + return func(options *BroadcastOptions) error { + options.note = new(string) + *options.note = val + return nil + } +} + +func WithGas(val flags.GasSetting) BroadcastOption { + return func(options *BroadcastOptions) error { + options.gas = new(flags.GasSetting) + *options.gas = val + return nil + } +} + +func WithGasPrices(val string) BroadcastOption { + return func(options *BroadcastOptions) error { + options.gasPrices = new(string) + *options.gasPrices = val + return nil + } +} + +func WithFees(val string) BroadcastOption { + return func(options *BroadcastOptions) error { + options.fees = new(string) + *options.fees = val + return nil + } +} + +func WithTimeoutHeight(val uint64) BroadcastOption { + return func(options *BroadcastOptions) error { + options.timeoutHeight = new(uint64) + *options.timeoutHeight = val + return nil + } +} func WithResultCodeAsError() BroadcastOption { - return func(opts *BroadcastOptions) *BroadcastOptions { + return func(opts *BroadcastOptions) error { opts.resultAsError = true - return opts + return nil + } +} + +func WithSkipConfirm(val bool) BroadcastOption { + return func(opts *BroadcastOptions) error { + opts.skipConfirm = new(bool) + *opts.skipConfirm = val + return nil + } +} + +func WithConfirmFn(val ConfirmFn) BroadcastOption { + return func(opts *BroadcastOptions) error { + opts.confirmFn = val + return nil } } @@ -79,6 +152,11 @@ type broadcastReq struct { id uintptr responsech chan<- broadcastResp msgs []sdk.Msg + opts *BroadcastOptions +} +type broadcastTxs struct { + msgs []sdk.Msg + opts *BroadcastOptions } type seqResp struct { @@ -95,23 +173,26 @@ type broadcast struct { donech chan<- error respch chan<- broadcastResp msgs []sdk.Msg + opts *BroadcastOptions } type serialBroadcaster struct { - ctx context.Context - cctx sdkclient.Context - info keyring.Info - broadcastTimeout time.Duration - reqch chan broadcastReq - broadcastch chan broadcast - seqreqch chan seqReq - lc lifecycle.Lifecycle - nd *node - log log.Logger + ctx context.Context + cctx sdkclient.Context + info keyring.Info + reqch chan broadcastReq + broadcastch chan broadcast + seqreqch chan seqReq + lc lifecycle.Lifecycle + nd *node + log log.Logger } -func newSerialTx(ctx context.Context, cctx sdkclient.Context, flags *pflag.FlagSet, nd *node, timeout time.Duration) (*serialBroadcaster, error) { - txf := tx.NewFactoryCLI(cctx, flags).WithTxConfig(cctx.TxConfig).WithAccountRetriever(cctx.AccountRetriever) +func newSerialTx(ctx context.Context, cctx sdkclient.Context, nd *node, opts ...cltypes.ClientOption) (*serialBroadcaster, error) { + txf, err := cltypes.NewTxFactory(cctx, opts...) + if err != nil { + return nil, err + } keyname := cctx.GetFromName() info, err := txf.Keybase().Key(keyname) @@ -123,23 +204,16 @@ func newSerialTx(ctx context.Context, cctx sdkclient.Context, flags *pflag.FlagS return nil, err } - // populate account number, current sequence number - txf, err = PrepareFactory(cctx, txf) - if err != nil { - return nil, err - } - client := &serialBroadcaster{ - ctx: ctx, - cctx: cctx, - info: info, - broadcastTimeout: timeout, - lc: lifecycle.New(), - reqch: make(chan broadcastReq, 1), - broadcastch: make(chan broadcast, 1), - seqreqch: make(chan seqReq), - nd: nd, - log: ctxlog.Logger(ctx).With("cmp", "client/broadcaster"), + ctx: ctx, + cctx: cctx, + info: info, + lc: lifecycle.New(), + reqch: make(chan broadcastReq, 1), + broadcastch: make(chan broadcast, 1), + seqreqch: make(chan seqReq), + nd: nd, + log: ctxlog.Logger(ctx).With("cmp", "client/broadcaster"), } go client.lc.WatchContext(ctx) @@ -154,20 +228,30 @@ func newSerialTx(ctx context.Context, cctx sdkclient.Context, flags *pflag.FlagS } func (c *serialBroadcaster) Broadcast(ctx context.Context, msgs []sdk.Msg, opts ...BroadcastOption) (interface{}, error) { + bOpts := &BroadcastOptions{ + confirmFn: defaultTxConfirm, + } + + for _, opt := range opts { + if err := opt(bOpts); err != nil { + return nil, err + } + } + + if bOpts.broadcastTimeout == nil { + bOpts.broadcastTimeout = new(time.Duration) + *bOpts.broadcastTimeout = BroadcastDefaultTimeout + } + responsech := make(chan broadcastResp, 1) request := broadcastReq{ responsech: responsech, msgs: msgs, + opts: bOpts, } request.id = uintptr(unsafe.Pointer(&request)) - ropts := &BroadcastOptions{} - - for _, opt := range opts { - _ = opt(ropts) - } - select { case c.reqch <- request: case <-ctx.Done(): @@ -180,7 +264,7 @@ func (c *serialBroadcaster) Broadcast(ctx context.Context, msgs []sdk.Msg, opts case resp := <-responsech: // if returned error is sdk error, it is likely to be wrapped response so discard it // as clients supposed to check Tx code, unless resp is nil, which is error during Tx preparation - if !errors.As(resp.err, &sdkerrors.Error{}) || resp.resp == nil || ropts.resultAsError { + if !errors.As(resp.err, &sdkerrors.Error{}) || resp.resp == nil || bOpts.resultAsError { return resp.resp, resp.err } return resp.resp, nil @@ -210,6 +294,7 @@ func (c *serialBroadcaster) run() { donech: broadcastDoneCh, respch: req.responsech, msgs: req.msgs, + opts: req.opts, }: broadcastCh = nil _ = pending.PopFront() @@ -238,6 +323,30 @@ loop: } } +func deriveTxfFromOptions(txf tx.Factory, opts *BroadcastOptions) tx.Factory { + if opt := opts.note; opt != nil { + txf = txf.WithMemo(*opt) + } + + if opt := opts.gas; opt != nil { + txf = txf.WithGas(opt.Gas).WithSimulateAndExecute(opt.Simulate) + } + + if opt := opts.fees; opt != nil { + txf = txf.WithFees(*opt) + } + + if opt := opts.gasPrices; opt != nil { + txf = txf.WithGasPrices(*opt) + } + + if opt := opts.timeoutHeight; opt != nil { + txf = txf.WithTimeoutHeight(*opt) + } + + return txf +} + func (c *serialBroadcaster) broadcaster(ptxf tx.Factory) { syncSequence := func(f tx.Factory, rErr error) (uint64, bool) { if rErr != nil { @@ -262,13 +371,19 @@ func (c *serialBroadcaster) broadcaster(ptxf tx.Factory) { var err error var resp interface{} + txf := deriveTxfFromOptions(ptxf, req.opts) + if c.cctx.GenerateOnly { - resp, err = c.generateTxs(ptxf, req.msgs...) + resp, err = c.generateTxs(txf, req.msgs...) } else { done: for i := 0; i < 2; i++ { var rseq uint64 - resp, rseq, err = c.broadcastTxs(ptxf, req.msgs...) + txs := broadcastTxs{ + msgs: req.msgs, + opts: req.opts, + } + resp, rseq, err = c.broadcastTxs(txf, txs) ptxf = ptxf.WithSequence(rseq) rSeq, synced := syncSequence(ptxf, err) @@ -365,7 +480,7 @@ func (c *serialBroadcaster) generateTxs(txf tx.Factory, msgs ...sdk.Msg) ([]byte return data, nil } -func DefaultTxConfirm(txn string) (bool, error) { +func defaultTxConfirm(txn string) (bool, error) { _, _ = fmt.Printf("%s\n\n", txn) buf := bufio.NewReader(os.Stdin) @@ -373,13 +488,13 @@ func DefaultTxConfirm(txn string) (bool, error) { return input.GetConfirmation("confirm transaction before signing and broadcasting", buf, os.Stdin) } -func (c *serialBroadcaster) broadcastTxs(txf tx.Factory, msgs ...sdk.Msg) (interface{}, uint64, error) { +func (c *serialBroadcaster) broadcastTxs(txf tx.Factory, txs broadcastTxs) (interface{}, uint64, error) { var err error var resp proto.Message if txf.SimulateAndExecute() || c.cctx.Simulate { var adjusted uint64 - resp, adjusted, err = tx.CalculateGas(c.cctx, txf, msgs...) + resp, adjusted, err = tx.CalculateGas(c.cctx, txf, txs.msgs...) if err != nil { return nil, txf.Sequence(), err } @@ -391,7 +506,7 @@ func (c *serialBroadcaster) broadcastTxs(txf tx.Factory, msgs ...sdk.Msg) (inter return resp, txf.Sequence(), nil } - txn, err := tx.BuildUnsignedTx(txf, msgs...) + txn, err := tx.BuildUnsignedTx(txf, txs.msgs...) if err != nil { return nil, txf.Sequence(), err } @@ -406,7 +521,7 @@ func (c *serialBroadcaster) broadcastTxs(txf tx.Factory, msgs ...sdk.Msg) (inter return nil, txf.Sequence(), err } - isYes, err := DefaultTxConfirm(string(out)) + isYes, err := txs.opts.confirmFn(string(out)) if err != nil { return nil, txf.Sequence(), err } @@ -428,7 +543,7 @@ func (c *serialBroadcaster) broadcastTxs(txf tx.Factory, msgs ...sdk.Msg) (inter return nil, txf.Sequence(), err } - response, err := c.doBroadcast(c.cctx, bytes, c.broadcastTimeout) + response, err := c.doBroadcast(c.cctx, bytes, *txs.opts.broadcastTimeout) if err != nil { return response, txf.Sequence(), err } @@ -505,28 +620,3 @@ func (c *serialBroadcaster) doBroadcast(cctx sdkclient.Context, data []byte, tim return cres, lctx.Err() } - -// PrepareFactory has been copied from cosmos-sdk to make it public. -// Source: https://github.com/cosmos/cosmos-sdk/blob/v0.43.0-rc2/client/tx/tx.go#L311 -func PrepareFactory(cctx sdkclient.Context, txf tx.Factory) (tx.Factory, error) { - if cctx.Offline { - return txf, nil - } - - from := cctx.GetFromAddress() - - if err := txf.AccountRetriever().EnsureExists(cctx, from); err != nil { - return txf, err - } - - if txf.AccountNumber() == 0 || txf.Sequence() == 0 { - num, seq, err := txf.AccountRetriever().GetAccountNumberSequence(cctx, from) - if err != nil { - return txf, err - } - - txf = txf.WithAccountNumber(num).WithSequence(seq) - } - - return txf, nil -}