fix: enforce MCP OAuth callback state before errors (#775)

This commit is contained in:
Kevin Codex
2026-04-20 09:36:05 +08:00
committed by GitHub
parent f166ec1a4e
commit 739b8d1f40
2 changed files with 171 additions and 45 deletions

View File

@@ -0,0 +1,61 @@
import assert from 'node:assert/strict'
import test from 'node:test'
import { validateOAuthCallbackParams } from './auth.js'
test('OAuth callback rejects error parameters before state validation can be bypassed', () => {
const result = validateOAuthCallbackParams(
{
error: 'access_denied',
error_description: 'denied by provider',
},
'expected-state',
)
assert.deepEqual(result, { type: 'state_mismatch' })
})
test('OAuth callback accepts provider errors only when state matches', () => {
const result = validateOAuthCallbackParams(
{
state: 'expected-state',
error: 'access_denied',
error_description: 'denied by provider',
error_uri: 'https://example.test/error',
},
'expected-state',
)
assert.deepEqual(result, {
type: 'error',
error: 'access_denied',
errorDescription: 'denied by provider',
errorUri: 'https://example.test/error',
message:
'OAuth error: access_denied - denied by provider (See: https://example.test/error)',
})
})
test('OAuth callback accepts authorization codes only when state matches', () => {
assert.deepEqual(
validateOAuthCallbackParams(
{
state: 'expected-state',
code: 'auth-code',
},
'expected-state',
),
{ type: 'code', code: 'auth-code' },
)
assert.deepEqual(
validateOAuthCallbackParams(
{
state: 'wrong-state',
code: 'auth-code',
},
'expected-state',
),
{ type: 'state_mismatch' },
)
})

View File

@@ -124,6 +124,74 @@ function redactSensitiveUrlParams(url: string): string {
} }
} }
type OAuthCallbackParamValue = string | string[] | null | undefined
type OAuthCallbackValidationResult =
| { type: 'code'; code: string }
| {
type: 'error'
error: string
errorDescription: string
errorUri: string
message: string
}
| { type: 'missing_result' }
| { type: 'state_mismatch' }
function getFirstOAuthCallbackParam(
value: OAuthCallbackParamValue,
): string | undefined {
if (Array.isArray(value)) {
return value.find(item => item.length > 0)
}
return value && value.length > 0 ? value : undefined
}
export function validateOAuthCallbackParams(
params: {
code?: OAuthCallbackParamValue
state?: OAuthCallbackParamValue
error?: OAuthCallbackParamValue
error_description?: OAuthCallbackParamValue
error_uri?: OAuthCallbackParamValue
},
oauthState: string,
): OAuthCallbackValidationResult {
const code = getFirstOAuthCallbackParam(params.code)
const state = getFirstOAuthCallbackParam(params.state)
const error = getFirstOAuthCallbackParam(params.error)
const errorDescription =
getFirstOAuthCallbackParam(params.error_description) ?? ''
const errorUri = getFirstOAuthCallbackParam(params.error_uri) ?? ''
if (state !== oauthState) {
return { type: 'state_mismatch' }
}
if (error) {
let message = `OAuth error: ${error}`
if (errorDescription) {
message += ` - ${errorDescription}`
}
if (errorUri) {
message += ` (See: ${errorUri})`
}
return {
type: 'error',
error,
errorDescription,
errorUri,
message,
}
}
if (code) {
return { type: 'code', code }
}
return { type: 'missing_result' }
}
/** /**
* Some OAuth servers (notably Slack) return HTTP 200 for all responses, * Some OAuth servers (notably Slack) return HTTP 200 for all responses,
* signaling errors via the JSON body instead. The SDK's executeTokenRequest * signaling errors via the JSON body instead. The SDK's executeTokenRequest
@@ -1058,30 +1126,31 @@ export async function performMCPOAuthFlow(
options.onWaitingForCallback((callbackUrl: string) => { options.onWaitingForCallback((callbackUrl: string) => {
try { try {
const parsed = new URL(callbackUrl) const parsed = new URL(callbackUrl)
const code = parsed.searchParams.get('code') const result = validateOAuthCallbackParams(
const state = parsed.searchParams.get('state') {
const error = parsed.searchParams.get('error') code: parsed.searchParams.get('code'),
state: parsed.searchParams.get('state'),
error: parsed.searchParams.get('error'),
error_description:
parsed.searchParams.get('error_description'),
error_uri: parsed.searchParams.get('error_uri'),
},
oauthState,
)
if (error) { if (result.type === 'state_mismatch') {
const errorDescription = // Ignore so a stray or malicious URL cannot cancel an active flow.
parsed.searchParams.get('error_description') || ''
cleanup()
rejectOnce(
new Error(`OAuth error: ${error} - ${errorDescription}`),
)
return return
} }
if (!code) { if (result.type === 'missing_result') {
// Not a valid callback URL, ignore so the user can try again // Not a valid callback URL, ignore so the user can try again.
return return
} }
if (state !== oauthState) { if (result.type === 'error') {
cleanup() cleanup()
rejectOnce( rejectOnce(new Error(result.message))
new Error('OAuth state mismatch - possible CSRF attack'),
)
return return
} }
@@ -1090,7 +1159,7 @@ export async function performMCPOAuthFlow(
`Received auth code via manual callback URL`, `Received auth code via manual callback URL`,
) )
cleanup() cleanup()
resolveOnce(code) resolveOnce(result.code)
} catch { } catch {
// Invalid URL, ignore so the user can try again // Invalid URL, ignore so the user can try again
} }
@@ -1101,53 +1170,49 @@ export async function performMCPOAuthFlow(
const parsedUrl = parse(req.url || '', true) const parsedUrl = parse(req.url || '', true)
if (parsedUrl.pathname === '/callback') { if (parsedUrl.pathname === '/callback') {
const code = parsedUrl.query.code as string const result = validateOAuthCallbackParams(
const state = parsedUrl.query.state as string parsedUrl.query,
const error = parsedUrl.query.error oauthState,
const errorDescription = parsedUrl.query.error_description as string )
const errorUri = parsedUrl.query.error_uri as string
// Validate OAuth state to prevent CSRF attacks // Validate OAuth state to prevent CSRF attacks
if (!error && state !== oauthState) { if (result.type === 'state_mismatch') {
res.writeHead(400, { 'Content-Type': 'text/html' }) res.writeHead(400, { 'Content-Type': 'text/html' })
res.end( res.end(
`<h1>Authentication Error</h1><p>Invalid state parameter. Please try again.</p><p>You can close this window.</p>`, `<h1>Authentication Error</h1><p>Invalid state parameter. Please try again.</p><p>You can close this window.</p>`,
) )
cleanup()
rejectOnce(new Error('OAuth state mismatch - possible CSRF attack'))
return return
} }
if (error) { if (result.type === 'missing_result') {
res.writeHead(400, { 'Content-Type': 'text/html' })
res.end(
`<h1>Authentication Error</h1><p>Missing OAuth result. Please try again.</p><p>You can close this window.</p>`,
)
return
}
if (result.type === 'error') {
res.writeHead(200, { 'Content-Type': 'text/html' }) res.writeHead(200, { 'Content-Type': 'text/html' })
// Sanitize error messages to prevent XSS // Sanitize error messages to prevent XSS
const sanitizedError = xss(String(error)) const sanitizedError = xss(result.error)
const sanitizedErrorDescription = errorDescription const sanitizedErrorDescription = result.errorDescription
? xss(String(errorDescription)) ? xss(result.errorDescription)
: '' : ''
res.end( res.end(
`<h1>Authentication Error</h1><p>${sanitizedError}: ${sanitizedErrorDescription}</p><p>You can close this window.</p>`, `<h1>Authentication Error</h1><p>${sanitizedError}: ${sanitizedErrorDescription}</p><p>You can close this window.</p>`,
) )
cleanup() cleanup()
let errorMessage = `OAuth error: ${error}` rejectOnce(new Error(result.message))
if (errorDescription) {
errorMessage += ` - ${errorDescription}`
}
if (errorUri) {
errorMessage += ` (See: ${errorUri})`
}
rejectOnce(new Error(errorMessage))
return return
} }
if (code) { res.writeHead(200, { 'Content-Type': 'text/html' })
res.writeHead(200, { 'Content-Type': 'text/html' }) res.end(
res.end( `<h1>Authentication Successful</h1><p>You can close this window. Return to Claude Code.</p>`,
`<h1>Authentication Successful</h1><p>You can close this window. Return to Claude Code.</p>`, )
) cleanup()
cleanup() resolveOnce(result.code)
resolveOnce(code)
}
} }
}) })