Skip to content

Commit

Permalink
Insert SSH 3.11.41, add more unit tests and logging to tests (#473)
Browse files Browse the repository at this point in the history
Fix for https://github.com/devdiv-microsoft/basis-planning/issues/1618

Insert SSH 3.11.41 that has the fix.
Add a unit test for client connecting to host when the tunnel has multiple port.
Add more logging to unit tests.
  • Loading branch information
IlyaBiryukov authored Aug 16, 2024
1 parent 464f85e commit 3679192
Show file tree
Hide file tree
Showing 5 changed files with 281 additions and 12 deletions.
2 changes: 1 addition & 1 deletion cs/build/build.props
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
<ReportGeneratorVersion>4.8.13</ReportGeneratorVersion>
<SystemTextEncodingsWebPackageVersion>4.7.2</SystemTextEncodingsWebPackageVersion>
<VisualStudioValidationVersion>15.5.31</VisualStudioValidationVersion>
<DevTunnelsSshPackageVersion>3.11.36</DevTunnelsSshPackageVersion>
<DevTunnelsSshPackageVersion>3.11.41</DevTunnelsSshPackageVersion>
<XunitRunnerVisualStudioVersion>2.4.0</XunitRunnerVisualStudioVersion>
<XunitVersion>2.4.0</XunitVersion>
</PropertyGroup>
Expand Down
163 changes: 163 additions & 0 deletions cs/test/TunnelsSDK.Test/TcpListeners.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
using System.Diagnostics;
using System.Globalization;
using System.Net;
using System.Net.Sockets;
using System.Text;
using Microsoft.DevTunnels.Management;

namespace Microsoft.DevTunnels.Test;
public sealed class TcpListeners : IAsyncDisposable
{
private const int MaxAttempts = 10;

private readonly TraceSource trace;
private readonly CancellationTokenSource cts = new();
private readonly List<TcpListener> listeners = new();
private readonly List<Task> listenerTasks = new();

public TcpListeners(int count, TraceSource trace)
{
Requires.Argument(count > 0, nameof(count), "Count must be greater than 0.");
this.trace = trace.WithName("TcpListeners");
Ports = new int[count];
for (int index = 0; index < count; index++)
{
TcpListener listener = null;
int port;
int attempt = 0;
while (true)
{
try
{
port = TcpUtils.GetAvailableTcpPort(canReuseAddress: false);
listener = new TcpListener(IPAddress.Loopback, port);
listener.Start();
break;
}
catch (SocketException ex)
{
listener?.Stop();
if (++attempt >= MaxAttempts)
{
throw new InvalidOperationException("Failed to find available port", ex);
}
}
catch
{
listener?.Stop();
throw;
}
}

Ports[index] = port;
this.listeners.Add(listener);
this.listenerTasks.Add(AcceptConnectionsAsync(listener, port));
}

this.trace.Info("Listening on ports: {0}", string.Join(", ", Ports));
}

public int Port { get; }

public int[] Ports { get; }

public async ValueTask DisposeAsync()
{
cts.Cancel();
StopListeners();
await Task.WhenAll(this.listenerTasks);
this.listenerTasks.Clear();
}

private async Task AcceptConnectionsAsync(TcpListener listener, int port)
{
var tasks = new List<Task>();
TaskCompletionSource allTasksCompleted = null;
try
{
while (!cts.IsCancellationRequested)
{
var tcpClient = await listener.AcceptTcpClientAsync(cts.Token);
var task = Task.Run(() => RunClientAsync(tcpClient, port));
lock (tasks)
{
tasks.Add(task);
}

_ = task.ContinueWith(
(t) =>
{
lock (tasks)
{
tasks.Remove(t);
if (tasks.Count == 0)
{
allTasksCompleted?.TrySetResult();
}
}
});
}
}
catch (OperationCanceledException) when (this.cts.IsCancellationRequested)
{
// Ignore
}
catch (SocketException) when (this.cts.IsCancellationRequested)
{
// Ignore
}
catch (Exception ex)
{
this.trace.Error($"Error accepting TCP client for port {port}: ${ex}");
}

lock (tasks)
{
if (tasks.Count == 0)
{
return;
}

allTasksCompleted = new TaskCompletionSource();
}

await allTasksCompleted.Task;
}

private async Task RunClientAsync(TcpClient tcpClient, int port)
{
try
{
using var disposable = tcpClient;

this.trace.Info($"Accepted client connection to TCP port {port}");
await using var stream = tcpClient.GetStream();

var bytes = Encoding.UTF8.GetBytes(port.ToString(CultureInfo.InvariantCulture));
await stream.WriteAsync(bytes);

}
catch (OperationCanceledException) when (this.cts.IsCancellationRequested)
{
// Ignore
}
catch (SocketException) when (this.cts.IsCancellationRequested)
{
// Ignore
}
catch (Exception ex)
{
this.trace.Error($"Error handling TCP client on listener running on port {port}: ${ex}");
}
}

private void StopListeners()
{
foreach (var listener in this.listeners)
{
listener.Stop();
}

this.listeners.Clear();
}
}
19 changes: 17 additions & 2 deletions cs/test/TunnelsSDK.Test/TcpUtils.cs
Original file line number Diff line number Diff line change
@@ -1,17 +1,32 @@
using System.Net;
using System.Globalization;
using System.Net;
using System.Net.Sockets;
using System.Text;

namespace Microsoft.DevTunnels.Test;

internal static class TcpUtils
{
public static int GetAvailableTcpPort()
public static int GetAvailableTcpPort(bool canReuseAddress = true)
{
// Get any available local tcp port
var l = new TcpListener(IPAddress.Loopback, 0);
if (!canReuseAddress)
{
l.Server.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, false);
}

l.Start();
int port = ((IPEndPoint)l.LocalEndpoint).Port;
l.Stop();
return port;
}

public static async Task<int> ReadIntToEndAsync(this Stream stream, CancellationToken cancellation)
{
var buffer = new byte[1024];
var length = await stream.ReadAsync(buffer, cancellation);
var text = Encoding.UTF8.GetString(buffer, 0, length);
return int.Parse(text, CultureInfo.InvariantCulture);
}
}
77 changes: 68 additions & 9 deletions cs/test/TunnelsSDK.Test/TunnelHostAndClientTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using Microsoft.DevTunnels.Test.Mocks;
using Nerdbank.Streams;
using Xunit;
using Xunit.Abstractions;
using Xunit.Sdk;

namespace Microsoft.DevTunnels.Test;
Expand All @@ -24,27 +25,24 @@ public class TunnelHostAndClientTests : IClassFixture<LocalPortsFixture>
private const string MockHostRelayUri = "ws://localhost/tunnel/host";
private const string MockClientRelayUri = "ws://localhost/tunnel/client";

private static readonly TraceSource TestTS =
private readonly TraceSource TestTS =
new TraceSource(nameof(TunnelHostAndClientTests));
private static readonly TimeSpan Timeout = Debugger.IsAttached ? TimeSpan.FromHours(1) : TimeSpan.FromSeconds(10);
private static readonly TimeSpan Timeout = Debugger.IsAttached ? TimeSpan.FromHours(1) : TimeSpan.FromSeconds(20);
private readonly CancellationToken TimeoutToken = new CancellationTokenSource(Timeout).Token;

private Stream serverStream;
private Stream clientStream;
private readonly IKeyPair serverSshKey;
private readonly LocalPortsFixture localPortsFixture;

static TunnelHostAndClientTests()
{
// Enabling tracing to debug console.
TestTS.Switch.Level = SourceLevels.All;
}

public TunnelHostAndClientTests(LocalPortsFixture localPortsFixture)
public TunnelHostAndClientTests(LocalPortsFixture localPortsFixture, ITestOutputHelper output)
{
(this.serverStream, this.clientStream) = FullDuplexStream.CreatePair();
this.serverSshKey = SshAlgorithms.PublicKey.ECDsaSha2Nistp384.GenerateKeyPair();
this.localPortsFixture = localPortsFixture;

TestTS.Switch.Level = SourceLevels.All;
TestTS.Listeners.Add(new XunitTraceListener(output));
}

private Tunnel CreateRelayTunnel(bool addClientEndpoint = true) => CreateRelayTunnel(addClientEndpoint, Enumerable.Empty<int>());
Expand Down Expand Up @@ -1453,6 +1451,67 @@ public async Task ConnectRelayHostThenConnectRelayClientToForwardedPortStream()
using var sshStream = await clientSshSession.ConnectToForwardedPortAsync(port, TimeoutToken);
}

[Fact]
public async Task ConnectRelayHostThenConnectRelayClientsToForwardedPortStreamsThenSendData()
{
const int PortCount = 2;
const int ClientConnectionCount = 50;

var managementClient = new MockTunnelManagementClient
{
HostRelayUri = MockHostRelayUri,
ClientRelayUri = MockClientRelayUri,
};

var relayHost = new TunnelRelayTunnelHost(managementClient, TestTS);

await using var listeners = new TcpListeners(PortCount, TestTS);
var tunnel = CreateRelayTunnel(false, listeners.Ports);

using var multiChannelStream = await ConnectRelayHostAsync(relayHost, tunnel);
Assert.Equal(ConnectionStatus.Connected, relayHost.ConnectionStatus);

var clientStreamFactory = new MockTunnelRelayStreamFactory(TunnelRelayConnection.ClientWebSocketSubProtocol)
{
StreamFactory = async (accessToken) =>
{
return await multiChannelStream.OpenStreamAsync(TunnelRelayTunnelHost.ClientStreamChannelType);
},
};

for (int clientConnection = 0; clientConnection < ClientConnectionCount; clientConnection++)
{
foreach (var port in listeners.Ports)
{
TestTS.TraceInformation("Connecting client #{0} to port {1}", clientConnection, port);

// Create and connect tunnel client
await using var relayClient = new TunnelRelayTunnelClient(TestTS)
{
AcceptLocalConnectionsForForwardedPorts = false,
StreamFactory = clientStreamFactory,
};

Assert.Equal(ConnectionStatus.None, relayClient.ConnectionStatus);

await relayClient.ConnectAsync(tunnel, TimeoutToken);
Assert.Equal(ConnectionStatus.Connected, relayClient.ConnectionStatus);

await relayClient.WaitForForwardedPortAsync(port, TimeoutToken);
using var stream = await relayClient.ConnectToForwardedPortAsync(port, TimeoutToken);

var actualPort = await stream.ReadIntToEndAsync(TimeoutToken);
if (port != actualPort)
{
// Debugger.Launch();
TestTS.TraceInformation("Client #{0} received unexpected port {1} instead of {2}", clientConnection, actualPort, port);
}

Assert.Equal(port, actualPort);
}
}
}

[Fact]
public async Task ConnectRelayHostThenConnectRelayClientToDifferentPort_Fails()
{
Expand Down
32 changes: 32 additions & 0 deletions cs/test/TunnelsSDK.Test/XunitTraceListener.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
using System.Diagnostics;
using System.Text;
using Xunit.Abstractions;

namespace Microsoft.DevTunnels.Test;

internal sealed class XunitTraceListener : TraceListener
{
private readonly ITestOutputHelper output;
private readonly StringBuilder currentLine = new ();
private readonly DateTimeOffset loggingStart = DateTimeOffset.UtcNow;
private DateTimeOffset? messageStart;

public XunitTraceListener(ITestOutputHelper output)
{
this.output = output;
}

public override void Write(string message)
{
this.messageStart ??= DateTimeOffset.UtcNow;
this.currentLine.Append(message);
}

public override void WriteLine(string message)
{
var messageTime = (this.messageStart ?? DateTimeOffset.UtcNow) - this.loggingStart;
this.output.WriteLine($"{messageTime} {this.currentLine}{message}");
this.currentLine.Clear();
this.messageStart = null;
}
}

0 comments on commit 3679192

Please sign in to comment.