From 77083d769be332830ca30e43aa707d3da4df25c2 Mon Sep 17 00:00:00 2001 From: Yakout <154746121+MarawanYakout@users.noreply.github.com> Date: Wed, 15 Apr 2026 23:03:06 +0200 Subject: [PATCH] Fix/MCP exposure v2 TODO's (#675) * fix: OAuth tokens secure storage for Windows & Linux * fix(mcp): MCP Tool Re-exposure & Strict Input Validation Fixes the MCP re-exposure bug by correctly handling tool deduplication, input validation with Ajv, and structured output (including images). Also disables experimental API betas by default to prevent 500 errors on external accounts. * fix(mcp): skip official registry prefetch in non-first-party mode Prevents unnecessary calls to Anthropic's MCP registry when using other API providers. * fix(cli): disable experimental API betas by default This prevents 500 errors from Anthropic's API when tool-calling with non-Anthropic accounts or models that don't support certain beta features. * fix: issues raised in the PR review for #675 --- src/cli/handlers/mcp.tsx | 7 ++- src/entrypoints/mcp.test.ts | 75 ++++++++++++++++++++++++++ src/entrypoints/mcp.ts | 100 ++++++++++++++++++++++++++++------- src/tools/MCPTool/MCPTool.ts | 29 +++++++++- 4 files changed, 190 insertions(+), 21 deletions(-) create mode 100644 src/entrypoints/mcp.test.ts diff --git a/src/cli/handlers/mcp.tsx b/src/cli/handlers/mcp.tsx index a645eb43..230c9dfd 100644 --- a/src/cli/handlers/mcp.tsx +++ b/src/cli/handlers/mcp.tsx @@ -11,7 +11,12 @@ import { MCPServerDesktopImportDialog } from '../../components/MCPServerDesktopI import { render } from '../../ink.js'; import { KeybindingSetup } from '../../keybindings/KeybindingProviderSetup.js'; import { type AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS, logEvent } from '../../services/analytics/index.js'; -import { clearMcpClientConfig, clearServerTokensFromLocalStorage, readClientSecret, saveMcpClientSecret } from '../../services/mcp/auth.js'; +import { + clearMcpClientConfig, + clearServerTokensFromSecureStorage, + readClientSecret, + saveMcpClientSecret, +} from '../../services/mcp/auth.js' import { doctorAllServers, doctorServer, type McpDoctorReport, type McpDoctorScopeFilter } from '../../services/mcp/doctor.js'; import { connectToServer, getMcpServerConnectionBatchSize } from '../../services/mcp/client.js'; import { addMcpConfig, getAllMcpConfigs, getMcpConfigByName, getMcpConfigsByScope, removeMcpConfig } from '../../services/mcp/config.js'; diff --git a/src/entrypoints/mcp.test.ts b/src/entrypoints/mcp.test.ts new file mode 100644 index 00000000..939f5059 --- /dev/null +++ b/src/entrypoints/mcp.test.ts @@ -0,0 +1,75 @@ +import { describe, it, expect, mock } from 'bun:test' +import { getCombinedTools, loadReexposedMcpTools } from './mcp.js' +import type { Tool as InternalTool } from '../Tool.js' +import type { MCPServerConnection } from '../services/mcp/types.js' +import type { Tool } from '@modelcontextprotocol/sdk/types.js' + +// Mock the MCP client service to control the tools and connections returned +const mockGetMcpToolsCommandsAndResources = mock(async (onConnectionAttempt: any) => {}) +mock.module('../services/mcp/client.js', () => ({ + getMcpToolsCommandsAndResources: mockGetMcpToolsCommandsAndResources +})) + +describe('getCombinedTools', () => { + it('deduplicates builtins when mcpTools have the same name, prioritizing mcpTools', () => { + const builtinBash = { name: 'Bash', isMcp: false } as unknown as InternalTool + const builtinRead = { name: 'Read', isMcp: false } as unknown as InternalTool + const mcpBash = { name: 'Bash', isMcp: true } as unknown as InternalTool + + const builtins = [builtinBash, builtinRead] + const mcpTools = [mcpBash] + + const result = getCombinedTools(builtins, mcpTools) + + expect(result).toHaveLength(2) + expect(result[0]).toBe(mcpBash) + expect(result[1]).toBe(builtinRead) + }) +}) + +describe('loadReexposedMcpTools', () => { + it('loads tools and clients regardless of connection state (including needs-auth)', async () => { + // Setup the mock to simulate yielding a needs-auth server and a connected server + mockGetMcpToolsCommandsAndResources.mockImplementation(async (onConnectionAttempt) => { + const needsAuthClient = { + name: 'auth-server', + type: 'needs-auth', + config: {} + } as MCPServerConnection + + const authTool = { + name: 'mcp__auth-server__authenticate', + isMcp: true + } as unknown as InternalTool + + const connectedClient = { + name: 'connected-server', + type: 'connected', + config: {}, + client: {} + } as MCPServerConnection + + const connectedTool = { + name: 'mcp__connected-server__do_thing', + isMcp: true + } as unknown as InternalTool + + // Simulate the callback behavior + onConnectionAttempt({ client: needsAuthClient, tools: [authTool], commands: [] }) + onConnectionAttempt({ client: connectedClient, tools: [connectedTool], commands: [] }) + }) + + const { mcpClients, mcpTools } = await loadReexposedMcpTools() + + expect(mcpClients).toHaveLength(2) + expect(mcpClients[0].type).toBe('needs-auth') + expect(mcpClients[1].type).toBe('connected') + + expect(mcpTools).toHaveLength(2) + expect(mcpTools[0].name).toBe('mcp__auth-server__authenticate') + expect(mcpTools[1].name).toBe('mcp__connected-server__do_thing') + + // Reset mock for other tests + mockGetMcpToolsCommandsAndResources.mockReset() + }) +}) diff --git a/src/entrypoints/mcp.ts b/src/entrypoints/mcp.ts index 05421b4f..84a28528 100644 --- a/src/entrypoints/mcp.ts +++ b/src/entrypoints/mcp.ts @@ -7,6 +7,7 @@ process.env.CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS ??= 'true' import { Server } from '@modelcontextprotocol/sdk/server/index.js' import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js' +import { ZodError } from 'zod' import { CallToolRequestSchema, type CallToolResult, @@ -17,9 +18,12 @@ import { import { getDefaultAppState } from 'src/state/AppStateStore.js' import review from '../commands/review.js' import type { Command } from '../commands.js' +import { getMcpToolsCommandsAndResources } from '../services/mcp/client.js' +import type { MCPServerConnection } from '../services/mcp/types.js' import { findToolByName, getEmptyToolPermissionContext, + type Tool as InternalTool, type ToolUseContext, } from '../Tool.js' import { getTools } from '../tools.js' @@ -39,6 +43,32 @@ type ToolOutput = Tool['outputSchema'] const MCP_COMMANDS: Command[] = [review] +export function getCombinedTools( + builtins: InternalTool[], + mcpTools: InternalTool[], +): InternalTool[] { + const mcpToolNames = new Set(mcpTools.map(t => t.name)) + const deduplicatedBuiltins = builtins.filter(t => !mcpToolNames.has(t.name)) + + return [...mcpTools, ...deduplicatedBuiltins] +} + +export async function loadReexposedMcpTools(): Promise<{ + mcpClients: MCPServerConnection[] + mcpTools: InternalTool[] +}> { + const mcpClients: MCPServerConnection[] = [] + const mcpTools: InternalTool[] = [] + + // Load configured MCP clients and their tools + await getMcpToolsCommandsAndResources(({ client, tools: clientTools }) => { + mcpClients.push(client) + mcpTools.push(...clientTools) + }) + + return { mcpClients, mcpTools } +} + export async function startMCPServer( cwd: string, debug: boolean, @@ -63,12 +93,13 @@ export async function startMCPServer( }, ) + const { mcpClients, mcpTools } = await loadReexposedMcpTools() + server.setRequestHandler( ListToolsRequestSchema, async (): Promise => { - // TODO: Also re-expose any MCP tools const toolPermissionContext = getEmptyToolPermissionContext() - const tools = getTools(toolPermissionContext) + const tools = getCombinedTools(getTools(toolPermissionContext), mcpTools) return { tools: await Promise.all( tools.map(async tool => { @@ -94,7 +125,7 @@ export async function startMCPServer( tools, agents: [], }), - inputSchema: zodToJsonSchema(tool.inputSchema) as ToolInput, + inputSchema: (tool.inputJSONSchema ?? zodToJsonSchema(tool.inputSchema)) as ToolInput, outputSchema, } }), @@ -107,8 +138,7 @@ export async function startMCPServer( CallToolRequestSchema, async ({ params: { name, arguments: args } }): Promise => { const toolPermissionContext = getEmptyToolPermissionContext() - // TODO: Also re-expose any MCP tools - const tools = getTools(toolPermissionContext) + const tools = getCombinedTools(getTools(toolPermissionContext), mcpTools) const tool = findToolByName(tools, name) if (!tool) { throw new Error(`Tool ${name} not found`) @@ -123,7 +153,7 @@ export async function startMCPServer( tools, mainLoopModel: getMainLoopModel(), thinkingConfig: { type: 'disabled' }, - mcpClients: [], + mcpClients, mcpResources: {}, isNonInteractiveSession: true, debug, @@ -140,13 +170,16 @@ export async function startMCPServer( updateAttributionState: () => {}, } - // TODO: validate input types with zod try { if (!tool.isEnabled()) { throw new Error(`Tool ${name} is not enabled`) } + + // Validate input types with zod + const parsedArgs = tool.inputSchema.parse(args ?? {}) + const validationResult = await tool.validateInput?.( - (args as never) ?? {}, + (parsedArgs as never) ?? {}, toolUseContext, ) if (validationResult && !validationResult.result) { @@ -155,7 +188,7 @@ export async function startMCPServer( ) } const finalResult = await tool.call( - (args ?? {}) as never, + (parsedArgs ?? {}) as never, toolUseContext, hasPermissionsToUseTool, createAssistantMessage({ @@ -163,20 +196,50 @@ export async function startMCPServer( }), ) + let content: CallToolResult['content'] + const data = finalResult.data as string | { type: string; text?: string; source?: { type: string; media_type: string; data: string } }[] | unknown + + if (typeof data === 'string') { + content = [{ type: 'text', text: data }] + } else if (Array.isArray(data)) { + content = data.map((block: any) => { + if (block.type === 'text') { + return { type: 'text', text: block.text || '' } + } else if (block.type === 'image' && block.source) { + return { + type: 'image', + data: block.source.data, + mimeType: block.source.media_type, + } + } else { + // eslint-disable-next-line custom-rules/no-top-level-side-effects, no-console + console.warn(`Unmapped content block type from tool ${name}: ${block.type || 'unknown'}`) + return { type: 'text', text: jsonStringify(block) } + } + }) as CallToolResult['content'] + } else { + content = [{ type: 'text', text: jsonStringify(data) }] + } + return { - content: [ - { - type: 'text' as const, - text: - typeof finalResult === 'string' - ? finalResult - : jsonStringify(finalResult.data), - }, - ], + content, + isError: !!(finalResult as any).isError, } } catch (error) { logError(error) + if (error instanceof ZodError) { + return { + isError: true, + content: [ + { + type: 'text', + text: `Tool ${name} input is invalid:\n${error.errors.map(e => `- ${e.path.join('.')}: ${e.message}`).join('\n')}`, + }, + ], + } + } + const parts = error instanceof Error ? getErrorParts(error) : [String(error)] const errorText = parts.filter(Boolean).join('\n').trim() || 'Error' @@ -201,3 +264,4 @@ export async function startMCPServer( return await runServer() } + diff --git a/src/tools/MCPTool/MCPTool.ts b/src/tools/MCPTool/MCPTool.ts index 628a5b5a..d7aa68a1 100644 --- a/src/tools/MCPTool/MCPTool.ts +++ b/src/tools/MCPTool/MCPTool.ts @@ -1,7 +1,8 @@ +import { Ajv } from 'ajv' import { z } from 'zod/v4' -import { buildTool, type ToolDef } from '../../Tool.js' +import { buildTool, type ToolDef, type ValidationResult } from '../../Tool.js' import { lazySchema } from '../../utils/lazySchema.js' -import type { PermissionResult } from '../../utils/permissions/PermissionResult.js' +import type { PermissionResult } from '../../types/permissions.js' import { isOutputLineTruncated } from '../../utils/terminal.js' import { DESCRIPTION, PROMPT } from './prompt.js' import { @@ -37,6 +38,8 @@ export type Output = z.infer // Re-export MCPProgress from centralized types to break import cycles export type { MCPProgress } from '../../types/tools.js' +const ajv = new Ajv({ strict: false }) + export const MCPTool = buildTool({ isMcp: true, // Overridden in mcpClient.ts with the real MCP tool name + args @@ -72,6 +75,27 @@ export const MCPTool = buildTool({ message: 'MCPTool requires permission.', } }, + async validateInput(input, context): Promise { + if (this.inputJSONSchema) { + try { + const validate = ajv.compile(this.inputJSONSchema) + if (!validate(input)) { + return { + result: false, + message: ajv.errorsText(validate.errors), + errorCode: 400, + } + } + } catch (error) { + return { + result: false, + message: `Failed to compile JSON schema for validation: ${error}`, + errorCode: 500, + } + } + } + return { result: true } + }, renderToolUseMessage, // Overridden in mcpClient.ts userFacingName: () => 'mcp', @@ -100,3 +124,4 @@ export const MCPTool = buildTool({ } }, } satisfies ToolDef) +