-
Notifications
You must be signed in to change notification settings - Fork 161
/
Copy pathSseIntegrationTests.cs
158 lines (131 loc) · 6.35 KB
/
SseIntegrationTests.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Routing;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using ModelContextProtocol.Client;
using ModelContextProtocol.Protocol.Messages;
using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Server;
using ModelContextProtocol.Tests.Utils;
using ModelContextProtocol.Utils.Json;
namespace ModelContextProtocol.Tests;
public class SseIntegrationTests(ITestOutputHelper outputHelper) : KestrelInMemoryTest(outputHelper)
{
private SseClientTransportOptions DefaultTransportOptions = new()
{
Endpoint = new Uri("http://localhost/sse"),
Name = "In-memory Test Server",
};
private Task<IMcpClient> ConnectMcpClient(HttpClient httpClient, McpClientOptions? clientOptions = null)
=> McpClientFactory.CreateAsync(
new SseClientTransport(DefaultTransportOptions, httpClient, LoggerFactory),
clientOptions,
LoggerFactory,
TestContext.Current.CancellationToken);
[Fact]
public async Task ConnectAndReceiveMessage_InMemoryServer()
{
await using var app = Builder.Build();
app.MapMcp();
await app.StartAsync(TestContext.Current.CancellationToken);
using var httpClient = CreateHttpClient();
await using var mcpClient = await ConnectMcpClient(httpClient);
// Send a test message through POST endpoint
await mcpClient.SendNotificationAsync("test/message", new { message = "Hello, SSE!" }, cancellationToken: TestContext.Current.CancellationToken);
Assert.True(true);
}
[Fact]
public async Task ConnectAndReceiveMessage_InMemoryServer_WithFullEndpointEventUri()
{
await using var app = Builder.Build();
MapAbsoluteEndpointUriMcp(app);
await app.StartAsync(TestContext.Current.CancellationToken);
using var httpClient = CreateHttpClient();
await using var mcpClient = await ConnectMcpClient(httpClient);
// Send a test message through POST endpoint
await mcpClient.SendNotificationAsync("test/message", new { message = "Hello, SSE!" }, cancellationToken: TestContext.Current.CancellationToken);
Assert.True(true);
}
[Fact]
public async Task ConnectAndReceiveNotification_InMemoryServer()
{
var receivedNotification = new TaskCompletionSource<string?>();
await using var app = Builder.Build();
app.MapMcp(runSessionAsync: (httpContext, mcpServer, cancellationToken) =>
{
mcpServer.RegisterNotificationHandler("test/notification", async (notification, cancellationToken) =>
{
Assert.Equal("Hello from client!", notification.Params?["message"]?.GetValue<string>());
await mcpServer.SendNotificationAsync("test/notification", new { message = "Hello from server!" }, cancellationToken: cancellationToken);
});
return mcpServer.RunAsync(cancellationToken);
});
await app.StartAsync(TestContext.Current.CancellationToken);
using var httpClient = CreateHttpClient();
await using var mcpClient = await ConnectMcpClient(httpClient);
mcpClient.RegisterNotificationHandler("test/notification", (args, ca) =>
{
var msg = args.Params?["message"]?.GetValue<string>();
receivedNotification.SetResult(msg);
return Task.CompletedTask;
});
// Send a test message through POST endpoint
await mcpClient.SendNotificationAsync("test/notification", new { message = "Hello from client!" }, cancellationToken: TestContext.Current.CancellationToken);
var message = await receivedNotification.Task.WaitAsync(TimeSpan.FromSeconds(10), TestContext.Current.CancellationToken);
Assert.Equal("Hello from server!", message);
}
private static void MapAbsoluteEndpointUriMcp(IEndpointRouteBuilder endpoints)
{
var loggerFactory = endpoints.ServiceProvider.GetRequiredService<ILoggerFactory>();
var optionsSnapshot = endpoints.ServiceProvider.GetRequiredService<IOptions<McpServerOptions>>();
var routeGroup = endpoints.MapGroup("");
SseResponseStreamTransport? session = null;
routeGroup.MapGet("/sse", async context =>
{
var response = context.Response;
var requestAborted = context.RequestAborted;
response.Headers.ContentType = "text/event-stream";
await using var transport = new SseResponseStreamTransport(response.Body, "http://localhost/message");
session = transport;
try
{
var transportTask = transport.RunAsync(cancellationToken: requestAborted);
await using var server = McpServerFactory.Create(transport, optionsSnapshot.Value, loggerFactory, endpoints.ServiceProvider);
try
{
await server.RunAsync(requestAborted);
}
finally
{
await transport.DisposeAsync();
await transportTask;
}
}
catch (OperationCanceledException) when (requestAborted.IsCancellationRequested)
{
// RequestAborted always triggers when the client disconnects before a complete response body is written,
// but this is how SSE connections are typically closed.
}
});
routeGroup.MapPost("/message", async context =>
{
if (session is null)
{
await Results.BadRequest("Session not started.").ExecuteAsync(context);
return;
}
var message = (IJsonRpcMessage?)await context.Request.ReadFromJsonAsync(McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(IJsonRpcMessage)), context.RequestAborted);
if (message is null)
{
await Results.BadRequest("No message in request body.").ExecuteAsync(context);
return;
}
await session.OnMessageReceivedAsync(message, context.RequestAborted);
context.Response.StatusCode = StatusCodes.Status202Accepted;
await context.Response.WriteAsync("Accepted");
});
}
}