From ab0d32d2935f0d639a6192b2365d095ac3714038 Mon Sep 17 00:00:00 2001 From: Simon Knott Date: Wed, 23 Apr 2025 10:30:01 +0200 Subject: [PATCH] chore: conditionally disable modal state tools --- src/server.ts | 19 ++++++++++++++----- tests/capabilities.spec.ts | 4 ---- tests/files.spec.ts | 6 ++++++ 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/src/server.ts b/src/server.ts index 2716528e..827d5d1e 100644 --- a/src/server.ts +++ b/src/server.ts @@ -42,8 +42,10 @@ export function createServerWithTools(options: Options): Server { }); server.setRequestHandler(ListToolsRequestSchema, async () => { + const modalStates = context.modalStates().map(state => state.type); + const activeTools = tools.filter(tool => !tool.clearsModalState || modalStates.includes(tool.clearsModalState)); return { - tools: tools.map(tool => ({ + tools: activeTools.map(tool => ({ name: tool.schema.name, description: tool.schema.description, inputSchema: zodToJsonSchema(tool.schema.inputSchema) @@ -64,9 +66,9 @@ export function createServerWithTools(options: Options): Server { }; } - const modalStates = context.modalStates().map(state => state.type); - if ((tool.clearsModalState && !modalStates.includes(tool.clearsModalState)) || - (!tool.clearsModalState && modalStates.length)) { + const modalStates = new Set(context.modalStates().map(state => state.type)); + if ((tool.clearsModalState && !modalStates.has(tool.clearsModalState)) || + (!tool.clearsModalState && modalStates.size)) { const text = [ `Tool "${request.params.name}" does not handle the modal state.`, ...context.modalStatesMarkdown(), @@ -78,7 +80,14 @@ export function createServerWithTools(options: Options): Server { } try { - return await context.run(tool, request.params.arguments); + const response = await context.run(tool, request.params.arguments); + + const newModalStates = context.modalStates().map(state => state.type); + const modalStateChanged = newModalStates.length !== modalStates.size || newModalStates.some(state => !modalStates.has(state)); + if (modalStateChanged) + await server.sendToolListChanged(); + + return response; } catch (error) { return { content: [{ type: 'text', text: String(error) }], diff --git a/tests/capabilities.spec.ts b/tests/capabilities.spec.ts index f4ad6889..f46a863f 100644 --- a/tests/capabilities.spec.ts +++ b/tests/capabilities.spec.ts @@ -22,8 +22,6 @@ test('test snapshot tool list', async ({ client }) => { 'browser_click', 'browser_console_messages', 'browser_drag', - 'browser_file_upload', - 'browser_handle_dialog', 'browser_hover', 'browser_select_option', 'browser_type', @@ -51,8 +49,6 @@ test('test vision tool list', async ({ visionClient }) => { expect(new Set(visionTools.map(t => t.name))).toEqual(new Set([ 'browser_close', 'browser_console_messages', - 'browser_file_upload', - 'browser_handle_dialog', 'browser_install', 'browser_navigate_back', 'browser_navigate_forward', diff --git a/tests/files.spec.ts b/tests/files.spec.ts index 6e9b7979..2488efa9 100644 --- a/tests/files.spec.ts +++ b/tests/files.spec.ts @@ -14,10 +14,14 @@ * limitations under the License. */ +import { Notification } from '@modelcontextprotocol/sdk/types.js'; import { test, expect } from './fixtures'; import fs from 'fs/promises'; test('browser_file_upload', async ({ client }) => { + const notifications: Notification[] = []; + client.fallbackNotificationHandler = async notification => { notifications.push(notification); }; + expect(await client.callTool({ name: 'browser_navigate', arguments: { @@ -39,6 +43,8 @@ test('browser_file_upload', async ({ client }) => { })).toContainTextContent(`### Modal state - [File chooser]: can be handled by the "browser_file_upload" tool`); + expect(notifications).toContainEqual(expect.objectContaining({ method: 'notifications/tools/list_changed' })); + const filePath = test.info().outputPath('test.txt'); await fs.writeFile(filePath, 'Hello, world!');