From d9ae56bc585a5081e2760be65d16c1e4183db622 Mon Sep 17 00:00:00 2001 From: Kevin Codex Date: Sun, 26 Apr 2026 11:15:25 +0800 Subject: [PATCH] fix provider switch not presistingin session (#903) * fix provider switch not presistingin session * fix broken tests --- src/commands/provider/provider.test.tsx | 27 +++++++++++++ src/commands/provider/provider.tsx | 37 ++++++++++++++---- src/components/ProviderManager.tsx | 50 +++++++++++++++---------- src/constants/promptIdentity.test.ts | 17 +++++++++ src/constants/prompts.ts | 4 +- src/query.ts | 7 +++- src/utils/auth.ts | 31 +++++++++++---- 7 files changed, 134 insertions(+), 39 deletions(-) diff --git a/src/commands/provider/provider.test.tsx b/src/commands/provider/provider.test.tsx index 12113f11..bde1822c 100644 --- a/src/commands/provider/provider.test.tsx +++ b/src/commands/provider/provider.test.tsx @@ -11,6 +11,7 @@ import { buildCodexOAuthProfileEnv, buildCurrentProviderSummary, buildProfileSaveMessage, + buildProviderManagerCompletion, getProviderWizardDefaults, ProviderWizard, TextEntryDialog, @@ -264,6 +265,32 @@ test('wizard step remount prevents a typed API key from leaking into the next fi expect(output).not.toContain('sk-secret-12345678') }) +test('buildProviderManagerCompletion records provider switch event and model-visible reminder', () => { + const completion = buildProviderManagerCompletion({ + action: 'activated', + activeProviderName: 'Sadaf Provider', + activeProviderModel: 'sadaf-model', + message: 'Provider switched to Sadaf Provider (sadaf-model)', + }) + + expect(completion.message).toBe( + 'Provider switched to Sadaf Provider (sadaf-model)', + ) + expect(completion.metaMessages).toEqual([ + 'Provider switched mid-session to Sadaf Provider using model sadaf-model. Use this provider/model for subsequent requests unless the user switches again.', + ]) +}) + +test('buildProviderManagerCompletion skips provider reminder when manager is cancelled', () => { + const completion = buildProviderManagerCompletion({ + action: 'cancelled', + message: 'Provider manager closed', + }) + + expect(completion.message).toBe('Provider manager closed') + expect(completion.metaMessages).toBeUndefined() +}) + test('buildProfileSaveMessage maps provider fields without echoing secrets', () => { const message = buildProfileSaveMessage( 'openai', diff --git a/src/commands/provider/provider.tsx b/src/commands/provider/provider.tsx index 6d954d2a..db6de719 100644 --- a/src/commands/provider/provider.tsx +++ b/src/commands/provider/provider.tsx @@ -2,7 +2,10 @@ import * as React from 'react' import type { LocalJSXCommandCall, LocalJSXCommandOnDone } from '../../types/command.js' import { COMMON_HELP_ARGS, COMMON_INFO_ARGS } from '../../constants/xml.js' -import { ProviderManager } from '../../components/ProviderManager.js' +import { + ProviderManager, + type ProviderManagerResult, +} from '../../components/ProviderManager.js' import TextInput from '../../components/TextInput.js' import { Select, @@ -70,6 +73,29 @@ import { type OllamaGenerationReadiness, } from '../../utils/providerDiscovery.js' +export function buildProviderManagerCompletion(result?: ProviderManagerResult): { + message: string + metaMessages?: string[] +} { + const message = + result?.message ?? + (result?.action === 'saved' + ? 'Provider profile updated' + : 'Provider manager closed') + const metaMessages = + result?.action === 'activated' && result.activeProviderName + ? [ + `Provider switched mid-session to ${result.activeProviderName}${ + result.activeProviderModel + ? ` using model ${result.activeProviderModel}` + : '' + }. Use this provider/model for subsequent requests unless the user switches again.`, + ] + : undefined + + return { message, metaMessages } +} + function describeOllamaReadinessIssue( readiness: OllamaGenerationReadiness, options?: { @@ -1703,13 +1729,8 @@ export const call: LocalJSXCommandCall = async (onDone, _context, args) => { { - const message = - result?.message ?? - (result?.action === 'saved' - ? 'Provider profile updated' - : 'Provider manager closed') - - onDone(message, { display: 'system' }) + const { message, metaMessages } = buildProviderManagerCompletion(result) + onDone(message, { display: 'system', metaMessages }) }} /> ) diff --git a/src/components/ProviderManager.tsx b/src/components/ProviderManager.tsx index e72883e9..73c1bcde 100644 --- a/src/components/ProviderManager.tsx +++ b/src/components/ProviderManager.tsx @@ -58,8 +58,10 @@ import TextInput from './TextInput.js' import { useCodexOAuthFlow } from './useCodexOAuthFlow.js' export type ProviderManagerResult = { - action: 'saved' | 'cancelled' + action: 'saved' | 'cancelled' | 'activated' activeProfileId?: string + activeProviderName?: string + activeProviderModel?: string message?: string } @@ -759,12 +761,14 @@ export function ProviderManager({ mode, onDone }: Props): React.ReactNode { mainLoopModelForSession: null, })) refreshProfiles() - setAppState(prev => ({ - ...prev, - mainLoopModel: GITHUB_PROVIDER_DEFAULT_MODEL, - })) setStatusMessage(`Active provider: ${GITHUB_PROVIDER_LABEL}`) setIsActivating(false) + onDone({ + action: 'activated', + activeProviderName: GITHUB_PROVIDER_LABEL, + activeProviderModel: GITHUB_PROVIDER_DEFAULT_MODEL, + message: `Provider switched to ${GITHUB_PROVIDER_LABEL} (${GITHUB_PROVIDER_DEFAULT_MODEL})`, + }) returnToMenu() return } @@ -799,23 +803,29 @@ export function ProviderManager({ mode, onDone }: Props): React.ReactNode { : null refreshProfiles() - setStatusMessage( - isActiveCodexOAuth - ? buildCodexOAuthActivationMessage({ - prefix: `Active provider: ${active.name}`, + const activationMessage = isActiveCodexOAuth + ? buildCodexOAuthActivationMessage({ + prefix: `Active provider: ${active.name}`, + activationWarning, + warnings: [ activationWarning, - warnings: [ - activationWarning, - settingsOverrideError - ? `could not clear startup provider override (${settingsOverrideError})` - : null, - ].filter((warning): warning is string => Boolean(warning)), - }) - : settingsOverrideError - ? `Active provider: ${active.name}. Warning: could not clear startup provider override (${settingsOverrideError}).` - : `Active provider: ${active.name}`, - ) + settingsOverrideError + ? `could not clear startup provider override (${settingsOverrideError})` + : null, + ].filter((warning): warning is string => Boolean(warning)), + }) + : settingsOverrideError + ? `Active provider: ${active.name}. Warning: could not clear startup provider override (${settingsOverrideError}).` + : `Active provider: ${active.name}` + setStatusMessage(activationMessage) setIsActivating(false) + onDone({ + action: 'activated', + activeProfileId: active.id, + activeProviderName: active.name, + activeProviderModel: newModel, + message: `Provider switched to ${active.name} (${newModel})`, + }) returnToMenu() } catch (error) { refreshProfiles() diff --git a/src/constants/promptIdentity.test.ts b/src/constants/promptIdentity.test.ts index 818e7d0e..5c55041b 100644 --- a/src/constants/promptIdentity.test.ts +++ b/src/constants/promptIdentity.test.ts @@ -11,6 +11,7 @@ import { afterEach, expect, test } from 'bun:test' NATIVE_PACKAGE_URL: undefined, } +import { clearSystemPromptSections } from './systemPromptSections.js' import { getSystemPrompt, DEFAULT_AGENT_PROMPT } from './prompts.js' import { CLI_SYSPROMPT_PREFIXES, getCLISyspromptPrefix } from './system.js' import { CLAUDE_CODE_GUIDE_AGENT } from '../tools/AgentTool/built-in/claudeCodeGuideAgent.js' @@ -23,6 +24,7 @@ const originalSimpleEnv = process.env.CLAUDE_CODE_SIMPLE afterEach(() => { process.env.CLAUDE_CODE_SIMPLE = originalSimpleEnv + clearSystemPromptSections() }) test('CLI identity prefixes describe OpenClaude instead of Claude Code', () => { @@ -47,6 +49,21 @@ test('simple mode identity describes OpenClaude instead of Claude Code', async ( expect(prompt[0]).not.toContain("Anthropic's official CLI for Claude") }) +test('system prompt model identity updates when model changes mid-session', async () => { + delete process.env.CLAUDE_CODE_SIMPLE + clearSystemPromptSections() + + const firstPrompt = await getSystemPrompt([], 'old-test-model') + const secondPrompt = await getSystemPrompt([], 'new-test-model') + + const firstText = firstPrompt.join('\n') + const secondText = secondPrompt.join('\n') + + expect(firstText).toContain('You are powered by the model old-test-model.') + expect(secondText).toContain('You are powered by the model new-test-model.') + expect(secondText).not.toContain('You are powered by the model old-test-model.') +}) + test('built-in agent prompts describe OpenClaude instead of Claude Code', () => { expect(DEFAULT_AGENT_PROMPT).toContain('OpenClaude') expect(DEFAULT_AGENT_PROMPT).not.toContain('Claude Code') diff --git a/src/constants/prompts.ts b/src/constants/prompts.ts index 52054df7..53691ab1 100644 --- a/src/constants/prompts.ts +++ b/src/constants/prompts.ts @@ -496,7 +496,7 @@ ${CYBER_RISK_INSTRUCTION}`, systemPromptSection('ant_model_override', () => getAntModelOverrideSection(), ), - systemPromptSection('env_info_simple', () => + systemPromptSection(`env_info_simple:${model}`, () => computeSimpleEnvInfo(model, additionalWorkingDirectories), ), systemPromptSection('language', () => @@ -519,7 +519,7 @@ ${CYBER_RISK_INSTRUCTION}`, 'MCP servers connect/disconnect between turns', ), systemPromptSection('scratchpad', () => getScratchpadInstructions()), - systemPromptSection('frc', () => getFunctionResultClearingSection(model)), + systemPromptSection(`frc:${model}`, () => getFunctionResultClearingSection(model)), systemPromptSection( 'summarize_tool_results', () => SUMMARIZE_TOOL_RESULTS_SECTION, diff --git a/src/query.ts b/src/query.ts index 6187ed61..fe2e84e6 100644 --- a/src/query.ts +++ b/src/query.ts @@ -77,6 +77,7 @@ import { import { notifyCommandLifecycle } from './utils/commandLifecycle.js' import { headlessProfilerCheckpoint } from './utils/headlessProfiler.js' import { + getDefaultMainLoopModelSetting, getRuntimeMainLoopModel, renderModelName, } from './utils/model/model.js' @@ -604,9 +605,13 @@ async function* queryLoop( const appState = toolUseContext.getAppState() const permissionMode = appState.toolPermissionContext.mode + const appStateMainLoopModel = + appState.mainLoopModelForSession ?? + appState.mainLoopModel ?? + getDefaultMainLoopModelSetting() let currentModel = getRuntimeMainLoopModel({ permissionMode, - mainLoopModel: toolUseContext.options.mainLoopModel, + mainLoopModel: appStateMainLoopModel, exceeds200kTokens: permissionMode === 'plan' && doesMostRecentAssistantMessageExceed200k(messagesForQuery), diff --git a/src/utils/auth.ts b/src/utils/auth.ts index b4a67f24..69774d93 100644 --- a/src/utils/auth.ts +++ b/src/utils/auth.ts @@ -130,10 +130,18 @@ export function isAnthropicAuthEnabled(): boolean { apiKeyHelper || process.env.CLAUDE_CODE_API_KEY_FILE_DESCRIPTOR - // Check if API key is from an external source (not managed by /login) - const { source: apiKeySource } = getAnthropicApiKeyWithSource({ - skipRetrievingKeyFromApiKeyHelper: true, - }) + // Check if API key is from an external source (not managed by /login). + // Predicate must not throw: getAnthropicApiKeyWithSource throws under + // CI/NODE_ENV=test when no key is configured, but here we just want to + // know the source — "no key" is a valid answer. + let apiKeySource: ApiKeySource + try { + ;({ source: apiKeySource } = getAnthropicApiKeyWithSource({ + skipRetrievingKeyFromApiKeyHelper: true, + })) + } catch { + apiKeySource = 'none' + } const hasExternalApiKey = apiKeySource === 'ANTHROPIC_API_KEY' || apiKeySource === 'apiKeyHelper' @@ -221,10 +229,17 @@ export function getAnthropicApiKey(): null | string { } export function hasAnthropicApiKeyAuth(): boolean { - const { key, source } = getAnthropicApiKeyWithSource({ - skipRetrievingKeyFromApiKeyHelper: true, - }) - return key !== null && source !== 'none' + // Predicate: never throw. getAnthropicApiKeyWithSource throws under + // CI/NODE_ENV=test when no key is configured — but "do we have auth?" is + // exactly the question that has to answer cleanly in that state. + try { + const { key, source } = getAnthropicApiKeyWithSource({ + skipRetrievingKeyFromApiKeyHelper: true, + }) + return key !== null && source !== 'none' + } catch { + return false + } } export function getAnthropicApiKeyWithSource(