From fd05ceb8bfe53465fe8edacad41166b59a72f392 Mon Sep 17 00:00:00 2001 From: Lonny Wong Date: Sat, 13 Jul 2024 10:05:57 +0800 Subject: [PATCH] expect send text supports ssh tokens --- tssh/ctrl_unix.go | 15 +++++------ tssh/expect.go | 61 ++++++++++++++++++++++++++++++--------------- tssh/login.go | 20 +++++++-------- tssh/ssh.go | 1 + tssh/tokens.go | 10 ++++---- tssh/tokens_test.go | 1 + tssh/udp.go | 6 +++-- 7 files changed, 70 insertions(+), 44 deletions(-) diff --git a/tssh/ctrl_unix.go b/tssh/ctrl_unix.go index 2f42e69..3d53455 100644 --- a/tssh/ctrl_unix.go +++ b/tssh/ctrl_unix.go @@ -102,7 +102,7 @@ func (c *controlMaster) handleStdout() <-chan error { return doneCh } -func (c *controlMaster) fillPassword(args *sshArgs, expectCount int) (cancel context.CancelFunc) { +func (c *controlMaster) fillPassword(args *sshArgs, param *sshParam, expectCount int) (cancel context.CancelFunc) { var ctx context.Context expectTimeout := getExpectTimeout(args, "Ctrl") if expectTimeout > 0 { @@ -112,7 +112,8 @@ func (c *controlMaster) fillPassword(args *sshArgs, expectCount int) (cancel con } expect := &sshExpect{ - alias: args.Destination, + param: param, + args: args, ctx: ctx, pre: "Ctrl", out: make(chan []byte, 100), @@ -141,7 +142,7 @@ func (c *controlMaster) checkExit() <-chan struct{} { return exitCh } -func (c *controlMaster) start(args *sshArgs) error { +func (c *controlMaster) start(args *sshArgs, param *sshParam) error { var err error c.cmd = exec.Command(c.path, c.args...) expectCount := getExpectCount(args, "Ctrl") @@ -157,7 +158,7 @@ func (c *controlMaster) start(args *sshArgs) error { defer tty.Close() c.cmd.Stdin = tty c.ptmx = pty - cancel := c.fillPassword(args, expectCount) + cancel := c.fillPassword(args, param, expectCount) defer cancel() } if c.stdout, err = c.cmd.StdoutPipe(); err != nil { @@ -250,7 +251,7 @@ func getOpenSSH() (string, int, int, error) { return sshPath, majorVersion, minorVersion, nil } -func startControlMaster(args *sshArgs, sshPath string) error { +func startControlMaster(args *sshArgs, param *sshParam, sshPath string) error { cmdArgs := []string{"-T", "-oRemoteCommand=none", "-oConnectTimeout=10"} if args.Debug { @@ -311,7 +312,7 @@ func startControlMaster(args *sshArgs, sshPath string) error { } ctrlMaster := &controlMaster{path: sshPath, args: cmdArgs} - if err := ctrlMaster.start(args); err != nil { + if err := ctrlMaster.start(args, param); err != nil { return err } debug("start control master success") @@ -356,7 +357,7 @@ func connectViaControl(args *sshArgs, param *sshParam) SshClient { } fallthrough case "auto", "autoask": - if err := startControlMaster(args, sshPath); err != nil { + if err := startControlMaster(args, param, sshPath); err != nil { warning("start control master failed: %v", err) } } diff --git a/tssh/expect.go b/tssh/expect.go index b518b52..d8e0457 100644 --- a/tssh/expect.go +++ b/tssh/expect.go @@ -41,7 +41,8 @@ const ( ) type sshExpect struct { - alias string + param *sshParam + args *sshArgs pre string ctx context.Context out chan []byte @@ -54,6 +55,11 @@ type expectSender struct { input string } +type expectSendText struct { + showText string + sendText string +} + type caseSend struct { pattern string sender *expectSender @@ -81,8 +87,22 @@ func newTextSender(expect *sshExpect, input string) *expectSender { return &expectSender{expect, false, input} } -func (s *expectSender) decodeText(text string) [][]string { - var texts [][]string +func (s *expectSender) newSendText(showText, sendText string) *expectSendText { + var err error + showText, err = expandTokens(showText, s.expect.args, s.expect.param, "%hprnLlj") + if err != nil { + warning("expand send text [%s] failed: %v", showText, err) + } else { + sendText, err = expandTokens(sendText, s.expect.args, s.expect.param, "%hprnLlj") + if err != nil { + warning("expand send text %s failed: %v", strconv.QuoteToASCII(sendText), strconv.QuoteToASCII(err.Error())) + } + } + return &expectSendText{showText: showText, sendText: sendText} +} + +func (s *expectSender) decodeText(text string) []*expectSendText { + var texts []*expectSendText var buf strings.Builder state := byte(0) idx := 0 @@ -105,7 +125,7 @@ func (s *expectSender) decodeText(text string) [][]string { case 'n': buf.WriteRune('\n') case '|': - texts = append(texts, []string{text[idx : i-1], buf.String()}) + texts = append(texts, s.newSendText(text[idx:i-1], buf.String())) idx = i + 1 buf.Reset() default: @@ -118,12 +138,12 @@ func (s *expectSender) decodeText(text string) [][]string { warning("[%s] ends with \\ is invalid", text) buf.WriteRune('\\') } - texts = append(texts, []string{text[idx:], buf.String()}) + texts = append(texts, s.newSendText(text[idx:], buf.String())) return texts } func (s *expectSender) getExpectPsssSleep() (bool, bool) { - passSleep := getExConfig(s.expect.alias, fmt.Sprintf("%sExpectPassSleep", s.expect.pre)) + passSleep := getExConfig(s.expect.args.Destination, fmt.Sprintf("%sExpectPassSleep", s.expect.pre)) switch strings.ToLower(passSleep) { case "each": return true, false @@ -135,7 +155,7 @@ func (s *expectSender) getExpectPsssSleep() (bool, bool) { } func (s *expectSender) getExpectSleepTime() time.Duration { - expectSleepMS := getExConfig(s.expect.alias, fmt.Sprintf("%sExpectSleepMS", s.expect.pre)) + expectSleepMS := getExConfig(s.expect.args.Destination, fmt.Sprintf("%sExpectSleepMS", s.expect.pre)) if expectSleepMS == "" { return kDefaultExpectSleepMS * time.Millisecond } @@ -188,11 +208,11 @@ func (s *expectSender) sendInput(writer io.Writer, id string) bool { debug("expect %s sleep: %v", id, sleepTime) time.Sleep(sleepTime) } - if text[1] == "" { + if text.sendText == "" { continue } - debug("expect %s send: %s", id, text[0]) - if err := writeAll(writer, []byte(text[1])); err != nil { + debug("expect %s send: %s", id, text.showText) + if err := writeAll(writer, []byte(text.sendText)); err != nil { warning("expect %s send input failed: %v", id, err) return false } @@ -377,7 +397,7 @@ func (e *sshExpect) waitForPattern(pattern string, caseSends *caseSendList) erro } func (e *sshExpect) getExpectSender(idx int) *expectSender { - if pass := getExConfig(e.alias, fmt.Sprintf("%sExpectSendPass%d", e.pre, idx)); pass != "" { + if pass := getExConfig(e.args.Destination, fmt.Sprintf("%sExpectSendPass%d", e.pre, idx)); pass != "" { secret, err := decodeSecret(pass) if err != nil { warning("decode %sExpectSendPass%d [%s] failed: %v", e.pre, idx, pass, err) @@ -386,11 +406,11 @@ func (e *sshExpect) getExpectSender(idx int) *expectSender { return newPassSender(e, secret) } - if text := getExConfig(e.alias, fmt.Sprintf("%sExpectSendText%d", e.pre, idx)); text != "" { + if text := getExConfig(e.args.Destination, fmt.Sprintf("%sExpectSendText%d", e.pre, idx)); text != "" { return newTextSender(e, text) } - if encTotp := getExConfig(e.alias, fmt.Sprintf("%sExpectSendEncTotp%d", e.pre, idx)); encTotp != "" { + if encTotp := getExConfig(e.args.Destination, fmt.Sprintf("%sExpectSendEncTotp%d", e.pre, idx)); encTotp != "" { secret, err := decodeSecret(encTotp) if err != nil { warning("decode %sExpectSendEncTotp%d [%s] failed: %v", e.pre, idx, encTotp, err) @@ -399,7 +419,7 @@ func (e *sshExpect) getExpectSender(idx int) *expectSender { return newPassSender(e, getTotpCode(secret)) } - if encOtp := getExConfig(e.alias, fmt.Sprintf("%sExpectSendEncOtp%d", e.pre, idx)); encOtp != "" { + if encOtp := getExConfig(e.args.Destination, fmt.Sprintf("%sExpectSendEncOtp%d", e.pre, idx)); encOtp != "" { command, err := decodeSecret(encOtp) if err != nil { warning("decode %sExpectSendEncOtp%d [%s] failed: %v", e.pre, idx, encOtp, err) @@ -408,11 +428,11 @@ func (e *sshExpect) getExpectSender(idx int) *expectSender { return newPassSender(e, getOtpCommandOutput(command)) } - if secret := getExConfig(e.alias, fmt.Sprintf("%sExpectSendTotp%d", e.pre, idx)); secret != "" { + if secret := getExConfig(e.args.Destination, fmt.Sprintf("%sExpectSendTotp%d", e.pre, idx)); secret != "" { return newPassSender(e, getTotpCode(secret)) } - if command := getExConfig(e.alias, fmt.Sprintf("%sExpectSendOtp%d", e.pre, idx)); command != "" { + if command := getExConfig(e.args.Destination, fmt.Sprintf("%sExpectSendOtp%d", e.pre, idx)); command != "" { return newPassSender(e, getOtpCommandOutput(command)) } @@ -421,19 +441,19 @@ func (e *sshExpect) getExpectSender(idx int) *expectSender { func (e *sshExpect) execInteractions(writer io.Writer, expectCount int) { for idx := 1; idx <= expectCount; idx++ { - pattern := getExConfig(e.alias, fmt.Sprintf("%sExpectPattern%d", e.pre, idx)) + pattern := getExConfig(e.args.Destination, fmt.Sprintf("%sExpectPattern%d", e.pre, idx)) if pattern != "" { debug("expect %d pattern: %s", idx, pattern) } else { warning("expect %d pattern is empty, no output will be matched", idx) } caseSends := &caseSendList{e, writer, nil} - for _, cfg := range getAllExConfig(e.alias, fmt.Sprintf("%sExpectCaseSendPass%d", e.pre, idx)) { + for _, cfg := range getAllExConfig(e.args.Destination, fmt.Sprintf("%sExpectCaseSendPass%d", e.pre, idx)) { if err := caseSends.addCaseSendPass(cfg); err != nil { warning("Invalid ExpectCaseSendPass%d: %v", idx, err) } } - for _, cfg := range getAllExConfig(e.alias, fmt.Sprintf("%sExpectCaseSendText%d", e.pre, idx)) { + for _, cfg := range getAllExConfig(e.args.Destination, fmt.Sprintf("%sExpectCaseSendText%d", e.pre, idx)) { if err := caseSends.addCaseSendText(cfg); err != nil { warning("Invalid ExpectCaseSendText%d: %v", idx, err) } @@ -497,7 +517,8 @@ func execExpectInteractions(args *sshArgs, ss *sshClientSession) { defer cancel() expect := &sshExpect{ - alias: args.Destination, + param: ss.param, + args: args, ctx: ctx, out: make(chan []byte, 10), err: make(chan []byte, 10), diff --git a/tssh/login.go b/tssh/login.go index c557f94..b420a2b 100644 --- a/tssh/login.go +++ b/tssh/login.go @@ -1153,7 +1153,7 @@ func sshAgentForward(args *sshArgs, param *sshParam, client SshClient, session S debug("request ssh agent forwarding success") } -func sshTcpLogin(args *sshArgs) (ss *sshClientSession, param *sshParam, udpMode int, err error) { +func sshTcpLogin(args *sshArgs) (ss *sshClientSession, udpMode int, err error) { ss = &sshClientSession{} defer func() { if err != nil { @@ -1161,13 +1161,13 @@ func sshTcpLogin(args *sshArgs) (ss *sshClientSession, param *sshParam, udpMode } else { sshLoginSuccess.Store(true) // execute local command if necessary - execLocalCommand(args, param) + execLocalCommand(args, ss.param) } }() // ssh login var control bool - ss.client, param, control, err = sshConnect(args, nil, "") + ss.client, ss.param, control, err = sshConnect(args, nil, "") if err != nil { return } @@ -1176,7 +1176,7 @@ func sshTcpLogin(args *sshArgs) (ss *sshClientSession, param *sshParam, udpMode udpMode = getUdpMode(args) // parse cmd and tty - ss.cmd, ss.tty, err = parseCmdAndTTY(args, param) + ss.cmd, ss.tty, err = parseCmdAndTTY(args, ss.param) if err != nil { return } @@ -1194,7 +1194,7 @@ func sshTcpLogin(args *sshArgs) (ss *sshClientSession, param *sshParam, udpMode // ssh port forwarding if !control && udpMode == kUdpModeNo { - if err = sshForward(ss.client, args, param); err != nil { + if err = sshForward(ss.client, args, ss.param); err != nil { return } } @@ -1231,7 +1231,7 @@ func sshTcpLogin(args *sshArgs) (ss *sshClientSession, param *sshParam, udpMode if !control && udpMode == kUdpModeNo { // ssh agent forward - sshAgentForward(args, param, ss.client, ss.session) + sshAgentForward(args, ss.param, ss.client, ss.session) // x11 forward sshX11Forward(args, ss.client, ss.session) } @@ -1240,20 +1240,20 @@ func sshTcpLogin(args *sshArgs) (ss *sshClientSession, param *sshParam, udpMode } func sshLogin(args *sshArgs) (*sshClientSession, error) { - ss, param, udpMode, err := sshTcpLogin(args) + ss, udpMode, err := sshTcpLogin(args) if err != nil { return nil, err } if udpMode != kUdpModeNo { - ss, err = sshUdpLogin(args, param, ss, udpMode) + ss, err = sshUdpLogin(args, ss, udpMode) if err != nil { return nil, err } // ssh port forwarding if not running as a proxy ( aka: not stdio forward ). if args.StdioForward == "" { - if err := sshForward(ss.client, args, param); err != nil { + if err := sshForward(ss.client, args, ss.param); err != nil { ss.Close() return nil, err } @@ -1263,7 +1263,7 @@ func sshLogin(args *sshArgs) (*sshClientSession, error) { // if not running as a proxy ( aka: not stdio forward ) and executing remote command if args.StdioForward == "" && !args.NoCommand { // ssh agent forward - sshAgentForward(args, param, ss.client, ss.session) + sshAgentForward(args, ss.param, ss.client, ss.session) // x11 forward sshX11Forward(args, ss.client, ss.session) } diff --git a/tssh/ssh.go b/tssh/ssh.go index b00be0c..f2b7b38 100644 --- a/tssh/ssh.go +++ b/tssh/ssh.go @@ -207,6 +207,7 @@ type sshClientSession struct { serverIn io.WriteCloser serverOut io.Reader serverErr io.Reader + param *sshParam cmd string tty bool } diff --git a/tssh/tokens.go b/tssh/tokens.go index dc94941..2fbeb58 100644 --- a/tssh/tokens.go +++ b/tssh/tokens.go @@ -93,21 +93,21 @@ func expandTokens(str string, args *sshArgs, param *sshParam, tokens string) (st } state = 0 if !strings.ContainsRune(tokens, c) { - return "", fmt.Errorf("token [%%%c] in [%s] is not supported", c, str) + return str, fmt.Errorf("token [%%%c] in [%s] is not supported", c, str) } switch c { case '%': buf.WriteRune('%') case 'h': if !isHostValid(param.host) { - return "", fmt.Errorf("hostname contains invalid characters") + return str, fmt.Errorf("hostname contains invalid characters") } buf.WriteString(param.host) case 'p': buf.WriteString(param.port) case 'r': if !isUserValid(param.user) { - return "", fmt.Errorf("remote username contains invalid characters") + return str, fmt.Errorf("remote username contains invalid characters") } buf.WriteString(param.user) case 'n': @@ -131,11 +131,11 @@ func expandTokens(str string, args *sshArgs, param *sshParam, tokens string) (st } buf.WriteString(fmt.Sprintf("%x", sha1.Sum([]byte(hashStr)))) default: - return "", fmt.Errorf("token [%%%c] in [%s] is not supported yet", c, str) + return str, fmt.Errorf("token [%%%c] in [%s] is not supported yet", c, str) } } if state != 0 { - return "", fmt.Errorf("[%s] ends with %% is invalid", str) + return str, fmt.Errorf("[%s] ends with %% is invalid", str) } return buf.String(), nil } diff --git a/tssh/tokens_test.go b/tssh/tokens_test.go index 9dde4ea..1dd9e70 100644 --- a/tssh/tokens_test.go +++ b/tssh/tokens_test.go @@ -54,6 +54,7 @@ func TestExpandTokens(t *testing.T) { result, err := expandTokens(original, args, param, "%hnpr") if errMsg != "" { require.NotNil(err) + assert.Equal(original, result) assert.Equal(errMsg, err.Error()) return } diff --git a/tssh/udp.go b/tssh/udp.go index d938fed..090b496 100644 --- a/tssh/udp.go +++ b/tssh/udp.go @@ -704,14 +704,14 @@ func (c *sshUdpChannel) Stderr() io.ReadWriter { return nil } -func sshUdpLogin(args *sshArgs, param *sshParam, ss *sshClientSession, udpMode int) (*sshClientSession, error) { +func sshUdpLogin(args *sshArgs, ss *sshClientSession, udpMode int) (*sshClientSession, error) { defer ss.Close() serverInfo, err := startTsshdServer(args, ss, udpMode) if err != nil { return nil, err } - client, err := tsshd.NewClient(param.host, serverInfo) + client, err := tsshd.NewClient(ss.param.host, serverInfo) if err != nil { return nil, err } @@ -759,6 +759,7 @@ func sshUdpLogin(args *sshArgs, param *sshParam, ss *sshClientSession, udpMode i if args.StdioForward != "" || args.NoCommand { return &sshClientSession{ client: &udpClient, + param: ss.param, cmd: ss.cmd, tty: ss.tty, }, nil @@ -778,6 +779,7 @@ func sshUdpLogin(args *sshArgs, param *sshParam, ss *sshClientSession, udpMode i serverIn: serverIn, serverOut: serverOut, serverErr: nil, + param: ss.param, cmd: ss.cmd, tty: ss.tty, }, nil