diff --git a/apps/studio/electron/main/chat/index.ts b/apps/studio/electron/main/chat/index.ts index 11a17ab70f..4fd1bc3936 100644 --- a/apps/studio/electron/main/chat/index.ts +++ b/apps/studio/electron/main/chat/index.ts @@ -1,6 +1,6 @@ import { PromptProvider } from '@onlook/ai/src/prompt/provider'; import { listFilesTool, readFileTool } from '@onlook/ai/src/tools'; -import { CLAUDE_MODELS, LLMProvider } from '@onlook/models'; +import { CLAUDE_MODELS, LLMProvider, MCP_MODELS } from '@onlook/models'; import { ChatSuggestionSchema, StreamRequestType, @@ -9,7 +9,13 @@ import { type UsageCheckResult, } from '@onlook/models/chat'; import { MainChannels } from '@onlook/models/constants'; -import { generateObject, streamText, type CoreMessage, type CoreSystemMessage } from 'ai'; +import { + generateObject, + streamText, + type CoreMessage, + type CoreSystemMessage, + type LanguageModelV1, +} from 'ai'; import { mainWindow } from '..'; import { PersistentStorage } from '../storage'; import { initModel } from './llmProvider'; @@ -68,9 +74,7 @@ class LlmManager { } as CoreSystemMessage; messages = [systemMessage, ...messages]; } - const model = await initModel(LLMProvider.ANTHROPIC, CLAUDE_MODELS.SONNET, { - requestType, - }); + const model = await this.getModel(requestType); const { textStream } = await streamText({ model, @@ -156,11 +160,18 @@ class LlmManager { return 'An unknown error occurred'; } + private async getModel(requestType: StreamRequestType): Promise { + // Get the provider and model from settings or use defaults + const settings = PersistentStorage.USER_SETTINGS.read() || {}; + const provider = settings.llmProvider || LLMProvider.ANTHROPIC; + const modelName = settings.llmModel || CLAUDE_MODELS.SONNET; + + return await initModel(provider, modelName, { requestType }); + } + public async generateSuggestions(messages: CoreMessage[]): Promise { try { - const model = await initModel(LLMProvider.ANTHROPIC, CLAUDE_MODELS.HAIKU, { - requestType: StreamRequestType.SUGGESTIONS, - }); + const model = await this.getModel(StreamRequestType.SUGGESTIONS); const { object } = await generateObject({ model, diff --git a/apps/studio/electron/main/chat/llmProvider.ts b/apps/studio/electron/main/chat/llmProvider.ts index a18c34fcf6..105609bb6c 100644 --- a/apps/studio/electron/main/chat/llmProvider.ts +++ b/apps/studio/electron/main/chat/llmProvider.ts @@ -1,7 +1,7 @@ import { createAnthropic } from '@ai-sdk/anthropic'; import type { StreamRequestType } from '@onlook/models/chat'; import { BASE_PROXY_ROUTE, FUNCTIONS_ROUTE, ProxyRoutes } from '@onlook/models/constants'; -import { CLAUDE_MODELS, LLMProvider } from '@onlook/models/llm'; +import { CLAUDE_MODELS, LLMProvider, MCP_MODELS } from '@onlook/models/llm'; import { type LanguageModelV1 } from 'ai'; import { getRefreshedAuthTokens } from '../auth'; export interface OnlookPayload { @@ -10,12 +10,14 @@ export interface OnlookPayload { export async function initModel( provider: LLMProvider, - model: CLAUDE_MODELS, + model: CLAUDE_MODELS | MCP_MODELS, payload: OnlookPayload, ): Promise { switch (provider) { case LLMProvider.ANTHROPIC: - return await getAnthropicProvider(model, payload); + return await getAnthropicProvider(model as CLAUDE_MODELS, payload); + case LLMProvider.MCP: + return await getMCPProvider(model as MCP_MODELS, payload); default: throw new Error(`Unsupported provider: ${provider}`); } @@ -54,3 +56,31 @@ async function getAnthropicProvider( cacheControl: true, }); } + +async function getMCPProvider(model: MCP_MODELS, payload: OnlookPayload): Promise { + // Import the MCP client and adapter + const { MCPClient, createMCPLanguageModel } = await import('@onlook/ai/mcp'); + + // Create a new MCP client + const client = new MCPClient({ + name: 'onlook-mcp-client', + version: '1.0.0', + }); + + // Connect to the MCP server + // Note: The connection details would need to be configured + await client.connect({ + type: 'stdio', + command: 'mcp-server', // This would need to be configured + args: ['--stdio'], + }); + + // Initialize the client + await client.initialize(); + + // Create a language model adapter that implements the LanguageModelV1 interface + return createMCPLanguageModel({ + client, + model: model.toString(), + }); +} diff --git a/bun.lockb b/bun.lockb index 6b3679a839..13eb0c23e2 100755 Binary files a/bun.lockb and b/bun.lockb differ diff --git a/packages/ai/package.json b/packages/ai/package.json index f1e268b76d..37e086c3d3 100644 --- a/packages/ai/package.json +++ b/packages/ai/package.json @@ -31,8 +31,10 @@ "@onlook/typescript": "*" }, "dependencies": { + "@modelcontextprotocol/sdk": "latest", "diff-match-patch": "^1.0.5", "fg": "^0.0.3", - "marked": "^15.0.7" + "marked": "^15.0.7", + "zod": "^3.22.4" } } diff --git a/packages/ai/src/index.ts b/packages/ai/src/index.ts index 7bfe02e857..2fd23f6a0f 100644 --- a/packages/ai/src/index.ts +++ b/packages/ai/src/index.ts @@ -1,2 +1,3 @@ export * from './coder'; export * from './prompt'; +export * from './mcp'; diff --git a/packages/ai/src/mcp/adapters/index.ts b/packages/ai/src/mcp/adapters/index.ts new file mode 100644 index 0000000000..8e281e0bb2 --- /dev/null +++ b/packages/ai/src/mcp/adapters/index.ts @@ -0,0 +1,5 @@ +/** + * MCP adapters exports + */ + +export * from './language-model'; diff --git a/packages/ai/src/mcp/adapters/language-model.ts b/packages/ai/src/mcp/adapters/language-model.ts new file mode 100644 index 0000000000..992bc62b1d --- /dev/null +++ b/packages/ai/src/mcp/adapters/language-model.ts @@ -0,0 +1,141 @@ +/** + * MCP language model adapter for the AI SDK + */ + +import type { LanguageModelV1, LanguageModelV1CallOptions } from 'ai'; +import { MCPClient } from '../client'; +import { MCPCapabilityManager } from '../context/manager'; +import { MCPError, MCPErrorType } from '../client/types'; + +/** + * Options for creating an MCP language model adapter + */ +export interface MCPLanguageModelOptions { + /** + * MCP client + */ + client: MCPClient; + + /** + * Capability manager + */ + capabilityManager?: MCPCapabilityManager; + + /** + * Model name + */ + model?: string; +} + +/** + * Create an MCP language model adapter + * + * @param options Options for creating the adapter + * @returns Language model adapter + */ +export function createMCPLanguageModel(options: MCPLanguageModelOptions): LanguageModelV1 { + const { client, capabilityManager, model } = options; + + // Create a capability manager if not provided + const manager = capabilityManager || new MCPCapabilityManager(client); + + return { + specificationVersion: 'v1', + provider: 'mcp', + modelId: `mcp-${model || 'default'}`, + defaultObjectGenerationMode: undefined, + doStream: async (options: LanguageModelV1CallOptions) => { + throw new Error('Streaming is not supported by the MCP adapter yet'); + }, + doGenerate: async (options: LanguageModelV1CallOptions) => { + try { + // Refresh capabilities if needed + if (!manager.getCapabilities().tools.length) { + await manager.refreshAll(); + } + + // Find a suitable tool for text generation + const tools = manager.getCapabilities().tools; + const generateTool = tools.find( + (tool) => tool.name === 'generate' || tool.name.includes('generate'), + ); + + if (!generateTool) { + throw new MCPError( + 'No text generation tool found', + MCPErrorType.TOOL_CALL_ERROR, + ); + } + + // Extract parameters from options + const { maxTokens, temperature, topP, stopSequences } = options; + + // Prepare messages for the tool call + const messages = []; + + // Add user content if available + if (options.prompt) { + messages.push({ role: 'user', content: options.prompt }); + } + + // Call the tool with the messages + const result = await client.callTool(generateTool.name, { + messages, + maxTokens, + temperature, + topP, + stopSequences, + }); + + // Extract the text from the result + const text = result.content.map((item) => item.text).join(''); + + // Prepare the request body for rawCall + const requestBody = { + messages, + maxTokens, + temperature, + topP, + stopSequences, + }; + + return { + text, + toolCalls: undefined, + finishReason: 'stop', + usage: { + promptTokens: 0, + completionTokens: 0, + }, + rawCall: { + rawPrompt: messages, + rawSettings: { + maxTokens, + temperature, + topP, + stopSequences, + }, + }, + request: { + body: JSON.stringify(requestBody), + }, + response: { + id: `mcp-response-${Date.now()}`, + timestamp: new Date(), + modelId: `mcp-${model || 'default'}`, + }, + }; + } catch (error) { + if (error instanceof MCPError) { + throw error; + } + + throw new MCPError( + `Failed to generate text: ${error instanceof Error ? error.message : String(error)}`, + MCPErrorType.TOOL_CALL_ERROR, + error, + ); + } + }, + }; +} diff --git a/packages/ai/src/mcp/capabilities/index.ts b/packages/ai/src/mcp/capabilities/index.ts new file mode 100644 index 0000000000..e1b652418d --- /dev/null +++ b/packages/ai/src/mcp/capabilities/index.ts @@ -0,0 +1,8 @@ +/** + * MCP capability exports + */ + +export * from './tools'; +export * from './resources'; +export * from './prompts'; +export * from './roots'; diff --git a/packages/ai/src/mcp/capabilities/prompts.ts b/packages/ai/src/mcp/capabilities/prompts.ts new file mode 100644 index 0000000000..55bad6596e --- /dev/null +++ b/packages/ai/src/mcp/capabilities/prompts.ts @@ -0,0 +1,84 @@ +/** + * Prompt capability management for MCP client + */ + +import type { Prompt } from '../client/types'; +import { MCPClient } from '../client'; +import { MCPError, MCPErrorType } from '../client/types'; + +/** + * Find a prompt by name + * + * @param client MCP client + * @param name Name of the prompt to find + * @returns Prompt if found, null otherwise + */ +export async function findPromptByName(client: MCPClient, name: string): Promise { + const prompts = await client.listPrompts(); + return prompts.prompts.find((p) => p.name === name) || null; +} + +/** + * Validate prompt arguments + * + * @param prompt Prompt to validate arguments for + * @param args Arguments to validate + * @returns Validated arguments + */ +export function validatePromptArgs( + prompt: Prompt, + args: Record, +): Record { + if (!prompt.arguments || prompt.arguments.length === 0) { + return {}; + } + + const validatedArgs: Record = {}; + const missingRequired: string[] = []; + + for (const arg of prompt.arguments) { + const value = args[arg.name]; + + if (arg.required && (value === undefined || value === null)) { + missingRequired.push(arg.name); + } else if (value !== undefined) { + validatedArgs[arg.name] = value; + } + } + + if (missingRequired.length > 0) { + throw new MCPError( + `Missing required arguments for prompt ${prompt.name}: ${missingRequired.join(', ')}`, + MCPErrorType.VALIDATION_ERROR, + ); + } + + return validatedArgs; +} + +/** + * Get a prompt with validated arguments + * + * @param client MCP client + * @param promptName Name of the prompt to get + * @param args Arguments for the prompt + * @returns Prompt with validated arguments + */ +export async function getPromptWithValidation( + client: MCPClient, + promptName: string, + args: Record, +) { + const prompt = await findPromptByName(client, promptName); + + if (!prompt) { + throw new MCPError(`Prompt not found: ${promptName}`, MCPErrorType.PROMPT_ERROR); + } + + const validatedArgs = validatePromptArgs(prompt, args); + + return { + prompt, + args: validatedArgs, + }; +} diff --git a/packages/ai/src/mcp/capabilities/resources.ts b/packages/ai/src/mcp/capabilities/resources.ts new file mode 100644 index 0000000000..23369dc11a --- /dev/null +++ b/packages/ai/src/mcp/capabilities/resources.ts @@ -0,0 +1,110 @@ +/** + * Resource capability management for MCP client + */ + +import type { Resource, ResourceTemplate } from '../client/types'; +import { MCPClient } from '../client'; +import { MCPError, MCPErrorType } from '../client/types'; + +/** + * Expand a resource template with parameters + * + * @param template Resource template to expand + * @param params Parameters to expand the template with + * @returns Expanded URI + */ +export function expandResourceTemplate( + template: ResourceTemplate, + params: Record, +): string { + let uri = template.uriTemplate; + + for (const [key, value] of Object.entries(params)) { + uri = uri.replace(`{${key}}`, value); + } + + return uri; +} + +/** + * Read a resource with template expansion + * + * @param client MCP client + * @param uriOrTemplate URI or template to read + * @param params Parameters for template expansion + * @returns Resource content + */ +export async function readResourceWithTemplate( + client: MCPClient, + uriOrTemplate: string, + params?: Record, +) { + try { + // If params are provided, treat as a template + if (params) { + const resources = await client.listResources(); + // The SDK types don't match our types exactly + const resourceTemplates = Array.isArray(resources.resourceTemplates) + ? resources.resourceTemplates + : []; + const template = resourceTemplates.find((t: any) => t.uriTemplate === uriOrTemplate); + + if (!template) { + throw new MCPError( + `Resource template not found: ${uriOrTemplate}`, + MCPErrorType.RESOURCE_ERROR, + ); + } + + const uri = expandResourceTemplate(template, params); + return await client.readResource(uri); + } + + // Otherwise, treat as a direct URI + return await client.readResource(uriOrTemplate); + } catch (error) { + if (error instanceof MCPError) { + throw error; + } + + throw new MCPError( + `Failed to read resource: ${error instanceof Error ? error.message : String(error)}`, + MCPErrorType.RESOURCE_ERROR, + error, + ); + } +} + +/** + * Find a resource by name + * + * @param client MCP client + * @param name Name of the resource to find + * @returns Resource if found, null otherwise + */ +export async function findResourceByName( + client: MCPClient, + name: string, +): Promise { + const resources = await client.listResources(); + return resources.resources.find((r) => r.name === name) || null; +} + +/** + * Find a resource template by name + * + * @param client MCP client + * @param name Name of the resource template to find + * @returns Resource template if found, null otherwise + */ +export async function findResourceTemplateByName( + client: MCPClient, + name: string, +): Promise { + const resources = await client.listResources(); + // The SDK types don't match our types exactly + const resourceTemplates = Array.isArray(resources.resourceTemplates) + ? resources.resourceTemplates + : []; + return resourceTemplates.find((t: any) => t.name === name) || null; +} diff --git a/packages/ai/src/mcp/capabilities/roots.ts b/packages/ai/src/mcp/capabilities/roots.ts new file mode 100644 index 0000000000..94ee0bc922 --- /dev/null +++ b/packages/ai/src/mcp/capabilities/roots.ts @@ -0,0 +1,48 @@ +/** + * Root capability management for MCP client + */ + +import type { Root } from '../client/types'; + +/** + * Create a file system root + * + * @param path Path to the root directory + * @param name Name of the root + * @returns Root declaration + */ +export function createFileSystemRoot(path: string, name: string): Root { + return { + uri: `file://${path}`, + name, + }; +} + +/** + * Create an HTTP root + * + * @param url URL of the root + * @param name Name of the root + * @returns Root declaration + */ +export function createHttpRoot(url: string, name: string): Root { + return { + uri: url.startsWith('http') ? url : `https://${url}`, + name, + }; +} + +/** + * Create a set of common roots for a project + * + * @param projectPath Path to the project directory + * @param projectName Name of the project + * @returns Array of root declarations + */ +export function createProjectRoots(projectPath: string, projectName: string): Root[] { + return [ + createFileSystemRoot(projectPath, projectName), + createFileSystemRoot(`${projectPath}/src`, `${projectName} Source`), + createFileSystemRoot(`${projectPath}/test`, `${projectName} Tests`), + ]; +} diff --git a/packages/ai/src/mcp/capabilities/tools.ts b/packages/ai/src/mcp/capabilities/tools.ts new file mode 100644 index 0000000000..62e1b12a20 --- /dev/null +++ b/packages/ai/src/mcp/capabilities/tools.ts @@ -0,0 +1,167 @@ +/** + * Tool capability management for MCP client + */ + +import { z } from 'zod'; +import type { Tool } from '../client/types'; +import { MCPClient } from '../client'; +import { MCPError, MCPErrorType } from '../client/types'; + +/** + * Validate tool arguments against the tool's input schema + * + * @param tool Tool to validate arguments for + * @param args Arguments to validate + * @returns Validated arguments + */ +export function validateToolArgs( + tool: Tool, + args: Record, +): Record { + try { + const schema = createZodSchema(tool.inputSchema); + return schema.parse(args); + } catch (error) { + throw new MCPError( + `Invalid arguments for tool ${tool.name}: ${error instanceof Error ? error.message : String(error)}`, + MCPErrorType.VALIDATION_ERROR, + error, + ); + } +} + +/** + * Create a Zod schema from a JSON schema + * + * @param schema JSON schema + * @returns Zod schema + */ +export function createZodSchema(schema: any): z.ZodTypeAny { + if (!schema) { + return z.any(); + } + + const type = schema.type; + + if (type === 'string') { + let stringSchema = z.string(); + if (schema.pattern) { + stringSchema = stringSchema.regex(new RegExp(schema.pattern)); + } + if (schema.minLength !== undefined) { + stringSchema = stringSchema.min(schema.minLength); + } + if (schema.maxLength !== undefined) { + stringSchema = stringSchema.max(schema.maxLength); + } + return stringSchema; + } else if (type === 'number') { + let numberSchema = z.number(); + if (schema.minimum !== undefined) { + numberSchema = numberSchema.min(schema.minimum); + } + if (schema.maximum !== undefined) { + numberSchema = numberSchema.max(schema.maximum); + } + return numberSchema; + } else if (type === 'integer') { + let intSchema = z.number().int(); + if (schema.minimum !== undefined) { + intSchema = intSchema.min(schema.minimum); + } + if (schema.maximum !== undefined) { + intSchema = intSchema.max(schema.maximum); + } + return intSchema; + } else if (type === 'boolean') { + return z.boolean(); + } else if (type === 'array') { + const items = schema.items || {}; + let arraySchema = z.array(createZodSchema(items)); + if (schema.minItems !== undefined) { + arraySchema = arraySchema.min(schema.minItems); + } + if (schema.maxItems !== undefined) { + arraySchema = arraySchema.max(schema.maxItems); + } + return arraySchema; + } else if (type === 'object') { + const properties = schema.properties || {}; + const shape: Record = {}; + + for (const [key, value] of Object.entries(properties)) { + shape[key] = createZodSchema(value); + } + + let objectSchema = z.object(shape); + + if (schema.required && Array.isArray(schema.required)) { + const required = schema.required as string[]; + for (const key of required) { + if (shape[key]) { + shape[key] = shape[key].optional(); + } + } + } + + return objectSchema; + } else { + return z.any(); + } +} + +/** + * Call a tool with validated arguments + * + * @param client MCP client + * @param toolName Name of the tool to call + * @param args Arguments for the tool + * @returns Result of the tool call + */ +export async function callToolWithValidation( + client: MCPClient, + toolName: string, + args: Record, +) { + const tools = await client.listTools(); + const tool = tools.tools.find((t) => t.name === toolName); + + if (!tool) { + throw new MCPError(`Tool not found: ${toolName}`, MCPErrorType.TOOL_CALL_ERROR); + } + + // @ts-ignore - The SDK types don't match our types exactly + const validatedArgs = validateToolArgs(tool, args); + return await client.callTool(tool.name, validatedArgs); +} + +/** + * Generate an example for a tool + * + * @param tool Tool to generate an example for + * @returns Example arguments for the tool + */ +export function generateToolExample(tool: Tool): Record | null { + if (!tool.inputSchema || !tool.inputSchema.properties) { + return null; + } + + const exampleArgs: Record = {}; + + for (const [propName, prop] of Object.entries(tool.inputSchema.properties)) { + const typedProp = prop as { type: string }; + if (typedProp.type === 'string') { + exampleArgs[propName] = 'example_string'; + } else if (typedProp.type === 'number' || typedProp.type === 'integer') { + exampleArgs[propName] = 42; + } else if (typedProp.type === 'boolean') { + exampleArgs[propName] = true; + } else if (typedProp.type === 'array') { + exampleArgs[propName] = []; + } else if (typedProp.type === 'object') { + exampleArgs[propName] = {}; + } + } + + return exampleArgs; +} diff --git a/packages/ai/src/mcp/client/client.ts b/packages/ai/src/mcp/client/client.ts new file mode 100644 index 0000000000..575677f520 --- /dev/null +++ b/packages/ai/src/mcp/client/client.ts @@ -0,0 +1,351 @@ +/** + * Core MCP client implementation + */ + +import { Client } from '@modelcontextprotocol/sdk/client/index.js'; +import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; +import { createTransport } from './transport'; +import type { + MCPClientOptions, + Root, + ServerCapabilities, + ToolCallResult, + TransportOptions, +} from './types'; +import { MCPError, MCPErrorType } from './types'; + +/** + * MCP client for communicating with MCP servers + */ +export class MCPClient { + private client: Client; + private connected: boolean = false; + private initialized: boolean = false; + private serverCapabilities: ServerCapabilities | null = null; + + /** + * Create a new MCP client + * + * @param options Client options + */ + constructor(options: MCPClientOptions = {}) { + this.client = new Client( + { + name: options.name || 'onlook-mcp-client', + version: options.version || '1.0.0', + }, + {}, + ); + } + + /** + * Connect to an MCP server + * + * @param transportOptions Transport options + */ + async connect(transportOptions: TransportOptions): Promise { + try { + const transport = await createTransport(transportOptions); + await this.connectWithTransport(transport); + } catch (error) { + if (error instanceof MCPError) { + throw error; + } + throw new MCPError( + `Failed to connect to MCP server: ${error instanceof Error ? error.message : String(error)}`, + MCPErrorType.CONNECTION_ERROR, + error, + ); + } + } + + /** + * Connect to an MCP server with a transport + * + * @param transport Transport instance + */ + async connectWithTransport(transport: Transport): Promise { + try { + await this.client.connect(transport); + this.connected = true; + } catch (error) { + throw new MCPError( + `Failed to connect to MCP server: ${error instanceof Error ? error.message : String(error)}`, + MCPErrorType.CONNECTION_ERROR, + error, + ); + } + } + + /** + * Initialize the MCP client + * + * @param roots Optional roots to declare to the server + * @returns Server capabilities + */ + async initialize(roots?: Root[]): Promise { + if (!this.connected) { + throw new MCPError( + 'Client is not connected to a server', + MCPErrorType.INITIALIZATION_ERROR, + ); + } + + try { + // @ts-ignore - The SDK types don't match our types exactly + const response = await this.client.initialize({ + roots, + }); + + this.serverCapabilities = response as ServerCapabilities; + this.initialized = true; + return response as ServerCapabilities; + } catch (error) { + throw new MCPError( + `Failed to initialize MCP client: ${error instanceof Error ? error.message : String(error)}`, + MCPErrorType.INITIALIZATION_ERROR, + error, + ); + } + } + + /** + * List available tools from the server + * + * @returns List of available tools + */ + async listTools() { + this.ensureInitialized(); + + try { + // @ts-ignore - The SDK types don't match our types exactly + return await this.client.listTools(); + } catch (error) { + throw new MCPError( + `Failed to list tools: ${error instanceof Error ? error.message : String(error)}`, + MCPErrorType.TOOL_CALL_ERROR, + error, + ); + } + } + + /** + * List available prompts from the server + * + * @returns List of available prompts + */ + async listPrompts() { + this.ensureInitialized(); + + try { + // @ts-ignore - The SDK types don't match our types exactly + return await this.client.listPrompts(); + } catch (error) { + throw new MCPError( + `Failed to list prompts: ${error instanceof Error ? error.message : String(error)}`, + MCPErrorType.PROMPT_ERROR, + error, + ); + } + } + + /** + * List available resources from the server + * + * @returns List of available resources + */ + async listResources() { + this.ensureInitialized(); + + try { + // @ts-ignore - The SDK types don't match our types exactly + return await this.client.listResources(); + } catch (error) { + throw new MCPError( + `Failed to list resources: ${error instanceof Error ? error.message : String(error)}`, + MCPErrorType.RESOURCE_ERROR, + error, + ); + } + } + + /** + * Call a tool on the server + * + * @param name Name of the tool to call + * @param args Arguments for the tool + * @returns Result of the tool call + */ + async callTool(name: string, args: Record): Promise { + this.ensureInitialized(); + + try { + // @ts-ignore - The SDK types don't match our types exactly + const result = await this.client.callTool({ + name, + arguments: args, + }); + return result as ToolCallResult; + } catch (error) { + throw new MCPError( + `Failed to call tool ${name}: ${error instanceof Error ? error.message : String(error)}`, + MCPErrorType.TOOL_CALL_ERROR, + error, + ); + } + } + + /** + * Read a resource from the server + * + * @param uri URI of the resource to read + * @returns Resource content + */ + async readResource(uri: string) { + this.ensureInitialized(); + + try { + // @ts-ignore - The SDK types don't match our types exactly + return await this.client.readResource({ + uri, + }); + } catch (error) { + throw new MCPError( + `Failed to read resource ${uri}: ${error instanceof Error ? error.message : String(error)}`, + MCPErrorType.RESOURCE_ERROR, + error, + ); + } + } + + /** + * Subscribe to a resource for updates + * + * @param uri URI of the resource to subscribe to + */ + async subscribeToResource(uri: string) { + this.ensureInitialized(); + + try { + // @ts-ignore - The SDK types don't match our types exactly + await this.client.subscribeResource({ + uri, + }); + } catch (error) { + throw new MCPError( + `Failed to subscribe to resource ${uri}: ${error instanceof Error ? error.message : String(error)}`, + MCPErrorType.RESOURCE_ERROR, + error, + ); + } + } + + /** + * Unsubscribe from a resource + * + * @param uri URI of the resource to unsubscribe from + */ + async unsubscribeFromResource(uri: string) { + this.ensureInitialized(); + + try { + // @ts-ignore - The SDK types don't match our types exactly + await this.client.unsubscribeResource({ + uri, + }); + } catch (error) { + throw new MCPError( + `Failed to unsubscribe from resource ${uri}: ${error instanceof Error ? error.message : String(error)}`, + MCPErrorType.RESOURCE_ERROR, + error, + ); + } + } + + /** + * Set a request handler for the client + * + * @param method Method to handle + * @param handler Handler function + */ + setRequestHandler(method: string, handler: (params: T) => Promise) { + // @ts-ignore - The SDK types don't match our types exactly + this.client.setRequestHandler(method, handler); + } + + /** + * Set a notification handler for the client + * + * @param method Method to handle + * @param handler Handler function + */ + setNotificationHandler(method: string, handler: (params: T) => void) { + // @ts-ignore - The SDK types don't match our types exactly + this.client.setNotificationHandler(method, handler); + } + + /** + * Close the connection to the server + */ + async close() { + try { + await this.client.close(); + this.connected = false; + this.initialized = false; + this.serverCapabilities = null; + } catch (error) { + throw new MCPError( + `Failed to close MCP client: ${error instanceof Error ? error.message : String(error)}`, + MCPErrorType.CONNECTION_ERROR, + error, + ); + } + } + + /** + * Get the underlying client instance + * + * @returns Client instance + */ + getClient(): Client { + return this.client; + } + + /** + * Get the server capabilities + * + * @returns Server capabilities + */ + getServerCapabilities(): ServerCapabilities | null { + return this.serverCapabilities; + } + + /** + * Check if the client is connected + * + * @returns Whether the client is connected + */ + isConnected(): boolean { + return this.connected; + } + + /** + * Check if the client is initialized + * + * @returns Whether the client is initialized + */ + isInitialized(): boolean { + return this.initialized; + } + + /** + * Ensure that the client is initialized + * + * @throws MCPError if the client is not initialized + */ + private ensureInitialized() { + if (!this.initialized) { + throw new MCPError('Client is not initialized', MCPErrorType.INITIALIZATION_ERROR); + } + } +} diff --git a/packages/ai/src/mcp/client/index.ts b/packages/ai/src/mcp/client/index.ts new file mode 100644 index 0000000000..1d2383a314 --- /dev/null +++ b/packages/ai/src/mcp/client/index.ts @@ -0,0 +1,7 @@ +/** + * MCP client exports + */ + +export * from './client'; +export * from './transport'; +export * from './types'; diff --git a/packages/ai/src/mcp/client/transport.ts b/packages/ai/src/mcp/client/transport.ts new file mode 100644 index 0000000000..d1f35e2c7c --- /dev/null +++ b/packages/ai/src/mcp/client/transport.ts @@ -0,0 +1,57 @@ +/** + * Transport implementations for MCP client + */ + +import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js'; +import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; +import type { TransportOptions } from './types'; +import { MCPError, MCPErrorType } from './types'; + +/** + * Create a transport for connecting to an MCP server + * + * @param options Transport options + * @returns Transport instance + */ +export async function createTransport(options: TransportOptions): Promise { + try { + switch (options.type) { + case 'stdio': + if (!options.command) { + throw new MCPError( + 'Command is required for stdio transport', + MCPErrorType.VALIDATION_ERROR, + ); + } + return new StdioClientTransport({ + command: options.command, + args: options.args || [], + }); + case 'websocket': + if (!options.url) { + throw new MCPError( + 'URL is required for websocket transport', + MCPErrorType.VALIDATION_ERROR, + ); + } + throw new MCPError( + 'WebSocket transport is not yet implemented', + MCPErrorType.VALIDATION_ERROR, + ); + default: + throw new MCPError( + `Unsupported transport type: ${(options as any).type}`, + MCPErrorType.VALIDATION_ERROR, + ); + } + } catch (error) { + if (error instanceof MCPError) { + throw error; + } + throw new MCPError( + `Failed to create transport: ${error instanceof Error ? error.message : String(error)}`, + MCPErrorType.TRANSPORT_ERROR, + error, + ); + } +} diff --git a/packages/ai/src/mcp/client/types.ts b/packages/ai/src/mcp/client/types.ts new file mode 100644 index 0000000000..b323413e06 --- /dev/null +++ b/packages/ai/src/mcp/client/types.ts @@ -0,0 +1,182 @@ +/** + * Type definitions for the MCP client + */ + +import type { Client } from '@modelcontextprotocol/sdk/client/index.js'; +import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; + +/** + * Tool definition + */ +export interface Tool { + name: string; + description: string; + inputSchema: { + type: string; + properties?: Record< + string, + { + type: string; + description?: string; + } + >; + required?: string[]; + }; +} + +/** + * Prompt definition + */ +export interface Prompt { + name: string; + description?: string; + arguments?: Array<{ + name: string; + description?: string; + required?: boolean; + }>; +} + +/** + * Resource definition + */ +export interface Resource { + name: string; + description?: string; + uri: string; +} + +/** + * Resource template definition + */ +export interface ResourceTemplate { + name: string; + description: string; + uriTemplate: string; +} + +/** + * Options for creating an MCP client + */ +export interface MCPClientOptions { + /** + * Name of the client + * @default "onlook-mcp-client" + */ + name?: string; + + /** + * Version of the client + * @default "1.0.0" + */ + version?: string; +} + +/** + * Transport options for connecting to an MCP server + */ +export interface TransportOptions { + /** + * Type of transport to use + */ + type: 'stdio' | 'websocket'; + + /** + * Command to execute for stdio transport + */ + command?: string; + + /** + * Arguments for the command for stdio transport + */ + args?: string[]; + + /** + * URL for websocket transport + */ + url?: string; +} + +/** + * Root declaration for MCP servers + */ +export interface Root { + /** + * URI of the root + */ + uri: string; + + /** + * Name of the root + */ + name: string; +} + +/** + * Server capabilities returned from initialization + */ +export interface ServerCapabilities { + capabilities: { + tools?: Record; + resources?: Record; + prompts?: Record; + roots?: Record; + sampling?: Record; + }; + serverInfo: { + name: string; + version: string; + }; +} + +/** + * Result of a tool call + */ +export interface ToolCallResult { + /** + * Whether the tool call resulted in an error + */ + isError?: boolean; + + /** + * Content of the tool call result + */ + content: Array<{ + type: string; + text: string; + }>; +} + +/** + * MCP error types + */ +export enum MCPErrorType { + CONNECTION_ERROR = 'connection_error', + INITIALIZATION_ERROR = 'initialization_error', + TOOL_CALL_ERROR = 'tool_call_error', + RESOURCE_ERROR = 'resource_error', + PROMPT_ERROR = 'prompt_error', + TRANSPORT_ERROR = 'transport_error', + VALIDATION_ERROR = 'validation_error', + UNKNOWN_ERROR = 'unknown_error', +} + +/** + * MCP error + */ +export class MCPError extends Error { + type: MCPErrorType; + cause?: unknown; + + constructor(message: string, type: MCPErrorType, cause?: unknown) { + super(message); + this.name = 'MCPError'; + this.type = type; + this.cause = cause; + } +} + +/** + * Export SDK types for convenience + */ +export type { Client, Transport }; diff --git a/packages/ai/src/mcp/context/formatter.ts b/packages/ai/src/mcp/context/formatter.ts new file mode 100644 index 0000000000..e661dec61b --- /dev/null +++ b/packages/ai/src/mcp/context/formatter.ts @@ -0,0 +1,104 @@ +/** + * Context formatting for MCP client + */ + +import type { Tool, Prompt, Resource } from '../client/types'; +import { generateToolExample } from '../capabilities'; + +/** + * Format MCP capabilities for LLM consumption + */ +export class LLMContextFormatter { + /** + * Format tools for LLM consumption + * + * @param tools Tools to format + * @param format Format to use + * @returns Formatted tools + */ + static formatTools( + tools: Tool[], + format: 'anthropic' | 'openai' | 'generic' = 'generic', + ): Record { + if (format === 'anthropic') { + return { + tools: tools.map((tool) => ({ + name: tool.name, + description: tool.description, + input_schema: tool.inputSchema, + })), + }; + } else if (format === 'openai') { + return { + functions: tools.map((tool) => ({ + name: tool.name, + description: tool.description, + parameters: tool.inputSchema, + })), + }; + } else { + return { + tools: tools.map((tool) => ({ + name: tool.name, + description: tool.description, + parameters: tool.inputSchema, + example: generateToolExample(tool), + })), + }; + } + } + + /** + * Format prompts for LLM consumption + * + * @param prompts Prompts to format + * @returns Formatted prompts + */ + static formatPrompts(prompts: Prompt[]): Record { + return { + prompts: prompts.map((prompt) => ({ + name: prompt.name, + description: prompt.description, + arguments: prompt.arguments, + })), + }; + } + + /** + * Format resources for LLM consumption + * + * @param resources Resources to format + * @returns Formatted resources + */ + static formatResources(resources: Resource[]): Record { + return { + resources: resources.map((resource) => ({ + name: resource.name, + description: resource.description, + uri: resource.uri, + })), + }; + } + + /** + * Format all capabilities for LLM consumption + * + * @param capabilities Capabilities to format + * @param format Format to use + * @returns Formatted capabilities + */ + static formatAllCapabilities( + capabilities: { + tools: Tool[]; + prompts: Prompt[]; + resources: Resource[]; + }, + format: 'anthropic' | 'openai' | 'generic' = 'generic', + ): Record { + return { + ...this.formatTools(capabilities.tools, format), + ...this.formatPrompts(capabilities.prompts), + ...this.formatResources(capabilities.resources), + }; + } +} diff --git a/packages/ai/src/mcp/context/index.ts b/packages/ai/src/mcp/context/index.ts new file mode 100644 index 0000000000..4bf7eef67b --- /dev/null +++ b/packages/ai/src/mcp/context/index.ts @@ -0,0 +1,6 @@ +/** + * MCP context exports + */ + +export * from './manager'; +export * from './formatter'; diff --git a/packages/ai/src/mcp/context/manager.ts b/packages/ai/src/mcp/context/manager.ts new file mode 100644 index 0000000000..2e8662d741 --- /dev/null +++ b/packages/ai/src/mcp/context/manager.ts @@ -0,0 +1,139 @@ +/** + * Context management for MCP client + */ + +import type { Tool, Prompt, Resource } from '../client/types'; +import { MCPClient } from '../client'; +import { generateToolExample } from '../capabilities'; + +/** + * Manager for MCP capabilities and context + */ +export class MCPCapabilityManager { + private client: MCPClient; + private capabilities: { + tools?: Tool[]; + prompts?: Prompt[]; + resources?: Resource[]; + } = {}; + private lastUpdate: number | null = null; + + /** + * Create a new capability manager + * + * @param client MCP client + */ + constructor(client: MCPClient) { + this.client = client; + + // Set up notification handlers for capability changes + client.setNotificationHandler( + 'notifications/tools/list_changed', + this.handleToolsChanged.bind(this), + ); + client.setNotificationHandler( + 'notifications/prompts/list_changed', + this.handlePromptsChanged.bind(this), + ); + client.setNotificationHandler( + 'notifications/resources/list_changed', + this.handleResourcesChanged.bind(this), + ); + } + + /** + * Refresh all capabilities + */ + async refreshAll(): Promise { + await this.refreshTools(); + await this.refreshPrompts(); + await this.refreshResources(); + this.lastUpdate = Date.now(); + } + + /** + * Refresh tools + */ + async refreshTools(): Promise { + try { + const tools = await this.client.listTools(); + // @ts-ignore - The SDK types don't match our types exactly + this.capabilities.tools = tools.tools; + } catch (error) { + console.error('Error refreshing tools:', error); + } + } + + /** + * Refresh prompts + */ + async refreshPrompts(): Promise { + try { + const prompts = await this.client.listPrompts(); + // @ts-ignore - The SDK types don't match our types exactly + this.capabilities.prompts = prompts.prompts; + } catch (error) { + console.error('Error refreshing prompts:', error); + } + } + + /** + * Refresh resources + */ + async refreshResources(): Promise { + try { + const resources = await this.client.listResources(); + // @ts-ignore - The SDK types don't match our types exactly + this.capabilities.resources = resources.resources; + } catch (error) { + console.error('Error refreshing resources:', error); + } + } + + /** + * Handle tools changed notification + */ + private handleToolsChanged(): void { + this.refreshTools(); + } + + /** + * Handle prompts changed notification + */ + private handlePromptsChanged(): void { + this.refreshPrompts(); + } + + /** + * Handle resources changed notification + */ + private handleResourcesChanged(): void { + this.refreshResources(); + } + + /** + * Get all capabilities + * + * @returns All capabilities + */ + getCapabilities(): { + tools: Tool[]; + prompts: Prompt[]; + resources: Resource[]; + } { + return { + tools: this.capabilities.tools || [], + prompts: this.capabilities.prompts || [], + resources: this.capabilities.resources || [], + }; + } + + /** + * Get the last update time + * + * @returns Last update time in milliseconds since epoch + */ + getLastUpdateTime(): number | null { + return this.lastUpdate; + } +} diff --git a/packages/ai/src/mcp/examples/basic-client.ts b/packages/ai/src/mcp/examples/basic-client.ts new file mode 100644 index 0000000000..2de49ec6aa --- /dev/null +++ b/packages/ai/src/mcp/examples/basic-client.ts @@ -0,0 +1,82 @@ +/** + * Basic example of using the MCP client + */ + +import { MCPClient } from '../client'; +import { MCPCapabilityManager } from '../context'; +import { LLMContextFormatter } from '../context'; +import { createFileSystemRoot } from '../capabilities'; +import { MCPLogger, LogLevel } from '../utils'; + +/** + * Run the example + */ +async function runExample() { + // Create a logger + const logger = new MCPLogger({ + level: LogLevel.DEBUG, + prefix: 'MCP-Example', + }); + + logger.info('Creating MCP client...'); + + // Create a client + const client = new MCPClient({ + name: 'onlook-mcp-example', + version: '1.0.0', + }); + + try { + // Connect to a server + logger.info('Connecting to MCP server...'); + await client.connect({ + type: 'stdio', + command: 'mcp-server', + args: ['--stdio'], + }); + + // Initialize the client + logger.info('Initializing MCP client...'); + const roots = [createFileSystemRoot('/path/to/project', 'Project Root')]; + const capabilities = await client.initialize(roots); + logger.info('Server capabilities:', capabilities); + + // Create a capability manager + logger.info('Creating capability manager...'); + const capabilityManager = new MCPCapabilityManager(client); + await capabilityManager.refreshAll(); + + // Get capabilities + const allCapabilities = capabilityManager.getCapabilities(); + logger.info('Available tools:', allCapabilities.tools.length); + logger.info('Available prompts:', allCapabilities.prompts.length); + logger.info('Available resources:', allCapabilities.resources.length); + + // Format capabilities for LLM + logger.info('Formatting capabilities for LLM...'); + const formattedCapabilities = LLMContextFormatter.formatAllCapabilities( + allCapabilities, + 'anthropic', + ); + logger.info('Formatted capabilities:', formattedCapabilities); + + // Call a tool + if (allCapabilities.tools.length > 0) { + const tool = allCapabilities.tools[0]; + logger.info(`Calling tool ${tool.name}...`); + const result = await client.callTool(tool.name, {}); + logger.info('Tool result:', result); + } + + // Close the client + logger.info('Closing MCP client...'); + await client.close(); + } catch (error) { + logger.error('Error:', error); + } +} + +// Run the example if this file is executed directly +if (typeof import.meta.main === 'boolean' && import.meta.main) { + runExample().catch(console.error); +} diff --git a/packages/ai/src/mcp/index.ts b/packages/ai/src/mcp/index.ts new file mode 100644 index 0000000000..2f1d453413 --- /dev/null +++ b/packages/ai/src/mcp/index.ts @@ -0,0 +1,13 @@ +/** + * Model Context Protocol (MCP) client implementation for Onlook + * + * This module provides a client implementation for the Model Context Protocol, + * enabling Onlook to communicate with MCP servers to access tools, resources, + * and prompts through a standardized protocol. + */ + +export * from './client'; +export * from './capabilities'; +export * from './context'; +export * from './utils'; +export * from './adapters'; diff --git a/packages/ai/src/mcp/utils/error.ts b/packages/ai/src/mcp/utils/error.ts new file mode 100644 index 0000000000..af92b0a1a9 --- /dev/null +++ b/packages/ai/src/mcp/utils/error.ts @@ -0,0 +1,93 @@ +/** + * Error utilities for MCP client + */ + +import { MCPError, MCPErrorType } from '../client/types'; + +/** + * Create a connection error + * + * @param message Error message + * @param cause Error cause + * @returns MCP error + */ +export function createConnectionError(message: string, cause?: unknown): MCPError { + return new MCPError(message, MCPErrorType.CONNECTION_ERROR, cause); +} + +/** + * Create an initialization error + * + * @param message Error message + * @param cause Error cause + * @returns MCP error + */ +export function createInitializationError(message: string, cause?: unknown): MCPError { + return new MCPError(message, MCPErrorType.INITIALIZATION_ERROR, cause); +} + +/** + * Create a tool call error + * + * @param message Error message + * @param cause Error cause + * @returns MCP error + */ +export function createToolCallError(message: string, cause?: unknown): MCPError { + return new MCPError(message, MCPErrorType.TOOL_CALL_ERROR, cause); +} + +/** + * Create a resource error + * + * @param message Error message + * @param cause Error cause + * @returns MCP error + */ +export function createResourceError(message: string, cause?: unknown): MCPError { + return new MCPError(message, MCPErrorType.RESOURCE_ERROR, cause); +} + +/** + * Create a prompt error + * + * @param message Error message + * @param cause Error cause + * @returns MCP error + */ +export function createPromptError(message: string, cause?: unknown): MCPError { + return new MCPError(message, MCPErrorType.PROMPT_ERROR, cause); +} + +/** + * Create a transport error + * + * @param message Error message + * @param cause Error cause + * @returns MCP error + */ +export function createTransportError(message: string, cause?: unknown): MCPError { + return new MCPError(message, MCPErrorType.TRANSPORT_ERROR, cause); +} + +/** + * Create a validation error + * + * @param message Error message + * @param cause Error cause + * @returns MCP error + */ +export function createValidationError(message: string, cause?: unknown): MCPError { + return new MCPError(message, MCPErrorType.VALIDATION_ERROR, cause); +} + +/** + * Create an unknown error + * + * @param message Error message + * @param cause Error cause + * @returns MCP error + */ +export function createUnknownError(message: string, cause?: unknown): MCPError { + return new MCPError(message, MCPErrorType.UNKNOWN_ERROR, cause); +} diff --git a/packages/ai/src/mcp/utils/index.ts b/packages/ai/src/mcp/utils/index.ts new file mode 100644 index 0000000000..5b4d37f0c5 --- /dev/null +++ b/packages/ai/src/mcp/utils/index.ts @@ -0,0 +1,7 @@ +/** + * MCP utility exports + */ + +export * from './validation'; +export * from './error'; +export * from './logging'; diff --git a/packages/ai/src/mcp/utils/logging.ts b/packages/ai/src/mcp/utils/logging.ts new file mode 100644 index 0000000000..672cd9ceef --- /dev/null +++ b/packages/ai/src/mcp/utils/logging.ts @@ -0,0 +1,130 @@ +/** + * Logging utilities for MCP client + */ + +/** + * Log levels + */ +export enum LogLevel { + DEBUG = 'debug', + INFO = 'info', + WARN = 'warn', + ERROR = 'error', +} + +/** + * Logger for MCP client + */ +export class MCPLogger { + private level: LogLevel; + private prefix: string; + + /** + * Create a new logger + * + * @param options Logger options + */ + constructor( + options: { + level?: LogLevel; + prefix?: string; + } = {}, + ) { + this.level = options.level || LogLevel.INFO; + this.prefix = options.prefix || 'MCP'; + } + + /** + * Log a debug message + * + * @param message Message to log + * @param args Additional arguments + */ + debug(message: string, ...args: unknown[]): void { + if (this.shouldLog(LogLevel.DEBUG)) { + console.debug(`[${this.prefix}] ${message}`, ...args); + } + } + + /** + * Log an info message + * + * @param message Message to log + * @param args Additional arguments + */ + info(message: string, ...args: unknown[]): void { + if (this.shouldLog(LogLevel.INFO)) { + console.info(`[${this.prefix}] ${message}`, ...args); + } + } + + /** + * Log a warning message + * + * @param message Message to log + * @param args Additional arguments + */ + warn(message: string, ...args: unknown[]): void { + if (this.shouldLog(LogLevel.WARN)) { + console.warn(`[${this.prefix}] ${message}`, ...args); + } + } + + /** + * Log an error message + * + * @param message Message to log + * @param args Additional arguments + */ + error(message: string, ...args: unknown[]): void { + if (this.shouldLog(LogLevel.ERROR)) { + console.error(`[${this.prefix}] ${message}`, ...args); + } + } + + /** + * Set the log level + * + * @param level Log level + */ + setLevel(level: LogLevel): void { + this.level = level; + } + + /** + * Set the prefix + * + * @param prefix Prefix + */ + setPrefix(prefix: string): void { + this.prefix = prefix; + } + + /** + * Check if a message should be logged + * + * @param level Log level + * @returns Whether the message should be logged + */ + private shouldLog(level: LogLevel): boolean { + const levels = [LogLevel.DEBUG, LogLevel.INFO, LogLevel.WARN, LogLevel.ERROR]; + const currentLevelIndex = levels.indexOf(this.level); + const messageLevelIndex = levels.indexOf(level); + return messageLevelIndex >= currentLevelIndex; + } +} + +/** + * Create a new logger + * + * @param options Logger options + * @returns Logger + */ +export function createLogger( + options: { + level?: LogLevel; + prefix?: string; + } = {}, +): MCPLogger { + return new MCPLogger(options); +} diff --git a/packages/ai/src/mcp/utils/validation.ts b/packages/ai/src/mcp/utils/validation.ts new file mode 100644 index 0000000000..d01148a97a --- /dev/null +++ b/packages/ai/src/mcp/utils/validation.ts @@ -0,0 +1,69 @@ +/** + * Validation utilities for MCP client + */ + +import { z } from 'zod'; +import { MCPError, MCPErrorType } from '../client/types'; + +/** + * Validate a value against a schema + * + * @param schema Schema to validate against + * @param value Value to validate + * @param errorMessage Error message to use if validation fails + * @returns Validated value + */ +export function validate(schema: z.ZodType, value: unknown, errorMessage: string): T { + try { + return schema.parse(value); + } catch (error) { + throw new MCPError( + `${errorMessage}: ${error instanceof Error ? error.message : String(error)}`, + MCPErrorType.VALIDATION_ERROR, + error, + ); + } +} + +/** + * Create a schema for transport options + * + * @returns Schema for transport options + */ +export function createTransportOptionsSchema() { + return z.discriminatedUnion('type', [ + z.object({ + type: z.literal('stdio'), + command: z.string(), + args: z.array(z.string()).optional(), + }), + z.object({ + type: z.literal('websocket'), + url: z.string().url(), + }), + ]); +} + +/** + * Create a schema for client options + * + * @returns Schema for client options + */ +export function createClientOptionsSchema() { + return z.object({ + name: z.string().optional(), + version: z.string().optional(), + }); +} + +/** + * Create a schema for root declaration + * + * @returns Schema for root declaration + */ +export function createRootSchema() { + return z.object({ + uri: z.string(), + name: z.string(), + }); +} diff --git a/packages/ai/test/mcp/client.test.ts b/packages/ai/test/mcp/client.test.ts new file mode 100644 index 0000000000..e2e2bfd64d --- /dev/null +++ b/packages/ai/test/mcp/client.test.ts @@ -0,0 +1,234 @@ +/** + * Tests for the MCP client + */ + +import { describe, test, expect, beforeEach, afterEach, mock } from 'bun:test'; +import { MCPClient } from '../../src/mcp/client'; +import { MCPError, MCPErrorType } from '../../src/mcp/client/types'; + +// Mock the transport +const mockTransport = { + connect: mock(() => Promise.resolve()), + send: mock(() => Promise.resolve()), + close: mock(() => Promise.resolve()), + onMessage: mock(() => {}), +}; + +// Mock the client +mock.module('@modelcontextprotocol/sdk/client/index.js', () => ({ + Client: class MockClient { + constructor() {} + connect = mock(() => Promise.resolve()); + initialize = mock(() => + Promise.resolve({ + capabilities: { + tools: {}, + resources: {}, + prompts: {}, + }, + serverInfo: { + name: 'mock-server', + version: '1.0.0', + }, + }), + ); + listTools = mock(() => + Promise.resolve({ + tools: [ + { + name: 'mock-tool', + description: 'A mock tool', + inputSchema: { + type: 'object', + properties: { + input: { + type: 'string', + }, + }, + }, + }, + ], + }), + ); + listPrompts = mock(() => + Promise.resolve({ + prompts: [ + { + name: 'mock-prompt', + description: 'A mock prompt', + arguments: [ + { + name: 'input', + description: 'Input for the prompt', + required: true, + }, + ], + }, + ], + }), + ); + listResources = mock(() => + Promise.resolve({ + resources: [ + { + name: 'mock-resource', + description: 'A mock resource', + uri: 'mock://resource', + }, + ], + resourceTemplates: [ + { + name: 'mock-template', + description: 'A mock template', + uriTemplate: 'mock://template/{param}', + }, + ], + }), + ); + callTool = mock(() => + Promise.resolve({ + content: [ + { + type: 'text', + text: 'Mock tool result', + }, + ], + }), + ); + readResource = mock(() => + Promise.resolve({ + content: 'Mock resource content', + }), + ); + subscribeToResource = mock(() => Promise.resolve()); + unsubscribeFromResource = mock(() => Promise.resolve()); + setRequestHandler = mock(() => {}); + setNotificationHandler = mock(() => {}); + close = mock(() => Promise.resolve()); + }, +})); + +// Mock the transport creation +mock.module('../../src/mcp/client/transport', () => ({ + createTransport: mock(() => Promise.resolve(mockTransport)), +})); + +describe('MCPClient', () => { + let client: MCPClient; + + beforeEach(() => { + client = new MCPClient({ + name: 'test-client', + version: '1.0.0', + }); + }); + + afterEach(() => { + mock.restore(); + }); + + test('should create a client with default options', () => { + const defaultClient = new MCPClient(); + expect(defaultClient).toBeDefined(); + }); + + test('should connect to a server', async () => { + await client.connect({ + type: 'stdio', + command: 'mock-command', + }); + expect(client.isConnected()).toBe(true); + }); + + test('should initialize the client', async () => { + await client.connect({ + type: 'stdio', + command: 'mock-command', + }); + const capabilities = await client.initialize(); + expect(client.isInitialized()).toBe(true); + expect(capabilities).toBeDefined(); + expect(capabilities.serverInfo.name).toBe('mock-server'); + }); + + test('should list tools', async () => { + await client.connect({ + type: 'stdio', + command: 'mock-command', + }); + await client.initialize(); + const tools = await client.listTools(); + expect(tools).toBeDefined(); + expect(tools.tools.length).toBe(1); + expect(tools.tools[0].name).toBe('mock-tool'); + }); + + test('should list prompts', async () => { + await client.connect({ + type: 'stdio', + command: 'mock-command', + }); + await client.initialize(); + const prompts = await client.listPrompts(); + expect(prompts).toBeDefined(); + expect(prompts.prompts.length).toBe(1); + expect(prompts.prompts[0].name).toBe('mock-prompt'); + }); + + test('should list resources', async () => { + await client.connect({ + type: 'stdio', + command: 'mock-command', + }); + await client.initialize(); + const resources = await client.listResources(); + expect(resources).toBeDefined(); + expect(resources.resources.length).toBe(1); + expect(resources.resources[0].name).toBe('mock-resource'); + }); + + test('should call a tool', async () => { + await client.connect({ + type: 'stdio', + command: 'mock-command', + }); + await client.initialize(); + const result = await client.callTool('mock-tool', { input: 'test' }); + expect(result).toBeDefined(); + expect(result.content[0].text).toBe('Mock tool result'); + }); + + test('should read a resource', async () => { + await client.connect({ + type: 'stdio', + command: 'mock-command', + }); + await client.initialize(); + const result = await client.readResource('mock://resource'); + expect(result).toBeDefined(); + expect(result.content).toBe('Mock resource content'); + }); + + test('should close the client', async () => { + await client.connect({ + type: 'stdio', + command: 'mock-command', + }); + await client.initialize(); + await client.close(); + expect(client.isConnected()).toBe(false); + expect(client.isInitialized()).toBe(false); + }); + + test('should throw an error if not connected', async () => { + await expect(client.initialize()).rejects.toThrow(MCPError); + }); + + test('should throw an error if not initialized', async () => { + await client.connect({ + type: 'stdio', + command: 'mock-command', + }); + await expect(client.listTools()).rejects.toThrow(MCPError); + }); +}); diff --git a/packages/models/src/llm/index.ts b/packages/models/src/llm/index.ts index d6e714b568..56431bac1a 100644 --- a/packages/models/src/llm/index.ts +++ b/packages/models/src/llm/index.ts @@ -1,8 +1,16 @@ export enum LLMProvider { ANTHROPIC = 'anthropic', + OPENAI = 'openai', + OLLAMA = 'ollama', + MCP = 'mcp', // Add MCP as a provider type } export enum CLAUDE_MODELS { SONNET = 'claude-3-7-sonnet-20250219', HAIKU = 'claude-3-5-haiku-20241022', } + +export enum MCP_MODELS { + DEFAULT = 'default', + // Add more MCP models as needed +} diff --git a/packages/models/src/settings/index.ts b/packages/models/src/settings/index.ts index de454ad987..048539ddb1 100644 --- a/packages/models/src/settings/index.ts +++ b/packages/models/src/settings/index.ts @@ -7,6 +7,8 @@ export interface UserSettings { signInMethod?: string; editor?: EditorSettings; chat?: ChatSettings; + llmProvider?: string; + llmModel?: string; } export interface EditorSettings {