Skip to content

Commit

Permalink
Fix HTTP rewriting
Browse files Browse the repository at this point in the history
  • Loading branch information
angelobreuer committed Jan 7, 2021
1 parent 3c4b681 commit 330e757
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 29 deletions.
2 changes: 1 addition & 1 deletion src/CommandLine/TunnelDashboard.cs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ private static void UpdateConnections(Tunnel tunnel, Stack<TunnelConnection> con
foreach (var connection in connectionHistory)
{
var httpConnection = connection as ProxiedHttpTunnelConnection;
var requestMessage = httpConnection?.RequestMessage;
var requestMessage = httpConnection?.HttpRequest;

var bytesIn = httpConnection?.Statistics.BytesIn / 1024F;
var bytesOut = httpConnection?.Statistics.BytesIn / 1024F;
Expand Down
36 changes: 19 additions & 17 deletions src/Connections/ProxiedHttpTunnelConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public ProxiedHttpTunnelConnection(TunnelConnectionHandle handle, ProxiedHttpTun

public ProxiedHttpTunnelOptions Options { get; }

public HttpRequestMessage? RequestMessage { get; private set; }
public HttpRequestMessage? HttpRequest { get; private set; }

public ConnectionStatistics Statistics => _statistics;

Expand Down Expand Up @@ -75,7 +75,6 @@ protected override void Dispose(bool disposing)
return;
}

RequestMessage?.Dispose();
_proxyStream?.Dispose();
_proxySocket?.Dispose();
ArrayPool<byte>.Shared.Return(_receiveBuffer);
Expand Down Expand Up @@ -104,28 +103,31 @@ private void BeginRead()

private void ProcessRequest(ref ArraySegment<byte> data)
{
var memoryStream = new MemoryStream(data.Array!, data.Offset, data.Array!.Length);
var requestBuffer = data.Array!;
var requestBody = (ReadOnlySpan<byte>)data;

using (var streamReader = new StreamReader(memoryStream, leaveOpen: true))
{
RequestMessage = RequestReader.Parse(streamReader, BaseUri)!;
}
HttpRequest = RequestReader.Parse(ref requestBody, BaseUri)!;
Options.RequestProcessor!.Process(this, HttpRequest);

// save request body as span
var requestBody = data.Array.AsSpan(data.Offset + (int)memoryStream.Position);
memoryStream.Position = 0;

Options.RequestProcessor!.Process(this, RequestMessage);
var pooledBuffer = Tunnel.ArrayPool.Rent(data.Count + 8096);

// write request back
using (var streamWriter = new StreamWriter(memoryStream, leaveOpen: true))
int requestLength;
using (var memoryStream = new MemoryStream(pooledBuffer))
{
RequestWriter.WriteRequest(streamWriter, RequestMessage);
using (var streamWriter = new StreamWriter(memoryStream, leaveOpen: true))
{
RequestWriter.WriteRequest(streamWriter, HttpRequest, requestBody.Length);
}

requestLength = (int)memoryStream.Position;
}

// write request body
memoryStream.Write(requestBody);
data = new(data.Array!, data.Offset, (int)memoryStream.Position);
requestBody.CopyTo(pooledBuffer.AsSpan(requestLength));
data = new(pooledBuffer, 0, requestLength + requestBody.Length);

// return current buffer
Tunnel.ArrayPool.Return(requestBuffer);
}

private void ReceiveCallbackInternal(IAsyncResult asyncResult)
Expand Down
28 changes: 22 additions & 6 deletions src/Http/RequestReader.cs
Original file line number Diff line number Diff line change
@@ -1,16 +1,32 @@
namespace Localtunnel.Http
{
using System;
using System.IO;
using System.Net;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Text;

internal static class RequestReader
{
public static HttpRequestMessage? Parse(TextReader textReader, Uri baseUri)
private static readonly byte[] _eol = new byte[] { (byte)'\r', (byte)'\n' };

private static string? ReadLine(ref ReadOnlySpan<byte> span)
{
var start = span.IndexOf(_eol);

if (start is -1)
{
return null;
}

var content = span[0..start];
span = span[(start + 2)..];
return Encoding.UTF8.GetString(content);
}

public static HttpRequestMessage? Parse(ref ReadOnlySpan<byte> span, Uri baseUri)
{
var statusLine = textReader.ReadLine();
var statusLine = ReadLine(ref span);

if (string.IsNullOrWhiteSpace(statusLine))
{
Expand All @@ -32,7 +48,7 @@ internal static class RequestReader
};

// read headers
ReadHttpHeaders(textReader, requestMessage.Headers);
ReadHttpHeaders(ref span, requestMessage.Headers);

return requestMessage;
}
Expand All @@ -45,10 +61,10 @@ internal static class RequestReader
_ => HttpVersion.Unknown,
};

private static void ReadHttpHeaders(TextReader textReader, HttpRequestHeaders headers)
private static void ReadHttpHeaders(ref ReadOnlySpan<byte> span, HttpRequestHeaders headers)
{
string? line;
while (!string.IsNullOrWhiteSpace(line = textReader.ReadLine()))
while (!string.IsNullOrWhiteSpace(line = ReadLine(ref span)))
{
var index = line.IndexOf(':');

Expand Down
11 changes: 7 additions & 4 deletions src/Http/RequestWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,18 @@ internal static class RequestWriter
{
private const string HTTP_EOL = "\r\n";

public static void WriteRequest(TextWriter writer, HttpRequestMessage request)
public static void WriteRequest(TextWriter writer, HttpRequestMessage request, long contentLength)
{
// status line
writer.Write(request.Method);
writer.Write(' ');
writer.Write(request.RequestUri!.PathAndQuery);
writer.Write(" HTTP/");
writer.Write(request.Version.ToString(2));
writer.Write(" HTTP/1.1");
writer.Write(HTTP_EOL);

// content length
writer.Write("Content-Length: ");
writer.Write(contentLength);
writer.Write(HTTP_EOL);

// headers
Expand All @@ -27,7 +31,6 @@ public static void WriteRequest(TextWriter writer, HttpRequestMessage request)
}

writer.Write(HTTP_EOL);
writer.Write(HTTP_EOL);
}
}
}
11 changes: 10 additions & 1 deletion src/Tunnels/TunnelSocketContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,16 @@ private void NotifyCompletedReceive(SocketAsyncEventArgs eventArgs)
// initialize connection
var handle = new TunnelConnectionHandle(this);
connection = _connection = Tunnel.ConnectionFactory(handle);
connection.Open();

try
{
connection.Open();
}
catch (Exception)
{
Dispose();
return;
}
}

// capture buffer
Expand Down

0 comments on commit 330e757

Please sign in to comment.