diff --git a/src/entrypoints/cli.tsx b/src/entrypoints/cli.tsx index 58c071d5..c82a0bb2 100644 --- a/src/entrypoints/cli.tsx +++ b/src/entrypoints/cli.tsx @@ -399,6 +399,22 @@ async function main(): Promise { process.env.CLAUDE_CODE_SIMPLE = '1'; } + // --provider: set provider env vars early, before main module loads. + // This mirrors the --bare pattern: env vars must be in place before + // Commander option building and module-level constants are evaluated. + if (args.includes('--provider')) { + const { parseProviderFlag, applyProviderFlag } = await import('../utils/providerFlag.js'); + const provider = parseProviderFlag(args); + if (provider) { + const result = applyProviderFlag(provider, args); + if (result.error) { + // biome-ignore lint/suspicious/noConsole:: intentional error output + console.error(`Error: ${result.error}`); + process.exit(1); + } + } + } + // No special flags detected, load and run the full CLI if (process.env.OPENCLAUDE_ENABLE_EARLY_INPUT === '1') { const { diff --git a/src/main.tsx b/src/main.tsx index a29c7e24..219d7f09 100644 --- a/src/main.tsx +++ b/src/main.tsx @@ -984,7 +984,7 @@ async function run(): Promise { return Number.isFinite(n) ? n : undefined; }).hideHelp()).option('--from-pr [value]', 'Resume a session linked to a PR by PR number/URL, or open interactive picker with optional search term', value => value || true).option('--no-session-persistence', 'Disable session persistence - sessions will not be saved to disk and cannot be resumed (only works with --print)').addOption(new Option('--resume-session-at ', 'When resuming, only messages up to and including the assistant message with (use with --resume in print mode)').argParser(String).hideHelp()).addOption(new Option('--rewind-files ', 'Restore files to state at the specified user message and exit (requires --resume)').hideHelp()) // @[MODEL LAUNCH]: Update the example model ID in the --model help text. - .option('--model ', `Model for the current session. Provide an alias for the latest model (e.g. 'sonnet' or 'opus') or a model's full name (e.g. 'claude-sonnet-4-6').`).addOption(new Option('--effort ', `Effort level for the current session (low, medium, high, max)`).argParser((rawValue: string) => { + .option('--model ', `Model for the current session. Provide an alias for the latest model (e.g. 'sonnet' or 'opus') or a model's full name (e.g. 'claude-sonnet-4-6').`).option('--provider ', `AI provider to use (anthropic, openai, gemini, github, bedrock, vertex, ollama). Reads API keys from environment variables.`).addOption(new Option('--effort ', `Effort level for the current session (low, medium, high, max)`).argParser((rawValue: string) => { const value = rawValue.toLowerCase(); const allowed = ['low', 'medium', 'high', 'max']; if (!allowed.includes(value)) { diff --git a/src/services/api/providerConfig.ts b/src/services/api/providerConfig.ts index 5097f142..0e5fe44d 100644 --- a/src/services/api/providerConfig.ts +++ b/src/services/api/providerConfig.ts @@ -291,8 +291,13 @@ export function resolveProviderRequest(options?: { process.env.OPENAI_BASE_URL ?? process.env.OPENAI_API_BASE ?? undefined + // Use Codex transport only when: + // - the base URL is explicitly the Codex endpoint, OR + // - the model is a Codex alias AND no custom base URL has been set + // A custom OPENAI_BASE_URL (e.g. Azure, OpenRouter) always wins over + // model-name-based Codex detection to prevent auth failures (#200, #203). const transport: ProviderTransport = - isCodexAlias(requestedModel) || isCodexBaseUrl(rawBaseUrl) + isCodexBaseUrl(rawBaseUrl) || (!rawBaseUrl && isCodexAlias(requestedModel)) ? 'codex_responses' : 'chat_completions' diff --git a/src/utils/providerFlag.test.ts b/src/utils/providerFlag.test.ts new file mode 100644 index 00000000..30f392fc --- /dev/null +++ b/src/utils/providerFlag.test.ts @@ -0,0 +1,139 @@ +import { describe, expect, test, afterEach } from 'bun:test' +import { parseProviderFlag, applyProviderFlag, VALID_PROVIDERS } from './providerFlag.js' + +const originalEnv = { ...process.env } + +afterEach(() => { + for (const key of [ + 'CLAUDE_CODE_USE_OPENAI', + 'CLAUDE_CODE_USE_GEMINI', + 'CLAUDE_CODE_USE_GITHUB', + 'CLAUDE_CODE_USE_BEDROCK', + 'CLAUDE_CODE_USE_VERTEX', + 'OPENAI_BASE_URL', + 'OPENAI_API_KEY', + 'OPENAI_MODEL', + 'GEMINI_MODEL', + ]) { + if (originalEnv[key] === undefined) delete process.env[key] + else process.env[key] = originalEnv[key] + } +}) + +// --- parseProviderFlag --- + +describe('parseProviderFlag', () => { + test('returns provider name when --provider flag present', () => { + expect(parseProviderFlag(['--provider', 'openai'])).toBe('openai') + }) + + test('returns provider name with --model alongside', () => { + expect(parseProviderFlag(['--provider', 'gemini', '--model', 'gemini-2.0-flash'])).toBe('gemini') + }) + + test('returns null when --provider flag absent', () => { + expect(parseProviderFlag(['--model', 'gpt-4o'])).toBeNull() + }) + + test('returns null for empty args', () => { + expect(parseProviderFlag([])).toBeNull() + }) + + test('returns null when --provider has no value', () => { + expect(parseProviderFlag(['--provider'])).toBeNull() + }) + + test('returns null when --provider value starts with --', () => { + expect(parseProviderFlag(['--provider', '--model'])).toBeNull() + }) +}) + +// --- applyProviderFlag --- + +describe('applyProviderFlag - anthropic', () => { + test('sets no env vars for anthropic (default)', () => { + const result = applyProviderFlag('anthropic', []) + expect(result.error).toBeUndefined() + expect(process.env.CLAUDE_CODE_USE_OPENAI).toBeUndefined() + expect(process.env.CLAUDE_CODE_USE_GEMINI).toBeUndefined() + }) +}) + +describe('applyProviderFlag - openai', () => { + test('sets CLAUDE_CODE_USE_OPENAI=1', () => { + const result = applyProviderFlag('openai', []) + expect(result.error).toBeUndefined() + expect(process.env.CLAUDE_CODE_USE_OPENAI).toBe('1') + }) + + test('sets OPENAI_MODEL when --model is provided', () => { + applyProviderFlag('openai', ['--model', 'gpt-4o']) + expect(process.env.OPENAI_MODEL).toBe('gpt-4o') + }) +}) + +describe('applyProviderFlag - gemini', () => { + test('sets CLAUDE_CODE_USE_GEMINI=1', () => { + const result = applyProviderFlag('gemini', []) + expect(result.error).toBeUndefined() + expect(process.env.CLAUDE_CODE_USE_GEMINI).toBe('1') + }) + + test('sets GEMINI_MODEL when --model is provided', () => { + applyProviderFlag('gemini', ['--model', 'gemini-2.0-flash']) + expect(process.env.GEMINI_MODEL).toBe('gemini-2.0-flash') + }) +}) + +describe('applyProviderFlag - github', () => { + test('sets CLAUDE_CODE_USE_GITHUB=1', () => { + const result = applyProviderFlag('github', []) + expect(result.error).toBeUndefined() + expect(process.env.CLAUDE_CODE_USE_GITHUB).toBe('1') + }) +}) + +describe('applyProviderFlag - bedrock', () => { + test('sets CLAUDE_CODE_USE_BEDROCK=1', () => { + const result = applyProviderFlag('bedrock', []) + expect(result.error).toBeUndefined() + expect(process.env.CLAUDE_CODE_USE_BEDROCK).toBe('1') + }) +}) + +describe('applyProviderFlag - vertex', () => { + test('sets CLAUDE_CODE_USE_VERTEX=1', () => { + const result = applyProviderFlag('vertex', []) + expect(result.error).toBeUndefined() + expect(process.env.CLAUDE_CODE_USE_VERTEX).toBe('1') + }) +}) + +describe('applyProviderFlag - ollama', () => { + test('sets CLAUDE_CODE_USE_OPENAI=1 with Ollama base URL', () => { + const result = applyProviderFlag('ollama', []) + expect(result.error).toBeUndefined() + expect(process.env.CLAUDE_CODE_USE_OPENAI).toBe('1') + expect(process.env.OPENAI_BASE_URL).toBe('http://localhost:11434/v1') + expect(process.env.OPENAI_API_KEY).toBe('ollama') + }) + + test('sets OPENAI_MODEL when --model is provided', () => { + applyProviderFlag('ollama', ['--model', 'llama3.2']) + expect(process.env.OPENAI_MODEL).toBe('llama3.2') + }) + + test('does not override existing OPENAI_BASE_URL when user set a custom one', () => { + process.env.OPENAI_BASE_URL = 'http://my-ollama:11434/v1' + applyProviderFlag('ollama', []) + expect(process.env.OPENAI_BASE_URL).toBe('http://my-ollama:11434/v1') + }) +}) + +describe('applyProviderFlag - invalid provider', () => { + test('returns error for unknown provider', () => { + const result = applyProviderFlag('unknown-provider', []) + expect(result.error).toContain('unknown-provider') + expect(result.error).toContain(VALID_PROVIDERS.join(', ')) + }) +}) diff --git a/src/utils/providerFlag.ts b/src/utils/providerFlag.ts new file mode 100644 index 00000000..6a56eca0 --- /dev/null +++ b/src/utils/providerFlag.ts @@ -0,0 +1,107 @@ +/** + * --provider CLI flag support. + * + * Maps the user-friendly provider name to the environment variables + * that the rest of the codebase uses for provider detection. + * + * Usage: + * openclaude --provider openai --model gpt-4o + * openclaude --provider gemini --model gemini-2.0-flash + * openclaude --provider ollama --model llama3.2 + * openclaude --provider anthropic (default, no-op) + */ + +export const VALID_PROVIDERS = [ + 'anthropic', + 'openai', + 'gemini', + 'github', + 'bedrock', + 'vertex', + 'ollama', +] as const + +export type ProviderFlagName = (typeof VALID_PROVIDERS)[number] + +/** + * Extract the value of --provider from argv. + * Returns null if the flag is absent or has no value. + */ +export function parseProviderFlag(args: string[]): string | null { + const idx = args.indexOf('--provider') + if (idx === -1) return null + const value = args[idx + 1] + if (!value || value.startsWith('--')) return null + return value +} + +/** + * Extract the value of --model from argv. + * Returns null if absent. + */ +function parseModelFlag(args: string[]): string | null { + const idx = args.indexOf('--model') + if (idx === -1) return null + const value = args[idx + 1] + if (!value || value.startsWith('--')) return null + return value +} + +/** + * Apply a provider name to process.env. + * Sets the required CLAUDE_CODE_USE_* flag and any provider-specific + * defaults (Ollama base URL, model routing). Does NOT overwrite values + * that are already set — explicit env vars always win. + * + * Returns { error } if the provider name is not recognized. + */ +export function applyProviderFlag( + provider: string, + args: string[], +): { error?: string } { + if (!(VALID_PROVIDERS as readonly string[]).includes(provider)) { + return { + error: `Unknown provider "${provider}". Valid providers: ${VALID_PROVIDERS.join(', ')}`, + } + } + + const model = parseModelFlag(args) + + switch (provider as ProviderFlagName) { + case 'anthropic': + // Default — no env vars needed + break + + case 'openai': + process.env.CLAUDE_CODE_USE_OPENAI = '1' + if (model) process.env.OPENAI_MODEL ??= model + break + + case 'gemini': + process.env.CLAUDE_CODE_USE_GEMINI = '1' + if (model) process.env.GEMINI_MODEL ??= model + break + + case 'github': + process.env.CLAUDE_CODE_USE_GITHUB = '1' + if (model) process.env.OPENAI_MODEL ??= model + break + + case 'bedrock': + process.env.CLAUDE_CODE_USE_BEDROCK = '1' + break + + case 'vertex': + process.env.CLAUDE_CODE_USE_VERTEX = '1' + break + + case 'ollama': + process.env.CLAUDE_CODE_USE_OPENAI = '1' + process.env.OPENAI_BASE_URL ??= 'http://localhost:11434/v1' + process.env.OPENAI_API_KEY ??= 'ollama' + if (model) process.env.OPENAI_MODEL ??= model + break + } + + return {} +}