Skip to content
This repository has been archived by the owner on Jan 9, 2023. It is now read-only.

Commit

Permalink
Merge pull request #730 from JoshVanL/ssh-tunnel-improvments
Browse files Browse the repository at this point in the history
Have timeout for tunnel, new connections reset timer
  • Loading branch information
jetstack-bot authored Feb 12, 2019
2 parents b364c17 + d3fe1b8 commit a662f32
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 26 deletions.
5 changes: 2 additions & 3 deletions cmd/tarmak/cmd/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ var tunnelCmd = &cobra.Command{
PreRunE: func(cmd *cobra.Command, args []string) error {
if len(args) != 3 {
return fmt.Errorf(
"expecting only a destination, destination and local port argument, got=%s", args)
"expecting only a destination, destination port and local port argument, got=%s", args)
}
return nil
},
Expand All @@ -43,8 +43,7 @@ var tunnelCmd = &cobra.Command{
time.Sleep(time.Second * 2)
}

time.Sleep(time.Minute * 10)
t.Cleanup()
<-tunnel.Done()
os.Exit(0)
},
Hidden: true,
Expand Down
1 change: 1 addition & 0 deletions pkg/tarmak/interfaces/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ type Tunnel interface {
Stop()
Port() string
BindAddress() string
Done() <-chan struct{}
}

type VaultTunnel interface {
Expand Down
5 changes: 2 additions & 3 deletions pkg/tarmak/kubectl/kubectl.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,6 @@ func (k *Kubectl) Kubectl(args []string, publicEndpoint bool) error {

cmd.Wait()

k.stopTunnel()

return nil
}

Expand Down Expand Up @@ -364,7 +362,8 @@ func (k *Kubectl) setupConfig(c *api.Config, publicAPIEndpoint bool) (*api.Confi
k.tunnel.BindAddress(), k.tunnel.Port())
}

k.log.Warnf("ssh tunnel connecting to Kubernetes API server will close after 10 minutes: %s",
k.log.Warnf(
"ssh tunnel connecting to Kubernetes API server will close after 10 minutes of inactivity: %s",
cluster.Server)
}

Expand Down
87 changes: 67 additions & 20 deletions pkg/tarmak/ssh/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@ import (
"github.com/jetstack/tarmak/pkg/tarmak/interfaces"
)

const (
timeout = time.Minute * 10
)

type Tunnel struct {
log *logrus.Entry
ssh *SSH
stopCh chan struct{}
ssh *SSH
log *logrus.Entry

dest string
destPort string
Expand All @@ -33,8 +36,14 @@ type Tunnel struct {
listener net.Listener
daemon *os.Process

closeConnsLock sync.Mutex // prevent closing the same connection multiple times at once
openedConns []net.Conn
// have both a stopCh and doneCh so we have a chance to clean up connections
// properly before we exit the program during daemon mode
stopCh chan struct{}
doneCh chan struct{}
openConns []<-chan struct{}

connsLock sync.Mutex // prevent closing the same connection multiple times at once
remoteConns []net.Conn
}

var _ interfaces.Tunnel = &Tunnel{}
Expand All @@ -48,7 +57,6 @@ func (s *SSH) Tunnel(dest, destPort, localPort string, daemonize bool) interface
destPort: destPort,
daemonize: daemonize,
localPort: localPort,
stopCh: make(chan struct{}),
}

s.tunnels = append(s.tunnels, tunnel)
Expand All @@ -58,6 +66,7 @@ func (s *SSH) Tunnel(dest, destPort, localPort string, daemonize bool) interface
// Start tunnel and wait till a tcp socket is reachable
func (t *Tunnel) Start() error {
t.stopCh = make(chan struct{})
t.doneCh = make(chan struct{})

// ensure there is connectivity to the bastion
bastionClient, err := t.ssh.bastionClient()
Expand Down Expand Up @@ -89,27 +98,29 @@ func (t *Tunnel) Start() error {
}

func (t *Tunnel) handle() {
tries := 10
go t.handleTimeout()
var errCount int

for {
remoteConn, err := t.bastionConn.Dial("tcp",
net.JoinHostPort(t.dest, t.destPort))
if err != nil {
tries--
if tries == 0 {
return
}

select {
case <-t.stopCh:
return
default:
}

errCount++
if errCount == 10 {
return
}

time.Sleep(time.Second * 3)
continue
}

t.openedConns = append(t.openedConns, remoteConn)
t.remoteConns = append(t.remoteConns, remoteConn)

conn, err := t.listener.Accept()
if err != nil {
Expand All @@ -122,14 +133,25 @@ func (t *Tunnel) handle() {
t.log.Warnf("error accepting ssh tunnel connection: %s", err)
continue
}
t.openedConns = append(t.openedConns, conn)
t.remoteConns = append(t.remoteConns, conn)

t.connsLock.Lock()
ch := make(chan struct{})
t.openConns = append(t.openConns, ch)
t.connsLock.Unlock()

go func() {
io.Copy(remoteConn, conn)
conn.Close()

// reset timer to another 10 mins since this connection is now closed
time.Sleep(timeout)
close(ch)
}()

go func() {
io.Copy(conn, remoteConn)
remoteConn.Close()
}()
}
}
Expand All @@ -144,19 +166,20 @@ func (t *Tunnel) Stop() {
}

func (t *Tunnel) cleanup() {
// prevent closing the same connection multiple times at once
t.closeConnsLock.Lock()
defer t.closeConnsLock.Unlock()
// prevent closing the same connection multiple times at once as well as
// accepting any new ones
t.connsLock.Lock()
defer t.connsLock.Unlock()

select {
case <-t.stopCh:
default:
close(t.stopCh)
}

for _, o := range t.openedConns {
if o != nil {
o.Close()
for _, conn := range t.remoteConns {
if conn != nil {
conn.Close()
}
}

Expand All @@ -173,6 +196,30 @@ func (t *Tunnel) BindAddress() string {
return "127.0.0.1"
}

func (t *Tunnel) handleTimeout() {
// initial timeout whilst waiting for first connection
time.Sleep(timeout)

// need to use C style for-loop so we catch new openConns channels in the
// slice to wait on
t.connsLock.Lock()
for i := 0; i < len(t.openConns); i++ {
t.connsLock.Unlock()

<-t.openConns[i]

t.connsLock.Lock()
}
t.connsLock.Unlock()

t.cleanup()
close(t.doneCh)
}

func (t *Tunnel) Done() <-chan struct{} {
return t.doneCh
}

func (t *Tunnel) startDaemon() error {
binaryPath, err := osext.Executable()
if err != nil {
Expand Down
4 changes: 4 additions & 0 deletions pkg/tarmak/vault/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ func (v *vaultTunnel) VaultClient() *vault.Client {
return v.client
}

func (v *vaultTunnel) Done() <-chan struct{} {
return v.tunnel.Done()
}

func (v *vaultTunnel) Status() int {
if v.tunnelError != nil {
return VaultStateErr
Expand Down
4 changes: 4 additions & 0 deletions pkg/tarmak/vault/tunnel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ func (ft *FakeTunnel) Stop() {
return
}

func (ft *FakeTunnel) Done() <-chan struct{} {
return nil
}

var _ interfaces.Tunnel = &FakeTunnel{}

func TestVaultTunnel(t *testing.T) {
Expand Down

0 comments on commit a662f32

Please sign in to comment.