diff --git a/src/services/mcp/client.test.ts b/src/services/mcp/client.test.ts new file mode 100644 index 00000000..6f69ee7b --- /dev/null +++ b/src/services/mcp/client.test.ts @@ -0,0 +1,48 @@ +import assert from 'node:assert/strict' +import test from 'node:test' + +import { cleanupFailedConnection } from './client.js' + +test('cleanupFailedConnection awaits transport close before resolving', async () => { + let closed = false + let resolveClose: (() => void) | undefined + + const transport = { + close: async () => + await new Promise(resolve => { + resolveClose = () => { + closed = true + resolve() + } + }), + } + + const cleanupPromise = cleanupFailedConnection(transport) + + assert.equal(closed, false) + resolveClose?.() + await cleanupPromise + assert.equal(closed, true) +}) + +test('cleanupFailedConnection closes in-process server and transport', async () => { + let inProcessClosed = false + let transportClosed = false + + const inProcessServer = { + close: async () => { + inProcessClosed = true + }, + } + + const transport = { + close: async () => { + transportClosed = true + }, + } + + await cleanupFailedConnection(transport, inProcessServer) + + assert.equal(inProcessClosed, true) + assert.equal(transportClosed, true) +}) diff --git a/src/services/mcp/client.ts b/src/services/mcp/client.ts index b053dbb6..8857b56c 100644 --- a/src/services/mcp/client.ts +++ b/src/services/mcp/client.ts @@ -560,6 +560,22 @@ function getRemoteMcpServerConnectionBatchSize(): number { ) } +type InProcessMcpServer = { + connect(t: Transport): Promise + close(): Promise +} + +export async function cleanupFailedConnection( + transport: Pick, + inProcessServer?: Pick, +): Promise { + if (inProcessServer) { + await inProcessServer.close().catch(() => {}) + } + + await transport.close().catch(() => {}) +} + function isLocalMcpServer(config: ScopedMcpServerConfig): boolean { return !config.type || config.type === 'stdio' || config.type === 'sdk' } @@ -606,9 +622,7 @@ export const connectToServer = memoize( }, ): Promise => { const connectStartTime = Date.now() - let inProcessServer: - | { connect(t: Transport): Promise; close(): Promise } - | undefined + let inProcessServer: InProcessMcpServer | undefined try { let transport @@ -1145,9 +1159,10 @@ export const connectToServer = memoize( }) } if (inProcessServer) { - inProcessServer.close().catch(() => { }) + await cleanupFailedConnection(transport, inProcessServer) + } else { + await cleanupFailedConnection(transport) } - transport.close().catch(() => { }) if (stderrOutput) { logMCPError(name, `Server stderr: ${stderrOutput}`) }