diff --git a/Protest/Proxy/TrafficCountingHttpMiddleware.cs b/Protest/Proxy/TrafficCountingHttpMiddleware.cs index 4944ea4b..849309b9 100644 --- a/Protest/Proxy/TrafficCountingHttpMiddleware.cs +++ b/Protest/Proxy/TrafficCountingHttpMiddleware.cs @@ -4,6 +4,8 @@ using System.Threading; using System.Net; using Microsoft.AspNetCore.Http; +using System.Collections.Generic; +using Microsoft.Extensions.Primitives; namespace Protest.Proxy; internal class TrafficCountingHttpMiddleware { @@ -21,6 +23,9 @@ public async Task InvokeAsync(HttpContext context) { IPAddress remoteIp = context.Connection.RemoteIpAddress; uint key = BitConverter.ToUInt32(remoteIp.GetAddressBytes(), 0); + long requestHeadersSize = CalculateHeadersSize(context.Request.Headers); + long responseHeadersSize; + Stream originalRequestBody = context.Request.Body; Stream originalResponseBody = context.Response.Body; @@ -31,18 +36,28 @@ public async Task InvokeAsync(HttpContext context) { context.Response.Body = responseBodyStream; await originalRequestBody.CopyToAsync(requestBodyStream); - bytesRx.AddOrUpdate(key, requestBodyStream.Length, (_, old) => old + requestBodyStream.Length); + bytesRx.AddOrUpdate(key, requestBodyStream.Length + requestHeadersSize, (_, old) => old + requestBodyStream.Length + requestHeadersSize); requestBodyStream.Seek(0, SeekOrigin.Begin); context.Request.Body = requestBodyStream; await _next(context); + responseHeadersSize = CalculateHeadersSize(context.Response.Headers); + responseBodyStream.Seek(0, SeekOrigin.Begin); await responseBodyStream.CopyToAsync(originalResponseBody); - bytesTx.AddOrUpdate(key, responseBodyStream.Length, (_, old) => old + responseBodyStream.Length); + bytesTx.AddOrUpdate(key, responseBodyStream.Length + responseHeadersSize, (_, old) => old + responseBodyStream.Length + responseHeadersSize); context.Request.Body = originalRequestBody; context.Response.Body = originalResponseBody; } + + private long CalculateHeadersSize(IHeaderDictionary headers) { + long size = 0; + foreach (KeyValuePair header in headers) { + size += header.Key.Length + header.Value.Sum(value => value.Length) + 4; + } + return size; + } }