diff --git a/src/services/mcp/auth.test.ts b/src/services/mcp/auth.test.ts new file mode 100644 index 00000000..d2888041 --- /dev/null +++ b/src/services/mcp/auth.test.ts @@ -0,0 +1,61 @@ +import assert from 'node:assert/strict' +import test from 'node:test' + +import { validateOAuthCallbackParams } from './auth.js' + +test('OAuth callback rejects error parameters before state validation can be bypassed', () => { + const result = validateOAuthCallbackParams( + { + error: 'access_denied', + error_description: 'denied by provider', + }, + 'expected-state', + ) + + assert.deepEqual(result, { type: 'state_mismatch' }) +}) + +test('OAuth callback accepts provider errors only when state matches', () => { + const result = validateOAuthCallbackParams( + { + state: 'expected-state', + error: 'access_denied', + error_description: 'denied by provider', + error_uri: 'https://example.test/error', + }, + 'expected-state', + ) + + assert.deepEqual(result, { + type: 'error', + error: 'access_denied', + errorDescription: 'denied by provider', + errorUri: 'https://example.test/error', + message: + 'OAuth error: access_denied - denied by provider (See: https://example.test/error)', + }) +}) + +test('OAuth callback accepts authorization codes only when state matches', () => { + assert.deepEqual( + validateOAuthCallbackParams( + { + state: 'expected-state', + code: 'auth-code', + }, + 'expected-state', + ), + { type: 'code', code: 'auth-code' }, + ) + + assert.deepEqual( + validateOAuthCallbackParams( + { + state: 'wrong-state', + code: 'auth-code', + }, + 'expected-state', + ), + { type: 'state_mismatch' }, + ) +}) diff --git a/src/services/mcp/auth.ts b/src/services/mcp/auth.ts index 939454cc..99dfe764 100644 --- a/src/services/mcp/auth.ts +++ b/src/services/mcp/auth.ts @@ -124,6 +124,74 @@ function redactSensitiveUrlParams(url: string): string { } } +type OAuthCallbackParamValue = string | string[] | null | undefined + +type OAuthCallbackValidationResult = + | { type: 'code'; code: string } + | { + type: 'error' + error: string + errorDescription: string + errorUri: string + message: string + } + | { type: 'missing_result' } + | { type: 'state_mismatch' } + +function getFirstOAuthCallbackParam( + value: OAuthCallbackParamValue, +): string | undefined { + if (Array.isArray(value)) { + return value.find(item => item.length > 0) + } + return value && value.length > 0 ? value : undefined +} + +export function validateOAuthCallbackParams( + params: { + code?: OAuthCallbackParamValue + state?: OAuthCallbackParamValue + error?: OAuthCallbackParamValue + error_description?: OAuthCallbackParamValue + error_uri?: OAuthCallbackParamValue + }, + oauthState: string, +): OAuthCallbackValidationResult { + const code = getFirstOAuthCallbackParam(params.code) + const state = getFirstOAuthCallbackParam(params.state) + const error = getFirstOAuthCallbackParam(params.error) + const errorDescription = + getFirstOAuthCallbackParam(params.error_description) ?? '' + const errorUri = getFirstOAuthCallbackParam(params.error_uri) ?? '' + + if (state !== oauthState) { + return { type: 'state_mismatch' } + } + + if (error) { + let message = `OAuth error: ${error}` + if (errorDescription) { + message += ` - ${errorDescription}` + } + if (errorUri) { + message += ` (See: ${errorUri})` + } + return { + type: 'error', + error, + errorDescription, + errorUri, + message, + } + } + + if (code) { + return { type: 'code', code } + } + + return { type: 'missing_result' } +} + /** * Some OAuth servers (notably Slack) return HTTP 200 for all responses, * signaling errors via the JSON body instead. The SDK's executeTokenRequest @@ -1058,30 +1126,31 @@ export async function performMCPOAuthFlow( options.onWaitingForCallback((callbackUrl: string) => { try { const parsed = new URL(callbackUrl) - const code = parsed.searchParams.get('code') - const state = parsed.searchParams.get('state') - const error = parsed.searchParams.get('error') + const result = validateOAuthCallbackParams( + { + code: parsed.searchParams.get('code'), + state: parsed.searchParams.get('state'), + error: parsed.searchParams.get('error'), + error_description: + parsed.searchParams.get('error_description'), + error_uri: parsed.searchParams.get('error_uri'), + }, + oauthState, + ) - if (error) { - const errorDescription = - parsed.searchParams.get('error_description') || '' - cleanup() - rejectOnce( - new Error(`OAuth error: ${error} - ${errorDescription}`), - ) + if (result.type === 'state_mismatch') { + // Ignore so a stray or malicious URL cannot cancel an active flow. return } - if (!code) { - // Not a valid callback URL, ignore so the user can try again + if (result.type === 'missing_result') { + // Not a valid callback URL, ignore so the user can try again. return } - if (state !== oauthState) { + if (result.type === 'error') { cleanup() - rejectOnce( - new Error('OAuth state mismatch - possible CSRF attack'), - ) + rejectOnce(new Error(result.message)) return } @@ -1090,7 +1159,7 @@ export async function performMCPOAuthFlow( `Received auth code via manual callback URL`, ) cleanup() - resolveOnce(code) + resolveOnce(result.code) } catch { // Invalid URL, ignore so the user can try again } @@ -1101,53 +1170,49 @@ export async function performMCPOAuthFlow( const parsedUrl = parse(req.url || '', true) if (parsedUrl.pathname === '/callback') { - const code = parsedUrl.query.code as string - const state = parsedUrl.query.state as string - const error = parsedUrl.query.error - const errorDescription = parsedUrl.query.error_description as string - const errorUri = parsedUrl.query.error_uri as string + const result = validateOAuthCallbackParams( + parsedUrl.query, + oauthState, + ) // Validate OAuth state to prevent CSRF attacks - if (!error && state !== oauthState) { + if (result.type === 'state_mismatch') { res.writeHead(400, { 'Content-Type': 'text/html' }) res.end( `

Authentication Error

Invalid state parameter. Please try again.

You can close this window.

`, ) - cleanup() - rejectOnce(new Error('OAuth state mismatch - possible CSRF attack')) return } - if (error) { + if (result.type === 'missing_result') { + res.writeHead(400, { 'Content-Type': 'text/html' }) + res.end( + `

Authentication Error

Missing OAuth result. Please try again.

You can close this window.

`, + ) + return + } + + if (result.type === 'error') { res.writeHead(200, { 'Content-Type': 'text/html' }) // Sanitize error messages to prevent XSS - const sanitizedError = xss(String(error)) - const sanitizedErrorDescription = errorDescription - ? xss(String(errorDescription)) + const sanitizedError = xss(result.error) + const sanitizedErrorDescription = result.errorDescription + ? xss(result.errorDescription) : '' res.end( `

Authentication Error

${sanitizedError}: ${sanitizedErrorDescription}

You can close this window.

`, ) cleanup() - let errorMessage = `OAuth error: ${error}` - if (errorDescription) { - errorMessage += ` - ${errorDescription}` - } - if (errorUri) { - errorMessage += ` (See: ${errorUri})` - } - rejectOnce(new Error(errorMessage)) + rejectOnce(new Error(result.message)) return } - if (code) { - res.writeHead(200, { 'Content-Type': 'text/html' }) - res.end( - `

Authentication Successful

You can close this window. Return to Claude Code.

`, - ) - cleanup() - resolveOnce(code) - } + res.writeHead(200, { 'Content-Type': 'text/html' }) + res.end( + `

Authentication Successful

You can close this window. Return to Claude Code.

`, + ) + cleanup() + resolveOnce(result.code) } })