Fix/MCP exposure v2 TODO's (#675)
* fix: OAuth tokens secure storage for Windows & Linux * fix(mcp): MCP Tool Re-exposure & Strict Input Validation Fixes the MCP re-exposure bug by correctly handling tool deduplication, input validation with Ajv, and structured output (including images). Also disables experimental API betas by default to prevent 500 errors on external accounts. * fix(mcp): skip official registry prefetch in non-first-party mode Prevents unnecessary calls to Anthropic's MCP registry when using other API providers. * fix(cli): disable experimental API betas by default This prevents 500 errors from Anthropic's API when tool-calling with non-Anthropic accounts or models that don't support certain beta features. * fix: issues raised in the PR review for #675
This commit is contained in:
@@ -11,7 +11,12 @@ import { MCPServerDesktopImportDialog } from '../../components/MCPServerDesktopI
|
|||||||
import { render } from '../../ink.js';
|
import { render } from '../../ink.js';
|
||||||
import { KeybindingSetup } from '../../keybindings/KeybindingProviderSetup.js';
|
import { KeybindingSetup } from '../../keybindings/KeybindingProviderSetup.js';
|
||||||
import { type AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS, logEvent } from '../../services/analytics/index.js';
|
import { type AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS, logEvent } from '../../services/analytics/index.js';
|
||||||
import { clearMcpClientConfig, clearServerTokensFromLocalStorage, readClientSecret, saveMcpClientSecret } from '../../services/mcp/auth.js';
|
import {
|
||||||
|
clearMcpClientConfig,
|
||||||
|
clearServerTokensFromSecureStorage,
|
||||||
|
readClientSecret,
|
||||||
|
saveMcpClientSecret,
|
||||||
|
} from '../../services/mcp/auth.js'
|
||||||
import { doctorAllServers, doctorServer, type McpDoctorReport, type McpDoctorScopeFilter } from '../../services/mcp/doctor.js';
|
import { doctorAllServers, doctorServer, type McpDoctorReport, type McpDoctorScopeFilter } from '../../services/mcp/doctor.js';
|
||||||
import { connectToServer, getMcpServerConnectionBatchSize } from '../../services/mcp/client.js';
|
import { connectToServer, getMcpServerConnectionBatchSize } from '../../services/mcp/client.js';
|
||||||
import { addMcpConfig, getAllMcpConfigs, getMcpConfigByName, getMcpConfigsByScope, removeMcpConfig } from '../../services/mcp/config.js';
|
import { addMcpConfig, getAllMcpConfigs, getMcpConfigByName, getMcpConfigsByScope, removeMcpConfig } from '../../services/mcp/config.js';
|
||||||
|
|||||||
75
src/entrypoints/mcp.test.ts
Normal file
75
src/entrypoints/mcp.test.ts
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
import { describe, it, expect, mock } from 'bun:test'
|
||||||
|
import { getCombinedTools, loadReexposedMcpTools } from './mcp.js'
|
||||||
|
import type { Tool as InternalTool } from '../Tool.js'
|
||||||
|
import type { MCPServerConnection } from '../services/mcp/types.js'
|
||||||
|
import type { Tool } from '@modelcontextprotocol/sdk/types.js'
|
||||||
|
|
||||||
|
// Mock the MCP client service to control the tools and connections returned
|
||||||
|
const mockGetMcpToolsCommandsAndResources = mock(async (onConnectionAttempt: any) => {})
|
||||||
|
mock.module('../services/mcp/client.js', () => ({
|
||||||
|
getMcpToolsCommandsAndResources: mockGetMcpToolsCommandsAndResources
|
||||||
|
}))
|
||||||
|
|
||||||
|
describe('getCombinedTools', () => {
|
||||||
|
it('deduplicates builtins when mcpTools have the same name, prioritizing mcpTools', () => {
|
||||||
|
const builtinBash = { name: 'Bash', isMcp: false } as unknown as InternalTool
|
||||||
|
const builtinRead = { name: 'Read', isMcp: false } as unknown as InternalTool
|
||||||
|
const mcpBash = { name: 'Bash', isMcp: true } as unknown as InternalTool
|
||||||
|
|
||||||
|
const builtins = [builtinBash, builtinRead]
|
||||||
|
const mcpTools = [mcpBash]
|
||||||
|
|
||||||
|
const result = getCombinedTools(builtins, mcpTools)
|
||||||
|
|
||||||
|
expect(result).toHaveLength(2)
|
||||||
|
expect(result[0]).toBe(mcpBash)
|
||||||
|
expect(result[1]).toBe(builtinRead)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('loadReexposedMcpTools', () => {
|
||||||
|
it('loads tools and clients regardless of connection state (including needs-auth)', async () => {
|
||||||
|
// Setup the mock to simulate yielding a needs-auth server and a connected server
|
||||||
|
mockGetMcpToolsCommandsAndResources.mockImplementation(async (onConnectionAttempt) => {
|
||||||
|
const needsAuthClient = {
|
||||||
|
name: 'auth-server',
|
||||||
|
type: 'needs-auth',
|
||||||
|
config: {}
|
||||||
|
} as MCPServerConnection
|
||||||
|
|
||||||
|
const authTool = {
|
||||||
|
name: 'mcp__auth-server__authenticate',
|
||||||
|
isMcp: true
|
||||||
|
} as unknown as InternalTool
|
||||||
|
|
||||||
|
const connectedClient = {
|
||||||
|
name: 'connected-server',
|
||||||
|
type: 'connected',
|
||||||
|
config: {},
|
||||||
|
client: {}
|
||||||
|
} as MCPServerConnection
|
||||||
|
|
||||||
|
const connectedTool = {
|
||||||
|
name: 'mcp__connected-server__do_thing',
|
||||||
|
isMcp: true
|
||||||
|
} as unknown as InternalTool
|
||||||
|
|
||||||
|
// Simulate the callback behavior
|
||||||
|
onConnectionAttempt({ client: needsAuthClient, tools: [authTool], commands: [] })
|
||||||
|
onConnectionAttempt({ client: connectedClient, tools: [connectedTool], commands: [] })
|
||||||
|
})
|
||||||
|
|
||||||
|
const { mcpClients, mcpTools } = await loadReexposedMcpTools()
|
||||||
|
|
||||||
|
expect(mcpClients).toHaveLength(2)
|
||||||
|
expect(mcpClients[0].type).toBe('needs-auth')
|
||||||
|
expect(mcpClients[1].type).toBe('connected')
|
||||||
|
|
||||||
|
expect(mcpTools).toHaveLength(2)
|
||||||
|
expect(mcpTools[0].name).toBe('mcp__auth-server__authenticate')
|
||||||
|
expect(mcpTools[1].name).toBe('mcp__connected-server__do_thing')
|
||||||
|
|
||||||
|
// Reset mock for other tests
|
||||||
|
mockGetMcpToolsCommandsAndResources.mockReset()
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -7,6 +7,7 @@ process.env.CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS ??= 'true'
|
|||||||
|
|
||||||
import { Server } from '@modelcontextprotocol/sdk/server/index.js'
|
import { Server } from '@modelcontextprotocol/sdk/server/index.js'
|
||||||
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'
|
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'
|
||||||
|
import { ZodError } from 'zod'
|
||||||
import {
|
import {
|
||||||
CallToolRequestSchema,
|
CallToolRequestSchema,
|
||||||
type CallToolResult,
|
type CallToolResult,
|
||||||
@@ -17,9 +18,12 @@ import {
|
|||||||
import { getDefaultAppState } from 'src/state/AppStateStore.js'
|
import { getDefaultAppState } from 'src/state/AppStateStore.js'
|
||||||
import review from '../commands/review.js'
|
import review from '../commands/review.js'
|
||||||
import type { Command } from '../commands.js'
|
import type { Command } from '../commands.js'
|
||||||
|
import { getMcpToolsCommandsAndResources } from '../services/mcp/client.js'
|
||||||
|
import type { MCPServerConnection } from '../services/mcp/types.js'
|
||||||
import {
|
import {
|
||||||
findToolByName,
|
findToolByName,
|
||||||
getEmptyToolPermissionContext,
|
getEmptyToolPermissionContext,
|
||||||
|
type Tool as InternalTool,
|
||||||
type ToolUseContext,
|
type ToolUseContext,
|
||||||
} from '../Tool.js'
|
} from '../Tool.js'
|
||||||
import { getTools } from '../tools.js'
|
import { getTools } from '../tools.js'
|
||||||
@@ -39,6 +43,32 @@ type ToolOutput = Tool['outputSchema']
|
|||||||
|
|
||||||
const MCP_COMMANDS: Command[] = [review]
|
const MCP_COMMANDS: Command[] = [review]
|
||||||
|
|
||||||
|
export function getCombinedTools(
|
||||||
|
builtins: InternalTool[],
|
||||||
|
mcpTools: InternalTool[],
|
||||||
|
): InternalTool[] {
|
||||||
|
const mcpToolNames = new Set(mcpTools.map(t => t.name))
|
||||||
|
const deduplicatedBuiltins = builtins.filter(t => !mcpToolNames.has(t.name))
|
||||||
|
|
||||||
|
return [...mcpTools, ...deduplicatedBuiltins]
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function loadReexposedMcpTools(): Promise<{
|
||||||
|
mcpClients: MCPServerConnection[]
|
||||||
|
mcpTools: InternalTool[]
|
||||||
|
}> {
|
||||||
|
const mcpClients: MCPServerConnection[] = []
|
||||||
|
const mcpTools: InternalTool[] = []
|
||||||
|
|
||||||
|
// Load configured MCP clients and their tools
|
||||||
|
await getMcpToolsCommandsAndResources(({ client, tools: clientTools }) => {
|
||||||
|
mcpClients.push(client)
|
||||||
|
mcpTools.push(...clientTools)
|
||||||
|
})
|
||||||
|
|
||||||
|
return { mcpClients, mcpTools }
|
||||||
|
}
|
||||||
|
|
||||||
export async function startMCPServer(
|
export async function startMCPServer(
|
||||||
cwd: string,
|
cwd: string,
|
||||||
debug: boolean,
|
debug: boolean,
|
||||||
@@ -63,12 +93,13 @@ export async function startMCPServer(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const { mcpClients, mcpTools } = await loadReexposedMcpTools()
|
||||||
|
|
||||||
server.setRequestHandler(
|
server.setRequestHandler(
|
||||||
ListToolsRequestSchema,
|
ListToolsRequestSchema,
|
||||||
async (): Promise<ListToolsResult> => {
|
async (): Promise<ListToolsResult> => {
|
||||||
// TODO: Also re-expose any MCP tools
|
|
||||||
const toolPermissionContext = getEmptyToolPermissionContext()
|
const toolPermissionContext = getEmptyToolPermissionContext()
|
||||||
const tools = getTools(toolPermissionContext)
|
const tools = getCombinedTools(getTools(toolPermissionContext), mcpTools)
|
||||||
return {
|
return {
|
||||||
tools: await Promise.all(
|
tools: await Promise.all(
|
||||||
tools.map(async tool => {
|
tools.map(async tool => {
|
||||||
@@ -94,7 +125,7 @@ export async function startMCPServer(
|
|||||||
tools,
|
tools,
|
||||||
agents: [],
|
agents: [],
|
||||||
}),
|
}),
|
||||||
inputSchema: zodToJsonSchema(tool.inputSchema) as ToolInput,
|
inputSchema: (tool.inputJSONSchema ?? zodToJsonSchema(tool.inputSchema)) as ToolInput,
|
||||||
outputSchema,
|
outputSchema,
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
@@ -107,8 +138,7 @@ export async function startMCPServer(
|
|||||||
CallToolRequestSchema,
|
CallToolRequestSchema,
|
||||||
async ({ params: { name, arguments: args } }): Promise<CallToolResult> => {
|
async ({ params: { name, arguments: args } }): Promise<CallToolResult> => {
|
||||||
const toolPermissionContext = getEmptyToolPermissionContext()
|
const toolPermissionContext = getEmptyToolPermissionContext()
|
||||||
// TODO: Also re-expose any MCP tools
|
const tools = getCombinedTools(getTools(toolPermissionContext), mcpTools)
|
||||||
const tools = getTools(toolPermissionContext)
|
|
||||||
const tool = findToolByName(tools, name)
|
const tool = findToolByName(tools, name)
|
||||||
if (!tool) {
|
if (!tool) {
|
||||||
throw new Error(`Tool ${name} not found`)
|
throw new Error(`Tool ${name} not found`)
|
||||||
@@ -123,7 +153,7 @@ export async function startMCPServer(
|
|||||||
tools,
|
tools,
|
||||||
mainLoopModel: getMainLoopModel(),
|
mainLoopModel: getMainLoopModel(),
|
||||||
thinkingConfig: { type: 'disabled' },
|
thinkingConfig: { type: 'disabled' },
|
||||||
mcpClients: [],
|
mcpClients,
|
||||||
mcpResources: {},
|
mcpResources: {},
|
||||||
isNonInteractiveSession: true,
|
isNonInteractiveSession: true,
|
||||||
debug,
|
debug,
|
||||||
@@ -140,13 +170,16 @@ export async function startMCPServer(
|
|||||||
updateAttributionState: () => {},
|
updateAttributionState: () => {},
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: validate input types with zod
|
|
||||||
try {
|
try {
|
||||||
if (!tool.isEnabled()) {
|
if (!tool.isEnabled()) {
|
||||||
throw new Error(`Tool ${name} is not enabled`)
|
throw new Error(`Tool ${name} is not enabled`)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate input types with zod
|
||||||
|
const parsedArgs = tool.inputSchema.parse(args ?? {})
|
||||||
|
|
||||||
const validationResult = await tool.validateInput?.(
|
const validationResult = await tool.validateInput?.(
|
||||||
(args as never) ?? {},
|
(parsedArgs as never) ?? {},
|
||||||
toolUseContext,
|
toolUseContext,
|
||||||
)
|
)
|
||||||
if (validationResult && !validationResult.result) {
|
if (validationResult && !validationResult.result) {
|
||||||
@@ -155,7 +188,7 @@ export async function startMCPServer(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
const finalResult = await tool.call(
|
const finalResult = await tool.call(
|
||||||
(args ?? {}) as never,
|
(parsedArgs ?? {}) as never,
|
||||||
toolUseContext,
|
toolUseContext,
|
||||||
hasPermissionsToUseTool,
|
hasPermissionsToUseTool,
|
||||||
createAssistantMessage({
|
createAssistantMessage({
|
||||||
@@ -163,20 +196,50 @@ export async function startMCPServer(
|
|||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
let content: CallToolResult['content']
|
||||||
|
const data = finalResult.data as string | { type: string; text?: string; source?: { type: string; media_type: string; data: string } }[] | unknown
|
||||||
|
|
||||||
|
if (typeof data === 'string') {
|
||||||
|
content = [{ type: 'text', text: data }]
|
||||||
|
} else if (Array.isArray(data)) {
|
||||||
|
content = data.map((block: any) => {
|
||||||
|
if (block.type === 'text') {
|
||||||
|
return { type: 'text', text: block.text || '' }
|
||||||
|
} else if (block.type === 'image' && block.source) {
|
||||||
|
return {
|
||||||
|
type: 'image',
|
||||||
|
data: block.source.data,
|
||||||
|
mimeType: block.source.media_type,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// eslint-disable-next-line custom-rules/no-top-level-side-effects, no-console
|
||||||
|
console.warn(`Unmapped content block type from tool ${name}: ${block.type || 'unknown'}`)
|
||||||
|
return { type: 'text', text: jsonStringify(block) }
|
||||||
|
}
|
||||||
|
}) as CallToolResult['content']
|
||||||
|
} else {
|
||||||
|
content = [{ type: 'text', text: jsonStringify(data) }]
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
content: [
|
content,
|
||||||
{
|
isError: !!(finalResult as any).isError,
|
||||||
type: 'text' as const,
|
|
||||||
text:
|
|
||||||
typeof finalResult === 'string'
|
|
||||||
? finalResult
|
|
||||||
: jsonStringify(finalResult.data),
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logError(error)
|
logError(error)
|
||||||
|
|
||||||
|
if (error instanceof ZodError) {
|
||||||
|
return {
|
||||||
|
isError: true,
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'text',
|
||||||
|
text: `Tool ${name} input is invalid:\n${error.errors.map(e => `- ${e.path.join('.')}: ${e.message}`).join('\n')}`,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const parts =
|
const parts =
|
||||||
error instanceof Error ? getErrorParts(error) : [String(error)]
|
error instanceof Error ? getErrorParts(error) : [String(error)]
|
||||||
const errorText = parts.filter(Boolean).join('\n').trim() || 'Error'
|
const errorText = parts.filter(Boolean).join('\n').trim() || 'Error'
|
||||||
@@ -201,3 +264,4 @@ export async function startMCPServer(
|
|||||||
|
|
||||||
return await runServer()
|
return await runServer()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
|
import { Ajv } from 'ajv'
|
||||||
import { z } from 'zod/v4'
|
import { z } from 'zod/v4'
|
||||||
import { buildTool, type ToolDef } from '../../Tool.js'
|
import { buildTool, type ToolDef, type ValidationResult } from '../../Tool.js'
|
||||||
import { lazySchema } from '../../utils/lazySchema.js'
|
import { lazySchema } from '../../utils/lazySchema.js'
|
||||||
import type { PermissionResult } from '../../utils/permissions/PermissionResult.js'
|
import type { PermissionResult } from '../../types/permissions.js'
|
||||||
import { isOutputLineTruncated } from '../../utils/terminal.js'
|
import { isOutputLineTruncated } from '../../utils/terminal.js'
|
||||||
import { DESCRIPTION, PROMPT } from './prompt.js'
|
import { DESCRIPTION, PROMPT } from './prompt.js'
|
||||||
import {
|
import {
|
||||||
@@ -37,6 +38,8 @@ export type Output = z.infer<OutputSchema>
|
|||||||
// Re-export MCPProgress from centralized types to break import cycles
|
// Re-export MCPProgress from centralized types to break import cycles
|
||||||
export type { MCPProgress } from '../../types/tools.js'
|
export type { MCPProgress } from '../../types/tools.js'
|
||||||
|
|
||||||
|
const ajv = new Ajv({ strict: false })
|
||||||
|
|
||||||
export const MCPTool = buildTool({
|
export const MCPTool = buildTool({
|
||||||
isMcp: true,
|
isMcp: true,
|
||||||
// Overridden in mcpClient.ts with the real MCP tool name + args
|
// Overridden in mcpClient.ts with the real MCP tool name + args
|
||||||
@@ -72,6 +75,27 @@ export const MCPTool = buildTool({
|
|||||||
message: 'MCPTool requires permission.',
|
message: 'MCPTool requires permission.',
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
async validateInput(input, context): Promise<ValidationResult> {
|
||||||
|
if (this.inputJSONSchema) {
|
||||||
|
try {
|
||||||
|
const validate = ajv.compile(this.inputJSONSchema)
|
||||||
|
if (!validate(input)) {
|
||||||
|
return {
|
||||||
|
result: false,
|
||||||
|
message: ajv.errorsText(validate.errors),
|
||||||
|
errorCode: 400,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
return {
|
||||||
|
result: false,
|
||||||
|
message: `Failed to compile JSON schema for validation: ${error}`,
|
||||||
|
errorCode: 500,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return { result: true }
|
||||||
|
},
|
||||||
renderToolUseMessage,
|
renderToolUseMessage,
|
||||||
// Overridden in mcpClient.ts
|
// Overridden in mcpClient.ts
|
||||||
userFacingName: () => 'mcp',
|
userFacingName: () => 'mcp',
|
||||||
@@ -100,3 +124,4 @@ export const MCPTool = buildTool({
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
} satisfies ToolDef<InputSchema, Output>)
|
} satisfies ToolDef<InputSchema, Output>)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user