diff --git a/Sources/NIOSSH/ByteBuffer+SSH.swift b/Sources/NIOSSH/ByteBuffer+SSH.swift index def36c5..0349803 100644 --- a/Sources/NIOSSH/ByteBuffer+SSH.swift +++ b/Sources/NIOSSH/ByteBuffer+SSH.swift @@ -169,7 +169,7 @@ extension ByteBuffer { /// Writes a given number of SSH-acceptable padding bytes to this buffer. @discardableResult - mutating func writeSSHPaddingBytes(count: Int) -> Int { + public mutating func writeSSHPaddingBytes(count: Int) -> Int { // Annoyingly, the system random number generator can only give bytes to us 8 bytes at a time. precondition(count >= 0, "Cannot write negative number of padding bytes: \(count)") diff --git a/Sources/NIOSSH/Connection State Machine/Operations/AcceptsKeyExchangeMessages.swift b/Sources/NIOSSH/Connection State Machine/Operations/AcceptsKeyExchangeMessages.swift index b053d32..a9787f8 100644 --- a/Sources/NIOSSH/Connection State Machine/Operations/AcceptsKeyExchangeMessages.swift +++ b/Sources/NIOSSH/Connection State Machine/Operations/AcceptsKeyExchangeMessages.swift @@ -30,7 +30,7 @@ extension AcceptsKeyExchangeMessages { } mutating func receiveKeyExchangeInitMessage(_ message: SSHMessage.KeyExchangeECDHInitMessage) throws -> SSHConnectionStateMachine.StateMachineInboundProcessResult { - let message = try self.keyExchangeStateMachine.handle(keyExchangeInit: message) + let message = try self.keyExchangeStateMachine.handle(keyExchangeInit: message.publicKey) if let message = message { return .emitMessage(message) diff --git a/Sources/NIOSSH/Connection State Machine/SSHConnectionStateMachine.swift b/Sources/NIOSSH/Connection State Machine/SSHConnectionStateMachine.swift index 56d473b..2935be9 100644 --- a/Sources/NIOSSH/Connection State Machine/SSHConnectionStateMachine.swift +++ b/Sources/NIOSSH/Connection State Machine/SSHConnectionStateMachine.swift @@ -60,12 +60,12 @@ struct SSHConnectionStateMachine { /// The state of this state machine. private var state: State - private static let defaultTransportProtectionSchemes: [NIOSSHTransportProtection.Type] = [ + static let bundledTransportProtectionSchemes: [NIOSSHTransportProtection.Type] = [ AES256GCMOpenSSHTransportProtection.self, AES128GCMOpenSSHTransportProtection.self, ] - init(role: SSHConnectionRole, protectionSchemes: [NIOSSHTransportProtection.Type] = Self.defaultTransportProtectionSchemes) { - self.state = .idle(IdleState(role: role, protectionSchemes: protectionSchemes)) + init(role: SSHConnectionRole) { + self.state = .idle(IdleState(role: role)) } func start() -> SSHMultiMessage? { @@ -182,6 +182,7 @@ struct SSHConnectionStateMachine { return .noMessage case .unimplemented(let unimplemented): throw NIOSSHError.remotePeerDoesNotSupportMessage(unimplemented) + default: // TODO: enforce RFC 4253: // diff --git a/Sources/NIOSSH/Connection State Machine/States/ActiveState.swift b/Sources/NIOSSH/Connection State Machine/States/ActiveState.swift index 70ee709..1e34cad 100644 --- a/Sources/NIOSSH/Connection State Machine/States/ActiveState.swift +++ b/Sources/NIOSSH/Connection State Machine/States/ActiveState.swift @@ -27,8 +27,6 @@ extension SSHConnectionStateMachine { internal var remoteVersion: String - internal var protectionSchemes: [NIOSSHTransportProtection.Type] - internal var sessionIdentifier: ByteBuffer init(_ previous: UserAuthenticationState) { @@ -36,7 +34,6 @@ extension SSHConnectionStateMachine { self.serializer = previous.serializer self.parser = previous.parser self.remoteVersion = previous.remoteVersion - self.protectionSchemes = previous.protectionSchemes self.sessionIdentifier = previous.sessionIdentifier } @@ -45,7 +42,6 @@ extension SSHConnectionStateMachine { self.serializer = previous.serializer self.parser = previous.parser self.remoteVersion = previous.remoteVersion - self.protectionSchemes = previous.protectionSchemes self.sessionIdentifier = previous.sessionIdentifier } @@ -54,7 +50,6 @@ extension SSHConnectionStateMachine { self.serializer = previous.serializer self.parser = previous.parser self.remoteVersion = previous.remoteVersion - self.protectionSchemes = previous.protectionSchemes self.sessionIdentifier = previous.sessionIdentifier } } diff --git a/Sources/NIOSSH/Connection State Machine/States/IdleState.swift b/Sources/NIOSSH/Connection State Machine/States/IdleState.swift index 0ea83d3..edba5ea 100644 --- a/Sources/NIOSSH/Connection State Machine/States/IdleState.swift +++ b/Sources/NIOSSH/Connection State Machine/States/IdleState.swift @@ -23,10 +23,13 @@ extension SSHConnectionStateMachine { internal var protectionSchemes: [NIOSSHTransportProtection.Type] - init(role: SSHConnectionRole, protectionSchemes: [NIOSSHTransportProtection.Type]) { + internal var keyExchangeAlgorithms: [NIOSSHKeyExchangeAlgorithmProtocol.Type] + + init(role: SSHConnectionRole) { self.role = role self.serializer = SSHPacketSerializer() - self.protectionSchemes = protectionSchemes + self.protectionSchemes = role.transportProtectionSchemes + self.keyExchangeAlgorithms = role.keyExchangeAlgorithms } } } diff --git a/Sources/NIOSSH/Connection State Machine/States/KeyExchangeState.swift b/Sources/NIOSSH/Connection State Machine/States/KeyExchangeState.swift index 50f13ae..fede0e7 100644 --- a/Sources/NIOSSH/Connection State Machine/States/KeyExchangeState.swift +++ b/Sources/NIOSSH/Connection State Machine/States/KeyExchangeState.swift @@ -28,8 +28,6 @@ extension SSHConnectionStateMachine { var remoteVersion: String - var protectionSchemes: [NIOSSHTransportProtection.Type] - /// The backing state machine. var keyExchangeStateMachine: SSHKeyExchangeStateMachine @@ -38,8 +36,7 @@ extension SSHConnectionStateMachine { self.parser = state.parser self.serializer = state.serializer self.remoteVersion = remoteVersion - self.protectionSchemes = state.protectionSchemes - self.keyExchangeStateMachine = SSHKeyExchangeStateMachine(allocator: allocator, loop: loop, role: state.role, remoteVersion: remoteVersion, protectionSchemes: state.protectionSchemes, previousSessionIdentifier: nil) + self.keyExchangeStateMachine = SSHKeyExchangeStateMachine(allocator: allocator, loop: loop, role: state.role, remoteVersion: remoteVersion, keyExchangeAlgorithms: state.role.keyExchangeAlgorithms, transportProtectionSchemes: state.role.transportProtectionSchemes, previousSessionIdentifier: nil) } } } diff --git a/Sources/NIOSSH/Connection State Machine/States/ReceivedKexInitWhenActiveState.swift b/Sources/NIOSSH/Connection State Machine/States/ReceivedKexInitWhenActiveState.swift index c8f53e4..5b578ae 100644 --- a/Sources/NIOSSH/Connection State Machine/States/ReceivedKexInitWhenActiveState.swift +++ b/Sources/NIOSSH/Connection State Machine/States/ReceivedKexInitWhenActiveState.swift @@ -28,8 +28,6 @@ extension SSHConnectionStateMachine { internal var remoteVersion: String - internal var protectionSchemes: [NIOSSHTransportProtection.Type] - internal var keyExchangeStateMachine: SSHKeyExchangeStateMachine internal var sessionIdentifier: ByteBuffer @@ -39,9 +37,8 @@ extension SSHConnectionStateMachine { self.serializer = previous.serializer self.parser = previous.parser self.remoteVersion = previous.remoteVersion - self.protectionSchemes = previous.protectionSchemes self.sessionIdentifier = previous.sessionIdentifier - self.keyExchangeStateMachine = SSHKeyExchangeStateMachine(allocator: allocator, loop: loop, role: previous.role, remoteVersion: previous.remoteVersion, protectionSchemes: previous.protectionSchemes, previousSessionIdentifier: self.sessionIdentifier) + self.keyExchangeStateMachine = SSHKeyExchangeStateMachine(allocator: allocator, loop: loop, role: previous.role, remoteVersion: previous.remoteVersion, keyExchangeAlgorithms: self.role.keyExchangeAlgorithms, transportProtectionSchemes: self.role.transportProtectionSchemes, previousSessionIdentifier: self.sessionIdentifier) } } } diff --git a/Sources/NIOSSH/Connection State Machine/States/ReceivedNewKeysState.swift b/Sources/NIOSSH/Connection State Machine/States/ReceivedNewKeysState.swift index f183d3d..3308bae 100644 --- a/Sources/NIOSSH/Connection State Machine/States/ReceivedNewKeysState.swift +++ b/Sources/NIOSSH/Connection State Machine/States/ReceivedNewKeysState.swift @@ -29,8 +29,6 @@ extension SSHConnectionStateMachine { var remoteVersion: String - var protectionSchemes: [NIOSSHTransportProtection.Type] - var sessionIdentifier: ByteBuffer /// The backing state machine. @@ -45,7 +43,6 @@ extension SSHConnectionStateMachine { self.parser = state.parser self.serializer = state.serializer self.remoteVersion = state.remoteVersion - self.protectionSchemes = state.protectionSchemes self.keyExchangeStateMachine = state.keyExchangeStateMachine // We force unwrap the session ID because it's programmer error to not have it at this time. diff --git a/Sources/NIOSSH/Connection State Machine/States/RekeyingReceivedNewKeysState.swift b/Sources/NIOSSH/Connection State Machine/States/RekeyingReceivedNewKeysState.swift index 7caa796..140c24f 100644 --- a/Sources/NIOSSH/Connection State Machine/States/RekeyingReceivedNewKeysState.swift +++ b/Sources/NIOSSH/Connection State Machine/States/RekeyingReceivedNewKeysState.swift @@ -29,8 +29,6 @@ extension SSHConnectionStateMachine { var remoteVersion: String - var protectionSchemes: [NIOSSHTransportProtection.Type] - var sessionIdentifier: ByteBuffer /// The backing state machine. @@ -41,7 +39,6 @@ extension SSHConnectionStateMachine { self.parser = previousState.parser self.serializer = previousState.serializer self.remoteVersion = previousState.remoteVersion - self.protectionSchemes = previousState.protectionSchemes self.sessionIdentifier = previousState.sessionIdentifier self.keyExchangeStateMachine = previousState.keyExchangeStateMachine } diff --git a/Sources/NIOSSH/Connection State Machine/States/RekeyingSentNewKeysState.swift b/Sources/NIOSSH/Connection State Machine/States/RekeyingSentNewKeysState.swift index 95c0865..41f5cfa 100644 --- a/Sources/NIOSSH/Connection State Machine/States/RekeyingSentNewKeysState.swift +++ b/Sources/NIOSSH/Connection State Machine/States/RekeyingSentNewKeysState.swift @@ -29,8 +29,6 @@ extension SSHConnectionStateMachine { var remoteVersion: String - var protectionSchemes: [NIOSSHTransportProtection.Type] - var sessionIdentifier: ByteBuffer /// The backing state machine. @@ -41,7 +39,6 @@ extension SSHConnectionStateMachine { self.parser = previousState.parser self.serializer = previousState.serializer self.remoteVersion = previousState.remoteVersion - self.protectionSchemes = previousState.protectionSchemes self.sessionIdentifier = previousState.sessionIdentifier self.keyExchangeStateMachine = previousState.keyExchangeStateMachine } diff --git a/Sources/NIOSSH/Connection State Machine/States/RekeyingState.swift b/Sources/NIOSSH/Connection State Machine/States/RekeyingState.swift index cbeea56..f6d73f0 100644 --- a/Sources/NIOSSH/Connection State Machine/States/RekeyingState.swift +++ b/Sources/NIOSSH/Connection State Machine/States/RekeyingState.swift @@ -28,8 +28,6 @@ extension SSHConnectionStateMachine { var remoteVersion: String - var protectionSchemes: [NIOSSHTransportProtection.Type] - var sessionIdentifier: ByteBuffer /// The backing state machine. @@ -40,7 +38,6 @@ extension SSHConnectionStateMachine { self.parser = previousState.parser self.serializer = previousState.serializer self.remoteVersion = previousState.remoteVersion - self.protectionSchemes = previousState.protectionSchemes self.sessionIdentifier = previousState.sessionIdentifier self.keyExchangeStateMachine = previousState.keyExchangeStateMachine } @@ -50,7 +47,6 @@ extension SSHConnectionStateMachine { self.parser = previousState.parser self.serializer = previousState.serializer self.remoteVersion = previousState.remoteVersion - self.protectionSchemes = previousState.protectionSchemes self.sessionIdentifier = previousState.sessionIdentitifier self.keyExchangeStateMachine = previousState.keyExchangeStateMachine } diff --git a/Sources/NIOSSH/Connection State Machine/States/SentKexInitWhenActiveState.swift b/Sources/NIOSSH/Connection State Machine/States/SentKexInitWhenActiveState.swift index 5af0890..f0a1fc5 100644 --- a/Sources/NIOSSH/Connection State Machine/States/SentKexInitWhenActiveState.swift +++ b/Sources/NIOSSH/Connection State Machine/States/SentKexInitWhenActiveState.swift @@ -28,8 +28,6 @@ extension SSHConnectionStateMachine { internal var remoteVersion: String - internal var protectionSchemes: [NIOSSHTransportProtection.Type] - internal var sessionIdentitifier: ByteBuffer internal var keyExchangeStateMachine: SSHKeyExchangeStateMachine @@ -39,9 +37,8 @@ extension SSHConnectionStateMachine { self.serializer = previous.serializer self.parser = previous.parser self.remoteVersion = previous.remoteVersion - self.protectionSchemes = previous.protectionSchemes self.sessionIdentitifier = previous.sessionIdentifier - self.keyExchangeStateMachine = SSHKeyExchangeStateMachine(allocator: allocator, loop: loop, role: self.role, remoteVersion: self.remoteVersion, protectionSchemes: self.protectionSchemes, previousSessionIdentifier: previous.sessionIdentifier) + self.keyExchangeStateMachine = SSHKeyExchangeStateMachine(allocator: allocator, loop: loop, role: self.role, remoteVersion: self.remoteVersion, keyExchangeAlgorithms: self.role.keyExchangeAlgorithms, transportProtectionSchemes: self.role.transportProtectionSchemes, previousSessionIdentifier: previous.sessionIdentifier) } } } diff --git a/Sources/NIOSSH/Connection State Machine/States/SentNewKeysState.swift b/Sources/NIOSSH/Connection State Machine/States/SentNewKeysState.swift index 53b637f..f1227a7 100644 --- a/Sources/NIOSSH/Connection State Machine/States/SentNewKeysState.swift +++ b/Sources/NIOSSH/Connection State Machine/States/SentNewKeysState.swift @@ -29,8 +29,6 @@ extension SSHConnectionStateMachine { var remoteVersion: String - var protectionSchemes: [NIOSSHTransportProtection.Type] - var sessionIdentifier: ByteBuffer /// The backing state machine. @@ -46,7 +44,6 @@ extension SSHConnectionStateMachine { self.serializer = state.serializer self.keyExchangeStateMachine = state.keyExchangeStateMachine self.remoteVersion = state.remoteVersion - self.protectionSchemes = state.protectionSchemes // We force unwrap the session ID here because it's programmer error to not have it at this stage. self.sessionIdentifier = self.keyExchangeStateMachine.sessionID! diff --git a/Sources/NIOSSH/Connection State Machine/States/SentVersionState.swift b/Sources/NIOSSH/Connection State Machine/States/SentVersionState.swift index 4c85b83..c86080a 100644 --- a/Sources/NIOSSH/Connection State Machine/States/SentVersionState.swift +++ b/Sources/NIOSSH/Connection State Machine/States/SentVersionState.swift @@ -26,14 +26,11 @@ extension SSHConnectionStateMachine { /// The packet serializer used by this state machine. var serializer: SSHPacketSerializer - var protectionSchemes: [NIOSSHTransportProtection.Type] - private let allocator: ByteBufferAllocator init(idleState state: IdleState, allocator: ByteBufferAllocator) { self.role = state.role self.serializer = state.serializer - self.protectionSchemes = state.protectionSchemes self.parser = SSHPacketParser(allocator: allocator) self.allocator = allocator diff --git a/Sources/NIOSSH/Connection State Machine/States/UserAuthenticationState.swift b/Sources/NIOSSH/Connection State Machine/States/UserAuthenticationState.swift index 847e53d..cde70c2 100644 --- a/Sources/NIOSSH/Connection State Machine/States/UserAuthenticationState.swift +++ b/Sources/NIOSSH/Connection State Machine/States/UserAuthenticationState.swift @@ -28,8 +28,6 @@ extension SSHConnectionStateMachine { var remoteVersion: String - var protectionSchemes: [NIOSSHTransportProtection.Type] - var sessionIdentifier: ByteBuffer /// The backing state machine. @@ -41,7 +39,6 @@ extension SSHConnectionStateMachine { self.serializer = state.serializer self.userAuthStateMachine = state.userAuthStateMachine self.remoteVersion = state.remoteVersion - self.protectionSchemes = state.protectionSchemes self.sessionIdentifier = state.sessionIdentifier } @@ -51,7 +48,6 @@ extension SSHConnectionStateMachine { self.serializer = state.serializer self.userAuthStateMachine = state.userAuthStateMachine self.remoteVersion = state.remoteVersion - self.protectionSchemes = state.protectionSchemes self.sessionIdentifier = state.sessionIdentifier } } diff --git a/Sources/NIOSSH/Key Exchange/EllipticCurveKeyExchange.swift b/Sources/NIOSSH/Key Exchange/EllipticCurveKeyExchange.swift index 548c2e1..b0c72e2 100644 --- a/Sources/NIOSSH/Key Exchange/EllipticCurveKeyExchange.swift +++ b/Sources/NIOSSH/Key Exchange/EllipticCurveKeyExchange.swift @@ -16,34 +16,56 @@ import Crypto import NIOCore import NIOFoundationCompat +public struct NIOSSHKeyExchangeServerReply { + public var hostKey: NIOSSHPublicKey + public var publicKey: ByteBuffer + public var signature: NIOSSHSignature + + public init(hostKey: NIOSSHPublicKey, publicKey: ByteBuffer, signature: NIOSSHSignature) { + self.hostKey = hostKey + self.publicKey = publicKey + self.signature = signature + } +} + /// This protocol defines a container used by the key exchange state machine to manage key exchange. /// This type erases the specific key exchanger. -protocol EllipticCurveKeyExchangeProtocol { +public protocol NIOSSHKeyExchangeAlgorithmProtocol { + static var keyExchangeInitMessageId: UInt8 { get } + static var keyExchangeReplyMessageId: UInt8 { get } + init(ourRole: SSHConnectionRole, previousSessionIdentifier: ByteBuffer?) - func initiateKeyExchangeClientSide(allocator: ByteBufferAllocator) -> SSHMessage.KeyExchangeECDHInitMessage + func initiateKeyExchangeClientSide(allocator: ByteBufferAllocator) -> ByteBuffer - mutating func completeKeyExchangeServerSide(clientKeyExchangeMessage message: SSHMessage.KeyExchangeECDHInitMessage, - serverHostKey: NIOSSHPrivateKey, - initialExchangeBytes: inout ByteBuffer, - allocator: ByteBufferAllocator, - expectedKeySizes: ExpectedKeySizes) throws -> (KeyExchangeResult, SSHMessage.KeyExchangeECDHReplyMessage) + mutating func completeKeyExchangeServerSide( + clientKeyExchangeMessage message: ByteBuffer, + serverHostKey: NIOSSHPrivateKey, + initialExchangeBytes: inout ByteBuffer, + allocator: ByteBufferAllocator, + expectedKeySizes: ExpectedKeySizes + ) throws -> (KeyExchangeResult, NIOSSHKeyExchangeServerReply) - mutating func receiveServerKeyExchangePayload(serverKeyExchangeMessage message: SSHMessage.KeyExchangeECDHReplyMessage, - initialExchangeBytes: inout ByteBuffer, - allocator: ByteBufferAllocator, - expectedKeySizes: ExpectedKeySizes) throws -> KeyExchangeResult + mutating func receiveServerKeyExchangePayload( + serverKeyExchangeMessage: NIOSSHKeyExchangeServerReply, + initialExchangeBytes: inout ByteBuffer, + allocator: ByteBufferAllocator, + expectedKeySizes: ExpectedKeySizes + ) throws -> KeyExchangeResult static var keyExchangeAlgorithmNames: [Substring] { get } } -struct EllipticCurveKeyExchange: EllipticCurveKeyExchangeProtocol { +struct EllipticCurveKeyExchange: NIOSSHKeyExchangeAlgorithmProtocol { private var previousSessionIdentifier: ByteBuffer? private var ourKey: PrivateKey private var theirKey: PrivateKey.PublicKey? private var ourRole: SSHConnectionRole private var sharedSecret: SharedSecret? + static var keyExchangeInitMessageId: UInt8 { 30 } + static var keyExchangeReplyMessageId: UInt8 { 31 } + init(ourRole: SSHConnectionRole, previousSessionIdentifier: ByteBuffer?) { self.ourRole = ourRole self.ourKey = PrivateKey() @@ -59,14 +81,13 @@ extension EllipticCurveKeyExchange { /// Initiates key exchange by producing an SSH message. /// /// For now, we just return the ByteBuffer containing the SSH string. - func initiateKeyExchangeClientSide(allocator: ByteBufferAllocator) -> SSHMessage.KeyExchangeECDHInitMessage { + func initiateKeyExchangeClientSide(allocator: ByteBufferAllocator) -> ByteBuffer { precondition(self.ourRole.isClient, "Only clients may initiate the client side key exchange!") // The largest key we're likely to end up with here is 256 bytes. var buffer = allocator.buffer(capacity: 256) self.ourKey.publicKey.write(to: &buffer) - - return .init(publicKey: buffer) + return buffer } /// Handles receiving the client key exchange payload on the server side. @@ -77,15 +98,15 @@ extension EllipticCurveKeyExchange { /// - initialExchangeBytes: The initial bytes of the exchange, suitable for writing into the exchange hash. /// - allocator: A `ByteBufferAllocator` suitable for this connection. /// - expectedKeySizes: The sizes of the keys we need to generate. - mutating func completeKeyExchangeServerSide(clientKeyExchangeMessage message: SSHMessage.KeyExchangeECDHInitMessage, + mutating func completeKeyExchangeServerSide(clientKeyExchangeMessage message: ByteBuffer, serverHostKey: NIOSSHPrivateKey, initialExchangeBytes: inout ByteBuffer, allocator: ByteBufferAllocator, - expectedKeySizes: ExpectedKeySizes) throws -> (KeyExchangeResult, SSHMessage.KeyExchangeECDHReplyMessage) { + expectedKeySizes: ExpectedKeySizes) throws -> (KeyExchangeResult, NIOSSHKeyExchangeServerReply) { precondition(self.ourRole.isServer, "Only servers may receive a client key exchange packet!") // With that, we have enough to finalize the key exchange. - let kexResult = try self.finalizeKeyExchange(theirKeyBytes: message.publicKey, + let kexResult = try self.finalizeKeyExchange(theirKeyBytes: message, initialExchangeBytes: &initialExchangeBytes, serverHostKey: serverHostKey.publicKey, allocator: allocator, @@ -100,9 +121,9 @@ extension EllipticCurveKeyExchange { self.ourKey.publicKey.write(to: &publicKeyBytes) // Now we have all we need. - let responseMessage = SSHMessage.KeyExchangeECDHReplyMessage(hostKey: serverHostKey.publicKey, - publicKey: publicKeyBytes, - signature: exchangeHashSignature) + let responseMessage = NIOSSHKeyExchangeServerReply(hostKey: serverHostKey.publicKey, + publicKey: publicKeyBytes, + signature: exchangeHashSignature) return (KeyExchangeResult(kexResult), responseMessage) } @@ -116,10 +137,12 @@ extension EllipticCurveKeyExchange { /// - initialExchangeBytes: The initial bytes of the exchange, suitable for writing into the exchange hash. /// - allocator: A `ByteBufferAllocator` suitable for this connection. /// - expectedKeySizes: The sizes of the keys we need to generate. - mutating func receiveServerKeyExchangePayload(serverKeyExchangeMessage message: SSHMessage.KeyExchangeECDHReplyMessage, - initialExchangeBytes: inout ByteBuffer, - allocator: ByteBufferAllocator, - expectedKeySizes: ExpectedKeySizes) throws -> KeyExchangeResult { + mutating func receiveServerKeyExchangePayload( + serverKeyExchangeMessage: NIOSSHKeyExchangeServerReply, + initialExchangeBytes: inout ByteBuffer, + allocator: ByteBufferAllocator, + expectedKeySizes: ExpectedKeySizes + ) throws -> KeyExchangeResult { precondition(self.ourRole.isClient, "Only clients may receive a server key exchange packet!") // Ok, we have a few steps here. Firstly, we need to extract the server's public key and generate our shared @@ -131,14 +154,14 @@ extension EllipticCurveKeyExchange { // // Finally, we return our generated keys to the state machine. - let kexResult = try self.finalizeKeyExchange(theirKeyBytes: message.publicKey, + let kexResult = try self.finalizeKeyExchange(theirKeyBytes: serverKeyExchangeMessage.publicKey, initialExchangeBytes: &initialExchangeBytes, - serverHostKey: message.hostKey, + serverHostKey: serverKeyExchangeMessage.hostKey, allocator: allocator, expectedKeySizes: expectedKeySizes) // We can now verify signature over the exchange hash. - guard message.hostKey.isValidSignature(message.signature, for: kexResult.exchangeHash) else { + guard serverKeyExchangeMessage.hostKey.isValidSignature(serverKeyExchangeMessage.signature, for: kexResult.exchangeHash) else { throw NIOSSHError.invalidExchangeHashSignature } diff --git a/Sources/NIOSSH/Key Exchange/SSHKeyExchangeResult.swift b/Sources/NIOSSH/Key Exchange/SSHKeyExchangeResult.swift index 8c9e394..a7bf006 100644 --- a/Sources/NIOSSH/Key Exchange/SSHKeyExchangeResult.swift +++ b/Sources/NIOSSH/Key Exchange/SSHKeyExchangeResult.swift @@ -19,11 +19,16 @@ import NIOCore /// /// A round of key exchange generates a number of keys and also generates an exchange hash. /// This exchange hash is used for a number of purposes. -struct KeyExchangeResult { +public struct KeyExchangeResult { /// The session ID to use for this connection. Will be static across the lifetime of a connection. - var sessionID: ByteBuffer + public var sessionID: ByteBuffer - var keys: NIOSSHSessionKeys + public var keys: NIOSSHSessionKeys + + public init(sessionID: ByteBuffer, keys: NIOSSHSessionKeys) { + self.sessionID = sessionID + self.keys = keys + } } extension KeyExchangeResult: Equatable {} @@ -46,18 +51,27 @@ extension KeyExchangeResult: Equatable {} /// Of these types, the encryption keys and the MAC keys are intended to be secret, and so /// we store them in the `SymmetricKey` types. The IVs do not need to be secret, and so are /// stored in regular heap buffers. -struct NIOSSHSessionKeys { - var initialInboundIV: [UInt8] +public struct NIOSSHSessionKeys { + public var initialInboundIV: [UInt8] + + public var initialOutboundIV: [UInt8] - var initialOutboundIV: [UInt8] + public var inboundEncryptionKey: SymmetricKey - var inboundEncryptionKey: SymmetricKey + public var outboundEncryptionKey: SymmetricKey - var outboundEncryptionKey: SymmetricKey + public var inboundMACKey: SymmetricKey - var inboundMACKey: SymmetricKey + public var outboundMACKey: SymmetricKey - var outboundMACKey: SymmetricKey + public init(initialInboundIV: [UInt8], initialOutboundIV: [UInt8], inboundEncryptionKey: SymmetricKey, outboundEncryptionKey: SymmetricKey, inboundMACKey: SymmetricKey, outboundMACKey: SymmetricKey) { + self.initialInboundIV = initialInboundIV + self.initialOutboundIV = initialOutboundIV + self.inboundEncryptionKey = inboundEncryptionKey + self.outboundEncryptionKey = outboundEncryptionKey + self.inboundMACKey = inboundMACKey + self.outboundMACKey = outboundMACKey + } } extension NIOSSHSessionKeys: Equatable {} @@ -68,10 +82,16 @@ extension NIOSSHSessionKeys: Equatable {} /// hash function invocations. The output of these hash functions is truncated to an appropriate /// length as needed, which means we need to ensure the code doing the calculation knows how /// to truncate appropriately. -struct ExpectedKeySizes { - var ivSize: Int +public struct ExpectedKeySizes { + public var ivSize: Int + + public var encryptionKeySize: Int - var encryptionKeySize: Int + public var macKeySize: Int - var macKeySize: Int + public init(ivSize: Int, encryptionKeySize: Int, macKeySize: Int) { + self.ivSize = ivSize + self.encryptionKeySize = encryptionKeySize + self.macKeySize = macKeySize + } } diff --git a/Sources/NIOSSH/Key Exchange/SSHKeyExchangeStateMachine.swift b/Sources/NIOSSH/Key Exchange/SSHKeyExchangeStateMachine.swift index 350baaa..3a075e5 100644 --- a/Sources/NIOSSH/Key Exchange/SSHKeyExchangeStateMachine.swift +++ b/Sources/NIOSSH/Key Exchange/SSHKeyExchangeStateMachine.swift @@ -39,19 +39,19 @@ struct SSHKeyExchangeStateMachine { /// party can enter this state. The remote peer may be sending a guess as well. /// /// We store the message we sent for later. - case keyExchangeReceived(exchange: EllipticCurveKeyExchangeProtocol, negotiated: NegotiationResult, expectingGuess: Bool) + case keyExchangeReceived(exchange: NIOSSHKeyExchangeAlgorithmProtocol, negotiated: NegotiationResult, expectingGuess: Bool) /// The peer has guessed what key exchange init packet is coming, and guessed wrong. We need to wait for them to send that packet. - case awaitingKeyExchangeInitInvalidGuess(exchange: EllipticCurveKeyExchangeProtocol, negotiated: NegotiationResult) + case awaitingKeyExchangeInitInvalidGuess(exchange: NIOSSHKeyExchangeAlgorithmProtocol, negotiated: NegotiationResult) /// Both sides have sent their initial key exchange message but we have not begun actually performing a key exchange. - case awaitingKeyExchangeInit(exchange: EllipticCurveKeyExchangeProtocol, negotiated: NegotiationResult) + case awaitingKeyExchangeInit(exchange: NIOSSHKeyExchangeAlgorithmProtocol, negotiated: NegotiationResult) /// We've received the key exchange init, but not sent our reply yet. case keyExchangeInitReceived(result: KeyExchangeResult, negotiated: NegotiationResult) /// We've sent our keyExchangeInit, but not received the keyExchangeReply. - case keyExchangeInitSent(exchange: EllipticCurveKeyExchangeProtocol, negotiated: NegotiationResult) + case keyExchangeInitSent(exchange: NIOSSHKeyExchangeAlgorithmProtocol, negotiated: NegotiationResult) /// The keys have been exchanged. case keysExchanged(result: KeyExchangeResult, protection: NIOSSHTransportProtection, negotiated: NegotiationResult) @@ -71,16 +71,18 @@ struct SSHKeyExchangeStateMachine { private let role: SSHConnectionRole private var state: State private var initialExchangeBytes: ByteBuffer - private var protectionSchemes: [NIOSSHTransportProtection.Type] + private var transportProtectionSchemes: [NIOSSHTransportProtection.Type] + private var keyExchangeAlgorithms: [NIOSSHKeyExchangeAlgorithmProtocol.Type] private var previousSessionIdentifier: ByteBuffer? - init(allocator: ByteBufferAllocator, loop: EventLoop, role: SSHConnectionRole, remoteVersion: String, protectionSchemes: [NIOSSHTransportProtection.Type], previousSessionIdentifier: ByteBuffer?) { + init(allocator: ByteBufferAllocator, loop: EventLoop, role: SSHConnectionRole, remoteVersion: String, keyExchangeAlgorithms: [NIOSSHKeyExchangeAlgorithmProtocol.Type], transportProtectionSchemes: [NIOSSHTransportProtection.Type], previousSessionIdentifier: ByteBuffer?) { self.allocator = allocator self.loop = loop self.role = role self.initialExchangeBytes = allocator.buffer(capacity: 1024) self.state = .idle - self.protectionSchemes = protectionSchemes + self.keyExchangeAlgorithms = keyExchangeAlgorithms + self.transportProtectionSchemes = transportProtectionSchemes self.previousSessionIdentifier = previousSessionIdentifier switch self.role { @@ -103,7 +105,7 @@ struct SSHKeyExchangeStateMachine { return .init( cookie: rng.randomCookie(allocator: self.allocator), - keyExchangeAlgorithms: Self.supportedKeyExchangeAlgorithms, + keyExchangeAlgorithms: self.role.keyExchangeAlgorithmNames, serverHostKeyAlgorithms: self.supportedHostKeyAlgorithms, encryptionAlgorithmsClientToServer: encryptionAlgorithms, encryptionAlgorithmsServerToClient: encryptionAlgorithms, @@ -129,7 +131,8 @@ struct SSHKeyExchangeStateMachine { let exchanger = try self.exchangerForAlgorithm(negotiated.negotiatedKeyExchangeAlgorithm) // Ok, we need to send the key exchange message. - let message = SSHMessage.keyExchangeInit(exchanger.initiateKeyExchangeClientSide(allocator: self.allocator)) + let publicKeyBuffer = exchanger.initiateKeyExchangeClientSide(allocator: self.allocator) + let message = SSHMessage.keyExchangeInit(.init(publicKey: publicKeyBuffer)) self.state = .awaitingKeyExchangeInit(exchange: exchanger, negotiated: negotiated) return SSHMultiMessage(message) case .server: @@ -165,7 +168,8 @@ struct SSHKeyExchangeStateMachine { let result: SSHMultiMessage switch self.role { case .client: - result = SSHMultiMessage(.keyExchange(ourMessage), SSHMessage.keyExchangeInit(exchanger.initiateKeyExchangeClientSide(allocator: self.allocator))) + let publicKeyBuffer = exchanger.initiateKeyExchangeClientSide(allocator: self.allocator) + result = SSHMultiMessage(.keyExchange(ourMessage), SSHMessage.keyExchangeInit(.init(publicKey: publicKeyBuffer))) case .server: result = SSHMultiMessage(.keyExchange(ourMessage)) } @@ -202,7 +206,7 @@ struct SSHKeyExchangeStateMachine { } } - mutating func handle(keyExchangeInit message: SSHMessage.KeyExchangeECDHInitMessage) throws -> SSHMultiMessage? { + mutating func handle(keyExchangeInit message: ByteBuffer) throws -> SSHMultiMessage? { switch self.state { case .awaitingKeyExchangeInitInvalidGuess(exchange: let exchanger, negotiated: let negotiated): // We're going to ignore this one, we already know it's wrong. @@ -222,7 +226,7 @@ struct SSHKeyExchangeStateMachine { allocator: self.allocator, expectedKeySizes: negotiated.negotiatedProtection.keySizes ) - let message = SSHMessage.keyExchangeReply(reply) + let message = SSHMessage.keyExchangeReply(.init(hostKey: reply.hostKey, publicKey: reply.publicKey, signature: reply.signature)) self.state = .keyExchangeInitReceived(result: result, negotiated: negotiated) return SSHMultiMessage(message, .newKeys) } @@ -253,7 +257,11 @@ struct SSHKeyExchangeStateMachine { } let result = try exchanger.receiveServerKeyExchangePayload( - serverKeyExchangeMessage: message, + serverKeyExchangeMessage: .init( + hostKey: message.hostKey, + publicKey: message.publicKey, + signature: message.signature + ), initialExchangeBytes: &self.initialExchangeBytes, allocator: self.allocator, expectedKeySizes: negotiated.negotiatedProtection.keySizes @@ -328,7 +336,7 @@ struct SSHKeyExchangeStateMachine { } // Ok, now we need to find the right transport protection scheme. This can technically fail. - guard let scheme = self.protectionSchemes.first(where: { $0.cipherName == clientEncryption && ($0.macName == nil || $0.macName! == clientMAC) }) else { + guard let scheme = self.transportProtectionSchemes.first(where: { $0.cipherName == clientEncryption && ($0.macName == nil || $0.macName! == clientMAC) }) else { throw NIOSSHError.keyExchangeNegotiationFailure } @@ -369,13 +377,13 @@ struct SSHKeyExchangeStateMachine { switch self.role { case .client: - clientAlgorithms = Self.supportedKeyExchangeAlgorithms + clientAlgorithms = self.role.keyExchangeAlgorithmNames serverAlgorithms = peerKeyExchangeAlgorithms clientHostKeyAlgorithms = self.supportedHostKeyAlgorithms serverHostKeyAlgorithms = peerHostKeyAlgorithms case .server: clientAlgorithms = peerKeyExchangeAlgorithms - serverAlgorithms = Self.supportedKeyExchangeAlgorithms + serverAlgorithms = self.role.keyExchangeAlgorithmNames clientHostKeyAlgorithms = peerHostKeyAlgorithms serverHostKeyAlgorithms = self.supportedHostKeyAlgorithms } @@ -447,21 +455,21 @@ struct SSHKeyExchangeStateMachine { } } - private func exchangerForAlgorithm(_ algorithm: Substring) throws -> EllipticCurveKeyExchangeProtocol { - for implementation in Self.supportedKeyExchangeImplementations { + private func exchangerForAlgorithm(_ algorithm: Substring) throws -> NIOSSHKeyExchangeAlgorithmProtocol { + for implementation in self.keyExchangeAlgorithms { if implementation.keyExchangeAlgorithmNames.contains(algorithm) { return implementation.init(ourRole: self.role, previousSessionIdentifier: self.previousSessionIdentifier) } } - // Huh, we didn't find it. Weird error. + // We didn't find a match throw NIOSSHError.keyExchangeNegotiationFailure } private func expectingIncorrectGuess(_ kexMessage: SSHMessage.KeyExchangeMessage) -> Bool { // A guess is wrong if the key exchange algorithm and/or the host key algorithm differ from our preference. kexMessage.firstKexPacketFollows && ( - kexMessage.keyExchangeAlgorithms.first != Self.supportedKeyExchangeAlgorithms.first || + kexMessage.keyExchangeAlgorithms.first != self.role.keyExchangeAlgorithmNames.first || kexMessage.serverHostKeyAlgorithms.first != self.supportedHostKeyAlgorithms.first ) } @@ -478,12 +486,12 @@ struct SSHKeyExchangeStateMachine { /// The encryption algorithms supported by this peer, in order of preference. private var supportedEncryptionAlgorithms: [Substring] { - self.protectionSchemes.map { Substring($0.cipherName) } + self.transportProtectionSchemes.map { Substring($0.cipherName) } } /// The MAC algorithms supported by this peer, in order of preference. private var supportedMacAlgorithms: [Substring] { - let schemes = self.protectionSchemes.compactMap { $0.macName.map { Substring($0) } } + let schemes = self.transportProtectionSchemes.compactMap { $0.macName.map { Substring($0) } } // We do a weird thing here: if there are no MAC schemes, we lie and put one in. This is // because some schemes (such as AES-GCM in OpenSSH mode) ignore the MAC negotiation. @@ -498,17 +506,22 @@ struct SSHKeyExchangeStateMachine { extension SSHKeyExchangeStateMachine { // For now this is a static list. - static let supportedKeyExchangeImplementations: [EllipticCurveKeyExchangeProtocol.Type] = [ + static let bundledKeyExchangeImplementations: [NIOSSHKeyExchangeAlgorithmProtocol.Type] = [ EllipticCurveKeyExchange.self, EllipticCurveKeyExchange.self, EllipticCurveKeyExchange.self, EllipticCurveKeyExchange.self, ] - static let supportedKeyExchangeAlgorithms: [Substring] = supportedKeyExchangeImplementations.flatMap { $0.keyExchangeAlgorithmNames } - /// All known host key algorithms. - static let supportedServerHostKeyAlgorithms: [Substring] = ["ssh-ed25519", "ecdsa-sha2-nistp384", "ecdsa-sha2-nistp256", "ecdsa-sha2-nistp521"] + static let bundledServerHostKeyAlgorithms: [Substring] = ["ssh-ed25519", "ecdsa-sha2-nistp384", "ecdsa-sha2-nistp256", "ecdsa-sha2-nistp521"] + + static var supportedServerHostKeyAlgorithms: [Substring] { + let bundledAlgorithms = bundledServerHostKeyAlgorithms + let customAlgorithms = NIOSSHPublicKey.customPublicKeyAlgorithms.map { Substring($0.publicKeyPrefix) } + + return bundledAlgorithms + customAlgorithms + } } extension SSHKeyExchangeStateMachine { diff --git a/Sources/NIOSSH/Keys And Signatures/CustomKeys.swift b/Sources/NIOSSH/Keys And Signatures/CustomKeys.swift new file mode 100644 index 0000000..c1d86b9 --- /dev/null +++ b/Sources/NIOSSH/Keys And Signatures/CustomKeys.swift @@ -0,0 +1,89 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2022 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Foundation +import NIO + +/// A signature is a mathematical scheme for verifying the authenticity of digital messages or documents. +/// +/// This protocol can be implemented by a type that represents such a signature to NIOSSH. +/// +/// - See: https://en.wikipedia.org/wiki/Digital_signature +public protocol NIOSSHSignatureProtocol { + /// An identifier that represents the type of signature used in an SSH packet. + /// This identifier MUST be unique to the signature implementation. + /// The returned value MUST NOT overlap with other signature implementations or a specifications that the signature does not implement. + static var signaturePrefix: String { get } + + /// The raw reprentation of this signature as a blob. + var rawRepresentation: Data { get } + + /// Serializes and writes the signature to the buffer. The calling function SHOULD NOT keep track of the size of the written blob. + /// If the result is not a fixed size, the serialized format SHOULD include a length. + func write(to buffer: inout ByteBuffer) -> Int + + /// Reads this Signature from the buffer using the same format implemented in `write(to:)` + static func read(from buffer: inout ByteBuffer) throws -> Self +} + +internal extension NIOSSHSignatureProtocol { + var signaturePrefix: String { + Self.signaturePrefix + } +} + +public protocol NIOSSHPublicKeyProtocol { + /// An identifier that represents the type of public key used in an SSH packet. + /// This identifier MUST be unique to the public key implementation. + /// The returned value MUST NOT overlap with other public key implementations or a specifications that the public key does not implement. + static var publicKeyPrefix: String { get } + + /// The raw reprentation of this publc key as a blob. + var rawRepresentation: Data { get } + + /// Verifies that `signature` is the result of signing `data` using the private key that this public key is derived from. + func isValidSignature(_ signature: NIOSSHSignatureProtocol, for data: D) -> Bool + + /// Serializes and writes the public key to the buffer. The calling function SHOULD NOT keep track of the size of the written blob. + /// If the result is not a fixed size, the serialized format SHOULD include a length. + func write(to buffer: inout ByteBuffer) -> Int + + /// Reads this Public Key from the buffer using the same format implemented in `write(to:)` + static func read(from buffer: inout ByteBuffer) throws -> Self +} + +internal extension NIOSSHPublicKeyProtocol { + var publicKeyPrefix: String { + Self.publicKeyPrefix + } +} + +public protocol NIOSSHPrivateKeyProtocol { + /// An identifier that represents the type of private key used in an SSH packet. + /// This identifier MUST be unique to the private key implementation. + /// The returned value MUST NOT overlap with other private key implementations or a specifications that the private key does not implement. + static var keyPrefix: String { get } + + /// A public key instance that is able to verify signatures that are created using this private key. + var publicKey: NIOSSHPublicKeyProtocol { get } + + /// Creates a signature, proving that `data` has been sent by the holder of this private key, and can be verified by `publicKey`. + func signature(for data: D) throws -> NIOSSHSignatureProtocol +} + +internal extension NIOSSHPrivateKeyProtocol { + var keyPrefix: String { + Self.keyPrefix + } +} diff --git a/Sources/NIOSSH/Keys And Signatures/NIOSSHCertifiedPublicKey.swift b/Sources/NIOSSH/Keys And Signatures/NIOSSHCertifiedPublicKey.swift index 42f3598..6eee10c 100644 --- a/Sources/NIOSSH/Keys And Signatures/NIOSSHCertifiedPublicKey.swift +++ b/Sources/NIOSSH/Keys And Signatures/NIOSSHCertifiedPublicKey.swift @@ -335,6 +335,8 @@ extension NIOSSHCertifiedPublicKey { return Self.p384KeyPrefix case .ecdsaP521: return Self.p521KeyPrefix + case .custom(let backingKey): + return backingKey.publicKeyPrefix.utf8 case .certified: preconditionFailure("base key cannot be certified") } diff --git a/Sources/NIOSSH/Keys And Signatures/NIOSSHPrivateKey.swift b/Sources/NIOSSH/Keys And Signatures/NIOSSHPrivateKey.swift index 2b16dbc..5a858b4 100644 --- a/Sources/NIOSSH/Keys And Signatures/NIOSSHPrivateKey.swift +++ b/Sources/NIOSSH/Keys And Signatures/NIOSSHPrivateKey.swift @@ -46,6 +46,10 @@ public struct NIOSSHPrivateKey { self.backingKey = .ecdsaP521(key) } + public init(custom key: PrivateKey) { + self.backingKey = .custom(key) + } + #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) public init(secureEnclaveP256Key key: SecureEnclave.P256.Signing.PrivateKey) { self.backingKey = .secureEnclaveP256(key) @@ -63,6 +67,8 @@ public struct NIOSSHPrivateKey { return ["ecdsa-sha2-nistp384"] case .ecdsaP521: return ["ecdsa-sha2-nistp521"] + case .custom(let backingKey): + return [Substring(backingKey.keyPrefix)] #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) case .secureEnclaveP256: return ["ecdsa-sha2-nistp256"] @@ -78,6 +84,7 @@ extension NIOSSHPrivateKey { case ecdsaP256(P256.Signing.PrivateKey) case ecdsaP384(P384.Signing.PrivateKey) case ecdsaP521(P521.Signing.PrivateKey) + case custom(NIOSSHPrivateKeyProtocol) #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) case secureEnclaveP256(SecureEnclave.P256.Signing.PrivateKey) @@ -108,6 +115,11 @@ extension NIOSSHPrivateKey { try key.signature(for: ptr) } return NIOSSHSignature(backingSignature: .ecdsaP521(signature)) + case .custom(let key): + let signature = try digest.withUnsafeBytes { ptr in + try key.signature(for: ptr) + } + return NIOSSHSignature(backingSignature: .custom(signature)) #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) case .secureEnclaveP256(let key): @@ -133,6 +145,9 @@ extension NIOSSHPrivateKey { case .ecdsaP521(let key): let signature = try key.signature(for: payload.bytes.readableBytesView) return NIOSSHSignature(backingSignature: .ecdsaP521(signature)) + case .custom(let key): + let signature = try key.signature(for: payload.bytes.readableBytesView) + return NIOSSHSignature(backingSignature: .custom(signature)) #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) case .secureEnclaveP256(let key): let signature = try key.signature(for: payload.bytes.readableBytesView) @@ -154,6 +169,8 @@ extension NIOSSHPrivateKey { return NIOSSHPublicKey(backingKey: .ecdsaP384(privateKey.publicKey)) case .ecdsaP521(let privateKey): return NIOSSHPublicKey(backingKey: .ecdsaP521(privateKey.publicKey)) + case .custom(let privateKey): + return NIOSSHPublicKey(backingKey: .custom(privateKey.publicKey)) #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) case .secureEnclaveP256(let privateKey): return NIOSSHPublicKey(backingKey: .ecdsaP256(privateKey.publicKey)) diff --git a/Sources/NIOSSH/Keys And Signatures/NIOSSHPublicKey.swift b/Sources/NIOSSH/Keys And Signatures/NIOSSHPublicKey.swift index 820d0d2..7d7d8fc 100644 --- a/Sources/NIOSSH/Keys And Signatures/NIOSSHPublicKey.swift +++ b/Sources/NIOSSH/Keys And Signatures/NIOSSHPublicKey.swift @@ -14,6 +14,7 @@ import Crypto import Foundation +import NIOConcurrencyHelpers import NIOCore /// An SSH public key. @@ -68,7 +69,7 @@ public struct NIOSSHPublicKey: Hashable { extension NIOSSHPublicKey { /// Verifies that a given `NIOSSHSignature` was created by the holder of the private key associated with this /// public key. - internal func isValidSignature(_ signature: NIOSSHSignature, for digest: DigestBytes) -> Bool { + public func isValidSignature(_ signature: NIOSSHSignature, for digest: DigestBytes) -> Bool { switch (self.backingKey, signature.backingSignature) { case (.ed25519(let key), .ed25519(let sig)): return digest.withUnsafeBytes { digestPtr in @@ -91,12 +92,17 @@ extension NIOSSHPublicKey { return digest.withUnsafeBytes { digestPtr in key.isValidSignature(sig, for: digestPtr) } + case (.custom(let key), .custom(let sig)): + return digest.withUnsafeBytes { digestPtr in + key.isValidSignature(sig, for: digestPtr) + } case (.certified(let key), _): return key.isValidSignature(signature, for: digest) case (.ed25519, _), (.ecdsaP256, _), (.ecdsaP384, _), - (.ecdsaP521, _): + (.ecdsaP521, _), + (.custom, _): return false } } @@ -113,12 +119,15 @@ extension NIOSSHPublicKey { return key.isValidSignature(sig, for: bytes.readableBytesView) case (.ecdsaP521(let key), .ecdsaP521(let sig)): return key.isValidSignature(sig, for: bytes.readableBytesView) + case (.custom(let key), .custom(let sig)): + return key.isValidSignature(sig, for: bytes.readableBytesView) case (.certified(let key), _): return key.isValidSignature(signature, for: bytes) case (.ed25519, _), (.ecdsaP256, _), (.ecdsaP384, _), - (.ecdsaP521, _): + (.ecdsaP521, _), + (.custom, _): return false } } @@ -135,12 +144,15 @@ extension NIOSSHPublicKey { return key.isValidSignature(sig, for: payload.bytes.readableBytesView) case (.ecdsaP521(let key), .ecdsaP521(let sig)): return key.isValidSignature(sig, for: payload.bytes.readableBytesView) + case (.custom(let key), .custom(let sig)): + return key.isValidSignature(sig, for: payload.bytes.readableBytesView) case (.certified(let key), _): return key.isValidSignature(signature, for: payload) case (.ed25519, _), (.ecdsaP256, _), (.ecdsaP384, _), - (.ecdsaP521, _): + (.ecdsaP521, _), + (.custom, _): return false } } @@ -153,6 +165,7 @@ extension NIOSSHPublicKey { case ecdsaP256(P256.Signing.PublicKey) case ecdsaP384(P384.Signing.PublicKey) case ecdsaP521(P521.Signing.PublicKey) + case custom(NIOSSHPublicKeyProtocol) case certified(NIOSSHCertifiedPublicKey) // This case recursively contains `NIOSSHPublicKey`. } @@ -178,16 +191,107 @@ extension NIOSSHPublicKey { return Self.ecdsaP384PublicKeyPrefix case .ecdsaP521: return Self.ecdsaP521PublicKeyPrefix + case .custom(let publicKey): + return publicKey.publicKeyPrefix.utf8 case .certified(let base): return base.keyPrefix } } + private static let bundledAlgorithms: [String.UTF8View] = [ + Self.ed25519PublicKeyPrefix, Self.ecdsaP384PublicKeyPrefix, Self.ecdsaP256PublicKeyPrefix, Self.ecdsaP521PublicKeyPrefix, + ] + internal static var knownAlgorithms: [String.UTF8View] { - [Self.ed25519PublicKeyPrefix, Self.ecdsaP384PublicKeyPrefix, Self.ecdsaP256PublicKeyPrefix, Self.ecdsaP521PublicKeyPrefix] + bundledAlgorithms + customPublicKeyAlgorithms.map { $0.publicKeyPrefix.utf8 } + } + + internal static var customPublicKeyAlgorithms: [NIOSSHPublicKeyProtocol.Type] { + _CustomAlgorithms.publicKeyAlgorithmsLock.withLock { + _CustomAlgorithms.publicKeyAlgorithms + } + } + + internal static var customSignatures: [NIOSSHSignatureProtocol.Type] { + _CustomAlgorithms.signaturesLock.withLock { + _CustomAlgorithms.signatures + } + } +} + +public enum NIOSSHAlgorithms { + public static func register(keyExchangeAlgorithm type: NIOSSHKeyExchangeAlgorithmProtocol.Type) { + _CustomAlgorithms.keyExchangeAlgorithmsLock.withLockVoid { + if !_CustomAlgorithms.keyExchangeAlgorithms.contains(where: { ObjectIdentifier($0) == ObjectIdentifier(type) }) { + _CustomAlgorithms.keyExchangeAlgorithms.append(type) + } + } + } + + public static func register(transportProtectionScheme type: NIOSSHTransportProtection.Type) { + _CustomAlgorithms.transportProtectionSchemesLock.withLockVoid { + if !_CustomAlgorithms.transportProtectionSchemes.contains(where: { ObjectIdentifier($0) == ObjectIdentifier(type) }) { + _CustomAlgorithms.transportProtectionSchemes.append(type) + } + } + } + + /// Registers a custom type tuple for use in Public Key Authentication. + public static func register< + PublicKey: NIOSSHPublicKeyProtocol, + Signature: NIOSSHSignatureProtocol + >( + publicKey type: PublicKey.Type, + signature: Signature.Type + ) { + _CustomAlgorithms.publicKeyAlgorithmsLock.withLockVoid { + if !_CustomAlgorithms.publicKeyAlgorithms.contains(where: { ObjectIdentifier($0) == ObjectIdentifier(type) }) { + _CustomAlgorithms.publicKeyAlgorithms.append(type) + _CustomAlgorithms.signatures.append(signature) + } + } + } + + /// Used for our unit tests + internal static func unregisterAlgorithms() { + _CustomAlgorithms.transportProtectionSchemesLock.withLockVoid { + _CustomAlgorithms.transportProtectionSchemes = [] + } + _CustomAlgorithms.publicKeyAlgorithmsLock.withLockVoid { + _CustomAlgorithms.publicKeyAlgorithms = [] + } + _CustomAlgorithms.signaturesLock.withLockVoid { + _CustomAlgorithms.signatures = [] + } + _CustomAlgorithms.keyExchangeAlgorithmsLock.withLockVoid { + _CustomAlgorithms.keyExchangeAlgorithms = [] + } + } +} + +internal var customTransportProtectionSchemes: [NIOSSHTransportProtection.Type] { + _CustomAlgorithms.transportProtectionSchemesLock.withLock { + _CustomAlgorithms.transportProtectionSchemes + } +} + +internal var customKeyExchangeAlgorithms: [NIOSSHKeyExchangeAlgorithmProtocol.Type] { + _CustomAlgorithms.keyExchangeAlgorithmsLock.withLock { + _CustomAlgorithms.keyExchangeAlgorithms } } +private enum _CustomAlgorithms { + static var transportProtectionSchemesLock = Lock() + static var transportProtectionSchemes = [NIOSSHTransportProtection.Type]() + static var keyExchangeAlgorithmsLock = Lock() + static var keyExchangeAlgorithms = [NIOSSHKeyExchangeAlgorithmProtocol.Type]() + static var publicKeyAlgorithmsLock = Lock() + static var publicKeyAlgorithms: [NIOSSHPublicKeyProtocol.Type] = [] + static var signaturesLock = Lock() + static var signatures: [NIOSSHSignatureProtocol.Type] = [] +} + extension NIOSSHPublicKey.BackingKey: Equatable { static func == (lhs: NIOSSHPublicKey.BackingKey, rhs: NIOSSHPublicKey.BackingKey) -> Bool { // We implement equatable in terms of the key representation. @@ -200,12 +304,17 @@ extension NIOSSHPublicKey.BackingKey: Equatable { return lhs.rawRepresentation == rhs.rawRepresentation case (.ecdsaP521(let lhs), .ecdsaP521(let rhs)): return lhs.rawRepresentation == rhs.rawRepresentation + case (.custom(let lhs), .custom(let rhs)): + return + lhs.publicKeyPrefix == rhs.publicKeyPrefix && + lhs.rawRepresentation == rhs.rawRepresentation case (.certified(let lhs), .certified(let rhs)): return lhs == rhs case (.ed25519, _), (.ecdsaP256, _), (.ecdsaP384, _), (.ecdsaP521, _), + (.custom, _), (.certified, _): return false } @@ -227,14 +336,48 @@ extension NIOSSHPublicKey.BackingKey: Hashable { case .ecdsaP521(let pkey): hasher.combine(4) hasher.combine(pkey.rawRepresentation) - case .certified(let pkey): + case .custom(let pkey): hasher.combine(5) + hasher.combine(pkey.publicKeyPrefix) + hasher.combine(pkey.rawRepresentation) + case .certified(let pkey): + hasher.combine(6) hasher.combine(pkey) } } } +extension NIOSSHPublicKey { + @discardableResult + public func write(to buffer: inout ByteBuffer) -> Int { + buffer.writeSSHHostKey(self) + } + + @discardableResult + public func writeWithoutHeader(to buffer: inout ByteBuffer) -> Int { + buffer.writeSSHHostKeyWithoutHeader(self) + } +} + extension ByteBuffer { + @discardableResult + mutating func writeSSHHostKeyWithoutHeader(_ key: NIOSSHPublicKey) -> Int { + switch key.backingKey { + case .ed25519(let key): + return self.writeEd25519PublicKey(baseKey: key) + case .ecdsaP256(let key): + return self.writeECDSAP256PublicKey(baseKey: key) + case .ecdsaP384(let key): + return self.writeECDSAP384PublicKey(baseKey: key) + case .ecdsaP521(let key): + return self.writeECDSAP521PublicKey(baseKey: key) + case .custom(let key): + return key.write(to: &self) + case .certified(let key): + return self.writeCertifiedKey(key) + } + } + /// Writes an SSH host key to this `ByteBuffer`. @discardableResult mutating func writeSSHHostKey(_ key: NIOSSHPublicKey) -> Int { @@ -253,6 +396,9 @@ extension ByteBuffer { case .ecdsaP521(let key): writtenBytes += self.writeSSHString(NIOSSHPublicKey.ecdsaP521PublicKeyPrefix) writtenBytes += self.writeECDSAP521PublicKey(baseKey: key) + case .custom(let key): + writtenBytes += writeSSHString(key.publicKeyPrefix.utf8) + writtenBytes += key.write(to: &self) case .certified(let key): return self.writeCertifiedKey(key) } @@ -274,6 +420,10 @@ extension ByteBuffer { return self.writeECDSAP384PublicKey(baseKey: key) case .ecdsaP521(let key): return self.writeECDSAP521PublicKey(baseKey: key) + case .custom(let key): + var writtenBytes = writeSSHString(key.publicKeyPrefix.utf8) + writtenBytes += key.write(to: &self) + return writtenBytes case .certified: preconditionFailure("Certified keys are the only callers of this method, and cannot contain themselves") } @@ -302,6 +452,13 @@ extension ByteBuffer { } else if keyIdentifierBytes.elementsEqual(NIOSSHPublicKey.ecdsaP521PublicKeyPrefix) { return try buffer.readECDSAP521PublicKey() } else { + for type in NIOSSHPublicKey.customPublicKeyAlgorithms { + if keyIdentifierBytes.elementsEqual(type.publicKeyPrefix.utf8) { + let publicKey = try type.read(from: &buffer) + return NIOSSHPublicKey(backingKey: .custom(publicKey)) + } + } + // We don't know this public key type. Maybe the certified keys do. return try buffer.readCertifiedKeyWithoutKeyPrefix(keyIdentifierBytes).map(NIOSSHPublicKey.init) } diff --git a/Sources/NIOSSH/Keys And Signatures/NIOSSHSignature.swift b/Sources/NIOSSH/Keys And Signatures/NIOSSHSignature.swift index 8872507..2317d29 100644 --- a/Sources/NIOSSH/Keys And Signatures/NIOSSHSignature.swift +++ b/Sources/NIOSSH/Keys And Signatures/NIOSSHSignature.swift @@ -36,6 +36,7 @@ extension NIOSSHSignature { case ecdsaP256(P256.Signing.ECDSASignature) case ecdsaP384(P384.Signing.ECDSASignature) case ecdsaP521(P521.Signing.ECDSASignature) + case custom(NIOSSHSignatureProtocol) internal enum RawBytes { case byteBuffer(ByteBuffer) @@ -85,10 +86,13 @@ extension NIOSSHSignature.BackingSignature: Equatable { return lhs.rawRepresentation == rhs.rawRepresentation case (.ecdsaP521(let lhs), .ecdsaP521(let rhs)): return lhs.rawRepresentation == rhs.rawRepresentation + case (.custom(let lhs), .custom(let rhs)): + return lhs.rawRepresentation == rhs.rawRepresentation case (.ed25519, _), (.ecdsaP256, _), (.ecdsaP384, _), - (.ecdsaP521, _): + (.ecdsaP521, _), + (.custom, _): return false } } @@ -109,6 +113,10 @@ extension NIOSSHSignature.BackingSignature: Hashable { case .ecdsaP521(let sig): hasher.combine(3) hasher.combine(sig.rawRepresentation) + case .custom(let sig): + hasher.combine(4) + hasher.combine(sig.signaturePrefix) + hasher.combine(sig.rawRepresentation) } } } @@ -126,6 +134,10 @@ extension ByteBuffer { return self.writeECDSAP384Signature(baseSignature: sig) case .ecdsaP521(let sig): return self.writeECDSAP521Signature(baseSignature: sig) + case .custom(let sig): + var writtenBytes = writeSSHString(sig.signaturePrefix.utf8) + writtenBytes += sig.write(to: &self) + return writtenBytes } } @@ -222,6 +234,13 @@ extension ByteBuffer { } else if bytesView.elementsEqual(NIOSSHSignature.ecdsaP521SignaturePrefix) { return try buffer.readECDSAP521Signature() } else { + for signature in NIOSSHPublicKey.customSignatures { + if bytesView.elementsEqual(signature.signaturePrefix.utf8) { + let signature = try signature.read(from: &buffer) + return NIOSSHSignature(backingSignature: .custom(signature)) + } + } + // We don't know this signature type. let signature = signatureIdentifierBytes.readString(length: signatureIdentifierBytes.readableBytes) ?? "" throw NIOSSHError.unknownSignature(algorithm: signature) diff --git a/Sources/NIOSSH/Role.swift b/Sources/NIOSSH/Role.swift index d118ab0..e012b2d 100644 --- a/Sources/NIOSSH/Role.swift +++ b/Sources/NIOSSH/Role.swift @@ -37,4 +37,26 @@ public enum SSHConnectionRole { return true } } + + internal var transportProtectionSchemes: [NIOSSHTransportProtection.Type] { + switch self { + case .client(let client): + return client.transportProtectionSchemes + case .server(let server): + return server.transportProtectionSchemes + } + } + + internal var keyExchangeAlgorithmNames: [Substring] { + self.keyExchangeAlgorithms.flatMap { $0.keyExchangeAlgorithmNames } + } + + internal var keyExchangeAlgorithms: [NIOSSHKeyExchangeAlgorithmProtocol.Type] { + switch self { + case .client(let client): + return client.keyExchangeAlgorithms + case .server(let server): + return server.keyExchangeAlgorithms + } + } } diff --git a/Sources/NIOSSH/SSHClientConfiguration.swift b/Sources/NIOSSH/SSHClientConfiguration.swift index 7ecbbc4..233c54f 100644 --- a/Sources/NIOSSH/SSHClientConfiguration.swift +++ b/Sources/NIOSSH/SSHClientConfiguration.swift @@ -23,6 +23,12 @@ public struct SSHClientConfiguration { /// The global request delegate to be used with this client. public var globalRequestDelegate: GlobalRequestDelegate + /// The enabled TransportProtectionSchemes + public var transportProtectionSchemes: [NIOSSHTransportProtection.Type] = SSHConnectionStateMachine.bundledTransportProtectionSchemes + + /// The enabled KeyExchangeAlgorithms + public var keyExchangeAlgorithms: [NIOSSHKeyExchangeAlgorithmProtocol.Type] = SSHKeyExchangeStateMachine.bundledKeyExchangeImplementations + public init(userAuthDelegate: NIOSSHClientUserAuthenticationDelegate, serverAuthDelegate: NIOSSHClientServerAuthenticationDelegate, globalRequestDelegate: GlobalRequestDelegate? = nil) { diff --git a/Sources/NIOSSH/SSHMessages.swift b/Sources/NIOSSH/SSHMessages.swift index 1720216..a0c9116 100644 --- a/Sources/NIOSSH/SSHMessages.swift +++ b/Sources/NIOSSH/SSHMessages.swift @@ -1162,7 +1162,7 @@ extension ByteBuffer { writtenBytes += self.writeInteger(SSHMessage.UserAuthFailureMessage.id) writtenBytes += self.writeUserAuthFailureMessage(message) case .userAuthSuccess: - writtenBytes += self.writeInteger(52 as UInt8) + writtenBytes += self.writeInteger(SSHMessage.UserAuthSuccessMessage.id) case .userAuthBanner(let message): writtenBytes += self.writeInteger(SSHMessage.UserAuthBannerMessage.id) writtenBytes += self.writeUserAuthBannerMessage(message) diff --git a/Sources/NIOSSH/SSHPacketParser.swift b/Sources/NIOSSH/SSHPacketParser.swift index 1d70c94..dcbc9dc 100644 --- a/Sources/NIOSSH/SSHPacketParser.swift +++ b/Sources/NIOSSH/SSHPacketParser.swift @@ -25,6 +25,7 @@ struct SSHPacketParser { private var buffer: ByteBuffer private var state: State + private var sequenceNumber: UInt32 = 0 /// Testing only: the number of bytes we can discard from this buffer. internal var _discardableBytes: Int { @@ -73,6 +74,7 @@ struct SSHPacketParser { if let length = self.buffer.getInteger(at: self.buffer.readerIndex, as: UInt32.self) { if let message = try self.parsePlaintext(length: length) { self.state = .cleartextWaitingForLength + self.sequenceNumber = self.sequenceNumber &+ 1 return message } self.state = .cleartextWaitingForBytes(length) @@ -82,6 +84,7 @@ struct SSHPacketParser { case .cleartextWaitingForBytes(let length): if let message = try self.parsePlaintext(length: length) { self.state = .cleartextWaitingForLength + self.sequenceNumber = self.sequenceNumber &+ 1 return message } return nil @@ -92,6 +95,7 @@ struct SSHPacketParser { if let message = try self.parseCiphertext(length: length, protection: protection) { self.state = .encryptedWaitingForLength(protection) + self.sequenceNumber = self.sequenceNumber &+ 1 return message } self.state = .encryptedWaitingForBytes(length, protection) @@ -99,6 +103,7 @@ struct SSHPacketParser { case .encryptedWaitingForBytes(let length, let protection): if let message = try self.parseCiphertext(length: length, protection: protection) { self.state = .encryptedWaitingForLength(protection) + self.sequenceNumber = self.sequenceNumber &+ 1 return message } return nil @@ -169,7 +174,7 @@ struct SSHPacketParser { return nil } - var content = try protection.decryptAndVerifyRemainingPacket(&buffer) + var content = try protection.decryptAndVerifyRemainingPacket(&buffer, sequenceNumber: sequenceNumber) guard let message = try content.readSSHMessage(), content.readableBytes == 0, buffer.readableBytes == 0 else { // Throw this error if the content wasn't exactly the right length for the message. throw NIOSSHError.invalidPacketFormat diff --git a/Sources/NIOSSH/SSHPacketSerializer.swift b/Sources/NIOSSH/SSHPacketSerializer.swift index 1ab96b7..af39db9 100644 --- a/Sources/NIOSSH/SSHPacketSerializer.swift +++ b/Sources/NIOSSH/SSHPacketSerializer.swift @@ -22,6 +22,7 @@ struct SSHPacketSerializer { } private var state: State = .initialized + private var sequenceNumber: UInt32 = 0 /// Encryption schemes can be added to a packet serializer whenever encryption is negotiated. mutating func addEncryption(_ protection: NIOSSHTransportProtection) { @@ -35,7 +36,10 @@ struct SSHPacketSerializer { } } - mutating func serialize(message: SSHMessage, to buffer: inout ByteBuffer) throws { + mutating func serialize( + message: SSHMessage, + to buffer: inout ByteBuffer + ) throws { switch self.state { case .initialized: switch message { @@ -75,9 +79,11 @@ struct SSHPacketSerializer { buffer.setInteger(UInt8(paddingLength), at: index + 4) /// random padding buffer.writeSSHPaddingBytes(count: paddingLength) + self.sequenceNumber = self.sequenceNumber &+ 1 case .encrypted(let protection): let payload = NIOSSHEncryptablePayload(message: message) - try protection.encryptPacket(payload, to: &buffer) + try protection.encryptPacket(payload, to: &buffer, sequenceNumber: self.sequenceNumber) + self.sequenceNumber = self.sequenceNumber &+ 1 } } } diff --git a/Sources/NIOSSH/SSHServerConfiguration.swift b/Sources/NIOSSH/SSHServerConfiguration.swift index 5c8f7b6..e0312d3 100644 --- a/Sources/NIOSSH/SSHServerConfiguration.swift +++ b/Sources/NIOSSH/SSHServerConfiguration.swift @@ -26,6 +26,12 @@ public struct SSHServerConfiguration { /// The ssh banner to display to clients upon authentication public var banner: UserAuthBanner? + /// The enabled TransportProtectionSchemes + public var transportProtectionSchemes: [NIOSSHTransportProtection.Type] = SSHConnectionStateMachine.bundledTransportProtectionSchemes + + /// The enabled KeyExchangeAlgorithms + public var keyExchangeAlgorithms: [NIOSSHKeyExchangeAlgorithmProtocol.Type] = SSHKeyExchangeStateMachine.bundledKeyExchangeImplementations + public init(hostKeys: [NIOSSHPrivateKey], userAuthDelegate: NIOSSHServerUserAuthenticationDelegate, globalRequestDelegate: GlobalRequestDelegate? = nil, banner: UserAuthBanner? = nil) { self.hostKeys = hostKeys self.userAuthDelegate = userAuthDelegate diff --git a/Sources/NIOSSH/TransportProtection/AESGCM.swift b/Sources/NIOSSH/TransportProtection/AESGCM.swift index 3cc1503..a090ce2 100644 --- a/Sources/NIOSSH/TransportProtection/AESGCM.swift +++ b/Sources/NIOSSH/TransportProtection/AESGCM.swift @@ -79,7 +79,7 @@ extension AESGCMTransportProtection: NIOSSHTransportProtection { // unencrypted! } - func decryptAndVerifyRemainingPacket(_ source: inout ByteBuffer) throws -> ByteBuffer { + func decryptAndVerifyRemainingPacket(_ source: inout ByteBuffer, sequenceNumber: UInt32) throws -> ByteBuffer { var plaintext: Data // Establish a nested scope here to avoid the byte buffer views causing an accidental CoW. @@ -117,7 +117,7 @@ extension AESGCMTransportProtection: NIOSSHTransportProtection { return source.readSlice(length: plaintext.count)! } - func encryptPacket(_ packet: NIOSSHEncryptablePayload, to outboundBuffer: inout ByteBuffer) throws { + func encryptPacket(_ packet: NIOSSHEncryptablePayload, to outboundBuffer: inout ByteBuffer, sequenceNumber: UInt32) throws { // Keep track of where the length is going to be written. let packetLengthIndex = outboundBuffer.writerIndex let packetLengthLength = MemoryLayout.size diff --git a/Sources/NIOSSH/TransportProtection/SSHTransportProtection.swift b/Sources/NIOSSH/TransportProtection/SSHTransportProtection.swift index 5b7ba97..bce4928 100644 --- a/Sources/NIOSSH/TransportProtection/SSHTransportProtection.swift +++ b/Sources/NIOSSH/TransportProtection/SSHTransportProtection.swift @@ -44,7 +44,7 @@ import NIOCore /// Implementers of this protocol **must not** expose unauthenticated plaintext, except for the length field. This /// is required by the SSH protocol, and swift-nio-ssh does its best to treat the length field as fundamentally /// untrusted information. -protocol NIOSSHTransportProtection: AnyObject { +public protocol NIOSSHTransportProtection: AnyObject { /// The name of the cipher portion of this transport protection scheme as negotiated on the wire. static var cipherName: String { get } @@ -87,10 +87,10 @@ protocol NIOSSHTransportProtection: AnyObject { /// length, the padding, or the MAC), and update source to indicate the consumed bytes. /// It must also perform any integrity checking that /// is required and throw if the integrity check fails. - func decryptAndVerifyRemainingPacket(_ source: inout ByteBuffer) throws -> ByteBuffer + func decryptAndVerifyRemainingPacket(_ source: inout ByteBuffer, sequenceNumber: UInt32) throws -> ByteBuffer /// Encrypt an entire outbound packet - func encryptPacket(_ packet: NIOSSHEncryptablePayload, to outboundBuffer: inout ByteBuffer) throws + func encryptPacket(_ packet: NIOSSHEncryptablePayload, to outboundBuffer: inout ByteBuffer, sequenceNumber: UInt32) throws } extension NIOSSHTransportProtection { diff --git a/Tests/NIOSSHTests/AESGCMTests.swift b/Tests/NIOSSHTests/AESGCMTests.swift index 2b46fbc..eb89833 100644 --- a/Tests/NIOSSHTests/AESGCMTests.swift +++ b/Tests/NIOSSHTests/AESGCMTests.swift @@ -42,7 +42,7 @@ final class AESGCMTests: XCTestCase { let initialKeys = self.generateKeys(keySize: .bits128) let aes128Encryptor = try assertNoThrowWithValue(AES128GCMOpenSSHTransportProtection(initialKeys: initialKeys)) - XCTAssertNoThrow(try aes128Encryptor.encryptPacket(NIOSSHEncryptablePayload(message: .newKeys), to: &self.buffer)) + XCTAssertNoThrow(try aes128Encryptor.encryptPacket(NIOSSHEncryptablePayload(message: .newKeys), to: &self.buffer, sequenceNumber: 0)) // The newKeys message is very straightforward: a single byte. Because of that, we expect that we will need // 14 padding bytes: one byte for the padding length, then 14 more to get out to one block size. Thus, the total @@ -59,7 +59,7 @@ final class AESGCMTests: XCTestCase { XCTAssertEqual(bufferCopy, self.buffer) /// After decryption the plaintext should be a newKeys message. - var plaintext = try assertNoThrowWithValue(aes128Decryptor.decryptAndVerifyRemainingPacket(&bufferCopy)) + var plaintext = try assertNoThrowWithValue(aes128Decryptor.decryptAndVerifyRemainingPacket(&bufferCopy, sequenceNumber: 0)) XCTAssertEqual(bufferCopy.readableBytes, 0) XCTAssertNotEqual(plaintext, self.buffer) XCTAssertEqual(plaintext.readableBytes, 1) @@ -77,7 +77,7 @@ final class AESGCMTests: XCTestCase { let initialKeys = self.generateKeys(keySize: .bits256) let aes256Encryptor = try assertNoThrowWithValue(AES256GCMOpenSSHTransportProtection(initialKeys: initialKeys)) - XCTAssertNoThrow(try aes256Encryptor.encryptPacket(NIOSSHEncryptablePayload(message: .newKeys), to: &self.buffer)) + XCTAssertNoThrow(try aes256Encryptor.encryptPacket(NIOSSHEncryptablePayload(message: .newKeys), to: &self.buffer, sequenceNumber: 0)) // The newKeys message is very straightforward: a single byte. Because of that, we expect that we will need // 14 padding bytes: one byte for the padding length, then 14 more to get out to one block size. Thus, the total @@ -94,7 +94,7 @@ final class AESGCMTests: XCTestCase { XCTAssertEqual(bufferCopy, self.buffer) /// After decryption the plaintext should be a newKeys message. - var plaintext = try assertNoThrowWithValue(aes256Decryptor.decryptAndVerifyRemainingPacket(&bufferCopy)) + var plaintext = try assertNoThrowWithValue(aes256Decryptor.decryptAndVerifyRemainingPacket(&bufferCopy, sequenceNumber: 0)) XCTAssertEqual(bufferCopy.readableBytes, 0) XCTAssertNotEqual(plaintext, self.buffer) XCTAssertEqual(plaintext.readableBytes, 1) @@ -300,7 +300,7 @@ final class AESGCMTests: XCTestCase { buffer.clear() buffer.writeRepeatingByte(42, count: ciphertextSize) - XCTAssertThrowsError(try aes128.decryptAndVerifyRemainingPacket(&buffer)) { error in + XCTAssertThrowsError(try aes128.decryptAndVerifyRemainingPacket(&buffer, sequenceNumber: 0)) { error in XCTAssertEqual((error as? NIOSSHError)?.type, .invalidEncryptedPacketLength) } } @@ -320,7 +320,7 @@ final class AESGCMTests: XCTestCase { buffer.clear() buffer.writeRepeatingByte(42, count: ciphertextSize) - XCTAssertThrowsError(try aes256.decryptAndVerifyRemainingPacket(&buffer)) { error in + XCTAssertThrowsError(try aes256.decryptAndVerifyRemainingPacket(&buffer, sequenceNumber: 0)) { error in XCTAssertEqual((error as? NIOSSHError)?.type, .invalidEncryptedPacketLength) } } @@ -350,7 +350,7 @@ final class AESGCMTests: XCTestCase { // We can now attempt to decrypt this packet. let aes128 = try assertNoThrowWithValue(AES128GCMOpenSSHTransportProtection(initialKeys: keys)) - XCTAssertThrowsError(try aes128.decryptAndVerifyRemainingPacket(&buffer)) { error in + XCTAssertThrowsError(try aes128.decryptAndVerifyRemainingPacket(&buffer, sequenceNumber: 0)) { error in XCTAssertEqual((error as? NIOSSHError)?.type, .excessPadding) } } @@ -379,7 +379,7 @@ final class AESGCMTests: XCTestCase { // We can now attempt to decrypt this packet. let aes256 = try assertNoThrowWithValue(AES256GCMOpenSSHTransportProtection(initialKeys: keys)) - XCTAssertThrowsError(try aes256.decryptAndVerifyRemainingPacket(&buffer)) { error in + XCTAssertThrowsError(try aes256.decryptAndVerifyRemainingPacket(&buffer, sequenceNumber: 0)) { error in XCTAssertEqual((error as? NIOSSHError)?.type, .excessPadding) } } @@ -408,7 +408,7 @@ final class AESGCMTests: XCTestCase { // We can now attempt to decrypt this packet. let aes128 = try assertNoThrowWithValue(AES128GCMOpenSSHTransportProtection(initialKeys: keys)) - XCTAssertThrowsError(try aes128.decryptAndVerifyRemainingPacket(&buffer)) { error in + XCTAssertThrowsError(try aes128.decryptAndVerifyRemainingPacket(&buffer, sequenceNumber: 0)) { error in XCTAssertEqual((error as? NIOSSHError)?.type, .insufficientPadding) } } @@ -437,7 +437,7 @@ final class AESGCMTests: XCTestCase { // We can now attempt to decrypt this packet. let aes256 = try assertNoThrowWithValue(AES256GCMOpenSSHTransportProtection(initialKeys: keys)) - XCTAssertThrowsError(try aes256.decryptAndVerifyRemainingPacket(&buffer)) { error in + XCTAssertThrowsError(try aes256.decryptAndVerifyRemainingPacket(&buffer, sequenceNumber: 0)) { error in XCTAssertEqual((error as? NIOSSHError)?.type, .insufficientPadding) } } diff --git a/Tests/NIOSSHTests/ECKeyExchangeTests.swift b/Tests/NIOSSHTests/ECKeyExchangeTests.swift index 76fd012..221c51b 100644 --- a/Tests/NIOSSHTests/ECKeyExchangeTests.swift +++ b/Tests/NIOSSHTests/ECKeyExchangeTests.swift @@ -279,7 +279,8 @@ final class KeyExchangeTests: XCTestCase { } func testWeValidateTheExchangeHash() throws { - var server = EllipticCurveKeyExchange(ourRole: .server([.init(ed25519Key: .init())]), previousSessionIdentifier: nil) + let serverPrivateKey = NIOSSHPrivateKey(ed25519Key: .init()) + var server = EllipticCurveKeyExchange(ourRole: .server([serverPrivateKey]), previousSessionIdentifier: nil) var client = EllipticCurveKeyExchange(ourRole: .client, previousSessionIdentifier: nil) let serverHostKey = NIOSSHPrivateKey(ed25519Key: .init()) @@ -297,7 +298,12 @@ final class KeyExchangeTests: XCTestCase { initialExchangeBytes.clear() // Ok, the server has sent a signature over the exchange hash. Let's change that signature. - serverResponse.signature = try assertNoThrowWithValue(serverHostKey.sign(digest: SHA256.hash(data: [1, 2, 3, 4, 5]))) + let badServerSignature = try serverHostKey.sign(digest: SHA256.hash(data: [1, 2, 3, 4, 5])) + serverResponse = NIOSSHKeyExchangeServerReply( + hostKey: serverResponse.hostKey, + publicKey: serverResponse.publicKey, + signature: badServerSignature + ) XCTAssertThrowsError( try client.receiveServerKeyExchangePayload(serverKeyExchangeMessage: serverResponse, diff --git a/Tests/NIOSSHTests/EndToEndTests.swift b/Tests/NIOSSHTests/EndToEndTests.swift index 3e5c3a8..b517d69 100644 --- a/Tests/NIOSSHTests/EndToEndTests.swift +++ b/Tests/NIOSSHTests/EndToEndTests.swift @@ -19,7 +19,343 @@ import NIOEmbedded import XCTest enum EndToEndTestError: Error { - case unableToCreateChildChannel + case unableToCreateChildChannel, invalidCustomPublicKey, invalidCustomSignature +} + +private let testKey = Data([0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF]) + +final class CustomTransportProtection: NIOSSHTransportProtection { + static let cipherName = "xor-with-42" + static let macName: String? = "insecure-sha1" + static var wasUsed = false + + static var keySizes: ExpectedKeySizes { + .init(ivSize: 19, encryptionKeySize: 17, macKeySize: 15) + } + + required init(initialKeys: NIOSSHSessionKeys) throws {} + + static var cipherBlockSize: Int { 18 } + var macBytes: Int { 20 } + + func updateKeys(_: NIOSSHSessionKeys) throws {} + + func decryptFirstBlock(_: inout ByteBuffer) throws { + // For us, decrypting the first block is very easy: do nothing. The length bytes are already + // unencrypted! + } + + func decryptAndVerifyRemainingPacket(_ source: inout ByteBuffer, sequenceNumber: UInt32) throws -> ByteBuffer { + Self.wasUsed = true + var plaintext: Data + + // The first 4 bytes are the length. The last 16 are the tag. Everything else is ciphertext. We expect + // that the ciphertext is a clean multiple of the block size, and to be non-zero. + guard + let lengthView: UInt32 = source.readInteger(), + var ciphertext = source.readData(length: Int(lengthView)), + let mac = source.readData(length: Insecure.SHA1.byteCount), + ciphertext.count > 0, ciphertext.count % Self.cipherBlockSize == 0 else { + // The only way this fails is if the payload doesn't match this encryption scheme. + throw NIOSSHError.invalidEncryptedPacketLength + } + + for i in 0 ..< ciphertext.count { + ciphertext[i] ^= 42 + } + + plaintext = ciphertext + + struct InvalidSHA1TestSignature: Error {} + guard Insecure.SHA1.hash(data: plaintext) == mac else { + throw InvalidSHA1TestSignature() + } + + let paddingBytes = plaintext[0] + + if paddingBytes < 4 || paddingBytes >= plaintext.count { + throw NIOSSHError.invalidDecryptedPlaintextLength + } + + // All good! A quick soundness check to verify that the length of the plaintext is ok. + guard plaintext.count % Self.cipherBlockSize == 0, plaintext.count == ciphertext.count else { + throw NIOSSHError.invalidDecryptedPlaintextLength + } + + // Remove padding + plaintext.removeFirst() + plaintext.removeLast(Int(paddingBytes)) + + return ByteBuffer(data: plaintext) + } + + func encryptPacket(_ packet: NIOSSHEncryptablePayload, to outboundBuffer: inout ByteBuffer, sequenceNumber: UInt32) throws { + // Keep track of where the length is going to be written. + let packetLengthIndex = outboundBuffer.writerIndex + let packetLengthLength = MemoryLayout.size + let packetPaddingIndex = outboundBuffer.writerIndex + packetLengthLength + let packetPaddingLength = MemoryLayout.size + + outboundBuffer.moveWriterIndex(forwardBy: packetLengthLength + packetPaddingLength) + + // First, we write the packet. + let payloadBytes = outboundBuffer.writeEncryptablePayload(packet) + + // Ok, now we need to pad. The rules for padding for AES GCM are: + // + // 1. We must pad out such that the total encrypted content (padding length byte, + // plus content bytes, plus padding bytes) is a multiple of the block size. + // 2. At least 4 bytes of padding MUST be added. + // 3. This padding SHOULD be random. + // + // Note that, unlike other protection modes, the length is not encrypted, and so we + // must exclude it from the padding calculation. + // + // So we check how many bytes we've already written, use modular arithmetic to work out + // how many more bytes we need, and then if that's fewer than 4 we add a block size to it + // to fill it out. + var encryptedBufferSize = payloadBytes + packetPaddingLength + var necessaryPaddingBytes = Self.cipherBlockSize - (encryptedBufferSize % Self.cipherBlockSize) + if necessaryPaddingBytes < 4 { + necessaryPaddingBytes += Self.cipherBlockSize + } + + // We now want to write that many padding bytes to the end of the buffer. These are supposed to be + // random bytes. We're going to get those from the system random number generator. + encryptedBufferSize += outboundBuffer.writeSSHPaddingBytes(count: necessaryPaddingBytes) + precondition(encryptedBufferSize % Self.cipherBlockSize == 0, "Incorrectly counted buffer size; got \(encryptedBufferSize)") + + // We now know the length: it's going to be "encrypted buffer size". The length does not include the tag, so don't add it. + // Let's write that in. We also need to write the number of padding bytes in. + outboundBuffer.setInteger(UInt32(encryptedBufferSize), at: packetLengthIndex) + outboundBuffer.setInteger(UInt8(necessaryPaddingBytes), at: packetPaddingIndex) + + // Ok, nice! Now we need to encrypt the data. We pass the length field as additional authenticated data, and the encrypted + // payload portion as the data to encrypt. We know these views will be valid, so we forcibly unwrap them: if they're invalid, + // our math was wrong and we cannot recover. + let plaintext = outboundBuffer.getBytes(at: packetPaddingIndex, length: encryptedBufferSize)! + let hash = Insecure.SHA1.hash(data: plaintext) + + var ciphertext = plaintext + for i in 0 ..< ciphertext.count { + ciphertext[i] ^= 42 + } + + // We now want to overwrite the portion of the bytebuffer that contains the plaintext with the ciphertext, and then append the + // tag. + outboundBuffer.setContiguousBytes(ciphertext, at: packetPaddingIndex) + let macLength = outboundBuffer.writeBytes(hash) + precondition(macLength == self.macBytes, "Unexpected short tag") + } +} + +struct CustomPrivateKey: NIOSSHPrivateKeyProtocol { + static let keyPrefix = "custom-prefix" + + var publicKey: NIOSSHPublicKeyProtocol { + CustomPublicKey() + } + + func generatedSharedSecret(with theirKey: CustomPublicKey) throws -> [UInt8] { + Array(testKey.reversed()) + } + + func signature(for data: D) throws -> NIOSSHSignatureProtocol where D: DataProtocol { + var data = Data(data) + + let testKeySize = testKey.count + for i in 0 ..< data.count { + data[i] ^= testKey[i % testKeySize] + } + + return CustomSignature(rawRepresentation: data) + } +} + +struct CustomSignature: NIOSSHSignatureProtocol { + static let signaturePrefix = "custom-prefix" + + let rawRepresentation: Data + + func write(to buffer: inout ByteBuffer) -> Int { + buffer.writeSSHString(self.rawRepresentation) + } + + static func read(from buffer: inout ByteBuffer) throws -> CustomSignature { + guard var buffer = buffer.readSSHString() else { + throw EndToEndTestError.invalidCustomSignature + } + + let data = buffer.readData(length: buffer.readableBytes)! + return CustomSignature(rawRepresentation: data) + } +} + +struct CustomPublicKey: NIOSSHPublicKeyProtocol { + static let publicKeyPrefix = "custom-prefix" + static let keyExchangeAlgorithmNames: [Substring] = ["custom-handshake"] + static var wasUsed = false + + func isValidSignature(_ signature: NIOSSHSignatureProtocol, for data: D) -> Bool where D: DataProtocol { + let testKeySize = testKey.count + var data = Data(data) + for i in 0 ..< data.count { + data[i] ^= testKey[i % testKeySize] + } + + return data == signature.rawRepresentation + } + + @discardableResult + func write(to buffer: inout ByteBuffer) -> Int { + 0 + } + + var rawRepresentation: Data { + testKey + } + + static func read(from buffer: inout ByteBuffer) throws -> CustomPublicKey { + guard buffer.readableBytes == 0 else { + throw EndToEndTestError.invalidCustomPublicKey + } + + self.wasUsed = true + return CustomPublicKey() + } +} + +struct CustomKeyExchange: NIOSSHKeyExchangeAlgorithmProtocol { + static var keyExchangeInitMessageId: UInt8 { 0xFF } + static var keyExchangeReplyMessageId: UInt8 { 0xFF } + static var wasUsed = false + + private var previousSessionIdentifier: ByteBuffer? + private var ourKey: CustomPrivateKey + private var theirKey: CustomPublicKey? + private var ourRole: SSHConnectionRole + private var sharedSecret: [UInt8]? + + init(ourRole: SSHConnectionRole, previousSessionIdentifier: ByteBuffer?) { + self.ourRole = ourRole + self.ourKey = CustomPrivateKey() + self.previousSessionIdentifier = previousSessionIdentifier + } + + func initiateKeyExchangeClientSide(allocator: ByteBufferAllocator) -> ByteBuffer { + var buffer = ByteBuffer() + _ = self.ourKey.publicKey.write(to: &buffer) + return buffer + } + + mutating func completeKeyExchangeServerSide( + clientKeyExchangeMessage message: ByteBuffer, + serverHostKey: NIOSSHPrivateKey, + initialExchangeBytes: inout ByteBuffer, + allocator: ByteBufferAllocator, + expectedKeySizes: ExpectedKeySizes + ) throws -> (KeyExchangeResult, NIOSSHKeyExchangeServerReply) { + var theirKeyBuffer = message + let theirKey = try CustomPublicKey.read(from: &theirKeyBuffer) + self.theirKey = theirKey + + // Shared secet is "expanded" + // That should make it usable by most transport encryption, at least the one used in our test + var sharedSecret = try self.ourKey.generatedSharedSecret(with: theirKey) + sharedSecret += sharedSecret + sharedSecret += sharedSecret + sharedSecret += sharedSecret + sharedSecret += sharedSecret + + self.sharedSecret = sharedSecret + + var hasher = SHA512() + hasher.update(data: initialExchangeBytes.readableBytesView) + hasher.update(data: sharedSecret) + + let exchangeHash = hasher.finalize() + + let sessionID: ByteBuffer + if let previousSessionIdentifier = self.previousSessionIdentifier { + sessionID = previousSessionIdentifier + } else { + sessionID = ByteBuffer(bytes: SHA512.hash(data: Data(exchangeHash))) + } + + let kexResult = KeyExchangeResult( + sessionID: sessionID, + keys: NIOSSHSessionKeys( + initialInboundIV: Array(sharedSecret[0 ..< expectedKeySizes.ivSize]), + initialOutboundIV: Array(sharedSecret[0 ..< expectedKeySizes.ivSize]), + inboundEncryptionKey: SymmetricKey(data: Data(sharedSecret[0 ..< expectedKeySizes.encryptionKeySize])), + outboundEncryptionKey: SymmetricKey(data: Data(sharedSecret[0 ..< expectedKeySizes.encryptionKeySize])), + inboundMACKey: SymmetricKey(data: Data(sharedSecret[0 ..< expectedKeySizes.macKeySize])), + outboundMACKey: SymmetricKey(data: Data(sharedSecret[0 ..< expectedKeySizes.macKeySize])) + ) + ) + + var publicKeyBytes = allocator.buffer(capacity: 256) + _ = self.ourKey.publicKey.write(to: &publicKeyBytes) + + let exchangeHashSignature = try serverHostKey.sign(digest: exchangeHash) + + let serverReply = NIOSSHKeyExchangeServerReply(hostKey: serverHostKey.publicKey, + publicKey: publicKeyBytes, signature: exchangeHashSignature) + + Self.wasUsed = true + return (kexResult, serverReply) + } + + mutating func receiveServerKeyExchangePayload( + serverKeyExchangeMessage: NIOSSHKeyExchangeServerReply, + initialExchangeBytes: inout ByteBuffer, + allocator: ByteBufferAllocator, + expectedKeySizes: ExpectedKeySizes + ) throws -> KeyExchangeResult { + var theirKeyBuffer = serverKeyExchangeMessage.publicKey + let theirKey = try CustomPublicKey.read(from: &theirKeyBuffer) + self.theirKey = theirKey + + // Shared secet is "expanded" + // That should make it usable by most transport encryption, at least the one used in our test + var sharedSecret = try self.ourKey.generatedSharedSecret(with: theirKey) + sharedSecret += sharedSecret + sharedSecret += sharedSecret + sharedSecret += sharedSecret + sharedSecret += sharedSecret + self.sharedSecret = sharedSecret + + var hasher = SHA512() + hasher.update(data: initialExchangeBytes.readableBytesView) + hasher.update(data: sharedSecret) + let exchangeHash = hasher.finalize() + + let sessionID: ByteBuffer + if let previousSessionIdentifier = self.previousSessionIdentifier { + sessionID = previousSessionIdentifier + } else { + sessionID = ByteBuffer(bytes: SHA512.hash(data: Data(exchangeHash))) + } + + guard serverKeyExchangeMessage.hostKey.isValidSignature(serverKeyExchangeMessage.signature, for: exchangeHash) else { + throw NIOSSHError.invalidExchangeHashSignature + } + + Self.wasUsed = true + return KeyExchangeResult( + sessionID: sessionID, + keys: NIOSSHSessionKeys( + initialInboundIV: Array(sharedSecret[0 ..< expectedKeySizes.ivSize]), + initialOutboundIV: Array(sharedSecret[0 ..< expectedKeySizes.ivSize]), + inboundEncryptionKey: SymmetricKey(data: Data(sharedSecret[0 ..< expectedKeySizes.encryptionKeySize])), + outboundEncryptionKey: SymmetricKey(data: Data(sharedSecret[0 ..< expectedKeySizes.encryptionKeySize])), + inboundMACKey: SymmetricKey(data: Data(sharedSecret[0 ..< expectedKeySizes.macKeySize])), + outboundMACKey: SymmetricKey(data: Data(sharedSecret[0 ..< expectedKeySizes.macKeySize])) + ) + ) + } + + static var keyExchangeAlgorithmNames: [Substring] { ["xorkex"] } } class BackToBackEmbeddedChannel { @@ -81,10 +417,23 @@ class BackToBackEmbeddedChannel { } func configureWithHarness(_ harness: TestHarness) throws { - let clientHandler = NIOSSHHandler(role: .client(.init(userAuthDelegate: harness.clientAuthDelegate, serverAuthDelegate: harness.clientServerAuthDelegate, globalRequestDelegate: harness.clientGlobalRequestDelegate)), + var clientConfiguration = SSHClientConfiguration(userAuthDelegate: harness.clientAuthDelegate, serverAuthDelegate: harness.clientServerAuthDelegate, globalRequestDelegate: harness.clientGlobalRequestDelegate) + var serverConfiguration = SSHServerConfiguration(hostKeys: harness.serverHostKeys, userAuthDelegate: harness.serverAuthDelegate, globalRequestDelegate: harness.serverGlobalRequestDelegate, banner: harness.serverAuthBanner) + + if let transportProtectionAlgoritms = harness.transportProtectionAlgoritms { + clientConfiguration.transportProtectionSchemes = transportProtectionAlgoritms + serverConfiguration.transportProtectionSchemes = transportProtectionAlgoritms + } + + if let keyExchangeAlgorithms = harness.keyExchangeAlgorithms { + clientConfiguration.keyExchangeAlgorithms = keyExchangeAlgorithms + serverConfiguration.keyExchangeAlgorithms = keyExchangeAlgorithms + } + + let clientHandler = NIOSSHHandler(role: .client(clientConfiguration), allocator: self.client.allocator, inboundChildChannelInitializer: nil) - let serverHandler = NIOSSHHandler(role: .server(.init(hostKeys: harness.serverHostKeys, userAuthDelegate: harness.serverAuthDelegate, globalRequestDelegate: harness.serverGlobalRequestDelegate, banner: harness.serverAuthBanner)), + let serverHandler = NIOSSHHandler(role: .server(serverConfiguration), allocator: self.server.allocator) { channel, _ in self.activeServerChannels.append(channel) channel.closeFuture.whenComplete { _ in self.activeServerChannels.removeAll(where: { $0 === channel }) } @@ -131,6 +480,10 @@ struct TestHarness { var serverHostKeys: [NIOSSHPrivateKey] = [.init(ed25519Key: .init())] + var keyExchangeAlgorithms: [NIOSSHKeyExchangeAlgorithmProtocol.Type]? + + var transportProtectionAlgoritms: [NIOSSHTransportProtection.Type]? + var serverAuthBanner: SSHServerConfiguration.UserAuthBanner? } @@ -437,6 +790,113 @@ class EndToEndTests: XCTestCase { #endif } + func testCustomPublicKeyAlgorithms() throws { + NIOSSHAlgorithms.unregisterAlgorithms() + CustomPublicKey.wasUsed = false + NIOSSHAlgorithms.register(publicKey: CustomPublicKey.self, signature: CustomSignature.self) + + // If we can't create this key, we skip the test. + let hostKey = NIOSSHPrivateKey(ed25519Key: .init()) + let clientAuthKey = NIOSSHPrivateKey(custom: CustomPrivateKey()) + + // We use the Secure Enclave keys for everything, just because we can. + var harness = TestHarness() + harness.serverHostKeys = [hostKey] + harness.clientAuthDelegate = PrivateKeyClientAuth(clientAuthKey) + harness.serverAuthDelegate = ExpectPublicKeyAuth(clientAuthKey.publicKey) + + XCTAssertNoThrow(try self.channel.configureWithHarness(harness)) + XCTAssertNoThrow(try self.channel.activate()) + XCTAssertNoThrow(try self.channel.interactInMemory()) + + // Create a channel, again, just because we can. + _ = try self.channel.createNewChannel() + XCTAssertNoThrow(try self.channel.interactInMemory()) + XCTAssertEqual(self.channel.activeServerChannels.count, 1) + XCTAssertTrue(CustomPublicKey.wasUsed) + } + + func testCustomHostKeyAlgorithms() throws { + NIOSSHAlgorithms.unregisterAlgorithms() + CustomPublicKey.wasUsed = false + NIOSSHAlgorithms.register(publicKey: CustomPublicKey.self, signature: CustomSignature.self) + + // If we can't create this key, we skip the test. + let hostKey = NIOSSHPrivateKey(custom: CustomPrivateKey()) + let clientAuthKey = NIOSSHPrivateKey(ed25519Key: .init()) + + // We use the Secure Enclave keys for everything, just because we can. + var harness = TestHarness() + harness.serverHostKeys = [hostKey] + harness.clientAuthDelegate = PrivateKeyClientAuth(clientAuthKey) + harness.serverAuthDelegate = ExpectPublicKeyAuth(clientAuthKey.publicKey) + + XCTAssertNoThrow(try self.channel.configureWithHarness(harness)) + XCTAssertNoThrow(try self.channel.activate()) + XCTAssertNoThrow(try self.channel.interactInMemory()) + + // Create a channel, again, just because we can. + _ = try self.channel.createNewChannel() + XCTAssertNoThrow(try self.channel.interactInMemory()) + XCTAssertEqual(self.channel.activeServerChannels.count, 1) + XCTAssertTrue(CustomPublicKey.wasUsed) + } + + func testCustomTransportProtectionAlgorithms() throws { + NIOSSHAlgorithms.unregisterAlgorithms() + CustomKeyExchange.wasUsed = false + NIOSSHAlgorithms.register(transportProtectionScheme: CustomTransportProtection.self) + + // If we can't create this key, we skip the test. + let hostKey = NIOSSHPrivateKey(ed25519Key: .init()) + let clientAuthKey = NIOSSHPrivateKey(ed25519Key: .init()) + + // We use the Secure Enclave keys for everything, just because we can. + var harness = TestHarness() + harness.transportProtectionAlgoritms = [CustomTransportProtection.self] + harness.serverHostKeys = [hostKey] + harness.clientAuthDelegate = PrivateKeyClientAuth(clientAuthKey) + harness.serverAuthDelegate = ExpectPublicKeyAuth(clientAuthKey.publicKey) + + XCTAssertNoThrow(try self.channel.configureWithHarness(harness)) + XCTAssertNoThrow(try self.channel.activate()) + XCTAssertNoThrow(try self.channel.interactInMemory()) + + // Create a channel, again, just because we can. + _ = try self.channel.createNewChannel() + XCTAssertNoThrow(try self.channel.interactInMemory()) + XCTAssertEqual(self.channel.activeServerChannels.count, 1) + XCTAssertTrue(CustomTransportProtection.wasUsed) + } + + func testCustomKeyExchangeAlgorithms() throws { + NIOSSHAlgorithms.unregisterAlgorithms() + CustomKeyExchange.wasUsed = false + NIOSSHAlgorithms.register(keyExchangeAlgorithm: CustomKeyExchange.self) + NIOSSHAlgorithms.register(publicKey: CustomPublicKey.self, signature: CustomSignature.self) + + // If we can't create this key, we skip the test. + let hostKey = NIOSSHPrivateKey(custom: CustomPrivateKey()) + let clientAuthKey = NIOSSHPrivateKey(ed25519Key: .init()) + + // We use the Secure Enclave keys for everything, just because we can. + var harness = TestHarness() + harness.keyExchangeAlgorithms = [CustomKeyExchange.self] + harness.serverHostKeys = [hostKey] + harness.clientAuthDelegate = PrivateKeyClientAuth(clientAuthKey) + harness.serverAuthDelegate = ExpectPublicKeyAuth(clientAuthKey.publicKey) + + XCTAssertNoThrow(try self.channel.configureWithHarness(harness)) + XCTAssertNoThrow(try self.channel.activate()) + XCTAssertNoThrow(try self.channel.interactInMemory()) + + // Create a channel, again, just because we can. + _ = try self.channel.createNewChannel() + XCTAssertNoThrow(try self.channel.interactInMemory()) + XCTAssertEqual(self.channel.activeServerChannels.count, 1) + XCTAssertTrue(CustomKeyExchange.wasUsed) + } + func testSupportClientInitiatedRekeying() throws { XCTAssertNoThrow(try self.channel.configureWithHarness(TestHarness())) XCTAssertNoThrow(try self.channel.activate()) diff --git a/Tests/NIOSSHTests/SSHConnectionStateMachineTests.swift b/Tests/NIOSSHTests/SSHConnectionStateMachineTests.swift index 55bac3b..a2e8e4a 100644 --- a/Tests/NIOSSHTests/SSHConnectionStateMachineTests.swift +++ b/Tests/NIOSSHTests/SSHConnectionStateMachineTests.swift @@ -594,9 +594,8 @@ final class SSHConnectionStateMachineTests: XCTestCase { func testFirstBlockDecodedOnce() throws { let allocator = ByteBufferAllocator() let loop = EmbeddedEventLoop() - let schemes: [NIOSSHTransportProtection.Type] = [TestTransportProtection.self] - var client = SSHConnectionStateMachine(role: .client(.init(userAuthDelegate: InfinitePasswordDelegate(), serverAuthDelegate: AcceptAllHostKeysDelegate())), protectionSchemes: schemes) - var server = SSHConnectionStateMachine(role: .server(.init(hostKeys: [NIOSSHPrivateKey(ed25519Key: .init())], userAuthDelegate: DenyThenAcceptDelegate(messagesToDeny: 0))), protectionSchemes: schemes) + var client = SSHConnectionStateMachine(role: .client(.init(userAuthDelegate: InfinitePasswordDelegate(), serverAuthDelegate: AcceptAllHostKeysDelegate()))) + var server = SSHConnectionStateMachine(role: .server(.init(hostKeys: [NIOSSHPrivateKey(ed25519Key: .init())], userAuthDelegate: DenyThenAcceptDelegate(messagesToDeny: 0)))) try assertSuccessfulConnection(client: &client, server: &server, allocator: allocator, loop: loop) let message = SSHMessage.channelData(.init(recipientChannel: 1, data: ByteBuffer(repeating: 17, count: 5))) diff --git a/Tests/NIOSSHTests/SSHKeyExchangeStateMachineTests.swift b/Tests/NIOSSHTests/SSHKeyExchangeStateMachineTests.swift index dbc4ab7..627664a 100644 --- a/Tests/NIOSSHTests/SSHKeyExchangeStateMachineTests.swift +++ b/Tests/NIOSSHTests/SSHKeyExchangeStateMachineTests.swift @@ -144,9 +144,9 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { var buffer = ByteBufferAllocator().buffer(capacity: 1024) do { - try client.encryptPacket(.init(message: message), to: &buffer) + try client.encryptPacket(.init(message: message), to: &buffer, sequenceNumber: 0) try server.decryptFirstBlock(&buffer) - var messageBuffer = try server.decryptAndVerifyRemainingPacket(&buffer) + var messageBuffer = try server.decryptAndVerifyRemainingPacket(&buffer, sequenceNumber: 0) let decrypted = try messageBuffer.readSSHMessage() XCTAssertEqual(message, decrypted) XCTAssertEqual(0, buffer.readableBytes) @@ -158,9 +158,9 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { buffer.clear() do { - try server.encryptPacket(.init(message: message), to: &buffer) + try server.encryptPacket(.init(message: message), to: &buffer, sequenceNumber: 1) try client.decryptFirstBlock(&buffer) - var messageBuffer = try client.decryptAndVerifyRemainingPacket(&buffer) + var messageBuffer = try client.decryptAndVerifyRemainingPacket(&buffer, sequenceNumber: 1) let decrypted = try messageBuffer.readSSHMessage() XCTAssertEqual(message, decrypted) XCTAssertEqual(0, buffer.readableBytes) @@ -202,7 +202,8 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { loop: loop, role: .client(.init(userAuthDelegate: ExplodingAuthDelegate(), serverAuthDelegate: AcceptAllHostKeysDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES256GCMOpenSSHTransportProtection.self], + keyExchangeAlgorithms: SSHKeyExchangeStateMachine.bundledKeyExchangeImplementations, + transportProtectionSchemes: [AES256GCMOpenSSHTransportProtection.self], previousSessionIdentifier: nil ) var server = SSHKeyExchangeStateMachine( @@ -210,7 +211,8 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { loop: loop, role: .server(.init(hostKeys: [.init(ed25519Key: .init())], userAuthDelegate: DenyAllServerAuthDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES256GCMOpenSSHTransportProtection.self], + keyExchangeAlgorithms: SSHKeyExchangeStateMachine.bundledKeyExchangeImplementations, + transportProtectionSchemes: [AES256GCMOpenSSHTransportProtection.self], previousSessionIdentifier: nil ) @@ -268,7 +270,8 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { loop: loop, role: .client(.init(userAuthDelegate: ExplodingAuthDelegate(), serverAuthDelegate: AcceptAllHostKeysDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES256GCMOpenSSHTransportProtection.self], + keyExchangeAlgorithms: SSHKeyExchangeStateMachine.bundledKeyExchangeImplementations, + transportProtectionSchemes: [AES256GCMOpenSSHTransportProtection.self], previousSessionIdentifier: nil ) var server = SSHKeyExchangeStateMachine( @@ -276,7 +279,8 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { loop: loop, role: .server(.init(hostKeys: [.init(ed25519Key: .init())], userAuthDelegate: DenyAllServerAuthDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES256GCMOpenSSHTransportProtection.self], + keyExchangeAlgorithms: SSHKeyExchangeStateMachine.bundledKeyExchangeImplementations, + transportProtectionSchemes: [AES256GCMOpenSSHTransportProtection.self], previousSessionIdentifier: nil ) @@ -321,7 +325,8 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { role: .server(.init(hostKeys: [.init(ed25519Key: .init())], userAuthDelegate: DenyAllServerAuthDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES256GCMOpenSSHTransportProtection.self], + keyExchangeAlgorithms: SSHKeyExchangeStateMachine.bundledKeyExchangeImplementations, + transportProtectionSchemes: [AES256GCMOpenSSHTransportProtection.self], previousSessionIdentifier: nil) // Server generates a key exchange message. @@ -363,7 +368,8 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { loop: loop, role: .server(.init(hostKeys: [.init(ed25519Key: .init())], userAuthDelegate: DenyAllServerAuthDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES256GCMOpenSSHTransportProtection.self], + keyExchangeAlgorithms: SSHKeyExchangeStateMachine.bundledKeyExchangeImplementations, + transportProtectionSchemes: [AES256GCMOpenSSHTransportProtection.self], previousSessionIdentifier: nil ) let serverMessage = server.createKeyExchangeMessage() @@ -394,7 +400,8 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { loop: loop, role: .client(.init(userAuthDelegate: ExplodingAuthDelegate(), serverAuthDelegate: AcceptAllHostKeysDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES256GCMOpenSSHTransportProtection.self], + keyExchangeAlgorithms: SSHKeyExchangeStateMachine.bundledKeyExchangeImplementations, + transportProtectionSchemes: [AES256GCMOpenSSHTransportProtection.self], previousSessionIdentifier: nil ) var server = SSHKeyExchangeStateMachine( @@ -402,7 +409,8 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { loop: loop, role: .server(.init(hostKeys: [.init(ed25519Key: .init())], userAuthDelegate: DenyAllServerAuthDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES256GCMOpenSSHTransportProtection.self], + keyExchangeAlgorithms: SSHKeyExchangeStateMachine.bundledKeyExchangeImplementations, + transportProtectionSchemes: [AES256GCMOpenSSHTransportProtection.self], previousSessionIdentifier: nil ) @@ -460,7 +468,8 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { loop: loop, role: .client(.init(userAuthDelegate: ExplodingAuthDelegate(), serverAuthDelegate: AcceptAllHostKeysDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES256GCMOpenSSHTransportProtection.self], + keyExchangeAlgorithms: SSHKeyExchangeStateMachine.bundledKeyExchangeImplementations, + transportProtectionSchemes: [AES256GCMOpenSSHTransportProtection.self], previousSessionIdentifier: nil ) var server = SSHKeyExchangeStateMachine( @@ -468,7 +477,8 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { loop: loop, role: .server(.init(hostKeys: [hostKey], userAuthDelegate: DenyAllServerAuthDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES256GCMOpenSSHTransportProtection.self], + keyExchangeAlgorithms: SSHKeyExchangeStateMachine.bundledKeyExchangeImplementations, + transportProtectionSchemes: [AES256GCMOpenSSHTransportProtection.self], previousSessionIdentifier: nil ) @@ -514,7 +524,8 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { loop: loop, role: .client(.init(userAuthDelegate: ExplodingAuthDelegate(), serverAuthDelegate: AcceptAllHostKeysDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES256GCMOpenSSHTransportProtection.self], + keyExchangeAlgorithms: SSHKeyExchangeStateMachine.bundledKeyExchangeImplementations, + transportProtectionSchemes: [AES256GCMOpenSSHTransportProtection.self], previousSessionIdentifier: nil ) let clientMessage = client.createKeyExchangeMessage() @@ -540,7 +551,8 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { loop: loop, role: .client(.init(userAuthDelegate: ExplodingAuthDelegate(), serverAuthDelegate: AcceptAllHostKeysDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES256GCMOpenSSHTransportProtection.self], + keyExchangeAlgorithms: SSHKeyExchangeStateMachine.bundledKeyExchangeImplementations, + transportProtectionSchemes: [AES256GCMOpenSSHTransportProtection.self], previousSessionIdentifier: nil ) var server = SSHKeyExchangeStateMachine( @@ -548,7 +560,8 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { loop: loop, role: .server(.init(hostKeys: [.init(ed25519Key: .init())], userAuthDelegate: DenyAllServerAuthDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES128GCMOpenSSHTransportProtection.self], + keyExchangeAlgorithms: SSHKeyExchangeStateMachine.bundledKeyExchangeImplementations, + transportProtectionSchemes: [AES128GCMOpenSSHTransportProtection.self], previousSessionIdentifier: nil ) @@ -580,7 +593,8 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { loop: loop, role: .client(.init(userAuthDelegate: ExplodingAuthDelegate(), serverAuthDelegate: AcceptAllHostKeysDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES128GCMOpenSSHTransportProtection.self, AES256GCMOpenSSHTransportProtection.self], + keyExchangeAlgorithms: SSHKeyExchangeStateMachine.bundledKeyExchangeImplementations, + transportProtectionSchemes: [AES128GCMOpenSSHTransportProtection.self, AES256GCMOpenSSHTransportProtection.self], previousSessionIdentifier: nil ) var server = SSHKeyExchangeStateMachine( @@ -588,7 +602,8 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { loop: loop, role: .server(.init(hostKeys: [.init(ed25519Key: .init())], userAuthDelegate: DenyAllServerAuthDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES256GCMOpenSSHTransportProtection.self, AES128GCMOpenSSHTransportProtection.self], + keyExchangeAlgorithms: SSHKeyExchangeStateMachine.bundledKeyExchangeImplementations, + transportProtectionSchemes: [AES256GCMOpenSSHTransportProtection.self, AES128GCMOpenSSHTransportProtection.self], previousSessionIdentifier: nil ) @@ -635,7 +650,8 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { loop: loop, role: .client(.init(userAuthDelegate: ExplodingAuthDelegate(), serverAuthDelegate: AcceptAllHostKeysDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES128GCMOpenSSHTransportProtection.self], + keyExchangeAlgorithms: SSHKeyExchangeStateMachine.bundledKeyExchangeImplementations, + transportProtectionSchemes: [AES128GCMOpenSSHTransportProtection.self], previousSessionIdentifier: nil ) var server = SSHKeyExchangeStateMachine( @@ -643,7 +659,8 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { loop: loop, role: .server(.init(hostKeys: [.init(ed25519Key: .init())], userAuthDelegate: DenyAllServerAuthDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES128GCMOpenSSHTransportProtection.self], + keyExchangeAlgorithms: SSHKeyExchangeStateMachine.bundledKeyExchangeImplementations, + transportProtectionSchemes: [AES128GCMOpenSSHTransportProtection.self], previousSessionIdentifier: nil ) @@ -690,7 +707,8 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { loop: loop, role: .client(.init(userAuthDelegate: ExplodingAuthDelegate(), serverAuthDelegate: AcceptAllHostKeysDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES128GCMOpenSSHTransportProtection.self], + keyExchangeAlgorithms: SSHKeyExchangeStateMachine.bundledKeyExchangeImplementations, + transportProtectionSchemes: [AES128GCMOpenSSHTransportProtection.self], previousSessionIdentifier: nil ) var server = SSHKeyExchangeStateMachine( @@ -698,7 +716,8 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { loop: loop, role: .server(.init(hostKeys: [.init(ed25519Key: .init())], userAuthDelegate: DenyAllServerAuthDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES128GCMOpenSSHTransportProtection.self], + keyExchangeAlgorithms: SSHKeyExchangeStateMachine.bundledKeyExchangeImplementations, + transportProtectionSchemes: [AES128GCMOpenSSHTransportProtection.self], previousSessionIdentifier: nil ) @@ -769,7 +788,8 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { loop: loop, role: .client(.init(userAuthDelegate: ExplodingAuthDelegate(), serverAuthDelegate: AcceptAllHostKeysDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES256GCMOpenSSHTransportProtection.self], + keyExchangeAlgorithms: SSHKeyExchangeStateMachine.bundledKeyExchangeImplementations, + transportProtectionSchemes: [AES256GCMOpenSSHTransportProtection.self], previousSessionIdentifier: nil ) var server = SSHKeyExchangeStateMachine( @@ -777,7 +797,8 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { loop: loop, role: .server(.init(hostKeys: keys, userAuthDelegate: DenyAllServerAuthDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES256GCMOpenSSHTransportProtection.self], + keyExchangeAlgorithms: SSHKeyExchangeStateMachine.bundledKeyExchangeImplementations, + transportProtectionSchemes: [AES256GCMOpenSSHTransportProtection.self], previousSessionIdentifier: nil ) @@ -829,7 +850,8 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { loop: loop, role: .client(.init(userAuthDelegate: ExplodingAuthDelegate(), serverAuthDelegate: AcceptAllHostKeysDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES256GCMOpenSSHTransportProtection.self], + keyExchangeAlgorithms: SSHKeyExchangeStateMachine.bundledKeyExchangeImplementations, + transportProtectionSchemes: [AES256GCMOpenSSHTransportProtection.self], previousSessionIdentifier: nil ) var server = SSHKeyExchangeStateMachine( @@ -837,7 +859,8 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { loop: loop, role: .server(.init(hostKeys: [NIOSSHPrivateKey(p256Key: .init())], userAuthDelegate: DenyAllServerAuthDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES256GCMOpenSSHTransportProtection.self], + keyExchangeAlgorithms: SSHKeyExchangeStateMachine.bundledKeyExchangeImplementations, + transportProtectionSchemes: [AES256GCMOpenSSHTransportProtection.self], previousSessionIdentifier: nil ) @@ -889,7 +912,8 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { loop: loop, role: .client(.init(userAuthDelegate: ExplodingAuthDelegate(), serverAuthDelegate: RejectHostKeyDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES256GCMOpenSSHTransportProtection.self], + keyExchangeAlgorithms: SSHKeyExchangeStateMachine.bundledKeyExchangeImplementations, + transportProtectionSchemes: [AES256GCMOpenSSHTransportProtection.self], previousSessionIdentifier: nil ) var server = SSHKeyExchangeStateMachine( @@ -897,7 +921,8 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { loop: loop, role: .server(.init(hostKeys: [NIOSSHPrivateKey(p256Key: .init())], userAuthDelegate: DenyAllServerAuthDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES256GCMOpenSSHTransportProtection.self], + keyExchangeAlgorithms: SSHKeyExchangeStateMachine.bundledKeyExchangeImplementations, + transportProtectionSchemes: [AES256GCMOpenSSHTransportProtection.self], previousSessionIdentifier: nil ) @@ -953,7 +978,8 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { loop: loop, role: .client(.init(userAuthDelegate: ExplodingAuthDelegate(), serverAuthDelegate: hostKeyDelegate)), remoteVersion: Constants.version, - protectionSchemes: [AES256GCMOpenSSHTransportProtection.self], + keyExchangeAlgorithms: SSHKeyExchangeStateMachine.bundledKeyExchangeImplementations, + transportProtectionSchemes: [AES256GCMOpenSSHTransportProtection.self], previousSessionIdentifier: nil ) var server = SSHKeyExchangeStateMachine( @@ -961,7 +987,8 @@ final class SSHKeyExchangeStateMachineTests: XCTestCase { loop: loop, role: .server(.init(hostKeys: [serverHostKey], userAuthDelegate: DenyAllServerAuthDelegate())), remoteVersion: Constants.version, - protectionSchemes: [AES256GCMOpenSSHTransportProtection.self], + keyExchangeAlgorithms: SSHKeyExchangeStateMachine.bundledKeyExchangeImplementations, + transportProtectionSchemes: [AES256GCMOpenSSHTransportProtection.self], previousSessionIdentifier: nil ) diff --git a/Tests/NIOSSHTests/Utilities.swift b/Tests/NIOSSHTests/Utilities.swift index a2c9299..d8f200e 100644 --- a/Tests/NIOSSHTests/Utilities.swift +++ b/Tests/NIOSSHTests/Utilities.swift @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// import Crypto +import NIO import NIOCore @testable import NIOSSH import XCTest @@ -30,6 +31,12 @@ func assertNoThrowWithValue(_ body: @autoclosure () throws -> T, defaultValue } } +extension SSHKeyExchangeStateMachine { + mutating func handle(keyExchangeInit message: SSHMessage.KeyExchangeECDHInitMessage) throws -> SSHMultiMessage? { + try self.handle(keyExchangeInit: message.publicKey) + } +} + // This algorithm is not secure, used only for testing purposes struct InsecureEncryptionAlgorithm { static func encrypt(key: ByteBuffer, plaintext: ByteBuffer) -> ByteBuffer { @@ -176,7 +183,7 @@ class TestTransportProtection: NIOSSHTransportProtection { source.setBytes(plaintext.readableBytesView, at: index) } - func decryptAndVerifyRemainingPacket(_ source: inout ByteBuffer) throws -> ByteBuffer { + func decryptAndVerifyRemainingPacket(_ source: inout ByteBuffer, sequenceNumber: UInt32) throws -> ByteBuffer { defer { self.lastFirstBlock = nil } @@ -211,7 +218,7 @@ class TestTransportProtection: NIOSSHTransportProtection { return plaintext.readSlice(length: plaintext.readableBytes - Int(paddingLength))! } - func encryptPacket(_ packet: NIOSSHEncryptablePayload, to outboundBuffer: inout ByteBuffer) throws { + func encryptPacket(_ packet: NIOSSHEncryptablePayload, to outboundBuffer: inout ByteBuffer, sequenceNumber: UInt32) throws { let packetLengthIndex = outboundBuffer.writerIndex let packetLengthLength = MemoryLayout.size let packetPaddingIndex = outboundBuffer.writerIndex + packetLengthLength diff --git a/Tests/NIOSSHTests/UtilitiesTests.swift b/Tests/NIOSSHTests/UtilitiesTests.swift index 96a4c33..9eaeda3 100644 --- a/Tests/NIOSSHTests/UtilitiesTests.swift +++ b/Tests/NIOSSHTests/UtilitiesTests.swift @@ -52,9 +52,9 @@ final class UtilitiesTests: XCTestCase { let message = SSHMessage.channelRequest(.init(recipientChannel: 1, type: .exec("uname"), wantReply: false)) let allocator = ByteBufferAllocator() var buffer = allocator.buffer(capacity: 1024) - XCTAssertNoThrow(try client.encryptPacket(.init(message: message), to: &buffer)) + XCTAssertNoThrow(try client.encryptPacket(.init(message: message), to: &buffer, sequenceNumber: 1)) XCTAssertNoThrow(try server.decryptFirstBlock(&buffer)) - var decoded = try server.decryptAndVerifyRemainingPacket(&buffer) + var decoded = try server.decryptAndVerifyRemainingPacket(&buffer, sequenceNumber: 1) XCTAssertEqual(message, try decoded.readSSHMessage()) } }