diff --git a/tunnel-server/index.ts b/tunnel-server/index.ts index fdba32cd..367268f5 100644 --- a/tunnel-server/index.ts +++ b/tunnel-server/index.ts @@ -40,11 +40,11 @@ const app = createApp({ proxyHandlers: proxyHandlers({ envStore, logger }), logger, }) -const sshLogger = logger.child({ name: 'ssh_server' }) +const sshLog = logger.child({ name: 'ssh_server' }) -const tunnelName = (clientId: string, remotePath: string) => { +const tunnelName = (hostnameSuffix: string, remotePath: string) => { const serviceName = remotePath.replace(/^\//, '') - return `${serviceName}-${clientId}`.toLowerCase() + return `${serviceName}-${hostnameSuffix}`.toLowerCase() } const tunnelUrl = ( @@ -53,43 +53,72 @@ const tunnelUrl = ( tunnel: string, ) => replaceHostname(rootUrl, `${tunnelName(clientId, tunnel)}.${rootUrl.hostname}`).toString() +const tunnelsPerClientUniqueId = new Map void }>>() + const sshServer = createSshServer({ - log: sshLogger, + log: sshLog, sshPrivateKey, socketDir: '/tmp', // TODO }) .on('client', client => { - const { clientId, publicKey } = client - const tunnels = new Map() + const { hostnameSuffix, publicKey, uniqueId } = client + const clientLog = sshLog.child({ uniqueClientId: uniqueId }) + const tunnels = new Map void }>() + tunnelsPerClientUniqueId.set(uniqueId, tunnels) client .on('forward', async (requestId, { path: remotePath, access }, accept, reject) => { - const key = tunnelName(clientId, remotePath) - if (await envStore.has(key)) { - reject(new Error(`duplicate path: ${key}`)) - return + const forwardLog = clientLog.child({ forwardId: requestId }) + const key = tunnelName(hostnameSuffix, remotePath) + const existingEntry = await envStore.get(key) + if (existingEntry) { + if (existingEntry.clientUniqueId === uniqueId) { + reject(new Error(`duplicate request ${requestId} for client ${uniqueId} suffix ${hostnameSuffix}`)) + return + } + forwardLog.warn('forward: overriding duplicate envStore entry for path %s: %j', key, existingEntry) + await envStore.delete(key, existingEntry.clientUniqueId) + + // close tunnel of overridden client + tunnelsPerClientUniqueId.get(existingEntry.clientUniqueId)?.get(requestId)?.closeForward() } const forward = await accept() - sshLogger.debug('creating tunnel %s for localSocket %s', key, forward.localSocketPath) + forwardLog.debug('creating tunnel %s for localSocket %s', key, forward.localSocketPath) await envStore.set(key, { + clientUniqueId: uniqueId, + hostnameSuffix, target: forward.localSocketPath, - clientId, publicKey: createPublicKey(publicKey.getPublicPEM()), access, }) - tunnels.set(requestId, tunnelUrl(BASE_URL, clientId, remotePath)) - tunnelsGauge.inc({ clientId }) + tunnels.set(requestId, { + tunnelUrl: tunnelUrl(BASE_URL, hostnameSuffix, remotePath), + closeForward: () => { + forwardLog.debug('calling forward.close') + forward.close() + }, + }) + tunnelsGauge.inc({ clientId: hostnameSuffix }) - forward.on('close', () => { - sshLogger.debug('deleting tunnel %s', key) + forward.on('close', async () => { + forwardLog.debug('forward close event') tunnels.delete(requestId) - void envStore.delete(key) - tunnelsGauge.dec({ clientId }) + const storedEnv = await envStore.delete(key, uniqueId) + if (!storedEnv) { + forwardLog.info('forward.close: no stored env') + return + } + tunnelsGauge.dec({ clientId: hostnameSuffix }) }) }) - .on('error', err => { sshLogger.warn('client error %j: %j', clientId, inspect(err)) }) + .on('close', () => { + clientLog.debug('client %s closed', uniqueId) + tunnels.forEach(t => t.closeForward()) + tunnelsPerClientUniqueId.delete(uniqueId) + }) + .on('error', err => { clientLog.warn('client error %j', inspect(err)) }) .on('hello', channel => { channel.stdout.write(`${JSON.stringify({ - clientId, + clientId: hostnameSuffix, // TODO: backwards compat, remove when we drop support for CLI v0.0.35 baseUrl: { hostname: BASE_URL.hostname, port: BASE_URL.port, protocol: BASE_URL.protocol }, rootUrl: BASE_URL.toString(), diff --git a/tunnel-server/src/preview-env.ts b/tunnel-server/src/preview-env.ts index 661de63b..73ac047b 100644 --- a/tunnel-server/src/preview-env.ts +++ b/tunnel-server/src/preview-env.ts @@ -1,7 +1,9 @@ import { KeyObject } from 'crypto' +import EventEmitter from 'events' export type PreviewEnv = { - clientId: string + clientUniqueId: string + hostnameSuffix: string target: string publicKey: KeyObject access: 'private' | 'public' @@ -11,17 +13,27 @@ export type PreviewEnvStore = { get: (key: string) => Promise set: (key: string, env: PreviewEnv) => Promise has: (key: string) => Promise - delete: (key: string) => Promise + delete: (key: string, clientUniqueId: string) => Promise } export const inMemoryPreviewEnvStore = (initial?: Record): PreviewEnvStore => { const map = new Map(Object.entries(initial ?? {})) - return { - get: async key => map.get(key), - set: async (key, value) => { + const emitter = new EventEmitter() + return Object.assign(emitter, { + get: async (key: string) => map.get(key), + set: async (key: string, value: PreviewEnv) => { map.set(key, value) }, - has: async key => map.has(key), - delete: async key => map.delete(key), - } + has: async (key: string) => map.has(key), + delete: (key: string, clientUniqueId: string) => new Promise(resolve => { + const existing = map.get(key) + if (!existing || existing.clientUniqueId !== clientUniqueId) { + resolve(false) + return + } + map.delete(key) + resolve(true) + emitter.emit('deleted', key) + }), + }) } diff --git a/tunnel-server/src/proxy.ts b/tunnel-server/src/proxy.ts index 051e8c58..6aaaa633 100644 --- a/tunnel-server/src/proxy.ts +++ b/tunnel-server/src/proxy.ts @@ -72,7 +72,7 @@ export function proxyHandlers({ } logger.debug('proxying to %j', { target: env.target, url: req.url }) - requestsCounter.inc({ clientId: env.clientId }) + requestsCounter.inc({ clientId: env.hostnameSuffix }) return proxy.web( req, diff --git a/tunnel-server/src/ssh-server.ts b/tunnel-server/src/ssh-server.ts index 6c62c258..8dd9ebc8 100644 --- a/tunnel-server/src/ssh-server.ts +++ b/tunnel-server/src/ssh-server.ts @@ -8,7 +8,7 @@ import EventEmitter from 'node:events' import { Writable } from 'node:stream' import { ForwardRequest, parseForwardRequest } from './forward-request' -const idFromPublicSsh = (key: Buffer) => +const hostnameSuffixFromPublicSsh = (key: Buffer) => crypto.createHash('sha1').update(key).digest('base64url').replace(/[_-]/g, '') .slice(0, 8) .toLowerCase() @@ -31,10 +31,12 @@ const parseForwardRequestFromSocketBindInfo = ( export interface ClientForward extends EventEmitter { localSocketPath: string on: (event: 'close', listener: () => void) => this + close: () => void } export interface SshClient extends EventEmitter { - clientId: string + hostnameSuffix: string + uniqueId: string publicKey: ParsedKey on: ( ( @@ -58,6 +60,11 @@ export interface SshClient extends EventEmitter { event: 'error', listener: (err: Error) => void, ) => this + ) & ( + ( + event: 'close', + listener: () => void, + ) => this ) } @@ -88,6 +95,8 @@ export const sshServer = ( socketDir: string } ): SshServer => { + const serverId = randomBytes(8).toString('base64url').replace(/[^A-Za-z0-9]/g, '') + let currentClientId = 0 const serverEmitter = new EventEmitter() as Omit const server = new ssh2.Server( { @@ -97,12 +106,15 @@ export const sshServer = ( hostKeys: [sshPrivateKey], }, client => { + currentClientId += 1 + const uniqueId = `${serverId}-${currentClientId}` + const clientLog = log.child({ clientUniqueId: uniqueId }) let preevySshClient: SshClient const socketServers = new Map() client .on('authentication', ctx => { - log.debug('authentication: %j', ctx) + clientLog.debug('authentication: %j', ctx) if (ctx.method !== 'publickey') { ctx.reject(['publickey']) return @@ -110,7 +122,7 @@ export const sshServer = ( const keyOrError = ssh2.utils.parseKey(ctx.key.data) if (!('getPublicSSH' in keyOrError)) { - log.error('error parsing key: %j', keyOrError) + clientLog.error('error parsing key: %j', keyOrError) ctx.reject() return } @@ -118,23 +130,24 @@ export const sshServer = ( // calling "accept" when no signature specified does not result in authenticated state // see: https://github.com/mscdex/ssh2/issues/561#issuecomment-303263753 if (ctx.signature && !keyOrError.verify(ctx.blob as Buffer, ctx.signature, ctx.key.algo)) { - log.error('error verifying key: %j', keyOrError) + clientLog.error('error verifying key: %j', keyOrError) ctx.reject(['publickey']) return } preevySshClient = Object.assign(new EventEmitter(), { publicKey: keyOrError, - clientId: idFromPublicSsh(keyOrError.getPublicSSH()), + hostnameSuffix: hostnameSuffixFromPublicSsh(keyOrError.getPublicSSH()), + uniqueId, }) - log.debug('accepting clientId %j', preevySshClient.clientId) + clientLog.debug('accepting hostnameSuffix %j', preevySshClient.hostnameSuffix) ctx.accept() serverEmitter.emit('client', preevySshClient) }) .on('request', async (accept, reject, name, info) => { - log.debug('request %j', { accept, reject, name, info }) + clientLog.debug('request %j', { accept, reject, name, info }) if (!client.authenticated) { - log.error('not authenticated, rejecting') + clientLog.error('not authenticated, rejecting') reject?.() return } @@ -143,7 +156,7 @@ export const sshServer = ( const request = forwardRequestFromSocketBindInfo(info as unknown as SocketBindInfo) const deleted = socketServers.get(request) if (!deleted) { - log.error('cancel-streamlocal-forward@openssh.com: request %j not found, rejecting', request) + clientLog.error('cancel-streamlocal-forward@openssh.com: request %j not found, rejecting', request) reject?.() return } @@ -153,7 +166,7 @@ export const sshServer = ( } if ((name as string) !== 'streamlocal-forward@openssh.com') { - log.error('invalid request %j', { name, info }) + clientLog.error('invalid request %j', { name, info }) reject?.() return } @@ -161,7 +174,7 @@ export const sshServer = ( const res = parseForwardRequestFromSocketBindInfo(info as unknown as SocketBindInfo) const { request } = res if ('error' in res) { - log.error('streamlocal-forward@openssh.com: rejecting %j, error parsing: %j', request, inspect(res.error)) + clientLog.error('streamlocal-forward@openssh.com: rejecting %j, error parsing: %j', request, inspect(res.error)) reject?.() return } @@ -169,7 +182,7 @@ export const sshServer = ( const { parsed } = res if (socketServers.has(request)) { - log.error('streamlocal-forward@openssh.com: rejecting %j, duplicate socket request', request) + clientLog.error('streamlocal-forward@openssh.com: rejecting %j, duplicate socket request', request) reject?.() return } @@ -180,15 +193,15 @@ export const sshServer = ( parsed, () => new Promise((resolveForward, rejectForward) => { const socketServer = net.createServer(socket => { - log.debug('socketServer connected %j', socket) + clientLog.debug('socketServer connected %j', socket) client.openssh_forwardOutStreamLocal( request, (err, upstream) => { if (err) { - log.error('error forwarding request %j: %s', request, inspect(err)) + clientLog.error('error forwarding request %j: %s', request, inspect(err)) socket.end() socketServer.close(closeErr => { - log.error('error closing socket server for request %j: %j', request, inspect(closeErr)) + clientLog.error('error closing socket server for request %j: %j', request, inspect(closeErr)) }) return } @@ -197,24 +210,24 @@ export const sshServer = ( ) }) - const socketPath = path.join(socketDir, `s_${preevySshClient.clientId}_${randomBytes(16).toString('hex')}`) + const socketPath = path.join(socketDir, `s_${preevySshClient.hostnameSuffix}_${randomBytes(16).toString('hex')}`) const closeSocketServer = () => socketServer.close() socketServer .listen(socketPath, () => { - log.debug('streamlocal-forward@openssh.com: request %j calling accept: %j', request, accept) + clientLog.debug('streamlocal-forward@openssh.com: request %j calling accept: %j', request, accept) accept?.() socketServers.set(request, socketServer) resolveForward(Object.assign(socketServer, { localSocketPath: socketPath })) }) .on('error', (err: unknown) => { - log.error('socketServer request %j error: %j', request, err) + clientLog.error('socketServer request %j error: %j', request, err) socketServer.close() rejectForward(err) }) .on('close', () => { - log.debug('socketServer close: %j', socketPath) + clientLog.debug('socketServer close: %j', socketPath) socketServers.delete(request) client.removeListener('close', closeSocketServer) }) @@ -222,23 +235,27 @@ export const sshServer = ( client.once('close', closeSocketServer) }), (reason: Error) => { - log.error('streamlocal-forward@openssh.com: rejecting %j, reason: %j', request, inspect(reason)) + clientLog.error('streamlocal-forward@openssh.com: rejecting %j, reason: %j', request, inspect(reason)) reject?.() } ) }) .on('error', err => { - log.error('client error: %j', inspect(err)) + clientLog.error('client error: %j', inspect(err)) preevySshClient?.emit('error', err) }) + .on('close', () => { + clientLog.debug('client close') + serverEmitter?.emit('close') + }) .on('session', accept => { - log.debug('session') + clientLog.debug('session') const session = accept() session.on('exec', async (acceptExec, rejectExec, info) => { - log.debug('exec %j', info) + clientLog.debug('exec %j', info) if (info.command !== 'hello') { - log.error('invalid exec command %j', info.command) + clientLog.error('invalid exec command %j', info.command) rejectExec() return }