diff --git a/package-lock.json b/package-lock.json index e9c4511f..3f1fc41a 100644 --- a/package-lock.json +++ b/package-lock.json @@ -9,7 +9,7 @@ "version": "0.0.15", "license": "Apache-2.0", "dependencies": { - "@modelcontextprotocol/sdk": "^1.6.1", + "@modelcontextprotocol/sdk": "^1.10.1", "commander": "^13.1.0", "playwright": "1.53.0-alpha-1745357020000", "yaml": "^2.7.1", @@ -228,17 +228,18 @@ } }, "node_modules/@modelcontextprotocol/sdk": { - "version": "1.7.0", - "resolved": "https://registry.npmjs.org/@modelcontextprotocol/sdk/-/sdk-1.7.0.tgz", - "integrity": "sha512-IYPe/FLpvF3IZrd/f5p5ffmWhMc3aEMuM2wGJASDqC2Ge7qatVCdbfPx3n/5xFeb19xN0j/911M2AaFuircsWA==", + "version": "1.10.1", + "resolved": "https://registry.npmjs.org/@modelcontextprotocol/sdk/-/sdk-1.10.1.tgz", + "integrity": "sha512-xNYdFdkJqEfIaTVP1gPKoEvluACHZsHZegIoICX8DM1o6Qf3G5u2BQJHmgd0n4YgRPqqK/u1ujQvrgAxxSJT9w==", "license": "MIT", "dependencies": { "content-type": "^1.0.5", "cors": "^2.8.5", + "cross-spawn": "^7.0.3", "eventsource": "^3.0.2", "express": "^5.0.1", "express-rate-limit": "^7.5.0", - "pkce-challenge": "^4.1.0", + "pkce-challenge": "^5.0.0", "raw-body": "^3.0.0", "zod": "^3.23.8", "zod-to-json-schema": "^3.24.1" @@ -1091,7 +1092,6 @@ "version": "7.0.6", "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", "integrity": "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==", - "dev": true, "license": "MIT", "dependencies": { "path-key": "^3.1.0", @@ -2786,7 +2786,6 @@ "version": "2.0.0", "resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz", "integrity": "sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==", - "dev": true, "license": "ISC" }, "node_modules/js-yaml": { @@ -3256,7 +3255,6 @@ "version": "3.1.1", "resolved": "https://registry.npmjs.org/path-key/-/path-key-3.1.1.tgz", "integrity": "sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==", - "dev": true, "license": "MIT", "engines": { "node": ">=8" @@ -3292,9 +3290,9 @@ } }, "node_modules/pkce-challenge": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/pkce-challenge/-/pkce-challenge-4.1.0.tgz", - "integrity": "sha512-ZBmhE1C9LcPoH9XZSdwiPtbPHZROwAnMy+kIFQVrnMCxY4Cudlz3gBOpzilgc0jOgRaiT3sIWfpMomW2ar2orQ==", + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/pkce-challenge/-/pkce-challenge-5.0.0.tgz", + "integrity": "sha512-ueGLflrrnvwB3xuo/uGob5pd5FN7l0MsLf0Z87o/UQmRtwjvfylfc9MurIxRAWywCYTgrvpXBcqjV4OfCYGCIQ==", "license": "MIT", "engines": { "node": ">=16.20.0" @@ -3796,7 +3794,6 @@ "version": "2.0.0", "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz", "integrity": "sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==", - "dev": true, "license": "MIT", "dependencies": { "shebang-regex": "^3.0.0" @@ -3809,7 +3806,6 @@ "version": "3.0.0", "resolved": "https://registry.npmjs.org/shebang-regex/-/shebang-regex-3.0.0.tgz", "integrity": "sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==", - "dev": true, "license": "MIT", "engines": { "node": ">=8" @@ -4238,7 +4234,6 @@ "version": "2.0.2", "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", "integrity": "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==", - "dev": true, "license": "ISC", "dependencies": { "isexe": "^2.0.0" diff --git a/package.json b/package.json index a27c613a..fa8144e8 100644 --- a/package.json +++ b/package.json @@ -34,7 +34,7 @@ } }, "dependencies": { - "@modelcontextprotocol/sdk": "^1.6.1", + "@modelcontextprotocol/sdk": "^1.10.1", "commander": "^13.1.0", "playwright": "1.53.0-alpha-1745357020000", "yaml": "^2.7.1", diff --git a/src/program.ts b/src/program.ts index c678c376..4c5f5632 100644 --- a/src/program.ts +++ b/src/program.ts @@ -14,18 +14,13 @@ * limitations under the License. */ -import http from 'http'; - import { program } from 'commander'; -import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'; -import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js'; - import { createServer } from './index'; import { ServerList } from './server'; -import assert from 'assert'; import { ToolCapability } from './tools/tool'; +import { startHttpTransport, startStdioTransport } from './transport'; const packageJSON = require('../package.json'); @@ -53,12 +48,10 @@ program })); setupExitWatchdog(serverList); - if (options.port) { - startSSEServer(+options.port, options.host || 'localhost', serverList); - } else { - const server = await serverList.create(); - await server.connect(new StdioServerTransport()); - } + if (options.port) + startHttpTransport(+options.port, options.host, serverList); + else + await startStdioTransport(serverList); }); function setupExitWatchdog(serverList: ServerList) { @@ -74,64 +67,3 @@ function setupExitWatchdog(serverList: ServerList) { } program.parse(process.argv); - -function startSSEServer(port: number, host: string, serverList: ServerList) { - const sessions = new Map(); - const httpServer = http.createServer(async (req, res) => { - if (req.method === 'POST') { - const searchParams = new URL(`http://localhost${req.url}`).searchParams; - const sessionId = searchParams.get('sessionId'); - if (!sessionId) { - res.statusCode = 400; - res.end('Missing sessionId'); - return; - } - const transport = sessions.get(sessionId); - if (!transport) { - res.statusCode = 404; - res.end('Session not found'); - return; - } - - await transport.handlePostMessage(req, res); - return; - } else if (req.method === 'GET') { - const transport = new SSEServerTransport('/sse', res); - sessions.set(transport.sessionId, transport); - const server = await serverList.create(); - res.on('close', () => { - sessions.delete(transport.sessionId); - serverList.close(server).catch(e => console.error(e)); - }); - await server.connect(transport); - return; - } else { - res.statusCode = 405; - res.end('Method not allowed'); - } - }); - - httpServer.listen(port, host, () => { - const address = httpServer.address(); - assert(address, 'Could not bind server socket'); - let url: string; - if (typeof address === 'string') { - url = address; - } else { - const resolvedPort = address.port; - let resolvedHost = address.family === 'IPv4' ? address.address : `[${address.address}]`; - if (resolvedHost === '0.0.0.0' || resolvedHost === '[::]') - resolvedHost = host === 'localhost' ? 'localhost' : resolvedHost; - url = `http://${resolvedHost}:${resolvedPort}`; - } - console.log(`Listening on ${url}`); - console.log('Put this in your client config:'); - console.log(JSON.stringify({ - 'mcpServers': { - 'playwright': { - 'url': `${url}/sse` - } - } - }, undefined, 2)); - }); -} diff --git a/src/transport.ts b/src/transport.ts new file mode 100644 index 00000000..b21db068 --- /dev/null +++ b/src/transport.ts @@ -0,0 +1,127 @@ +/** + * Copyright (c) Microsoft Corporation. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import http from 'node:http'; +import assert from 'node:assert'; +import crypto from 'node:crypto'; + +import { ServerList } from './server'; +import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'; +import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js'; +import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js'; + +export async function startStdioTransport(serverList: ServerList) { + const server = await serverList.create(); + await server.connect(new StdioServerTransport()); +} + +async function handleSSE(req: http.IncomingMessage, res: http.ServerResponse, url: URL, serverList: ServerList, sessions: Map) { + if (req.method === 'POST') { + const sessionId = url.searchParams.get('sessionId'); + if (!sessionId) { + res.statusCode = 400; + return res.end('Missing sessionId'); + } + + const transport = sessions.get(sessionId); + if (!transport) { + res.statusCode = 404; + return res.end('Session not found'); + } + + return await transport.handlePostMessage(req, res); + } else if (req.method === 'GET') { + const transport = new SSEServerTransport('/sse', res); + sessions.set(transport.sessionId, transport); + const server = await serverList.create(); + res.on('close', () => { + sessions.delete(transport.sessionId); + serverList.close(server).catch(e => console.error(e)); + }); + return await server.connect(transport); + } + + res.statusCode = 405; + res.end('Method not allowed'); +} + +async function handleStreamable(req: http.IncomingMessage, res: http.ServerResponse, serverList: ServerList, sessions: Map) { + const sessionId = req.headers['mcp-session-id'] as string | undefined; + if (sessionId) { + const transport = sessions.get(sessionId); + if (!transport) { + res.statusCode = 404; + res.end('Session not found'); + return; + } + return await transport.handleRequest(req, res); + } + + if (req.method === 'POST') { + const transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => crypto.randomUUID(), + onsessioninitialized: sessionId => { + sessions.set(sessionId, transport); + } + }); + transport.onclose = () => { + if (transport.sessionId) + sessions.delete(transport.sessionId); + }; + const server = await serverList.create(); + await server.connect(transport); + return await transport.handleRequest(req, res); + } + + res.statusCode = 400; + res.end('Invalid request'); +} + +export function startHttpTransport(port: number, hostname: string | undefined, serverList: ServerList) { + const sseSessions = new Map(); + const streamableSessions = new Map(); + const httpServer = http.createServer(async (req, res) => { + const url = new URL(`http://localhost${req.url}`); + if (url.pathname.startsWith('/mcp')) + await handleStreamable(req, res, serverList, streamableSessions); + else + await handleSSE(req, res, url, serverList, sseSessions); + }); + httpServer.listen(port, hostname, () => { + const address = httpServer.address(); + assert(address, 'Could not bind server socket'); + let url: string; + if (typeof address === 'string') { + url = address; + } else { + const resolvedPort = address.port; + let resolvedHost = address.family === 'IPv4' ? address.address : `[${address.address}]`; + if (resolvedHost === '0.0.0.0' || resolvedHost === '[::]') + resolvedHost = 'localhost'; + url = `http://${resolvedHost}:${resolvedPort}`; + } + console.log(`Listening on ${url}`); + console.log('Put this in your client config:'); + console.log(JSON.stringify({ + 'mcpServers': { + 'playwright': { + 'url': `${url}/sse` + } + } + }, undefined, 2)); + console.log('If your client supports streamable HTTP, you can use the /mcp endpoint instead.'); + }); +} diff --git a/tests/sse.spec.ts b/tests/sse.spec.ts index ad627efd..eaacd154 100644 --- a/tests/sse.spec.ts +++ b/tests/sse.spec.ts @@ -16,27 +16,45 @@ import { spawn } from 'node:child_process'; import path from 'node:path'; -import { test } from './fixtures'; +import { test as baseTest } from './fixtures'; +import { expect } from 'playwright/test'; -test('sse transport', async () => { - const cp = spawn('node', [path.join(__dirname, '../cli.js'), '--port', '0'], { stdio: 'pipe' }); - try { - let stdout = ''; - const url = await new Promise(resolve => cp.stdout?.on('data', data => { - stdout += data.toString(); - const match = stdout.match(/Listening on (http:\/\/.*)/); - if (match) - resolve(match[1]); - })); +const test = baseTest.extend<{ serverEndpoint: string }>({ + serverEndpoint: async ({}, use) => { + const cp = spawn('node', [path.join(__dirname, '../cli.js'), '--port', '0'], { stdio: 'pipe' }); + try { + let stdout = ''; + const url = await new Promise(resolve => cp.stdout?.on('data', data => { + stdout += data.toString(); + const match = stdout.match(/Listening on (http:\/\/.*)/); + if (match) + resolve(match[1]); + })); - // need dynamic import b/c of some ESM nonsense - const { SSEClientTransport } = await import('@modelcontextprotocol/sdk/client/sse.js'); - const { Client } = await import('@modelcontextprotocol/sdk/client/index.js'); - const transport = new SSEClientTransport(new URL(url)); - const client = new Client({ name: 'test', version: '1.0.0' }); - await client.connect(transport); - await client.ping(); - } finally { - cp.kill(); - } + await use(url); + } finally { + cp.kill(); + } + }, +}); + +test('sse transport', async ({ serverEndpoint }) => { + // need dynamic import b/c of some ESM nonsense + const { SSEClientTransport } = await import('@modelcontextprotocol/sdk/client/sse.js'); + const { Client } = await import('@modelcontextprotocol/sdk/client/index.js'); + const transport = new SSEClientTransport(new URL(serverEndpoint)); + const client = new Client({ name: 'test', version: '1.0.0' }); + await client.connect(transport); + await client.ping(); +}); + +test('streamable http transport', async ({ serverEndpoint }) => { + // need dynamic import b/c of some ESM nonsense + const { StreamableHTTPClientTransport } = await import('@modelcontextprotocol/sdk/client/streamableHttp.js'); + const { Client } = await import('@modelcontextprotocol/sdk/client/index.js'); + const transport = new StreamableHTTPClientTransport(new URL('/mcp', serverEndpoint)); + const client = new Client({ name: 'test', version: '1.0.0' }); + await client.connect(transport); + await client.ping(); + expect(transport.sessionId, 'has session support').toBeDefined(); });