tests: avoid global fetch mutation in GitHub device flow tests (#702)

This commit is contained in:
dhenuh
2026-04-15 04:38:46 -07:00
committed by GitHub
parent 7187fc007a
commit 114f772a4a

View File

@@ -1,4 +1,4 @@
import { afterEach, beforeEach, describe, expect, mock, test } from 'bun:test' import { afterEach, describe, expect, mock, test } from 'bun:test'
import { import {
DEFAULT_GITHUB_DEVICE_SCOPE, DEFAULT_GITHUB_DEVICE_SCOPE,
@@ -12,22 +12,15 @@ async function importFreshModule() {
return import(`./deviceFlow.ts?ts=${Date.now()}-${Math.random()}`) return import(`./deviceFlow.ts?ts=${Date.now()}-${Math.random()}`)
} }
afterEach(() => {
mock.restore()
})
describe('requestDeviceCode', () => { describe('requestDeviceCode', () => {
const originalFetch = globalThis.fetch
beforeEach(() => {
mock.restore()
globalThis.fetch = originalFetch
})
afterEach(() => {
globalThis.fetch = originalFetch
})
test('parses successful device code response', async () => { test('parses successful device code response', async () => {
const { requestDeviceCode } = await importFreshModule() const { requestDeviceCode } = await importFreshModule()
globalThis.fetch = mock(() => const fetchImpl = mock(() =>
Promise.resolve( Promise.resolve(
new Response( new Response(
JSON.stringify({ JSON.stringify({
@@ -44,7 +37,7 @@ describe('requestDeviceCode', () => {
const r = await requestDeviceCode({ const r = await requestDeviceCode({
clientId: 'test-client', clientId: 'test-client',
fetchImpl: globalThis.fetch, fetchImpl,
}) })
expect(r.device_code).toBe('abc') expect(r.device_code).toBe('abc')
expect(r.user_code).toBe('ABCD-1234') expect(r.user_code).toBe('ABCD-1234')
@@ -57,17 +50,17 @@ describe('requestDeviceCode', () => {
const { requestDeviceCode, GitHubDeviceFlowError } = const { requestDeviceCode, GitHubDeviceFlowError } =
await importFreshModule() await importFreshModule()
globalThis.fetch = mock(() => const fetchImpl = mock(() =>
Promise.resolve(new Response('bad', { status: 500 })), Promise.resolve(new Response('bad', { status: 500 })),
) )
await expect( await expect(
requestDeviceCode({ clientId: 'x', fetchImpl: globalThis.fetch }), requestDeviceCode({ clientId: 'x', fetchImpl }),
).rejects.toThrow(GitHubDeviceFlowError) ).rejects.toThrow(GitHubDeviceFlowError)
}) })
test('uses OAuth-safe default scope', async () => { test('uses OAuth-safe default scope', async () => {
let capturedScope = '' let capturedScope = ''
globalThis.fetch = mock((_url: RequestInfo | URL, init?: RequestInit) => { const fetchImpl = mock((_url: RequestInfo | URL, init?: RequestInit) => {
const body = init?.body const body = init?.body
if (body instanceof URLSearchParams) { if (body instanceof URLSearchParams) {
capturedScope = body.get('scope') ?? '' capturedScope = body.get('scope') ?? ''
@@ -87,7 +80,7 @@ describe('requestDeviceCode', () => {
) )
}) })
await requestDeviceCode({ clientId: 'test-client', fetchImpl: globalThis.fetch }) await requestDeviceCode({ clientId: 'test-client', fetchImpl })
expect(capturedScope).toBe(DEFAULT_GITHUB_DEVICE_SCOPE) expect(capturedScope).toBe(DEFAULT_GITHUB_DEVICE_SCOPE)
expect(capturedScope).toBe('read:user') expect(capturedScope).toBe('read:user')
}) })
@@ -96,7 +89,7 @@ describe('requestDeviceCode', () => {
const scopesSeen: string[] = [] const scopesSeen: string[] = []
let callCount = 0 let callCount = 0
globalThis.fetch = mock((_url: RequestInfo | URL, init?: RequestInit) => { const fetchImpl = mock((_url: RequestInfo | URL, init?: RequestInit) => {
const body = init?.body const body = init?.body
const scope = const scope =
body instanceof URLSearchParams body instanceof URLSearchParams
@@ -132,7 +125,7 @@ describe('requestDeviceCode', () => {
const result = await requestDeviceCode({ const result = await requestDeviceCode({
clientId: 'test-client', clientId: 'test-client',
scope: 'read:user,models:read', scope: 'read:user,models:read',
fetchImpl: globalThis.fetch, fetchImpl,
}) })
expect(result.device_code).toBe('abc') expect(result.device_code).toBe('abc')
@@ -142,17 +135,11 @@ describe('requestDeviceCode', () => {
}) })
describe('pollAccessToken', () => { describe('pollAccessToken', () => {
const originalFetch = globalThis.fetch
afterEach(() => {
globalThis.fetch = originalFetch
})
test('returns token when GitHub responds with access_token immediately', async () => { test('returns token when GitHub responds with access_token immediately', async () => {
const { pollAccessToken } = await importFreshModule() const { pollAccessToken } = await importFreshModule()
let calls = 0 let calls = 0
globalThis.fetch = mock(() => { const fetchImpl = mock(() => {
calls++ calls++
return Promise.resolve( return Promise.resolve(
new Response(JSON.stringify({ access_token: 'tok-xyz' }), { new Response(JSON.stringify({ access_token: 'tok-xyz' }), {
@@ -163,7 +150,7 @@ describe('pollAccessToken', () => {
const token = await pollAccessToken('dev-code', { const token = await pollAccessToken('dev-code', {
clientId: 'cid', clientId: 'cid',
fetchImpl: globalThis.fetch, fetchImpl,
}) })
expect(token).toBe('tok-xyz') expect(token).toBe('tok-xyz')
expect(calls).toBe(1) expect(calls).toBe(1)
@@ -172,7 +159,7 @@ describe('pollAccessToken', () => {
test('throws on access_denied', async () => { test('throws on access_denied', async () => {
const { pollAccessToken } = await importFreshModule() const { pollAccessToken } = await importFreshModule()
globalThis.fetch = mock(() => const fetchImpl = mock(() =>
Promise.resolve( Promise.resolve(
new Response(JSON.stringify({ error: 'access_denied' }), { new Response(JSON.stringify({ error: 'access_denied' }), {
status: 200, status: 200,
@@ -182,23 +169,17 @@ describe('pollAccessToken', () => {
await expect( await expect(
pollAccessToken('dc', { pollAccessToken('dc', {
clientId: 'c', clientId: 'c',
fetchImpl: globalThis.fetch, fetchImpl,
}), }),
).rejects.toThrow(/denied/) ).rejects.toThrow(/denied/)
}) })
}) })
describe('exchangeForCopilotToken', () => { describe('exchangeForCopilotToken', () => {
const originalFetch = globalThis.fetch
afterEach(() => {
globalThis.fetch = originalFetch
})
test('parses successful Copilot token response', async () => { test('parses successful Copilot token response', async () => {
const { exchangeForCopilotToken } = await importFreshModule() const { exchangeForCopilotToken } = await importFreshModule()
globalThis.fetch = mock(() => const fetchImpl = mock(() =>
Promise.resolve( Promise.resolve(
new Response( new Response(
JSON.stringify({ JSON.stringify({
@@ -214,7 +195,7 @@ describe('exchangeForCopilotToken', () => {
), ),
) )
const result = await exchangeForCopilotToken('oauth-token', globalThis.fetch) const result = await exchangeForCopilotToken('oauth-token', fetchImpl)
expect(result.token).toBe('copilot-token-xyz') expect(result.token).toBe('copilot-token-xyz')
expect(result.expires_at).toBe(1700000000) expect(result.expires_at).toBe(1700000000)
expect(result.refresh_in).toBe(3600) expect(result.refresh_in).toBe(3600)
@@ -225,24 +206,24 @@ describe('exchangeForCopilotToken', () => {
const { exchangeForCopilotToken, GitHubDeviceFlowError } = const { exchangeForCopilotToken, GitHubDeviceFlowError } =
await importFreshModule() await importFreshModule()
globalThis.fetch = mock(() => const fetchImpl = mock(() =>
Promise.resolve(new Response('unauthorized', { status: 401 })), Promise.resolve(new Response('unauthorized', { status: 401 })),
) )
await expect( await expect(
exchangeForCopilotToken('bad-token', globalThis.fetch), exchangeForCopilotToken('bad-token', fetchImpl),
).rejects.toThrow(GitHubDeviceFlowError) ).rejects.toThrow(GitHubDeviceFlowError)
}) })
test('throws on malformed response', async () => { test('throws on malformed response', async () => {
const { exchangeForCopilotToken } = await importFreshModule() const { exchangeForCopilotToken } = await importFreshModule()
globalThis.fetch = mock(() => const fetchImpl = mock(() =>
Promise.resolve( Promise.resolve(
new Response(JSON.stringify({ invalid: 'data' }), { status: 200 }), new Response(JSON.stringify({ invalid: 'data' }), { status: 200 }),
), ),
) )
await expect( await expect(
exchangeForCopilotToken('oauth-token', globalThis.fetch), exchangeForCopilotToken('oauth-token', fetchImpl),
).rejects.toThrow(/Malformed/) ).rejects.toThrow(/Malformed/)
}) })
}) })