fix: enforce MCP OAuth callback state before errors (#775)
This commit is contained in:
61
src/services/mcp/auth.test.ts
Normal file
61
src/services/mcp/auth.test.ts
Normal file
@@ -0,0 +1,61 @@
|
||||
import assert from 'node:assert/strict'
|
||||
import test from 'node:test'
|
||||
|
||||
import { validateOAuthCallbackParams } from './auth.js'
|
||||
|
||||
test('OAuth callback rejects error parameters before state validation can be bypassed', () => {
|
||||
const result = validateOAuthCallbackParams(
|
||||
{
|
||||
error: 'access_denied',
|
||||
error_description: 'denied by provider',
|
||||
},
|
||||
'expected-state',
|
||||
)
|
||||
|
||||
assert.deepEqual(result, { type: 'state_mismatch' })
|
||||
})
|
||||
|
||||
test('OAuth callback accepts provider errors only when state matches', () => {
|
||||
const result = validateOAuthCallbackParams(
|
||||
{
|
||||
state: 'expected-state',
|
||||
error: 'access_denied',
|
||||
error_description: 'denied by provider',
|
||||
error_uri: 'https://example.test/error',
|
||||
},
|
||||
'expected-state',
|
||||
)
|
||||
|
||||
assert.deepEqual(result, {
|
||||
type: 'error',
|
||||
error: 'access_denied',
|
||||
errorDescription: 'denied by provider',
|
||||
errorUri: 'https://example.test/error',
|
||||
message:
|
||||
'OAuth error: access_denied - denied by provider (See: https://example.test/error)',
|
||||
})
|
||||
})
|
||||
|
||||
test('OAuth callback accepts authorization codes only when state matches', () => {
|
||||
assert.deepEqual(
|
||||
validateOAuthCallbackParams(
|
||||
{
|
||||
state: 'expected-state',
|
||||
code: 'auth-code',
|
||||
},
|
||||
'expected-state',
|
||||
),
|
||||
{ type: 'code', code: 'auth-code' },
|
||||
)
|
||||
|
||||
assert.deepEqual(
|
||||
validateOAuthCallbackParams(
|
||||
{
|
||||
state: 'wrong-state',
|
||||
code: 'auth-code',
|
||||
},
|
||||
'expected-state',
|
||||
),
|
||||
{ type: 'state_mismatch' },
|
||||
)
|
||||
})
|
||||
@@ -124,6 +124,74 @@ function redactSensitiveUrlParams(url: string): string {
|
||||
}
|
||||
}
|
||||
|
||||
type OAuthCallbackParamValue = string | string[] | null | undefined
|
||||
|
||||
type OAuthCallbackValidationResult =
|
||||
| { type: 'code'; code: string }
|
||||
| {
|
||||
type: 'error'
|
||||
error: string
|
||||
errorDescription: string
|
||||
errorUri: string
|
||||
message: string
|
||||
}
|
||||
| { type: 'missing_result' }
|
||||
| { type: 'state_mismatch' }
|
||||
|
||||
function getFirstOAuthCallbackParam(
|
||||
value: OAuthCallbackParamValue,
|
||||
): string | undefined {
|
||||
if (Array.isArray(value)) {
|
||||
return value.find(item => item.length > 0)
|
||||
}
|
||||
return value && value.length > 0 ? value : undefined
|
||||
}
|
||||
|
||||
export function validateOAuthCallbackParams(
|
||||
params: {
|
||||
code?: OAuthCallbackParamValue
|
||||
state?: OAuthCallbackParamValue
|
||||
error?: OAuthCallbackParamValue
|
||||
error_description?: OAuthCallbackParamValue
|
||||
error_uri?: OAuthCallbackParamValue
|
||||
},
|
||||
oauthState: string,
|
||||
): OAuthCallbackValidationResult {
|
||||
const code = getFirstOAuthCallbackParam(params.code)
|
||||
const state = getFirstOAuthCallbackParam(params.state)
|
||||
const error = getFirstOAuthCallbackParam(params.error)
|
||||
const errorDescription =
|
||||
getFirstOAuthCallbackParam(params.error_description) ?? ''
|
||||
const errorUri = getFirstOAuthCallbackParam(params.error_uri) ?? ''
|
||||
|
||||
if (state !== oauthState) {
|
||||
return { type: 'state_mismatch' }
|
||||
}
|
||||
|
||||
if (error) {
|
||||
let message = `OAuth error: ${error}`
|
||||
if (errorDescription) {
|
||||
message += ` - ${errorDescription}`
|
||||
}
|
||||
if (errorUri) {
|
||||
message += ` (See: ${errorUri})`
|
||||
}
|
||||
return {
|
||||
type: 'error',
|
||||
error,
|
||||
errorDescription,
|
||||
errorUri,
|
||||
message,
|
||||
}
|
||||
}
|
||||
|
||||
if (code) {
|
||||
return { type: 'code', code }
|
||||
}
|
||||
|
||||
return { type: 'missing_result' }
|
||||
}
|
||||
|
||||
/**
|
||||
* Some OAuth servers (notably Slack) return HTTP 200 for all responses,
|
||||
* signaling errors via the JSON body instead. The SDK's executeTokenRequest
|
||||
@@ -1058,30 +1126,31 @@ export async function performMCPOAuthFlow(
|
||||
options.onWaitingForCallback((callbackUrl: string) => {
|
||||
try {
|
||||
const parsed = new URL(callbackUrl)
|
||||
const code = parsed.searchParams.get('code')
|
||||
const state = parsed.searchParams.get('state')
|
||||
const error = parsed.searchParams.get('error')
|
||||
|
||||
if (error) {
|
||||
const errorDescription =
|
||||
parsed.searchParams.get('error_description') || ''
|
||||
cleanup()
|
||||
rejectOnce(
|
||||
new Error(`OAuth error: ${error} - ${errorDescription}`),
|
||||
const result = validateOAuthCallbackParams(
|
||||
{
|
||||
code: parsed.searchParams.get('code'),
|
||||
state: parsed.searchParams.get('state'),
|
||||
error: parsed.searchParams.get('error'),
|
||||
error_description:
|
||||
parsed.searchParams.get('error_description'),
|
||||
error_uri: parsed.searchParams.get('error_uri'),
|
||||
},
|
||||
oauthState,
|
||||
)
|
||||
|
||||
if (result.type === 'state_mismatch') {
|
||||
// Ignore so a stray or malicious URL cannot cancel an active flow.
|
||||
return
|
||||
}
|
||||
|
||||
if (!code) {
|
||||
// Not a valid callback URL, ignore so the user can try again
|
||||
if (result.type === 'missing_result') {
|
||||
// Not a valid callback URL, ignore so the user can try again.
|
||||
return
|
||||
}
|
||||
|
||||
if (state !== oauthState) {
|
||||
if (result.type === 'error') {
|
||||
cleanup()
|
||||
rejectOnce(
|
||||
new Error('OAuth state mismatch - possible CSRF attack'),
|
||||
)
|
||||
rejectOnce(new Error(result.message))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1090,7 +1159,7 @@ export async function performMCPOAuthFlow(
|
||||
`Received auth code via manual callback URL`,
|
||||
)
|
||||
cleanup()
|
||||
resolveOnce(code)
|
||||
resolveOnce(result.code)
|
||||
} catch {
|
||||
// Invalid URL, ignore so the user can try again
|
||||
}
|
||||
@@ -1101,53 +1170,49 @@ export async function performMCPOAuthFlow(
|
||||
const parsedUrl = parse(req.url || '', true)
|
||||
|
||||
if (parsedUrl.pathname === '/callback') {
|
||||
const code = parsedUrl.query.code as string
|
||||
const state = parsedUrl.query.state as string
|
||||
const error = parsedUrl.query.error
|
||||
const errorDescription = parsedUrl.query.error_description as string
|
||||
const errorUri = parsedUrl.query.error_uri as string
|
||||
const result = validateOAuthCallbackParams(
|
||||
parsedUrl.query,
|
||||
oauthState,
|
||||
)
|
||||
|
||||
// Validate OAuth state to prevent CSRF attacks
|
||||
if (!error && state !== oauthState) {
|
||||
if (result.type === 'state_mismatch') {
|
||||
res.writeHead(400, { 'Content-Type': 'text/html' })
|
||||
res.end(
|
||||
`<h1>Authentication Error</h1><p>Invalid state parameter. Please try again.</p><p>You can close this window.</p>`,
|
||||
)
|
||||
cleanup()
|
||||
rejectOnce(new Error('OAuth state mismatch - possible CSRF attack'))
|
||||
return
|
||||
}
|
||||
|
||||
if (error) {
|
||||
if (result.type === 'missing_result') {
|
||||
res.writeHead(400, { 'Content-Type': 'text/html' })
|
||||
res.end(
|
||||
`<h1>Authentication Error</h1><p>Missing OAuth result. Please try again.</p><p>You can close this window.</p>`,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if (result.type === 'error') {
|
||||
res.writeHead(200, { 'Content-Type': 'text/html' })
|
||||
// Sanitize error messages to prevent XSS
|
||||
const sanitizedError = xss(String(error))
|
||||
const sanitizedErrorDescription = errorDescription
|
||||
? xss(String(errorDescription))
|
||||
const sanitizedError = xss(result.error)
|
||||
const sanitizedErrorDescription = result.errorDescription
|
||||
? xss(result.errorDescription)
|
||||
: ''
|
||||
res.end(
|
||||
`<h1>Authentication Error</h1><p>${sanitizedError}: ${sanitizedErrorDescription}</p><p>You can close this window.</p>`,
|
||||
)
|
||||
cleanup()
|
||||
let errorMessage = `OAuth error: ${error}`
|
||||
if (errorDescription) {
|
||||
errorMessage += ` - ${errorDescription}`
|
||||
}
|
||||
if (errorUri) {
|
||||
errorMessage += ` (See: ${errorUri})`
|
||||
}
|
||||
rejectOnce(new Error(errorMessage))
|
||||
rejectOnce(new Error(result.message))
|
||||
return
|
||||
}
|
||||
|
||||
if (code) {
|
||||
res.writeHead(200, { 'Content-Type': 'text/html' })
|
||||
res.end(
|
||||
`<h1>Authentication Successful</h1><p>You can close this window. Return to Claude Code.</p>`,
|
||||
)
|
||||
cleanup()
|
||||
resolveOnce(code)
|
||||
}
|
||||
resolveOnce(result.code)
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user