Skip to content

Add MCP client implementation #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions apps/studio/electron/main/chat/index.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { PromptProvider } from '@onlook/ai/src/prompt/provider';
import { listFilesTool, readFileTool } from '@onlook/ai/src/tools';
import { CLAUDE_MODELS, LLMProvider } from '@onlook/models';
import { CLAUDE_MODELS, LLMProvider, MCP_MODELS } from '@onlook/models';
import {
ChatSuggestionSchema,
StreamRequestType,
Expand All @@ -9,7 +9,13 @@ import {
type UsageCheckResult,
} from '@onlook/models/chat';
import { MainChannels } from '@onlook/models/constants';
import { generateObject, streamText, type CoreMessage, type CoreSystemMessage } from 'ai';
import {
generateObject,
streamText,
type CoreMessage,
type CoreSystemMessage,
type LanguageModelV1,
} from 'ai';
import { mainWindow } from '..';
import { PersistentStorage } from '../storage';
import { initModel } from './llmProvider';
Expand Down Expand Up @@ -68,9 +74,7 @@ class LlmManager {
} as CoreSystemMessage;
messages = [systemMessage, ...messages];
}
const model = await initModel(LLMProvider.ANTHROPIC, CLAUDE_MODELS.SONNET, {
requestType,
});
const model = await this.getModel(requestType);

const { textStream } = await streamText({
model,
Expand Down Expand Up @@ -156,11 +160,18 @@ class LlmManager {
return 'An unknown error occurred';
}

private async getModel(requestType: StreamRequestType): Promise<LanguageModelV1> {
// Get the provider and model from settings or use defaults
const settings = PersistentStorage.USER_SETTINGS.read() || {};
const provider = settings.llmProvider || LLMProvider.ANTHROPIC;
const modelName = settings.llmModel || CLAUDE_MODELS.SONNET;

return await initModel(provider, modelName, { requestType });
}

public async generateSuggestions(messages: CoreMessage[]): Promise<ChatSuggestion[]> {
try {
const model = await initModel(LLMProvider.ANTHROPIC, CLAUDE_MODELS.HAIKU, {
requestType: StreamRequestType.SUGGESTIONS,
});
const model = await this.getModel(StreamRequestType.SUGGESTIONS);

const { object } = await generateObject({
model,
Expand Down
36 changes: 33 additions & 3 deletions apps/studio/electron/main/chat/llmProvider.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { createAnthropic } from '@ai-sdk/anthropic';
import type { StreamRequestType } from '@onlook/models/chat';
import { BASE_PROXY_ROUTE, FUNCTIONS_ROUTE, ProxyRoutes } from '@onlook/models/constants';
import { CLAUDE_MODELS, LLMProvider } from '@onlook/models/llm';
import { CLAUDE_MODELS, LLMProvider, MCP_MODELS } from '@onlook/models/llm';
import { type LanguageModelV1 } from 'ai';
import { getRefreshedAuthTokens } from '../auth';
export interface OnlookPayload {
Expand All @@ -10,12 +10,14 @@ export interface OnlookPayload {

export async function initModel(
provider: LLMProvider,
model: CLAUDE_MODELS,
model: CLAUDE_MODELS | MCP_MODELS,
payload: OnlookPayload,
): Promise<LanguageModelV1> {
switch (provider) {
case LLMProvider.ANTHROPIC:
return await getAnthropicProvider(model, payload);
return await getAnthropicProvider(model as CLAUDE_MODELS, payload);
case LLMProvider.MCP:
return await getMCPProvider(model as MCP_MODELS, payload);
default:
throw new Error(`Unsupported provider: ${provider}`);
}
Expand Down Expand Up @@ -54,3 +56,31 @@ async function getAnthropicProvider(
cacheControl: true,
});
}

async function getMCPProvider(model: MCP_MODELS, payload: OnlookPayload): Promise<LanguageModelV1> {
// Import the MCP client and adapter
const { MCPClient, createMCPLanguageModel } = await import('@onlook/ai/mcp');

// Create a new MCP client
const client = new MCPClient({
name: 'onlook-mcp-client',
version: '1.0.0',
});

// Connect to the MCP server
// Note: The connection details would need to be configured
await client.connect({
type: 'stdio',
command: 'mcp-server', // This would need to be configured
args: ['--stdio'],
});

// Initialize the client
await client.initialize();

// Create a language model adapter that implements the LanguageModelV1 interface
return createMCPLanguageModel({
client,
model: model.toString(),
});
}
Binary file modified bun.lockb
Binary file not shown.
4 changes: 3 additions & 1 deletion packages/ai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@
"@onlook/typescript": "*"
},
"dependencies": {
"@modelcontextprotocol/sdk": "latest",
"diff-match-patch": "^1.0.5",
"fg": "^0.0.3",
"marked": "^15.0.7"
"marked": "^15.0.7",
"zod": "^3.22.4"
}
}
1 change: 1 addition & 0 deletions packages/ai/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
export * from './coder';
export * from './prompt';
export * from './mcp';
5 changes: 5 additions & 0 deletions packages/ai/src/mcp/adapters/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
/**
* MCP adapters exports
*/

export * from './language-model';
141 changes: 141 additions & 0 deletions packages/ai/src/mcp/adapters/language-model.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/**
* MCP language model adapter for the AI SDK
*/

import type { LanguageModelV1, LanguageModelV1CallOptions } from 'ai';
import { MCPClient } from '../client';
import { MCPCapabilityManager } from '../context/manager';
import { MCPError, MCPErrorType } from '../client/types';

/**
* Options for creating an MCP language model adapter
*/
export interface MCPLanguageModelOptions {
/**
* MCP client
*/
client: MCPClient;

/**
* Capability manager
*/
capabilityManager?: MCPCapabilityManager;

/**
* Model name
*/
model?: string;
}

/**
* Create an MCP language model adapter
*
* @param options Options for creating the adapter
* @returns Language model adapter
*/
export function createMCPLanguageModel(options: MCPLanguageModelOptions): LanguageModelV1 {
const { client, capabilityManager, model } = options;

// Create a capability manager if not provided
const manager = capabilityManager || new MCPCapabilityManager(client);

return {
specificationVersion: 'v1',
provider: 'mcp',
modelId: `mcp-${model || 'default'}`,
defaultObjectGenerationMode: undefined,
doStream: async (options: LanguageModelV1CallOptions) => {
throw new Error('Streaming is not supported by the MCP adapter yet');
},
doGenerate: async (options: LanguageModelV1CallOptions) => {
try {
// Refresh capabilities if needed
if (!manager.getCapabilities().tools.length) {
await manager.refreshAll();
}

// Find a suitable tool for text generation
const tools = manager.getCapabilities().tools;
const generateTool = tools.find(
(tool) => tool.name === 'generate' || tool.name.includes('generate'),
);

if (!generateTool) {
throw new MCPError(
'No text generation tool found',
MCPErrorType.TOOL_CALL_ERROR,
);
}

// Extract parameters from options
const { maxTokens, temperature, topP, stopSequences } = options;

// Prepare messages for the tool call
const messages = [];

// Add user content if available
if (options.prompt) {
messages.push({ role: 'user', content: options.prompt });
}

// Call the tool with the messages
const result = await client.callTool(generateTool.name, {
messages,
maxTokens,
temperature,
topP,
stopSequences,
});

// Extract the text from the result
const text = result.content.map((item) => item.text).join('');

// Prepare the request body for rawCall
const requestBody = {
messages,
maxTokens,
temperature,
topP,
stopSequences,
};

return {
text,
toolCalls: undefined,
finishReason: 'stop',
usage: {
promptTokens: 0,
completionTokens: 0,
},
rawCall: {
rawPrompt: messages,
rawSettings: {
maxTokens,
temperature,
topP,
stopSequences,
},
},
request: {
body: JSON.stringify(requestBody),
},
response: {
id: `mcp-response-${Date.now()}`,
timestamp: new Date(),
modelId: `mcp-${model || 'default'}`,
},
};
} catch (error) {
if (error instanceof MCPError) {
throw error;
}

throw new MCPError(
`Failed to generate text: ${error instanceof Error ? error.message : String(error)}`,
MCPErrorType.TOOL_CALL_ERROR,
error,
);
}
},
};
}
8 changes: 8 additions & 0 deletions packages/ai/src/mcp/capabilities/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
/**
* MCP capability exports
*/

export * from './tools';
export * from './resources';
export * from './prompts';
export * from './roots';
84 changes: 84 additions & 0 deletions packages/ai/src/mcp/capabilities/prompts.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/**
* Prompt capability management for MCP client
*/

import type { Prompt } from '../client/types';
import { MCPClient } from '../client';
import { MCPError, MCPErrorType } from '../client/types';

/**
* Find a prompt by name
*
* @param client MCP client
* @param name Name of the prompt to find
* @returns Prompt if found, null otherwise
*/
export async function findPromptByName(client: MCPClient, name: string): Promise<Prompt | null> {
const prompts = await client.listPrompts();
return prompts.prompts.find((p) => p.name === name) || null;
}

/**
* Validate prompt arguments
*
* @param prompt Prompt to validate arguments for
* @param args Arguments to validate
* @returns Validated arguments
*/
export function validatePromptArgs(
prompt: Prompt,
args: Record<string, unknown>,
): Record<string, unknown> {
if (!prompt.arguments || prompt.arguments.length === 0) {
return {};
}

const validatedArgs: Record<string, unknown> = {};
const missingRequired: string[] = [];

for (const arg of prompt.arguments) {
const value = args[arg.name];

if (arg.required && (value === undefined || value === null)) {
missingRequired.push(arg.name);
} else if (value !== undefined) {
validatedArgs[arg.name] = value;
}
}

if (missingRequired.length > 0) {
throw new MCPError(
`Missing required arguments for prompt ${prompt.name}: ${missingRequired.join(', ')}`,
MCPErrorType.VALIDATION_ERROR,
);
}

return validatedArgs;
}

/**
* Get a prompt with validated arguments
*
* @param client MCP client
* @param promptName Name of the prompt to get
* @param args Arguments for the prompt
* @returns Prompt with validated arguments
*/
export async function getPromptWithValidation(
client: MCPClient,
promptName: string,
args: Record<string, unknown>,
) {
const prompt = await findPromptByName(client, promptName);

if (!prompt) {
throw new MCPError(`Prompt not found: ${promptName}`, MCPErrorType.PROMPT_ERROR);
}

const validatedArgs = validatePromptArgs(prompt, args);

return {
prompt,
args: validatedArgs,
};
}
Loading