Skip to content

Commit

Permalink
SNI Support (ported netty SniHandler) (Azure#219)
Browse files Browse the repository at this point in the history
* First cut commit *unfinished

* Replace SNI handler with TlsHandler with certificate selected based on host name found in clientHello

* first cut tests for SniHandler

* test update

* Test update

* Supress further read when handler is replaced

* made the snitest more effective and IDN in hostname lower case as per netty impl

* More asserts to check whether snihandler gets replaces with tlshaldler in the pipeline

* assert server name is always found in clienthello as per the test setup

* Provided option to select default host name in case of error or client hello does not contail SNI extension, otherwise handshake fails in those cases

* More elaborate tests

* verbosity in test

* Fixed Read continues to get called after handler removed and removed the workaround in SniHandler

* relaced goto statement with flag for breaking outer for loop from within switch

* Update SniHandler.cs

* trigger CI build

* addressed review comments

* Fixed task continuation option

* addresses further review comments

* triggere build again with some more assert in test Azure#221

* suppress read logic is still needed due to async "void"

*  changing the map to (string -> Task<ServerTlsSettings)

* one more constructor overload

* extensive tls read/write test is not needed since that's already done in tlshandler test

* more readable target host validation in test to force retrigger confusing CI build

* retrigger

* addressed review comment "this generates 30 random data frames. pls replace with new [] { 1 }"
  • Loading branch information
krish-gh authored and nayato committed Mar 29, 2017
1 parent 8357d8e commit 1d3eda9
Show file tree
Hide file tree
Showing 10 changed files with 650 additions and 1 deletion.
Binary file added shared/contoso.com.pfx
Binary file not shown.
5 changes: 4 additions & 1 deletion src/DotNetty.Codecs/ByteToMessageDecoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ static IByteBuffer ExpandCumulation(IByteBufferAllocator allocator, IByteBuffer
public override void HandlerRemoved(IChannelHandlerContext context)
{
IByteBuffer buf = this.InternalBuffer;

// Directly set this to null so we are sure we not access it in any other method here anymore.
this.cumulation = null;
int readable = buf.ReadableBytes;
if (readable > 0)
{
Expand All @@ -162,7 +165,7 @@ public override void HandlerRemoved(IChannelHandlerContext context)
{
buf.Release();
}
this.cumulation = null;

context.FireChannelReadComplete();
this.HandlerRemovedInternal(context);
}
Expand Down
5 changes: 5 additions & 0 deletions src/DotNetty.Handlers/DotNetty.Handlers.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,9 @@
<Reference Include="System" />
<Reference Include="Microsoft.CSharp" />
</ItemGroup>
<ItemGroup Condition="'$(TargetFramework)' == 'netstandard1.3'">
<PackageReference Include="System.Globalization.Extensions">
<Version>4.3.0</Version>
</PackageReference>
</ItemGroup>
</Project>
23 changes: 23 additions & 0 deletions src/DotNetty.Handlers/Tls/ServerTlsSniSettings.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

namespace DotNetty.Handlers.Tls
{
using System;
using System.Diagnostics.Contracts;
using System.Threading.Tasks;

public sealed class ServerTlsSniSettings
{
public ServerTlsSniSettings(Func<string, Task<ServerTlsSettings>> serverTlsSettingMap, string defaultServerHostName = null)
{
Contract.Requires(serverTlsSettingMap != null);
this.ServerTlsSettingMap = serverTlsSettingMap;
this.DefaultServerHostName = defaultServerHostName;
}

public Func<string, Task<ServerTlsSettings>> ServerTlsSettingMap { get; }

public string DefaultServerHostName { get; }
}
}
316 changes: 316 additions & 0 deletions src/DotNetty.Handlers/Tls/SniHandler.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,316 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

namespace DotNetty.Handlers.Tls
{
using System;
using System.Collections.Generic;
using System.Diagnostics.Contracts;
using System.Globalization;
using System.IO;
using System.Net.Security;
using System.Text;
using DotNetty.Buffers;
using DotNetty.Codecs;
using DotNetty.Common.Internal.Logging;
using DotNetty.Transport.Channels;

public sealed class SniHandler : ByteToMessageDecoder
{
// Maximal number of ssl records to inspect before fallback to the default (aligned with netty)
const int MAX_SSL_RECORDS = 4;
static readonly IInternalLogger Logger = InternalLoggerFactory.GetInstance(typeof(SniHandler));
readonly Func<Stream, SslStream> sslStreamFactory;
readonly ServerTlsSniSettings serverTlsSniSettings;

bool handshakeFailed;
bool suppressRead;
bool readPending;

public SniHandler(ServerTlsSniSettings settings)
: this(stream => new SslStream(stream, true), settings)
{
}

public SniHandler(Func<Stream, SslStream> sslStreamFactory, ServerTlsSniSettings settings)
{
Contract.Requires(settings != null);
Contract.Requires(sslStreamFactory != null);
this.sslStreamFactory = sslStreamFactory;
this.serverTlsSniSettings = settings;
}

protected override void Decode(IChannelHandlerContext context, IByteBuffer input, List<object> output)
{
if (!this.suppressRead && !this.handshakeFailed)
{
int writerIndex = input.WriterIndex;
Exception error = null;
try
{
bool continueLoop = true;
for (int i = 0; i < MAX_SSL_RECORDS && continueLoop; i++)
{
int readerIndex = input.ReaderIndex;
int readableBytes = writerIndex - readerIndex;
if (readableBytes < TlsUtils.SSL_RECORD_HEADER_LENGTH)
{
// Not enough data to determine the record type and length.
return;
}

int command = input.GetByte(readerIndex);
// tls, but not handshake command
switch (command)
{
case TlsUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC:
case TlsUtils.SSL_CONTENT_TYPE_ALERT:
int len = TlsUtils.GetEncryptedPacketLength(input, readerIndex);

// Not an SSL/TLS packet
if (len == TlsUtils.NOT_ENCRYPTED)
{
this.handshakeFailed = true;
var e = new NotSslRecordException(
"not an SSL/TLS record: " + ByteBufferUtil.HexDump(input));
input.SkipBytes(input.ReadableBytes);

TlsUtils.NotifyHandshakeFailure(context, e);
throw e;
}
if (len == TlsUtils.NOT_ENOUGH_DATA ||
writerIndex - readerIndex - TlsUtils.SSL_RECORD_HEADER_LENGTH < len)
{
// Not enough data
return;
}

// increase readerIndex and try again.
input.SkipBytes(len);
continue;

case TlsUtils.SSL_CONTENT_TYPE_HANDSHAKE:
int majorVersion = input.GetByte(readerIndex + 1);

// SSLv3 or TLS
if (majorVersion == 3)
{
int packetLength = input.GetUnsignedShort(readerIndex + 3) + TlsUtils.SSL_RECORD_HEADER_LENGTH;

if (readableBytes < packetLength)
{
// client hello incomplete; try again to decode once more data is ready.
return;
}

// See https://tools.ietf.org/html/rfc5246#section-7.4.1.2
//
// Decode the ssl client hello packet.
// We have to skip bytes until SessionID (which sum to 43 bytes).
//
// struct {
// ProtocolVersion client_version;
// Random random;
// SessionID session_id;
// CipherSuite cipher_suites<2..2^16-2>;
// CompressionMethod compression_methods<1..2^8-1>;
// select (extensions_present) {
// case false:
// struct {};
// case true:
// Extension extensions<0..2^16-1>;
// };
// } ClientHello;
//

int endOffset = readerIndex + packetLength;
int offset = readerIndex + 43;

if (endOffset - offset < 6)
{
continueLoop = false;
break;
}

int sessionIdLength = input.GetByte(offset);
offset += sessionIdLength + 1;

int cipherSuitesLength = input.GetUnsignedShort(offset);
offset += cipherSuitesLength + 2;

int compressionMethodLength = input.GetByte(offset);
offset += compressionMethodLength + 1;

int extensionsLength = input.GetUnsignedShort(offset);
offset += 2;
int extensionsLimit = offset + extensionsLength;

if (extensionsLimit > endOffset)
{
// Extensions should never exceed the record boundary.
continueLoop = false;
break;
}

for (;;)
{
if (extensionsLimit - offset < 4)
{
continueLoop = false;
break;
}

int extensionType = input.GetUnsignedShort(offset);
offset += 2;

int extensionLength = input.GetUnsignedShort(offset);
offset += 2;

if (extensionsLimit - offset < extensionLength)
{
continueLoop = false;
break;
}

// SNI
// See https://tools.ietf.org/html/rfc6066#page-6
if (extensionType == 0)
{
offset += 2;
if (extensionsLimit - offset < 3)
{
continueLoop = false;
break;
}

int serverNameType = input.GetByte(offset);
offset++;

if (serverNameType == 0)
{
int serverNameLength = input.GetUnsignedShort(offset);
offset += 2;

if (serverNameLength <= 0 || extensionsLimit - offset < serverNameLength)
{
continueLoop = false;
break;
}

string hostname = input.ToString(offset, serverNameLength, Encoding.UTF8);
//try
//{
// select(ctx, IDN.toASCII(hostname,
// IDN.ALLOW_UNASSIGNED).toLowerCase(Locale.US));
//}
//catch (Throwable t)
//{
// PlatformDependent.throwException(t);
//}

var idn = new IdnMapping()
{
AllowUnassigned = true
};

hostname = idn.GetAscii(hostname);
#if NETSTANDARD1_3
// TODO: netcore does not have culture sensitive tolower()
hostname = hostname.ToLowerInvariant();
#else
hostname = hostname.ToLower(new CultureInfo("en-US"));
#endif
this.Select(context, hostname);
return;
}
else
{
// invalid enum value
continueLoop = false;
break;
}
}

offset += extensionLength;
}
}

break;
// Fall-through
default:
//not tls, ssl or application data, do not try sni
continueLoop = false;
break;
}
}
}
catch (Exception e)
{
error = e;

// unexpected encoding, ignore sni and use default
if (Logger.DebugEnabled)
{
Logger.Warn($"Unexpected client hello packet: {ByteBufferUtil.HexDump(input)}", e);
}
}

if (this.serverTlsSniSettings.DefaultServerHostName != null)
{
// Just select the default certifcate
this.Select(context, this.serverTlsSniSettings.DefaultServerHostName);
}
else
{
this.handshakeFailed = true;
var e = new DecoderException($"failed to get the Tls Certificate {error}");
TlsUtils.NotifyHandshakeFailure(context, e);
throw e;
}
}
}

async void Select(IChannelHandlerContext context, string hostName)
{
Contract.Requires(hostName != null);
this.suppressRead = true;
try
{
var serverTlsSetting = await this.serverTlsSniSettings.ServerTlsSettingMap(hostName);
this.ReplaceHandler(context, serverTlsSetting);
}
catch (Exception ex)
{
this.ExceptionCaught(context, new DecoderException($"failed to get the Tls Certificate for {hostName}, {ex}"));
}
finally
{
this.suppressRead = false;
if (this.readPending)
{
this.readPending = false;
context.Read();
}
}
}

void ReplaceHandler(IChannelHandlerContext context, ServerTlsSettings serverTlsSetting)
{
Contract.Requires(serverTlsSetting != null);
var tlsHandler = new TlsHandler(this.sslStreamFactory, serverTlsSetting);
context.Channel.Pipeline.Replace(this, nameof(TlsHandler), tlsHandler);
}

public override void Read(IChannelHandlerContext context)
{
if (this.suppressRead)
{
this.readPending = true;
}
else
{
base.Read(context);
}
}
}
}
6 changes: 6 additions & 0 deletions src/DotNetty.Handlers/Tls/TlsUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ static class TlsUtils
/// the length of the ssl record header (in bytes)
public const int SSL_RECORD_HEADER_LENGTH = 5;

// Not enough data in buffer to parse the record length
public const int NOT_ENOUGH_DATA = -1;

// data is not encrypted
public const int NOT_ENCRYPTED = -2;

/// <summary>
/// Return how much bytes can be read out of the encrypted data. Be aware that this method will not increase
/// the readerIndex of the given <see cref="IByteBuffer"/>.
Expand Down
Loading

0 comments on commit 1d3eda9

Please sign in to comment.