fix: enforce MCP OAuth callback state before errors (#775)
This commit is contained in:
61
src/services/mcp/auth.test.ts
Normal file
61
src/services/mcp/auth.test.ts
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
import assert from 'node:assert/strict'
|
||||||
|
import test from 'node:test'
|
||||||
|
|
||||||
|
import { validateOAuthCallbackParams } from './auth.js'
|
||||||
|
|
||||||
|
test('OAuth callback rejects error parameters before state validation can be bypassed', () => {
|
||||||
|
const result = validateOAuthCallbackParams(
|
||||||
|
{
|
||||||
|
error: 'access_denied',
|
||||||
|
error_description: 'denied by provider',
|
||||||
|
},
|
||||||
|
'expected-state',
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.deepEqual(result, { type: 'state_mismatch' })
|
||||||
|
})
|
||||||
|
|
||||||
|
test('OAuth callback accepts provider errors only when state matches', () => {
|
||||||
|
const result = validateOAuthCallbackParams(
|
||||||
|
{
|
||||||
|
state: 'expected-state',
|
||||||
|
error: 'access_denied',
|
||||||
|
error_description: 'denied by provider',
|
||||||
|
error_uri: 'https://example.test/error',
|
||||||
|
},
|
||||||
|
'expected-state',
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.deepEqual(result, {
|
||||||
|
type: 'error',
|
||||||
|
error: 'access_denied',
|
||||||
|
errorDescription: 'denied by provider',
|
||||||
|
errorUri: 'https://example.test/error',
|
||||||
|
message:
|
||||||
|
'OAuth error: access_denied - denied by provider (See: https://example.test/error)',
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
test('OAuth callback accepts authorization codes only when state matches', () => {
|
||||||
|
assert.deepEqual(
|
||||||
|
validateOAuthCallbackParams(
|
||||||
|
{
|
||||||
|
state: 'expected-state',
|
||||||
|
code: 'auth-code',
|
||||||
|
},
|
||||||
|
'expected-state',
|
||||||
|
),
|
||||||
|
{ type: 'code', code: 'auth-code' },
|
||||||
|
)
|
||||||
|
|
||||||
|
assert.deepEqual(
|
||||||
|
validateOAuthCallbackParams(
|
||||||
|
{
|
||||||
|
state: 'wrong-state',
|
||||||
|
code: 'auth-code',
|
||||||
|
},
|
||||||
|
'expected-state',
|
||||||
|
),
|
||||||
|
{ type: 'state_mismatch' },
|
||||||
|
)
|
||||||
|
})
|
||||||
@@ -124,6 +124,74 @@ function redactSensitiveUrlParams(url: string): string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OAuthCallbackParamValue = string | string[] | null | undefined
|
||||||
|
|
||||||
|
type OAuthCallbackValidationResult =
|
||||||
|
| { type: 'code'; code: string }
|
||||||
|
| {
|
||||||
|
type: 'error'
|
||||||
|
error: string
|
||||||
|
errorDescription: string
|
||||||
|
errorUri: string
|
||||||
|
message: string
|
||||||
|
}
|
||||||
|
| { type: 'missing_result' }
|
||||||
|
| { type: 'state_mismatch' }
|
||||||
|
|
||||||
|
function getFirstOAuthCallbackParam(
|
||||||
|
value: OAuthCallbackParamValue,
|
||||||
|
): string | undefined {
|
||||||
|
if (Array.isArray(value)) {
|
||||||
|
return value.find(item => item.length > 0)
|
||||||
|
}
|
||||||
|
return value && value.length > 0 ? value : undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
export function validateOAuthCallbackParams(
|
||||||
|
params: {
|
||||||
|
code?: OAuthCallbackParamValue
|
||||||
|
state?: OAuthCallbackParamValue
|
||||||
|
error?: OAuthCallbackParamValue
|
||||||
|
error_description?: OAuthCallbackParamValue
|
||||||
|
error_uri?: OAuthCallbackParamValue
|
||||||
|
},
|
||||||
|
oauthState: string,
|
||||||
|
): OAuthCallbackValidationResult {
|
||||||
|
const code = getFirstOAuthCallbackParam(params.code)
|
||||||
|
const state = getFirstOAuthCallbackParam(params.state)
|
||||||
|
const error = getFirstOAuthCallbackParam(params.error)
|
||||||
|
const errorDescription =
|
||||||
|
getFirstOAuthCallbackParam(params.error_description) ?? ''
|
||||||
|
const errorUri = getFirstOAuthCallbackParam(params.error_uri) ?? ''
|
||||||
|
|
||||||
|
if (state !== oauthState) {
|
||||||
|
return { type: 'state_mismatch' }
|
||||||
|
}
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
let message = `OAuth error: ${error}`
|
||||||
|
if (errorDescription) {
|
||||||
|
message += ` - ${errorDescription}`
|
||||||
|
}
|
||||||
|
if (errorUri) {
|
||||||
|
message += ` (See: ${errorUri})`
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
type: 'error',
|
||||||
|
error,
|
||||||
|
errorDescription,
|
||||||
|
errorUri,
|
||||||
|
message,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (code) {
|
||||||
|
return { type: 'code', code }
|
||||||
|
}
|
||||||
|
|
||||||
|
return { type: 'missing_result' }
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Some OAuth servers (notably Slack) return HTTP 200 for all responses,
|
* Some OAuth servers (notably Slack) return HTTP 200 for all responses,
|
||||||
* signaling errors via the JSON body instead. The SDK's executeTokenRequest
|
* signaling errors via the JSON body instead. The SDK's executeTokenRequest
|
||||||
@@ -1058,30 +1126,31 @@ export async function performMCPOAuthFlow(
|
|||||||
options.onWaitingForCallback((callbackUrl: string) => {
|
options.onWaitingForCallback((callbackUrl: string) => {
|
||||||
try {
|
try {
|
||||||
const parsed = new URL(callbackUrl)
|
const parsed = new URL(callbackUrl)
|
||||||
const code = parsed.searchParams.get('code')
|
const result = validateOAuthCallbackParams(
|
||||||
const state = parsed.searchParams.get('state')
|
{
|
||||||
const error = parsed.searchParams.get('error')
|
code: parsed.searchParams.get('code'),
|
||||||
|
state: parsed.searchParams.get('state'),
|
||||||
if (error) {
|
error: parsed.searchParams.get('error'),
|
||||||
const errorDescription =
|
error_description:
|
||||||
parsed.searchParams.get('error_description') || ''
|
parsed.searchParams.get('error_description'),
|
||||||
cleanup()
|
error_uri: parsed.searchParams.get('error_uri'),
|
||||||
rejectOnce(
|
},
|
||||||
new Error(`OAuth error: ${error} - ${errorDescription}`),
|
oauthState,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (result.type === 'state_mismatch') {
|
||||||
|
// Ignore so a stray or malicious URL cannot cancel an active flow.
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!code) {
|
if (result.type === 'missing_result') {
|
||||||
// Not a valid callback URL, ignore so the user can try again
|
// Not a valid callback URL, ignore so the user can try again.
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if (state !== oauthState) {
|
if (result.type === 'error') {
|
||||||
cleanup()
|
cleanup()
|
||||||
rejectOnce(
|
rejectOnce(new Error(result.message))
|
||||||
new Error('OAuth state mismatch - possible CSRF attack'),
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1090,7 +1159,7 @@ export async function performMCPOAuthFlow(
|
|||||||
`Received auth code via manual callback URL`,
|
`Received auth code via manual callback URL`,
|
||||||
)
|
)
|
||||||
cleanup()
|
cleanup()
|
||||||
resolveOnce(code)
|
resolveOnce(result.code)
|
||||||
} catch {
|
} catch {
|
||||||
// Invalid URL, ignore so the user can try again
|
// Invalid URL, ignore so the user can try again
|
||||||
}
|
}
|
||||||
@@ -1101,53 +1170,49 @@ export async function performMCPOAuthFlow(
|
|||||||
const parsedUrl = parse(req.url || '', true)
|
const parsedUrl = parse(req.url || '', true)
|
||||||
|
|
||||||
if (parsedUrl.pathname === '/callback') {
|
if (parsedUrl.pathname === '/callback') {
|
||||||
const code = parsedUrl.query.code as string
|
const result = validateOAuthCallbackParams(
|
||||||
const state = parsedUrl.query.state as string
|
parsedUrl.query,
|
||||||
const error = parsedUrl.query.error
|
oauthState,
|
||||||
const errorDescription = parsedUrl.query.error_description as string
|
)
|
||||||
const errorUri = parsedUrl.query.error_uri as string
|
|
||||||
|
|
||||||
// Validate OAuth state to prevent CSRF attacks
|
// Validate OAuth state to prevent CSRF attacks
|
||||||
if (!error && state !== oauthState) {
|
if (result.type === 'state_mismatch') {
|
||||||
res.writeHead(400, { 'Content-Type': 'text/html' })
|
res.writeHead(400, { 'Content-Type': 'text/html' })
|
||||||
res.end(
|
res.end(
|
||||||
`<h1>Authentication Error</h1><p>Invalid state parameter. Please try again.</p><p>You can close this window.</p>`,
|
`<h1>Authentication Error</h1><p>Invalid state parameter. Please try again.</p><p>You can close this window.</p>`,
|
||||||
)
|
)
|
||||||
cleanup()
|
|
||||||
rejectOnce(new Error('OAuth state mismatch - possible CSRF attack'))
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if (error) {
|
if (result.type === 'missing_result') {
|
||||||
|
res.writeHead(400, { 'Content-Type': 'text/html' })
|
||||||
|
res.end(
|
||||||
|
`<h1>Authentication Error</h1><p>Missing OAuth result. Please try again.</p><p>You can close this window.</p>`,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if (result.type === 'error') {
|
||||||
res.writeHead(200, { 'Content-Type': 'text/html' })
|
res.writeHead(200, { 'Content-Type': 'text/html' })
|
||||||
// Sanitize error messages to prevent XSS
|
// Sanitize error messages to prevent XSS
|
||||||
const sanitizedError = xss(String(error))
|
const sanitizedError = xss(result.error)
|
||||||
const sanitizedErrorDescription = errorDescription
|
const sanitizedErrorDescription = result.errorDescription
|
||||||
? xss(String(errorDescription))
|
? xss(result.errorDescription)
|
||||||
: ''
|
: ''
|
||||||
res.end(
|
res.end(
|
||||||
`<h1>Authentication Error</h1><p>${sanitizedError}: ${sanitizedErrorDescription}</p><p>You can close this window.</p>`,
|
`<h1>Authentication Error</h1><p>${sanitizedError}: ${sanitizedErrorDescription}</p><p>You can close this window.</p>`,
|
||||||
)
|
)
|
||||||
cleanup()
|
cleanup()
|
||||||
let errorMessage = `OAuth error: ${error}`
|
rejectOnce(new Error(result.message))
|
||||||
if (errorDescription) {
|
|
||||||
errorMessage += ` - ${errorDescription}`
|
|
||||||
}
|
|
||||||
if (errorUri) {
|
|
||||||
errorMessage += ` (See: ${errorUri})`
|
|
||||||
}
|
|
||||||
rejectOnce(new Error(errorMessage))
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if (code) {
|
|
||||||
res.writeHead(200, { 'Content-Type': 'text/html' })
|
res.writeHead(200, { 'Content-Type': 'text/html' })
|
||||||
res.end(
|
res.end(
|
||||||
`<h1>Authentication Successful</h1><p>You can close this window. Return to Claude Code.</p>`,
|
`<h1>Authentication Successful</h1><p>You can close this window. Return to Claude Code.</p>`,
|
||||||
)
|
)
|
||||||
cleanup()
|
cleanup()
|
||||||
resolveOnce(code)
|
resolveOnce(result.code)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user