Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WebSocket Feedback Follow-up #107662

Merged
merged 9 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,29 +37,6 @@ internal static partial class WebSocketValidate
private static readonly SearchValues<char> s_validSubprotocolChars =
SearchValues.Create("!#$%&'*+-.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ^_`abcdefghijklmnopqrstuvwxyz|~");

internal static void ThrowIfInvalidState(WebSocketState currentState, bool isDisposed, WebSocketState[] validStates)
CarnaViire marked this conversation as resolved.
Show resolved Hide resolved
=> ThrowIfInvalidState(currentState, isDisposed, innerException: null, validStates ?? []);

internal static void ThrowIfInvalidState(WebSocketState currentState, bool isDisposed, Exception? innerException, WebSocketState[]? validStates = null)
{
if (validStates is not null && Array.IndexOf(validStates, currentState) == -1)
{
string invalidStateMessage = SR.Format(
SR.net_WebSockets_InvalidState, currentState, string.Join(", ", validStates));

throw new WebSocketException(WebSocketError.InvalidState, invalidStateMessage, innerException);
}

if (innerException is not null)
{
Debug.Assert(currentState == WebSocketState.Aborted);
throw new OperationCanceledException(nameof(WebSocketState.Aborted), innerException);
}

// Ordering is important to maintain .NET 4.5 WebSocket implementation exception behavior.
ObjectDisposedException.ThrowIf(isDisposed, typeof(WebSocket));
}

internal static void ValidateSubprotocol(string subProtocol)
{
ArgumentException.ThrowIfNullOrWhiteSpace(subProtocol);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
<Compile Include="System\Net\WebSockets\Compression\WebSocketInflater.cs" />
<Compile Include="System\Net\WebSockets\ManagedWebSocket.cs" />
<Compile Include="System\Net\WebSockets\ManagedWebSocket.KeepAlive.cs" />
<Compile Include="System\Net\WebSockets\ManagedWebSocketStates.cs" />
<Compile Include="System\Net\WebSockets\NetEventSource.WebSockets.cs" />
<Compile Include="System\Net\WebSockets\ValueWebSocketReceiveResult.cs" />
<Compile Include="System\Net\WebSockets\WebSocket.cs" />
Expand All @@ -31,6 +32,7 @@
<Compile Include="System\Net\WebSockets\WebSocketMessageFlags.cs" />
<Compile Include="System\Net\WebSockets\WebSocketReceiveResult.cs" />
<Compile Include="System\Net\WebSockets\WebSocketState.cs" />
<Compile Include="System\Net\WebSockets\WebSocketStateHelper.cs" />
<Compile Include="$(CommonPath)System\Net\WebSockets\WebSocketDefaults.cs"
Link="Common\System\Net\WebSockets\WebSocketDefaults.cs" />
<Compile Include="$(CommonPath)System\Net\WebSockets\WebSocketValidate.cs"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,28 @@ public Task EnterAsync(CancellationToken cancellationToken)
// If cancellation was requested, bail immediately.
// If the mutex is not currently held nor contended, enter immediately.
// Otherwise, fall back to a more expensive likely-asynchronous wait.
return
cancellationToken.IsCancellationRequested ? Task.FromCanceled(cancellationToken) :
Interlocked.Decrement(ref _gate) >= 0 ? Task.CompletedTask :
Contended(cancellationToken);

if (cancellationToken.IsCancellationRequested)
{
return Task.FromCanceled(cancellationToken);
}

int gate = Interlocked.Decrement(ref _gate);
if (gate >= 0)
{
return Task.CompletedTask;
}

if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"Waiting to enter, queue length {-gate}");
CarnaViire marked this conversation as resolved.
Show resolved Hide resolved

return Contended(cancellationToken);

// Everything that follows is the equivalent of:
// return _sem.WaitAsync(cancellationToken);
// if _sem were to be constructed as `new SemaphoreSlim(0)`.

Task Contended(CancellationToken cancellationToken)
{
if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexContended(this, _gate);

var w = new Waiter(this);

// We need to register for cancellation before storing the waiter into the list.
Expand Down Expand Up @@ -178,18 +187,18 @@ static void OnCancellation(object? state, CancellationToken cancellationToken)
/// <remarks>The caller must logically own the mutex. This is not validated.</remarks>
public void Exit()
{
if (Interlocked.Increment(ref _gate) < 1)
// This is the equivalent of:
// _sem.Release();
// if _sem were to be constructed as `new SemaphoreSlim(0)`.
int gate = Interlocked.Increment(ref _gate);
if (gate < 1)
{
// This is the equivalent of:
// _sem.Release();
// if _sem were to be constructed as `new SemaphoreSlim(0)`.
if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"Unblocking next waiter on exit, remaining queue length {-_gate}", nameof(Exit));
Contended();
}

void Contended()
{
if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexContended(this, _gate);

Waiter? w;
lock (SyncObj)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Buffers;
using System.Buffers.Binary;
using System.Diagnostics;
using System.Runtime.ExceptionServices;
Expand All @@ -13,8 +12,6 @@ namespace System.Net.WebSockets
internal sealed partial class ManagedWebSocket : WebSocket
{
private bool IsUnsolicitedPongKeepAlive => _keepAlivePingState is null;
private static bool IsValidSendState(WebSocketState state) => Array.IndexOf(s_validSendStates, state) != -1;
private static bool IsValidReceiveState(WebSocketState state) => Array.IndexOf(s_validReceiveStates, state) != -1;

private void HeartBeat()
{
Expand All @@ -36,21 +33,19 @@ private void UnsolicitedPongHeartBeat()
TrySendKeepAliveFrameAsync(MessageOpcode.Pong));
}

private ValueTask TrySendKeepAliveFrameAsync(MessageOpcode opcode, ReadOnlyMemory<byte>? payload = null)
private ValueTask TrySendKeepAliveFrameAsync(MessageOpcode opcode, ReadOnlyMemory<byte> payload = default)
{
Debug.Assert(opcode is MessageOpcode.Pong || !IsUnsolicitedPongKeepAlive && opcode is MessageOpcode.Ping);
Debug.Assert((opcode is MessageOpcode.Pong) || (!IsUnsolicitedPongKeepAlive && opcode is MessageOpcode.Ping));

if (!IsValidSendState(_state))
if (!WebSocketStateHelper.IsValidSendState(_state))
{
if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"Cannot send keep-alive frame in {nameof(_state)}={_state}");

// we can't send any frames, but no need to throw as we are not observing errors anyway
return ValueTask.CompletedTask;
}

payload ??= ReadOnlyMemory<byte>.Empty;

return SendFrameAsync(opcode, endOfMessage: true, disableCompression: true, payload.Value, CancellationToken.None);
return SendFrameAsync(opcode, endOfMessage: true, disableCompression: true, payload, CancellationToken.None);
}

private void KeepAlivePingHeartBeat()
Expand All @@ -76,7 +71,7 @@ private void KeepAlivePingHeartBeat()

if (_keepAlivePingState.PingSent)
{
if (Environment.TickCount64 > _keepAlivePingState.PingTimeoutTimestamp)
if (now > _keepAlivePingState.PingTimeoutTimestamp)
{
if (NetEventSource.Log.IsEnabled())
{
Expand All @@ -92,7 +87,7 @@ private void KeepAlivePingHeartBeat()
}
else
{
if (Environment.TickCount64 > _keepAlivePingState.NextPingRequestTimestamp)
if (now > _keepAlivePingState.NextPingRequestTimestamp)
{
_keepAlivePingState.OnNextPingRequestCore(); // we are holding the lock
shouldSendPing = true;
Expand All @@ -119,18 +114,12 @@ private async ValueTask SendPingAsync(long pingPayload)
{
Debug.Assert(_keepAlivePingState != null);

byte[] pingPayloadBuffer = ArrayPool<byte>.Shared.Rent(sizeof(long));
byte[] pingPayloadBuffer = new byte[sizeof(long)];
BinaryPrimitives.WriteInt64BigEndian(pingPayloadBuffer, pingPayload);
try
{
await TrySendKeepAliveFrameAsync(MessageOpcode.Ping, pingPayloadBuffer.AsMemory(0, sizeof(long))).ConfigureAwait(false);

if (NetEventSource.Log.IsEnabled()) NetEventSource.KeepAlivePingSent(this, pingPayload);
}
finally
{
ArrayPool<byte>.Shared.Return(pingPayloadBuffer);
}
await TrySendKeepAliveFrameAsync(MessageOpcode.Ping, pingPayloadBuffer).ConfigureAwait(false);

if (NetEventSource.Log.IsEnabled()) NetEventSource.KeepAlivePingSent(this, pingPayload);
}

// "Observe" either a ValueTask result, or any exception, ignoring it
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,6 @@ internal sealed partial class ManagedWebSocket : WebSocket
/// <summary>Encoding for the payload of text messages: UTF-8 encoding that throws if invalid bytes are discovered, per the RFC.</summary>
private static readonly UTF8Encoding s_textEncoding = new UTF8Encoding(encoderShouldEmitUTF8Identifier: false, throwOnInvalidBytes: true);

/// <summary>Valid states to be in when calling SendAsync.</summary>
private static readonly WebSocketState[] s_validSendStates = { WebSocketState.Open, WebSocketState.CloseReceived };
/// <summary>Valid states to be in when calling ReceiveAsync.</summary>
private static readonly WebSocketState[] s_validReceiveStates = { WebSocketState.Open, WebSocketState.CloseSent };
/// <summary>Valid states to be in when calling CloseOutputAsync.</summary>
private static readonly WebSocketState[] s_validCloseOutputStates = { WebSocketState.Open, WebSocketState.CloseReceived };
/// <summary>Valid states to be in when calling CloseAsync.</summary>
private static readonly WebSocketState[] s_validCloseStates = { WebSocketState.Open, WebSocketState.CloseReceived, WebSocketState.CloseSent };

/// <summary>The maximum size in bytes of a message frame header that includes mask bytes.</summary>
internal const int MaxMessageHeaderLength = 14;
/// <summary>The maximum size of a control message payload.</summary>
Expand Down Expand Up @@ -337,7 +328,7 @@ public override ValueTask SendAsync(ReadOnlyMemory<byte> buffer, WebSocketMessag

try
{
ThrowIfInvalidState(s_validSendStates);
ThrowIfInvalidState(WebSocketStateHelper.ValidSendStates);
}
catch (Exception exc)
{
Expand Down Expand Up @@ -377,7 +368,7 @@ public override Task<WebSocketReceiveResult> ReceiveAsync(ArraySegment<byte> buf

try
{
ThrowIfInvalidState(s_validReceiveStates);
ThrowIfInvalidState(WebSocketStateHelper.ValidReceiveStates);

return ReceiveAsyncPrivate<WebSocketReceiveResult>(buffer, cancellationToken).AsTask();
}
Expand All @@ -394,7 +385,7 @@ public override ValueTask<ValueWebSocketReceiveResult> ReceiveAsync(Memory<byte>

try
{
ThrowIfInvalidState(s_validReceiveStates);
ThrowIfInvalidState(WebSocketStateHelper.ValidReceiveStates);

return ReceiveAsyncPrivate<ValueWebSocketReceiveResult>(buffer, cancellationToken);
}
Expand All @@ -413,7 +404,7 @@ public override Task CloseAsync(WebSocketCloseStatus closeStatus, string? status

try
{
ThrowIfInvalidState(s_validCloseStates);
ThrowIfInvalidState(WebSocketStateHelper.ValidCloseStates);
}
catch (Exception exc)
{
Expand All @@ -436,7 +427,7 @@ private async Task CloseOutputAsyncCore(WebSocketCloseStatus closeStatus, string
{
if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this);

ThrowIfInvalidState(s_validCloseOutputStates);
ThrowIfInvalidState(WebSocketStateHelper.ValidCloseOutputStates);

await SendCloseFrameAsync(closeStatus, statusDescription, cancellationToken).ConfigureAwait(false);

Expand Down Expand Up @@ -797,11 +788,9 @@ private async ValueTask<TResult> ReceiveAsyncPrivate<TResult>(Memory<byte> paylo

if (NetEventSource.Log.IsEnabled()) NetEventSource.ReceiveAsyncPrivateStarted(this, payloadBuffer.Length);

CancellationTokenRegistration registration = default;
CancellationTokenRegistration registration = cancellationToken.Register(static s => ((ManagedWebSocket)s!).Abort(), this);
CarnaViire marked this conversation as resolved.
Show resolved Hide resolved
try
{
registration = cancellationToken.Register(static s => ((ManagedWebSocket)s!).Abort(), this);

await _receiveMutex.EnterAsync(cancellationToken).ConfigureAwait(false);
if (NetEventSource.Log.IsEnabled()) NetEventSource.MutexEntered(_receiveMutex);

Expand Down Expand Up @@ -1737,9 +1726,9 @@ private void ThrowIfOperationInProgress(bool operationCompleted, [CallerMemberNa
cancellationToken);
}

private void ThrowIfDisposed() => ThrowIfInvalidState();
private void ThrowIfDisposed() => ThrowIfInvalidState(validStates: ManagedWebSocketStates.All);

private void ThrowIfInvalidState(WebSocketState[]? validStates = null)
private void ThrowIfInvalidState(ManagedWebSocketStates validStates)
{
bool disposed = _disposed;
WebSocketState state = _state;
Expand All @@ -1758,7 +1747,7 @@ private void ThrowIfInvalidState(WebSocketState[]? validStates = null)

if (NetEventSource.Log.IsEnabled()) NetEventSource.Trace(this, $"_state={state}, _disposed={disposed}, _keepAlivePingState.Exception={keepAliveException}");

WebSocketValidate.ThrowIfInvalidState(state, disposed, keepAliveException, validStates);
WebSocketStateHelper.ThrowIfInvalidState(state, disposed, keepAliveException, validStates);
}

// From https://github.com/aspnet/WebSockets/blob/aa63e27fce2e9202698053620679a9a1059b501e/src/Microsoft.AspNetCore.WebSockets.Protocol/Utilities.cs#L75
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

namespace System.Net.WebSockets
{
[Flags]
internal enum ManagedWebSocketStates
{
None = 0,

// WebSocketState.None = 0 -- this state is invalid for the managed implementation
// WebSocketState.Connecting = 1 -- this state is invalid for the managed implementation
Open = 0x04, // WebSocketState.Open = 2
CloseSent = 0x08, // WebSocketState.CloseSent = 3
CloseReceived = 0x10, // WebSocketState.CloseReceived = 4
Closed = 0x20, // WebSocketState.Closed = 5
Aborted = 0x40, // WebSocketState.Aborted = 6

All = Open | CloseSent | CloseReceived | Closed | Aborted
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ internal sealed partial class NetEventSource

private const int MutexEnterId = SendStopId + 1;
private const int MutexExitId = MutexEnterId + 1;
private const int MutexContendedId = MutexExitId + 1;

//
// Keep-Alive
Expand Down Expand Up @@ -185,10 +184,6 @@ private void MutexEnter(string objName, string memberName) =>
private void MutexExit(string objName, string memberName) =>
WriteEvent(MutexExitId, objName, memberName);

[Event(MutexContendedId, Keywords = Keywords.Debug, Level = EventLevel.Verbose)]
private void MutexContended(string objName, string memberName, int queueLength) =>
WriteEvent(MutexContendedId, objName, memberName, queueLength);

[NonEvent]
public static void MutexEntered(object? obj, [CallerMemberName] string? memberName = null)
{
Expand All @@ -203,13 +198,6 @@ public static void MutexExited(object? obj, [CallerMemberName] string? memberNam
Log.MutexExit(IdOf(obj), memberName ?? MissingMember);
}

[NonEvent]
public static void MutexContended(object? obj, int gateValue, [CallerMemberName] string? memberName = null)
{
Debug.Assert(Log.IsEnabled());
Log.MutexContended(IdOf(obj), memberName ?? MissingMember, -gateValue);
}

//
// WriteEvent overloads
//
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;

namespace System.Net.WebSockets
{
internal static class WebSocketStateHelper
{
/// <summary>Valid states to be in when calling SendAsync.</summary>
internal const ManagedWebSocketStates ValidSendStates = ManagedWebSocketStates.Open | ManagedWebSocketStates.CloseReceived;
/// <summary>Valid states to be in when calling ReceiveAsync.</summary>
internal const ManagedWebSocketStates ValidReceiveStates = ManagedWebSocketStates.Open | ManagedWebSocketStates.CloseSent;
/// <summary>Valid states to be in when calling CloseOutputAsync.</summary>
internal const ManagedWebSocketStates ValidCloseOutputStates = ManagedWebSocketStates.Open | ManagedWebSocketStates.CloseReceived;
/// <summary>Valid states to be in when calling CloseAsync.</summary>
internal const ManagedWebSocketStates ValidCloseStates = ManagedWebSocketStates.Open | ManagedWebSocketStates.CloseReceived | ManagedWebSocketStates.CloseSent;

internal static bool IsValidSendState(WebSocketState state) => ValidSendStates.HasFlag(ToFlag(state));

internal static void ThrowIfInvalidState(WebSocketState currentState, bool isDisposed, Exception? innerException, ManagedWebSocketStates validStates)
{
ManagedWebSocketStates state = ToFlag(currentState);

if (!validStates.HasFlag(state))
CarnaViire marked this conversation as resolved.
Show resolved Hide resolved
{
string invalidStateMessage = SR.Format(
SR.net_WebSockets_InvalidState, currentState, validStates);

throw new WebSocketException(WebSocketError.InvalidState, invalidStateMessage, innerException);
}

if (innerException is not null)
{
Debug.Assert(state == ManagedWebSocketStates.Aborted);
throw new OperationCanceledException(nameof(WebSocketState.Aborted), innerException);
}

// Ordering is important to maintain .NET 4.5 WebSocket implementation exception behavior.
ObjectDisposedException.ThrowIf(isDisposed, typeof(WebSocket));
}

private static ManagedWebSocketStates ToFlag(WebSocketState value)
{
ManagedWebSocketStates flag = (ManagedWebSocketStates)(1 << (int)value);
Debug.Assert(Enum.IsDefined(flag));
return flag;
}
}
}
Loading