Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix delayed server shutdown if clients are connected to the Watch server-streaming RPC of the Health API #2573

Closed
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
// Copyright 2024 The gRPC Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -17,11 +17,9 @@
#endregion

using Grpc.AspNetCore.HealthChecks;
using Grpc.HealthCheck;
using Grpc.Shared;
using Microsoft.Extensions.DependencyInjection.Extensions;
using Microsoft.Extensions.Diagnostics.HealthChecks;
using Microsoft.Extensions.Options;

namespace Microsoft.Extensions.DependencyInjection;

Expand Down
270 changes: 270 additions & 0 deletions src/Grpc.AspNetCore.HealthChecks/HealthServiceImpl.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
#region Copyright notice and license
// Copyright 2024 gRPC authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#endregion

using System.Threading.Channels;

using Grpc.Core;
using Grpc.Core.Utils;
using Grpc.Health.V1;
using Microsoft.Extensions.Hosting;

namespace Grpc.AspNetCore.HealthChecks;

/// <summary>
/// Implementation of a simple Health service. Useful for health checking.
///
/// Registering service with a server:
/// <code>
/// var serviceImpl = new HealthServiceImpl();
/// server = new Server();
/// server.AddServiceDefinition(Grpc.Health.V1.Health.BindService(serviceImpl));
/// </code>
/// </summary>
/// <param name="applicationLifetime">The application lifetime used to stop long-running streaming RPCs.</param>
public class HealthServiceImpl(IHostApplicationLifetime applicationLifetime) : Grpc.Health.V1.Health.HealthBase
{
// The maximum number of statuses to buffer on the server.
internal const int MaxStatusBufferSize = 5;

private readonly object statusLock = new object();
private readonly Dictionary<string, HealthCheckResponse.Types.ServingStatus> statusMap =
new Dictionary<string, HealthCheckResponse.Types.ServingStatus>();

private readonly object watchersLock = new object();
private readonly Dictionary<string, List<ChannelWriter<HealthCheckResponse>>> watchers =
new Dictionary<string, List<ChannelWriter<HealthCheckResponse>>>();

/// <summary>
/// Sets the health status for given service.
/// </summary>
/// <param name="service">The service. Cannot be null.</param>
/// <param name="status">the health status</param>
public void SetStatus(string service, HealthCheckResponse.Types.ServingStatus status)
{
HealthCheckResponse.Types.ServingStatus previousStatus;
lock (statusLock)
{
previousStatus = GetServiceStatus(service);
statusMap[service] = status;
}

if (status != previousStatus)
{
NotifyStatus(service, status);
}
}

/// <summary>
/// Clears health status for given service.
/// </summary>
/// <param name="service">The service. Cannot be null.</param>
public void ClearStatus(string service)
{
HealthCheckResponse.Types.ServingStatus previousStatus;
lock (statusLock)
{
previousStatus = GetServiceStatus(service);
statusMap.Remove(service);
}

if (previousStatus != HealthCheckResponse.Types.ServingStatus.ServiceUnknown)
{
NotifyStatus(service, HealthCheckResponse.Types.ServingStatus.ServiceUnknown);
}
}

/// <summary>
/// Clears statuses for all services.
/// </summary>
public void ClearAll()
{
List<KeyValuePair<string, HealthCheckResponse.Types.ServingStatus>> statuses;
lock (statusLock)
{
statuses = statusMap.ToList();
statusMap.Clear();
}

foreach (KeyValuePair<string, HealthCheckResponse.Types.ServingStatus> status in statuses)
{
if (status.Value != HealthCheckResponse.Types.ServingStatus.ServiceUnknown)
{
NotifyStatus(status.Key, HealthCheckResponse.Types.ServingStatus.ServiceUnknown);
}
}
}

/// <summary>
/// Performs a health status check.
/// </summary>
/// <param name="request">The check request.</param>
/// <param name="context">The call context.</param>
/// <returns>The asynchronous response.</returns>
public override Task<HealthCheckResponse> Check(HealthCheckRequest request, ServerCallContext context)
{
HealthCheckResponse response = GetHealthCheckResponse(request.Service, throwOnNotFound: true);

return Task.FromResult(response);
}

/// <summary>
/// Performs a watch for the serving status of the requested service.
/// The server will immediately send back a message indicating the current
/// serving status. It will then subsequently send a new message whenever
/// the service's serving status changes.
///
/// If the requested service is unknown when the call is received, the
/// server will send a message setting the serving status to
/// SERVICE_UNKNOWN but will *not* terminate the call. If at some
/// future point, the serving status of the service becomes known, the
/// server will send a new message with the service's serving status.
///
/// If the call terminates with status UNIMPLEMENTED, then clients
/// should assume this method is not supported and should not retry the
/// call. If the call terminates with any other status (including OK),
/// clients should retry the call with appropriate exponential backoff.
/// </summary>
/// <param name="request">The request received from the client.</param>
/// <param name="responseStream">Used for sending responses back to the client.</param>
/// <param name="context">The context of the server-side call handler being invoked.</param>
/// <returns>A task indicating completion of the handler.</returns>
public override async Task Watch(HealthCheckRequest request, IServerStreamWriter<HealthCheckResponse> responseStream, ServerCallContext context)
{
string service = request.Service;

// Channel is used to to marshall multiple callers updating status into a single queue.
// This is required because IServerStreamWriter is not thread safe.
//
// A queue of unwritten statuses could build up if flow control causes responseStream.WriteAsync to await.
// When this number is exceeded the server will discard older statuses. The discarded intermediate statues
// will never be sent to the client.
Channel<HealthCheckResponse> channel = Channel.CreateBounded<HealthCheckResponse>(new BoundedChannelOptions(capacity: MaxStatusBufferSize) {
SingleReader = true,
SingleWriter = false,
FullMode = BoundedChannelFullMode.DropOldest
});

lock (watchersLock)
{
if (!watchers.TryGetValue(service, out List<ChannelWriter<HealthCheckResponse>>? channelWriters))
{
channelWriters = new List<ChannelWriter<HealthCheckResponse>>();
watchers.Add(service, channelWriters);
}

channelWriters.Add(channel.Writer);
}

// Watch calls run until ended by the client canceling them.
context.CancellationToken.Register(() => {
lock (watchersLock)
{
if (watchers.TryGetValue(service, out List<ChannelWriter<HealthCheckResponse>>? channelWriters))
{
// Remove the writer from the watchers
if (channelWriters.Remove(channel.Writer))
{
// Remove empty collection if service has no more response streams
if (channelWriters.Count == 0)
{
watchers.Remove(service);
}
}
}
}

// Signal the writer is complete and the watch method can exit.
channel.Writer.Complete();
});

// Send current status immediately
HealthCheckResponse response = GetHealthCheckResponse(service, throwOnNotFound: false);
await responseStream.WriteAsync(response).ConfigureAwait(false);

try
{
// Read messages. WaitToReadAsync will wait until new messages are available.
// Loop will exit when the call is canceled and the writer is marked as complete or when the application is stopping.
while (await channel.Reader.WaitToReadAsync(applicationLifetime.ApplicationStopping).ConfigureAwait(false))
{
if (channel.Reader.TryRead(out HealthCheckResponse? item))
{
await responseStream.WriteAsync(item).ConfigureAwait(false);
}
}
}
catch (OperationCanceledException)
{
await responseStream.WriteAsync(new HealthCheckResponse { Status = HealthCheckResponse.Types.ServingStatus.NotServing });
}
}

private void NotifyStatus(string service, HealthCheckResponse.Types.ServingStatus status)
{
lock (watchersLock)
{
if (watchers.TryGetValue(service, out List<ChannelWriter<HealthCheckResponse>>? channelWriters))
{
HealthCheckResponse response = new HealthCheckResponse { Status = status };

foreach (ChannelWriter<HealthCheckResponse> writer in channelWriters)
{
if (!writer.TryWrite(response))
{
throw new InvalidOperationException("Unable to queue health check notification.");
}
}
}
}
}

private HealthCheckResponse GetHealthCheckResponse(string service, bool throwOnNotFound)
{
HealthCheckResponse response;
lock (statusLock)
{
if (!statusMap.TryGetValue(service, out HealthCheckResponse.Types.ServingStatus status))
{
if (throwOnNotFound)
{
// TODO(jtattermusch): returning specific status from server handler is not supported yet.
throw new RpcException(new Status(StatusCode.NotFound, ""));
}
else
{
status = HealthCheckResponse.Types.ServingStatus.ServiceUnknown;
}
}
response = new HealthCheckResponse { Status = status };
}

return response;
}

private HealthCheckResponse.Types.ServingStatus GetServiceStatus(string service)
{
GrpcPreconditions.CheckNotNull(service, nameof(service));
if (statusMap.TryGetValue(service, out HealthCheckResponse.Types.ServingStatus s))
{
return s;
}
else
{
// A service with no set status has a status of ServiceUnknown
return HealthCheckResponse.Types.ServingStatus.ServiceUnknown;
}
}
}