diff --git a/src/SuperSocket.Connection/ConnectionBase.cs b/src/SuperSocket.Connection/ConnectionBase.cs index d681bc09e..2e091c0f9 100644 --- a/src/SuperSocket.Connection/ConnectionBase.cs +++ b/src/SuperSocket.Connection/ConnectionBase.cs @@ -9,13 +9,13 @@ namespace SuperSocket.Connection { public abstract class ConnectionBase : IConnection - { + { public abstract IAsyncEnumerable RunAsync(IPipelineFilter pipelineFilter); public abstract ValueTask SendAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default); public abstract ValueTask SendAsync(IPackageEncoder packageEncoder, TPackage package, CancellationToken cancellationToken = default); - + public abstract ValueTask SendAsync(Action write, CancellationToken cancellationToken = default); public bool IsClosed { get; private set; } @@ -28,6 +28,8 @@ public abstract class ConnectionBase : IConnection public DateTimeOffset LastActiveTime { get; protected set; } = DateTimeOffset.Now; + public CancellationToken ConnectionToken { get; protected set; } + protected virtual void OnClosed() { IsClosed = true; diff --git a/src/SuperSocket.Connection/IConnection.cs b/src/SuperSocket.Connection/IConnection.cs index bc7d90289..bdf340bcf 100644 --- a/src/SuperSocket.Connection/IConnection.cs +++ b/src/SuperSocket.Connection/IConnection.cs @@ -33,5 +33,7 @@ public interface IConnection ValueTask DetachAsync(); CloseReason? CloseReason { get; } + + CancellationToken ConnectionToken { get; } } } diff --git a/src/SuperSocket.Connection/PipeConnectionBase.cs b/src/SuperSocket.Connection/PipeConnectionBase.cs index 9937d9e18..6601420d2 100644 --- a/src/SuperSocket.Connection/PipeConnectionBase.cs +++ b/src/SuperSocket.Connection/PipeConnectionBase.cs @@ -32,7 +32,7 @@ PipeReader IPipeConnection.InputReader { get { return InputReader; } } - + IPipelineFilter IPipeConnection.PipelineFilter { get { return _pipelineFilter; } @@ -52,6 +52,7 @@ protected PipeConnectionBase(PipeReader inputReader, PipeWriter outputWriter, Co Logger = options.Logger; InputReader = inputReader; OutputWriter = outputWriter; + ConnectionToken = _cts.Token; } protected virtual Task StartTask(IObjectPipe packagePipe) @@ -72,7 +73,7 @@ public async override IAsyncEnumerable RunAsync(IPip _packagePipe = packagePipe; _pipelineFilter = pipelineFilter; - + _pipeTask = StartTask(packagePipe); _ = HandleClosing(); @@ -118,7 +119,7 @@ private async ValueTask HandleClosing() { if (!IsIgnorableException(exc)) OnError("Unhandled exception in the method PipeChannel.Close.", exc); - } + } } } } @@ -172,7 +173,7 @@ public override async ValueTask SendAsync(ReadOnlyMemory buffer, Cancellat finally { SendLock.Release(); - } + } } private void WriteBuffer(PipeWriter writer, ReadOnlyMemory buffer) @@ -236,7 +237,7 @@ protected async Task ReadPipeAsync(PipeReader reader, IObjectPipe< { if (!IsIgnorableException(e) && !(e is OperationCanceledException)) OnError("Failed to read from the pipe", e); - + break; } @@ -267,7 +268,7 @@ protected async Task ReadPipeAsync(PipeReader reader, IObjectPipe< { completed = true; break; - } + } } if (completed) @@ -344,7 +345,7 @@ private bool ReaderBuffer(ref ReadOnlySequence buffer, IPipe Close(); return false; } - + if (packageInfo == null) { // the current pipeline filter needs more data to process @@ -370,12 +371,12 @@ private bool ReaderBuffer(ref ReadOnlySequence buffer, IPipe examined = consumed = buffer.End; return true; } - + if (bytesConsumed > 0) seqReader = new SequenceReader(seqReader.Sequence.Slice(bytesConsumed)); } } - + public override async ValueTask DetachAsync() { _isDetaching = true; diff --git a/src/SuperSocket.Kestrel/KestrelPipeConnection.cs b/src/SuperSocket.Kestrel/KestrelPipeConnection.cs index 8d8b14be2..253a9ccda 100644 --- a/src/SuperSocket.Kestrel/KestrelPipeConnection.cs +++ b/src/SuperSocket.Kestrel/KestrelPipeConnection.cs @@ -16,11 +16,19 @@ public KestrelPipeConnection(ConnectionContext context, ConnectionOptions option : base(context.Transport.Input, context.Transport.Output, options) { _context = context; - context.ConnectionClosed.Register(() => OnClosed()); + context.ConnectionClosed.Register(() => OnConnectionClosed()); LocalEndPoint = context.LocalEndPoint; RemoteEndPoint = context.RemoteEndPoint; } + protected override void OnClosed() + { + if (!CloseReason.HasValue) + CloseReason = Connection.CloseReason.RemoteClosing; + + base.OnClosed(); + } + public override ValueTask DetachAsync() { throw new NotSupportedException($"Detach is not supported by {nameof(KestrelPipeConnection)}."); @@ -39,14 +47,6 @@ protected override async void Close() } } - protected override void OnClosed() - { - if (!CloseReason.HasValue) - CloseReason = Connection.CloseReason.RemoteClosing; - - base.OnClosed(); - } - protected override void OnInputPipeRead(ReadResult result) { if (!result.IsCanceled && !result.IsCompleted) @@ -72,4 +72,9 @@ public override async ValueTask SendAsync(IPackageEncoder pa await base.SendAsync(packageEncoder, package, cancellationToken); UpdateLastActiveTime(); } + + private void OnConnectionClosed() + { + Cancel(); + } } diff --git a/src/SuperSocket.Server/SuperSocketService.cs b/src/SuperSocket.Server/SuperSocketService.cs index cd1b57cc8..ba255d6fa 100644 --- a/src/SuperSocket.Server/SuperSocketService.cs +++ b/src/SuperSocket.Server/SuperSocketService.cs @@ -314,11 +314,11 @@ protected virtual ValueTask OnSessionClosedAsync(IAppSession session, CloseEvent if (closedHandler != null) return closedHandler.Invoke(session, e); - #if NETSTANDARD2_1 - return GetCompletedTask(); - #else - return ValueTask.CompletedTask; - #endif +#if NETSTANDARD2_1 + return GetCompletedTask(); +#else + return ValueTask.CompletedTask; +#endif } protected virtual async ValueTask FireSessionConnectedEvent(AppSession session) @@ -350,7 +350,7 @@ protected virtual async ValueTask FireSessionClosedEvent(AppSession session, Clo if (!handshakeSession.Handshaked) return; } - + await UnRegisterSessionFromMiddlewares(session); _logger.LogInformation($"The session disconnected: {session.SessionID} ({reason})"); @@ -396,18 +396,18 @@ private async ValueTask HandleSession(AppSession session, IConnection connection var packageHandlingScheduler = _packageHandlingScheduler; #if NET6_0_OR_GREATER - using var cancellationTokenSource = GetPackageHandlingCancellationTokenSource(CancellationToken.None); + using var cancellationTokenSource = GetPackageHandlingCancellationTokenSource(connection.ConnectionToken); #endif await foreach (var p in packageStream) { - if(_packageHandlingContextAccessor != null) + if (_packageHandlingContextAccessor != null) { _packageHandlingContextAccessor.PackageHandlingContext = new PackageHandlingContext(session, p); } #if !NET6_0_OR_GREATER - using var cancellationTokenSource = GetPackageHandlingCancellationTokenSource(CancellationToken.None); + using var cancellationTokenSource = GetPackageHandlingCancellationTokenSource(connection.ConnectionToken); #endif await packageHandlingScheduler.HandlePackage(session, p, cancellationTokenSource.Token); @@ -424,13 +424,9 @@ private async ValueTask HandleSession(AppSession session, IConnection connection protected virtual CancellationTokenSource GetPackageHandlingCancellationTokenSource(CancellationToken cancellationToken) { -#if NET6_0_OR_GREATER - return CancellationTokenSourcePool.Shared.Rent(TimeSpan.FromSeconds(Options.PackageHandlingTimeOut)); -#else var cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); cancellationTokenSource.CancelAfter(TimeSpan.FromSeconds(Options.PackageHandlingTimeOut)); return cancellationTokenSource; -#endif } protected virtual ValueTask OnSessionErrorAsync(IAppSession session, PackageHandlingException exception) @@ -471,20 +467,20 @@ public async Task StartAsync(CancellationToken cancellationToken) protected virtual ValueTask OnStartedAsync() { - #if NETSTANDARD2_1 - return GetCompletedTask(); - #else - return ValueTask.CompletedTask; - #endif +#if NETSTANDARD2_1 + return GetCompletedTask(); +#else + return ValueTask.CompletedTask; +#endif } protected virtual ValueTask OnStopAsync() { - #if NETSTANDARD2_1 - return GetCompletedTask(); - #else - return ValueTask.CompletedTask; - #endif +#if NETSTANDARD2_1 + return GetCompletedTask(); +#else + return ValueTask.CompletedTask; +#endif } private async Task StopListener(IConnectionListener listener)