Fix/MCP exposure v2 TODO's (#675)
* fix: OAuth tokens secure storage for Windows & Linux * fix(mcp): MCP Tool Re-exposure & Strict Input Validation Fixes the MCP re-exposure bug by correctly handling tool deduplication, input validation with Ajv, and structured output (including images). Also disables experimental API betas by default to prevent 500 errors on external accounts. * fix(mcp): skip official registry prefetch in non-first-party mode Prevents unnecessary calls to Anthropic's MCP registry when using other API providers. * fix(cli): disable experimental API betas by default This prevents 500 errors from Anthropic's API when tool-calling with non-Anthropic accounts or models that don't support certain beta features. * fix: issues raised in the PR review for #675
This commit is contained in:
@@ -11,7 +11,12 @@ import { MCPServerDesktopImportDialog } from '../../components/MCPServerDesktopI
|
||||
import { render } from '../../ink.js';
|
||||
import { KeybindingSetup } from '../../keybindings/KeybindingProviderSetup.js';
|
||||
import { type AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS, logEvent } from '../../services/analytics/index.js';
|
||||
import { clearMcpClientConfig, clearServerTokensFromLocalStorage, readClientSecret, saveMcpClientSecret } from '../../services/mcp/auth.js';
|
||||
import {
|
||||
clearMcpClientConfig,
|
||||
clearServerTokensFromSecureStorage,
|
||||
readClientSecret,
|
||||
saveMcpClientSecret,
|
||||
} from '../../services/mcp/auth.js'
|
||||
import { doctorAllServers, doctorServer, type McpDoctorReport, type McpDoctorScopeFilter } from '../../services/mcp/doctor.js';
|
||||
import { connectToServer, getMcpServerConnectionBatchSize } from '../../services/mcp/client.js';
|
||||
import { addMcpConfig, getAllMcpConfigs, getMcpConfigByName, getMcpConfigsByScope, removeMcpConfig } from '../../services/mcp/config.js';
|
||||
|
||||
75
src/entrypoints/mcp.test.ts
Normal file
75
src/entrypoints/mcp.test.ts
Normal file
@@ -0,0 +1,75 @@
|
||||
import { describe, it, expect, mock } from 'bun:test'
|
||||
import { getCombinedTools, loadReexposedMcpTools } from './mcp.js'
|
||||
import type { Tool as InternalTool } from '../Tool.js'
|
||||
import type { MCPServerConnection } from '../services/mcp/types.js'
|
||||
import type { Tool } from '@modelcontextprotocol/sdk/types.js'
|
||||
|
||||
// Mock the MCP client service to control the tools and connections returned
|
||||
const mockGetMcpToolsCommandsAndResources = mock(async (onConnectionAttempt: any) => {})
|
||||
mock.module('../services/mcp/client.js', () => ({
|
||||
getMcpToolsCommandsAndResources: mockGetMcpToolsCommandsAndResources
|
||||
}))
|
||||
|
||||
describe('getCombinedTools', () => {
|
||||
it('deduplicates builtins when mcpTools have the same name, prioritizing mcpTools', () => {
|
||||
const builtinBash = { name: 'Bash', isMcp: false } as unknown as InternalTool
|
||||
const builtinRead = { name: 'Read', isMcp: false } as unknown as InternalTool
|
||||
const mcpBash = { name: 'Bash', isMcp: true } as unknown as InternalTool
|
||||
|
||||
const builtins = [builtinBash, builtinRead]
|
||||
const mcpTools = [mcpBash]
|
||||
|
||||
const result = getCombinedTools(builtins, mcpTools)
|
||||
|
||||
expect(result).toHaveLength(2)
|
||||
expect(result[0]).toBe(mcpBash)
|
||||
expect(result[1]).toBe(builtinRead)
|
||||
})
|
||||
})
|
||||
|
||||
describe('loadReexposedMcpTools', () => {
|
||||
it('loads tools and clients regardless of connection state (including needs-auth)', async () => {
|
||||
// Setup the mock to simulate yielding a needs-auth server and a connected server
|
||||
mockGetMcpToolsCommandsAndResources.mockImplementation(async (onConnectionAttempt) => {
|
||||
const needsAuthClient = {
|
||||
name: 'auth-server',
|
||||
type: 'needs-auth',
|
||||
config: {}
|
||||
} as MCPServerConnection
|
||||
|
||||
const authTool = {
|
||||
name: 'mcp__auth-server__authenticate',
|
||||
isMcp: true
|
||||
} as unknown as InternalTool
|
||||
|
||||
const connectedClient = {
|
||||
name: 'connected-server',
|
||||
type: 'connected',
|
||||
config: {},
|
||||
client: {}
|
||||
} as MCPServerConnection
|
||||
|
||||
const connectedTool = {
|
||||
name: 'mcp__connected-server__do_thing',
|
||||
isMcp: true
|
||||
} as unknown as InternalTool
|
||||
|
||||
// Simulate the callback behavior
|
||||
onConnectionAttempt({ client: needsAuthClient, tools: [authTool], commands: [] })
|
||||
onConnectionAttempt({ client: connectedClient, tools: [connectedTool], commands: [] })
|
||||
})
|
||||
|
||||
const { mcpClients, mcpTools } = await loadReexposedMcpTools()
|
||||
|
||||
expect(mcpClients).toHaveLength(2)
|
||||
expect(mcpClients[0].type).toBe('needs-auth')
|
||||
expect(mcpClients[1].type).toBe('connected')
|
||||
|
||||
expect(mcpTools).toHaveLength(2)
|
||||
expect(mcpTools[0].name).toBe('mcp__auth-server__authenticate')
|
||||
expect(mcpTools[1].name).toBe('mcp__connected-server__do_thing')
|
||||
|
||||
// Reset mock for other tests
|
||||
mockGetMcpToolsCommandsAndResources.mockReset()
|
||||
})
|
||||
})
|
||||
@@ -7,6 +7,7 @@ process.env.CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS ??= 'true'
|
||||
|
||||
import { Server } from '@modelcontextprotocol/sdk/server/index.js'
|
||||
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'
|
||||
import { ZodError } from 'zod'
|
||||
import {
|
||||
CallToolRequestSchema,
|
||||
type CallToolResult,
|
||||
@@ -17,9 +18,12 @@ import {
|
||||
import { getDefaultAppState } from 'src/state/AppStateStore.js'
|
||||
import review from '../commands/review.js'
|
||||
import type { Command } from '../commands.js'
|
||||
import { getMcpToolsCommandsAndResources } from '../services/mcp/client.js'
|
||||
import type { MCPServerConnection } from '../services/mcp/types.js'
|
||||
import {
|
||||
findToolByName,
|
||||
getEmptyToolPermissionContext,
|
||||
type Tool as InternalTool,
|
||||
type ToolUseContext,
|
||||
} from '../Tool.js'
|
||||
import { getTools } from '../tools.js'
|
||||
@@ -39,6 +43,32 @@ type ToolOutput = Tool['outputSchema']
|
||||
|
||||
const MCP_COMMANDS: Command[] = [review]
|
||||
|
||||
export function getCombinedTools(
|
||||
builtins: InternalTool[],
|
||||
mcpTools: InternalTool[],
|
||||
): InternalTool[] {
|
||||
const mcpToolNames = new Set(mcpTools.map(t => t.name))
|
||||
const deduplicatedBuiltins = builtins.filter(t => !mcpToolNames.has(t.name))
|
||||
|
||||
return [...mcpTools, ...deduplicatedBuiltins]
|
||||
}
|
||||
|
||||
export async function loadReexposedMcpTools(): Promise<{
|
||||
mcpClients: MCPServerConnection[]
|
||||
mcpTools: InternalTool[]
|
||||
}> {
|
||||
const mcpClients: MCPServerConnection[] = []
|
||||
const mcpTools: InternalTool[] = []
|
||||
|
||||
// Load configured MCP clients and their tools
|
||||
await getMcpToolsCommandsAndResources(({ client, tools: clientTools }) => {
|
||||
mcpClients.push(client)
|
||||
mcpTools.push(...clientTools)
|
||||
})
|
||||
|
||||
return { mcpClients, mcpTools }
|
||||
}
|
||||
|
||||
export async function startMCPServer(
|
||||
cwd: string,
|
||||
debug: boolean,
|
||||
@@ -63,12 +93,13 @@ export async function startMCPServer(
|
||||
},
|
||||
)
|
||||
|
||||
const { mcpClients, mcpTools } = await loadReexposedMcpTools()
|
||||
|
||||
server.setRequestHandler(
|
||||
ListToolsRequestSchema,
|
||||
async (): Promise<ListToolsResult> => {
|
||||
// TODO: Also re-expose any MCP tools
|
||||
const toolPermissionContext = getEmptyToolPermissionContext()
|
||||
const tools = getTools(toolPermissionContext)
|
||||
const tools = getCombinedTools(getTools(toolPermissionContext), mcpTools)
|
||||
return {
|
||||
tools: await Promise.all(
|
||||
tools.map(async tool => {
|
||||
@@ -94,7 +125,7 @@ export async function startMCPServer(
|
||||
tools,
|
||||
agents: [],
|
||||
}),
|
||||
inputSchema: zodToJsonSchema(tool.inputSchema) as ToolInput,
|
||||
inputSchema: (tool.inputJSONSchema ?? zodToJsonSchema(tool.inputSchema)) as ToolInput,
|
||||
outputSchema,
|
||||
}
|
||||
}),
|
||||
@@ -107,8 +138,7 @@ export async function startMCPServer(
|
||||
CallToolRequestSchema,
|
||||
async ({ params: { name, arguments: args } }): Promise<CallToolResult> => {
|
||||
const toolPermissionContext = getEmptyToolPermissionContext()
|
||||
// TODO: Also re-expose any MCP tools
|
||||
const tools = getTools(toolPermissionContext)
|
||||
const tools = getCombinedTools(getTools(toolPermissionContext), mcpTools)
|
||||
const tool = findToolByName(tools, name)
|
||||
if (!tool) {
|
||||
throw new Error(`Tool ${name} not found`)
|
||||
@@ -123,7 +153,7 @@ export async function startMCPServer(
|
||||
tools,
|
||||
mainLoopModel: getMainLoopModel(),
|
||||
thinkingConfig: { type: 'disabled' },
|
||||
mcpClients: [],
|
||||
mcpClients,
|
||||
mcpResources: {},
|
||||
isNonInteractiveSession: true,
|
||||
debug,
|
||||
@@ -140,13 +170,16 @@ export async function startMCPServer(
|
||||
updateAttributionState: () => {},
|
||||
}
|
||||
|
||||
// TODO: validate input types with zod
|
||||
try {
|
||||
if (!tool.isEnabled()) {
|
||||
throw new Error(`Tool ${name} is not enabled`)
|
||||
}
|
||||
|
||||
// Validate input types with zod
|
||||
const parsedArgs = tool.inputSchema.parse(args ?? {})
|
||||
|
||||
const validationResult = await tool.validateInput?.(
|
||||
(args as never) ?? {},
|
||||
(parsedArgs as never) ?? {},
|
||||
toolUseContext,
|
||||
)
|
||||
if (validationResult && !validationResult.result) {
|
||||
@@ -155,7 +188,7 @@ export async function startMCPServer(
|
||||
)
|
||||
}
|
||||
const finalResult = await tool.call(
|
||||
(args ?? {}) as never,
|
||||
(parsedArgs ?? {}) as never,
|
||||
toolUseContext,
|
||||
hasPermissionsToUseTool,
|
||||
createAssistantMessage({
|
||||
@@ -163,20 +196,50 @@ export async function startMCPServer(
|
||||
}),
|
||||
)
|
||||
|
||||
let content: CallToolResult['content']
|
||||
const data = finalResult.data as string | { type: string; text?: string; source?: { type: string; media_type: string; data: string } }[] | unknown
|
||||
|
||||
if (typeof data === 'string') {
|
||||
content = [{ type: 'text', text: data }]
|
||||
} else if (Array.isArray(data)) {
|
||||
content = data.map((block: any) => {
|
||||
if (block.type === 'text') {
|
||||
return { type: 'text', text: block.text || '' }
|
||||
} else if (block.type === 'image' && block.source) {
|
||||
return {
|
||||
type: 'image',
|
||||
data: block.source.data,
|
||||
mimeType: block.source.media_type,
|
||||
}
|
||||
} else {
|
||||
// eslint-disable-next-line custom-rules/no-top-level-side-effects, no-console
|
||||
console.warn(`Unmapped content block type from tool ${name}: ${block.type || 'unknown'}`)
|
||||
return { type: 'text', text: jsonStringify(block) }
|
||||
}
|
||||
}) as CallToolResult['content']
|
||||
} else {
|
||||
content = [{ type: 'text', text: jsonStringify(data) }]
|
||||
}
|
||||
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text' as const,
|
||||
text:
|
||||
typeof finalResult === 'string'
|
||||
? finalResult
|
||||
: jsonStringify(finalResult.data),
|
||||
},
|
||||
],
|
||||
content,
|
||||
isError: !!(finalResult as any).isError,
|
||||
}
|
||||
} catch (error) {
|
||||
logError(error)
|
||||
|
||||
if (error instanceof ZodError) {
|
||||
return {
|
||||
isError: true,
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: `Tool ${name} input is invalid:\n${error.errors.map(e => `- ${e.path.join('.')}: ${e.message}`).join('\n')}`,
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
const parts =
|
||||
error instanceof Error ? getErrorParts(error) : [String(error)]
|
||||
const errorText = parts.filter(Boolean).join('\n').trim() || 'Error'
|
||||
@@ -201,3 +264,4 @@ export async function startMCPServer(
|
||||
|
||||
return await runServer()
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import { Ajv } from 'ajv'
|
||||
import { z } from 'zod/v4'
|
||||
import { buildTool, type ToolDef } from '../../Tool.js'
|
||||
import { buildTool, type ToolDef, type ValidationResult } from '../../Tool.js'
|
||||
import { lazySchema } from '../../utils/lazySchema.js'
|
||||
import type { PermissionResult } from '../../utils/permissions/PermissionResult.js'
|
||||
import type { PermissionResult } from '../../types/permissions.js'
|
||||
import { isOutputLineTruncated } from '../../utils/terminal.js'
|
||||
import { DESCRIPTION, PROMPT } from './prompt.js'
|
||||
import {
|
||||
@@ -37,6 +38,8 @@ export type Output = z.infer<OutputSchema>
|
||||
// Re-export MCPProgress from centralized types to break import cycles
|
||||
export type { MCPProgress } from '../../types/tools.js'
|
||||
|
||||
const ajv = new Ajv({ strict: false })
|
||||
|
||||
export const MCPTool = buildTool({
|
||||
isMcp: true,
|
||||
// Overridden in mcpClient.ts with the real MCP tool name + args
|
||||
@@ -72,6 +75,27 @@ export const MCPTool = buildTool({
|
||||
message: 'MCPTool requires permission.',
|
||||
}
|
||||
},
|
||||
async validateInput(input, context): Promise<ValidationResult> {
|
||||
if (this.inputJSONSchema) {
|
||||
try {
|
||||
const validate = ajv.compile(this.inputJSONSchema)
|
||||
if (!validate(input)) {
|
||||
return {
|
||||
result: false,
|
||||
message: ajv.errorsText(validate.errors),
|
||||
errorCode: 400,
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
return {
|
||||
result: false,
|
||||
message: `Failed to compile JSON schema for validation: ${error}`,
|
||||
errorCode: 500,
|
||||
}
|
||||
}
|
||||
}
|
||||
return { result: true }
|
||||
},
|
||||
renderToolUseMessage,
|
||||
// Overridden in mcpClient.ts
|
||||
userFacingName: () => 'mcp',
|
||||
@@ -100,3 +124,4 @@ export const MCPTool = buildTool({
|
||||
}
|
||||
},
|
||||
} satisfies ToolDef<InputSchema, Output>)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user