fix: enforce MCP OAuth callback state before errors (#775)

This commit is contained in:
Kevin Codex
2026-04-20 09:36:05 +08:00
committed by GitHub
parent f166ec1a4e
commit 739b8d1f40
2 changed files with 171 additions and 45 deletions

View File

@@ -0,0 +1,61 @@
import assert from 'node:assert/strict'
import test from 'node:test'
import { validateOAuthCallbackParams } from './auth.js'
test('OAuth callback rejects error parameters before state validation can be bypassed', () => {
const result = validateOAuthCallbackParams(
{
error: 'access_denied',
error_description: 'denied by provider',
},
'expected-state',
)
assert.deepEqual(result, { type: 'state_mismatch' })
})
test('OAuth callback accepts provider errors only when state matches', () => {
const result = validateOAuthCallbackParams(
{
state: 'expected-state',
error: 'access_denied',
error_description: 'denied by provider',
error_uri: 'https://example.test/error',
},
'expected-state',
)
assert.deepEqual(result, {
type: 'error',
error: 'access_denied',
errorDescription: 'denied by provider',
errorUri: 'https://example.test/error',
message:
'OAuth error: access_denied - denied by provider (See: https://example.test/error)',
})
})
test('OAuth callback accepts authorization codes only when state matches', () => {
assert.deepEqual(
validateOAuthCallbackParams(
{
state: 'expected-state',
code: 'auth-code',
},
'expected-state',
),
{ type: 'code', code: 'auth-code' },
)
assert.deepEqual(
validateOAuthCallbackParams(
{
state: 'wrong-state',
code: 'auth-code',
},
'expected-state',
),
{ type: 'state_mismatch' },
)
})

View File

@@ -124,6 +124,74 @@ function redactSensitiveUrlParams(url: string): string {
}
}
type OAuthCallbackParamValue = string | string[] | null | undefined
type OAuthCallbackValidationResult =
| { type: 'code'; code: string }
| {
type: 'error'
error: string
errorDescription: string
errorUri: string
message: string
}
| { type: 'missing_result' }
| { type: 'state_mismatch' }
function getFirstOAuthCallbackParam(
value: OAuthCallbackParamValue,
): string | undefined {
if (Array.isArray(value)) {
return value.find(item => item.length > 0)
}
return value && value.length > 0 ? value : undefined
}
export function validateOAuthCallbackParams(
params: {
code?: OAuthCallbackParamValue
state?: OAuthCallbackParamValue
error?: OAuthCallbackParamValue
error_description?: OAuthCallbackParamValue
error_uri?: OAuthCallbackParamValue
},
oauthState: string,
): OAuthCallbackValidationResult {
const code = getFirstOAuthCallbackParam(params.code)
const state = getFirstOAuthCallbackParam(params.state)
const error = getFirstOAuthCallbackParam(params.error)
const errorDescription =
getFirstOAuthCallbackParam(params.error_description) ?? ''
const errorUri = getFirstOAuthCallbackParam(params.error_uri) ?? ''
if (state !== oauthState) {
return { type: 'state_mismatch' }
}
if (error) {
let message = `OAuth error: ${error}`
if (errorDescription) {
message += ` - ${errorDescription}`
}
if (errorUri) {
message += ` (See: ${errorUri})`
}
return {
type: 'error',
error,
errorDescription,
errorUri,
message,
}
}
if (code) {
return { type: 'code', code }
}
return { type: 'missing_result' }
}
/**
* Some OAuth servers (notably Slack) return HTTP 200 for all responses,
* signaling errors via the JSON body instead. The SDK's executeTokenRequest
@@ -1058,30 +1126,31 @@ export async function performMCPOAuthFlow(
options.onWaitingForCallback((callbackUrl: string) => {
try {
const parsed = new URL(callbackUrl)
const code = parsed.searchParams.get('code')
const state = parsed.searchParams.get('state')
const error = parsed.searchParams.get('error')
if (error) {
const errorDescription =
parsed.searchParams.get('error_description') || ''
cleanup()
rejectOnce(
new Error(`OAuth error: ${error} - ${errorDescription}`),
const result = validateOAuthCallbackParams(
{
code: parsed.searchParams.get('code'),
state: parsed.searchParams.get('state'),
error: parsed.searchParams.get('error'),
error_description:
parsed.searchParams.get('error_description'),
error_uri: parsed.searchParams.get('error_uri'),
},
oauthState,
)
if (result.type === 'state_mismatch') {
// Ignore so a stray or malicious URL cannot cancel an active flow.
return
}
if (!code) {
// Not a valid callback URL, ignore so the user can try again
if (result.type === 'missing_result') {
// Not a valid callback URL, ignore so the user can try again.
return
}
if (state !== oauthState) {
if (result.type === 'error') {
cleanup()
rejectOnce(
new Error('OAuth state mismatch - possible CSRF attack'),
)
rejectOnce(new Error(result.message))
return
}
@@ -1090,7 +1159,7 @@ export async function performMCPOAuthFlow(
`Received auth code via manual callback URL`,
)
cleanup()
resolveOnce(code)
resolveOnce(result.code)
} catch {
// Invalid URL, ignore so the user can try again
}
@@ -1101,53 +1170,49 @@ export async function performMCPOAuthFlow(
const parsedUrl = parse(req.url || '', true)
if (parsedUrl.pathname === '/callback') {
const code = parsedUrl.query.code as string
const state = parsedUrl.query.state as string
const error = parsedUrl.query.error
const errorDescription = parsedUrl.query.error_description as string
const errorUri = parsedUrl.query.error_uri as string
const result = validateOAuthCallbackParams(
parsedUrl.query,
oauthState,
)
// Validate OAuth state to prevent CSRF attacks
if (!error && state !== oauthState) {
if (result.type === 'state_mismatch') {
res.writeHead(400, { 'Content-Type': 'text/html' })
res.end(
`<h1>Authentication Error</h1><p>Invalid state parameter. Please try again.</p><p>You can close this window.</p>`,
)
cleanup()
rejectOnce(new Error('OAuth state mismatch - possible CSRF attack'))
return
}
if (error) {
if (result.type === 'missing_result') {
res.writeHead(400, { 'Content-Type': 'text/html' })
res.end(
`<h1>Authentication Error</h1><p>Missing OAuth result. Please try again.</p><p>You can close this window.</p>`,
)
return
}
if (result.type === 'error') {
res.writeHead(200, { 'Content-Type': 'text/html' })
// Sanitize error messages to prevent XSS
const sanitizedError = xss(String(error))
const sanitizedErrorDescription = errorDescription
? xss(String(errorDescription))
const sanitizedError = xss(result.error)
const sanitizedErrorDescription = result.errorDescription
? xss(result.errorDescription)
: ''
res.end(
`<h1>Authentication Error</h1><p>${sanitizedError}: ${sanitizedErrorDescription}</p><p>You can close this window.</p>`,
)
cleanup()
let errorMessage = `OAuth error: ${error}`
if (errorDescription) {
errorMessage += ` - ${errorDescription}`
}
if (errorUri) {
errorMessage += ` (See: ${errorUri})`
}
rejectOnce(new Error(errorMessage))
rejectOnce(new Error(result.message))
return
}
if (code) {
res.writeHead(200, { 'Content-Type': 'text/html' })
res.end(
`<h1>Authentication Successful</h1><p>You can close this window. Return to Claude Code.</p>`,
)
cleanup()
resolveOnce(code)
}
resolveOnce(result.code)
}
})