Compare commits
34 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c0b8a59a23 | ||
|
|
aab489055c | ||
|
|
7002cb302b | ||
|
|
739b8d1f40 | ||
|
|
f166ec1a4e | ||
|
|
13e9f22a83 | ||
|
|
f828171ef1 | ||
|
|
e6e8d9a248 | ||
|
|
2c98be7002 | ||
|
|
b786b765f0 | ||
|
|
55c5f262a9 | ||
|
|
002a8f1f6d | ||
|
|
3d1979ff06 | ||
|
|
b0d9fe7112 | ||
|
|
651123db1f | ||
|
|
34246635fb | ||
|
|
43ac6dba75 | ||
|
|
80a00acc2c | ||
|
|
eed77e6579 | ||
|
|
b280c740a6 | ||
|
|
2ff5710329 | ||
|
|
d6f5130c20 | ||
|
|
d32a2a1329 | ||
|
|
fbcd928f7f | ||
|
|
77083d769b | ||
|
|
b66633ea4d | ||
|
|
51191d6132 | ||
|
|
6b2121da12 | ||
|
|
c207cdbdcc | ||
|
|
a00b7928de | ||
|
|
12dd3755c6 | ||
|
|
114f772a4a | ||
|
|
7187fc007a | ||
|
|
0ed50ccfe7 |
24
.env.example
24
.env.example
@@ -225,6 +225,30 @@ ANTHROPIC_API_KEY=sk-ant-your-key-here
|
||||
# GOOGLE_CLOUD_PROJECT=your-gcp-project-id
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Option 9: NVIDIA NIM
|
||||
# -----------------------------------------------------------------------------
|
||||
# NVIDIA NIM provides hosted inference endpoints for NVIDIA models.
|
||||
# Get your API key from https://build.nvidia.com/
|
||||
#
|
||||
# CLAUDE_CODE_USE_OPENAI=1
|
||||
# NVIDIA_API_KEY=nvapi-your-key-here
|
||||
# OPENAI_BASE_URL=https://integrate.api.nvidia.com/v1
|
||||
# OPENAI_MODEL=nvidia/llama-3.1-nemotron-70b-instruct
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Option 10: MiniMax
|
||||
# -----------------------------------------------------------------------------
|
||||
# MiniMax API provides text generation models.
|
||||
# Get your API key from https://platform.minimax.io/
|
||||
#
|
||||
# CLAUDE_CODE_USE_OPENAI=1
|
||||
# MINIMAX_API_KEY=your-minimax-key-here
|
||||
# OPENAI_BASE_URL=https://api.minimax.io/v1
|
||||
# OPENAI_MODEL=MiniMax-M2.5
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# OPTIONAL TUNING
|
||||
# =============================================================================
|
||||
|
||||
1
.github/workflows/release.yml
vendored
1
.github/workflows/release.yml
vendored
@@ -11,6 +11,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
release-please:
|
||||
if: ${{ github.repository == 'Gitlawb/openclaude' }}
|
||||
name: Release Please
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
{
|
||||
".": "0.3.0"
|
||||
".": "0.5.1"
|
||||
}
|
||||
|
||||
47
CHANGELOG.md
47
CHANGELOG.md
@@ -1,5 +1,52 @@
|
||||
# Changelog
|
||||
|
||||
## [0.5.1](https://github.com/Gitlawb/openclaude/compare/v0.5.0...v0.5.1) (2026-04-20)
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* enforce Bash path constraints after sandbox allow ([#777](https://github.com/Gitlawb/openclaude/issues/777)) ([7002cb3](https://github.com/Gitlawb/openclaude/commit/7002cb302b78ea2a19da3f26226de24e2903fa1d))
|
||||
* enforce MCP OAuth callback state before errors ([#775](https://github.com/Gitlawb/openclaude/issues/775)) ([739b8d1](https://github.com/Gitlawb/openclaude/commit/739b8d1f40fde0e401a5cbd2b9a55d88bd5124ad))
|
||||
* require trusted approval for sandbox override ([#778](https://github.com/Gitlawb/openclaude/issues/778)) ([aab4890](https://github.com/Gitlawb/openclaude/commit/aab489055c53dd64369414116fe93226d2656273))
|
||||
|
||||
## [0.5.0](https://github.com/Gitlawb/openclaude/compare/v0.4.0...v0.5.0) (2026-04-20)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* add OPENCLAUDE_DISABLE_STRICT_TOOLS env var to opt out of strict MCP tool schema normalization ([#770](https://github.com/Gitlawb/openclaude/issues/770)) ([e6e8d9a](https://github.com/Gitlawb/openclaude/commit/e6e8d9a24897e4c9ef08b72df20fabbf8ef27f38))
|
||||
* mask provider api key input ([#772](https://github.com/Gitlawb/openclaude/issues/772)) ([13e9f22](https://github.com/Gitlawb/openclaude/commit/13e9f22a83a2b0f85f557b1e12c9442ba61241e4))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* allow provider recovery during startup ([#765](https://github.com/Gitlawb/openclaude/issues/765)) ([f828171](https://github.com/Gitlawb/openclaude/commit/f828171ef1ab94e2acf73a28a292799e4e26cc0d))
|
||||
* **api:** drop orphan tool results to satisfy strict role sequence ([#745](https://github.com/Gitlawb/openclaude/issues/745)) ([b786b76](https://github.com/Gitlawb/openclaude/commit/b786b765f01f392652eaf28ed3579a96b7260a53))
|
||||
* **help:** prevent /help tab crash from undefined descriptions ([#732](https://github.com/Gitlawb/openclaude/issues/732)) ([3d1979f](https://github.com/Gitlawb/openclaude/commit/3d1979ff066db32415e0c8321af916d81f5f2621))
|
||||
* **mcp:** sync required array with properties in tool schemas ([#754](https://github.com/Gitlawb/openclaude/issues/754)) ([002a8f1](https://github.com/Gitlawb/openclaude/commit/002a8f1f6de2fcfc917165d828501d3047bad61f))
|
||||
* remove cached mcpClient in diagnostic tracking to prevent stale references ([#727](https://github.com/Gitlawb/openclaude/issues/727)) ([2c98be7](https://github.com/Gitlawb/openclaude/commit/2c98be700274a4241963b5f43530bf3bd8f8963f))
|
||||
* use raw context window for auto-compact percentage display ([#748](https://github.com/Gitlawb/openclaude/issues/748)) ([55c5f26](https://github.com/Gitlawb/openclaude/commit/55c5f262a9a5a8be0aa9ae8dc6c7dafc465eb2c6))
|
||||
|
||||
## [0.4.0](https://github.com/Gitlawb/openclaude/compare/v0.3.0...v0.4.0) (2026-04-17)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* add Alibaba Coding Plan (DashScope) provider support ([#509](https://github.com/Gitlawb/openclaude/issues/509)) ([43ac6db](https://github.com/Gitlawb/openclaude/commit/43ac6dba75537282da1e2ad8f855082bc4e25f1e))
|
||||
* add NVIDIA NIM and MiniMax provider support ([#552](https://github.com/Gitlawb/openclaude/issues/552)) ([51191d6](https://github.com/Gitlawb/openclaude/commit/51191d61326e1f8319d70b3a3c0d9229e185a564))
|
||||
* add ripgrep to Dockerfile for faster file searching ([#688](https://github.com/Gitlawb/openclaude/issues/688)) ([12dd375](https://github.com/Gitlawb/openclaude/commit/12dd3755c619cc27af3b151ae8fdb9d425a7b9a2))
|
||||
* **api:** classify openai-compatible provider failures ([#708](https://github.com/Gitlawb/openclaude/issues/708)) ([80a00ac](https://github.com/Gitlawb/openclaude/commit/80a00acc2c6dc4657a78de7366f7a9ebc920bfbb))
|
||||
* **vscode:** add full chat interface to OpenClaude extension ([#608](https://github.com/Gitlawb/openclaude/issues/608)) ([fbcd928](https://github.com/Gitlawb/openclaude/commit/fbcd928f7f8511da795aea3ad318bddf0ab9a1a7))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* focus "Done" option after completing provider manager actions ([#718](https://github.com/Gitlawb/openclaude/issues/718)) ([d6f5130](https://github.com/Gitlawb/openclaude/commit/d6f5130c204d8ffe582212466768706cd7fd6774))
|
||||
* **models:** prevent /models crash from non-string saved model values ([#691](https://github.com/Gitlawb/openclaude/issues/691)) ([6b2121d](https://github.com/Gitlawb/openclaude/commit/6b2121da12189fa7ce1f33394d18abd24cf8a01b))
|
||||
* prevent crash in commands tab when description is undefined ([#730](https://github.com/Gitlawb/openclaude/issues/730)) ([eed77e6](https://github.com/Gitlawb/openclaude/commit/eed77e6579866a98384dcc948a0ad6406614ede3))
|
||||
* strip comments before scanning for missing imports ([#676](https://github.com/Gitlawb/openclaude/issues/676)) ([a00b792](https://github.com/Gitlawb/openclaude/commit/a00b7928de9662ffb7ef6abd8cd040afe6f4f122))
|
||||
* **ui:** show correct endpoint URL in intro screen for custom Anthropic endpoints ([#735](https://github.com/Gitlawb/openclaude/issues/735)) ([3424663](https://github.com/Gitlawb/openclaude/commit/34246635fb9a09499047a52e7f96ca9b36c8a85a))
|
||||
|
||||
## [0.3.0](https://github.com/Gitlawb/openclaude/compare/v0.2.3...v0.3.0) (2026-04-14)
|
||||
|
||||
|
||||
|
||||
@@ -36,14 +36,11 @@ COPY --from=build /app/node_modules/ node_modules/
|
||||
COPY --from=build /app/package.json package.json
|
||||
COPY README.md ./
|
||||
|
||||
# Install git — many CLI tool operations depend on it
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends git \
|
||||
# Install git and ripgrep — many CLI tool operations depend on them
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends git ripgrep \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Run as non-root user
|
||||
RUN groupadd --gid 1000 appuser && useradd --uid 1000 --gid appuser --shell /bin/bash --create-home appuser
|
||||
USER appuser
|
||||
WORKDIR /home/appuser
|
||||
ENV HOME=/home/appuser
|
||||
USER node
|
||||
|
||||
ENTRYPOINT ["node", "/app/dist/cli.mjs"]
|
||||
|
||||
19
README.md
19
README.md
@@ -15,6 +15,10 @@ OpenClaude is also mirrored to GitLawb:
|
||||
|
||||
[Quick Start](#quick-start) | [Setup Guides](#setup-guides) | [Providers](#supported-providers) | [Source Build](#source-build-and-local-development) | [VS Code Extension](#vs-code-extension) | [Community](#community)
|
||||
|
||||
## Star History
|
||||
|
||||
[](https://www.star-history.com/?repos=gitlawb%2Fopenclaude&type=date&legend=top-left)
|
||||
|
||||
## Why OpenClaude
|
||||
|
||||
- Use one CLI across cloud APIs and local model backends
|
||||
@@ -88,6 +92,16 @@ $env:OPENAI_MODEL="qwen2.5-coder:7b"
|
||||
openclaude
|
||||
```
|
||||
|
||||
### Using Ollama's launch command
|
||||
|
||||
If you have [Ollama](https://ollama.com) installed, you can skip the env var setup entirely:
|
||||
|
||||
```bash
|
||||
ollama launch openclaude --model qwen2.5-coder:7b
|
||||
```
|
||||
|
||||
This automatically sets `ANTHROPIC_BASE_URL`, model routing, and auth so all API traffic goes through your local Ollama instance. Works with any model you have pulled — local or cloud.
|
||||
|
||||
## Setup Guides
|
||||
|
||||
Beginner-friendly guides:
|
||||
@@ -110,7 +124,7 @@ Advanced and source-build guides:
|
||||
| GitHub Models | `/onboard-github` | Interactive onboarding with saved credentials |
|
||||
| Codex OAuth | `/provider` | Opens ChatGPT sign-in in your browser and stores Codex credentials securely |
|
||||
| Codex | `/provider` | Uses existing Codex CLI auth, OpenClaude secure storage, or env credentials |
|
||||
| Ollama | `/provider` or env vars | Local inference with no API key |
|
||||
| Ollama | `/provider`, env vars, or `ollama launch` | Local inference with no API key |
|
||||
| Atomic Chat | advanced setup | Local Apple Silicon backend |
|
||||
| Bedrock / Vertex / Foundry | env vars | Additional provider integrations for supported environments |
|
||||
|
||||
@@ -317,7 +331,8 @@ For larger changes, open an issue first so the scope is clear before implementat
|
||||
- `bun run build`
|
||||
- `bun run test:coverage`
|
||||
- `bun run smoke`
|
||||
- focused `bun test ...` runs for touched areas
|
||||
- focused `bun test ...` runs for files and flows you changed
|
||||
|
||||
|
||||
## Disclaimer
|
||||
|
||||
|
||||
@@ -84,6 +84,16 @@ OpenRouter model availability changes over time. If a model stops working, try a
|
||||
|
||||
### Ollama
|
||||
|
||||
Using `ollama launch` (recommended if you have Ollama installed):
|
||||
|
||||
```bash
|
||||
ollama launch openclaude --model llama3.3:70b
|
||||
```
|
||||
|
||||
This handles all environment setup automatically — no env vars needed. Works with any local or cloud model available in your Ollama instance.
|
||||
|
||||
Using environment variables manually:
|
||||
|
||||
```bash
|
||||
ollama pull llama3.3:70b
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@gitlawb/openclaude",
|
||||
"version": "0.3.0",
|
||||
"version": "0.5.1",
|
||||
"description": "Claude Code opened to any LLM — OpenAI, Gemini, DeepSeek, Ollama, and 200+ models",
|
||||
"type": "module",
|
||||
"bin": {
|
||||
|
||||
@@ -367,9 +367,17 @@ export const SeverityNumber = {};
|
||||
const full = pathMod.join(dir, ent.name)
|
||||
if (ent.isDirectory()) { walk(full); continue }
|
||||
if (!/\.(ts|tsx)$/.test(ent.name)) continue
|
||||
const code: string = fs.readFileSync(full, 'utf-8')
|
||||
const rawCode: string = fs.readFileSync(full, 'utf-8')
|
||||
const fileDir = pathMod.dirname(full)
|
||||
|
||||
// Strip comments before scanning for imports/requires.
|
||||
// The regex scanner matches require()/import() patterns
|
||||
// inside JSDoc comments, causing false-positive missing
|
||||
// module detection that breaks the build with noop stubs.
|
||||
const code = rawCode
|
||||
.replace(/\/\*[\s\S]*?\*\//g, '') // block comments
|
||||
.replace(/\/\/.*$/gm, '') // line comments
|
||||
|
||||
// Collect static imports: import { X } from '...'
|
||||
for (const m of code.matchAll(/import\s+(?:\{([^}]*)\}|(\w+))?\s*(?:,\s*\{([^}]*)\})?\s*from\s+['"](.*?)['"]/g)) {
|
||||
checkAndRegister(m[4], fileDir, m[1] || m[3] || '')
|
||||
|
||||
@@ -11,7 +11,12 @@ import { MCPServerDesktopImportDialog } from '../../components/MCPServerDesktopI
|
||||
import { render } from '../../ink.js';
|
||||
import { KeybindingSetup } from '../../keybindings/KeybindingProviderSetup.js';
|
||||
import { type AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS, logEvent } from '../../services/analytics/index.js';
|
||||
import { clearMcpClientConfig, clearServerTokensFromLocalStorage, readClientSecret, saveMcpClientSecret } from '../../services/mcp/auth.js';
|
||||
import {
|
||||
clearMcpClientConfig,
|
||||
clearServerTokensFromSecureStorage,
|
||||
readClientSecret,
|
||||
saveMcpClientSecret,
|
||||
} from '../../services/mcp/auth.js'
|
||||
import { doctorAllServers, doctorServer, type McpDoctorReport, type McpDoctorScopeFilter } from '../../services/mcp/doctor.js';
|
||||
import { connectToServer, getMcpServerConnectionBatchSize } from '../../services/mcp/client.js';
|
||||
import { addMcpConfig, getAllMcpConfigs, getMcpConfigByName, getMcpConfigsByScope, removeMcpConfig } from '../../services/mcp/config.js';
|
||||
|
||||
30
src/commands.test.ts
Normal file
30
src/commands.test.ts
Normal file
@@ -0,0 +1,30 @@
|
||||
import { formatDescriptionWithSource } from './commands.js'
|
||||
|
||||
describe('formatDescriptionWithSource', () => {
|
||||
test('returns empty text for prompt commands missing a description', () => {
|
||||
const command = {
|
||||
name: 'example',
|
||||
type: 'prompt',
|
||||
source: 'builtin',
|
||||
description: undefined,
|
||||
} as any
|
||||
|
||||
expect(formatDescriptionWithSource(command)).toBe('')
|
||||
})
|
||||
|
||||
test('formats plugin commands with missing description safely', () => {
|
||||
const command = {
|
||||
name: 'example',
|
||||
type: 'prompt',
|
||||
source: 'plugin',
|
||||
description: undefined,
|
||||
pluginInfo: {
|
||||
pluginManifest: {
|
||||
name: 'MyPlugin',
|
||||
},
|
||||
},
|
||||
} as any
|
||||
|
||||
expect(formatDescriptionWithSource(command)).toBe('(MyPlugin) ')
|
||||
})
|
||||
})
|
||||
@@ -740,23 +740,23 @@ export function getCommand(commandName: string, commands: Command[]): Command {
|
||||
*/
|
||||
export function formatDescriptionWithSource(cmd: Command): string {
|
||||
if (cmd.type !== 'prompt') {
|
||||
return cmd.description
|
||||
return cmd.description ?? ''
|
||||
}
|
||||
|
||||
if (cmd.kind === 'workflow') {
|
||||
return `${cmd.description} (workflow)`
|
||||
return `${cmd.description ?? ''} (workflow)`
|
||||
}
|
||||
|
||||
if (cmd.source === 'plugin') {
|
||||
const pluginName = cmd.pluginInfo?.pluginManifest.name
|
||||
if (pluginName) {
|
||||
return `(${pluginName}) ${cmd.description}`
|
||||
return `(${pluginName}) ${cmd.description ?? ''}`
|
||||
}
|
||||
return `${cmd.description} (plugin)`
|
||||
return `${cmd.description ?? ''} (plugin)`
|
||||
}
|
||||
|
||||
if (cmd.source === 'builtin' || cmd.source === 'mcp') {
|
||||
return cmd.description
|
||||
return cmd.description ?? ''
|
||||
}
|
||||
|
||||
if (cmd.source === 'bundled') {
|
||||
|
||||
@@ -401,7 +401,7 @@ test('buildCodexProfileEnv derives oauth source from secure storage when no expl
|
||||
})
|
||||
})
|
||||
|
||||
test('applySavedProfileToCurrentSession switches the current env to the saved Codex profile', async () => {
|
||||
test('explicitly declared env takes precedence over applySavedProfileToCurrentSession', async () => {
|
||||
// @ts-expect-error cache-busting query string for Bun module mocks
|
||||
const { applySavedProfileToCurrentSession } = await import(
|
||||
'../../utils/providerProfile.js?apply-saved-profile-codex'
|
||||
@@ -430,18 +430,18 @@ test('applySavedProfileToCurrentSession switches the current env to the saved Co
|
||||
|
||||
expect(warning).toBeNull()
|
||||
expect(processEnv.CLAUDE_CODE_USE_OPENAI).toBe('1')
|
||||
expect(processEnv.OPENAI_MODEL).toBe('codexplan')
|
||||
expect(processEnv.OPENAI_MODEL).toBe('gpt-4o')
|
||||
expect(processEnv.OPENAI_BASE_URL).toBe(
|
||||
'https://chatgpt.com/backend-api/codex',
|
||||
"https://api.openai.com/v1",
|
||||
)
|
||||
expect(processEnv.CODEX_API_KEY).toBe('codex-live')
|
||||
expect(processEnv.CHATGPT_ACCOUNT_ID).toBe('acct_codex')
|
||||
expect(processEnv.OPENAI_API_KEY).toBeUndefined()
|
||||
expect(processEnv.CODEX_API_KEY).toBeUndefined()
|
||||
expect(processEnv.CHATGPT_ACCOUNT_ID).toBeUndefined()
|
||||
expect(processEnv.OPENAI_API_KEY).toBe("sk-openai")
|
||||
expect(processEnv.CLAUDE_CODE_PROVIDER_PROFILE_ENV_APPLIED).toBeUndefined()
|
||||
expect(processEnv.CLAUDE_CODE_PROVIDER_PROFILE_ENV_APPLIED_ID).toBeUndefined()
|
||||
})
|
||||
|
||||
test('applySavedProfileToCurrentSession ignores stale Codex env overrides for OAuth-backed profiles', async () => {
|
||||
test('explicitly declared env takes precedence over applySavedProfileToCurrentSession', async () => {
|
||||
// @ts-expect-error cache-busting query string for Bun module mocks
|
||||
const { applySavedProfileToCurrentSession } = await import(
|
||||
'../../utils/providerProfile.js?apply-saved-profile-codex-oauth'
|
||||
@@ -465,13 +465,13 @@ test('applySavedProfileToCurrentSession ignores stale Codex env overrides for OA
|
||||
processEnv,
|
||||
})
|
||||
|
||||
expect(warning).toBeNull()
|
||||
expect(processEnv.OPENAI_MODEL).toBe('codexplan')
|
||||
expect(warning).not.toBeUndefined()
|
||||
expect(processEnv.OPENAI_MODEL).toBe('gpt-4o')
|
||||
expect(processEnv.OPENAI_BASE_URL).toBe(
|
||||
'https://chatgpt.com/backend-api/codex',
|
||||
"https://api.openai.com/v1",
|
||||
)
|
||||
expect(processEnv.CODEX_API_KEY).toBeUndefined()
|
||||
expect(processEnv.CHATGPT_ACCOUNT_ID).not.toBe('acct_stale')
|
||||
expect(processEnv.CODEX_API_KEY).toBe("stale-codex-key")
|
||||
expect(processEnv.CHATGPT_ACCOUNT_ID).toBe('acct_stale')
|
||||
expect(processEnv.CHATGPT_ACCOUNT_ID).toBeTruthy()
|
||||
})
|
||||
|
||||
@@ -487,8 +487,8 @@ test('buildCurrentProviderSummary redacts poisoned model and endpoint values', (
|
||||
})
|
||||
|
||||
expect(summary.providerLabel).toBe('OpenAI-compatible')
|
||||
expect(summary.modelLabel).toBe('sk-...5678')
|
||||
expect(summary.endpointLabel).toBe('sk-...5678')
|
||||
expect(summary.modelLabel).toBe('sk-...678')
|
||||
expect(summary.endpointLabel).toBe('sk-...678')
|
||||
})
|
||||
|
||||
test('buildCurrentProviderSummary labels generic local openai-compatible providers', () => {
|
||||
|
||||
@@ -3,12 +3,14 @@ import * as React from 'react'
|
||||
import { DEFAULT_CODEX_BASE_URL } from '../services/api/providerConfig.js'
|
||||
import { Box, Text } from '../ink.js'
|
||||
import { useKeybinding } from '../keybindings/useKeybinding.js'
|
||||
import { useSetAppState } from '../state/AppState.js'
|
||||
import type { ProviderProfile } from '../utils/config.js'
|
||||
import {
|
||||
clearCodexCredentials,
|
||||
readCodexCredentialsAsync,
|
||||
} from '../utils/codexCredentials.js'
|
||||
import { isBareMode, isEnvTruthy } from '../utils/envUtils.js'
|
||||
import { getPrimaryModel, hasMultipleModels, parseModelList } from '../utils/providerModels.js'
|
||||
import {
|
||||
applySavedProfileToCurrentSession,
|
||||
buildCodexOAuthProfileEnv,
|
||||
@@ -50,6 +52,7 @@ import {
|
||||
import { Pane } from './design-system/Pane.js'
|
||||
import TextInput from './TextInput.js'
|
||||
import { useCodexOAuthFlow } from './useCodexOAuthFlow.js'
|
||||
import { useSetAppState } from '../state/AppState.js'
|
||||
|
||||
export type ProviderManagerResult = {
|
||||
action: 'saved' | 'cancelled'
|
||||
@@ -108,8 +111,8 @@ const FORM_STEPS: Array<{
|
||||
{
|
||||
key: 'model',
|
||||
label: 'Default model',
|
||||
placeholder: 'e.g. llama3.1:8b',
|
||||
helpText: 'Model name to use when this provider is active.',
|
||||
placeholder: 'e.g. llama3.1:8b or glm-4.7, glm-4.7-flash',
|
||||
helpText: 'Model name(s) to use. Separate multiple with commas; first is default.',
|
||||
},
|
||||
{
|
||||
key: 'apiKey',
|
||||
@@ -153,7 +156,12 @@ function profileSummary(profile: ProviderProfile, isActive: boolean): string {
|
||||
const keyInfo = profile.apiKey ? 'key set' : 'no key'
|
||||
const providerKind =
|
||||
profile.provider === 'anthropic' ? 'anthropic' : 'openai-compatible'
|
||||
return `${providerKind} · ${profile.baseUrl} · ${profile.model} · ${keyInfo}${activeSuffix}`
|
||||
const models = parseModelList(profile.model)
|
||||
const modelDisplay =
|
||||
models.length <= 3
|
||||
? models.join(', ')
|
||||
: `${models[0]}, ${models[1]} + ${models.length - 2} more`
|
||||
return `${providerKind} · ${profile.baseUrl} · ${modelDisplay} · ${keyInfo}${activeSuffix}`
|
||||
}
|
||||
|
||||
function getGithubCredentialSourceFromEnv(
|
||||
@@ -320,6 +328,7 @@ function CodexOAuthSetup({
|
||||
}
|
||||
|
||||
export function ProviderManager({ mode, onDone }: Props): React.ReactNode {
|
||||
const setAppState = useSetAppState()
|
||||
const initialGithubCredentialSource = getGithubCredentialSourceFromEnv()
|
||||
const initialIsGithubActive = isEnvTruthy(process.env.CLAUDE_CODE_USE_GITHUB)
|
||||
const initialHasGithubCredential = initialGithubCredentialSource !== 'none'
|
||||
@@ -353,6 +362,7 @@ export function ProviderManager({ mode, onDone }: Props): React.ReactNode {
|
||||
const [cursorOffset, setCursorOffset] = React.useState(0)
|
||||
const [statusMessage, setStatusMessage] = React.useState<string | undefined>()
|
||||
const [errorMessage, setErrorMessage] = React.useState<string | undefined>()
|
||||
const [menuFocusValue, setMenuFocusValue] = React.useState<string | undefined>()
|
||||
const [hasStoredCodexOAuthCredentials, setHasStoredCodexOAuthCredentials] =
|
||||
React.useState(false)
|
||||
const [storedCodexOAuthProfileId, setStoredCodexOAuthProfileId] =
|
||||
@@ -568,24 +578,48 @@ export function ProviderManager({ mode, onDone }: Props): React.ReactNode {
|
||||
const githubError = activateGithubProvider()
|
||||
if (githubError) {
|
||||
setErrorMessage(`Could not activate GitHub provider: ${githubError}`)
|
||||
setScreen('menu')
|
||||
returnToMenu()
|
||||
return
|
||||
}
|
||||
|
||||
setAppState(prev => ({
|
||||
...prev,
|
||||
mainLoopModel: GITHUB_PROVIDER_DEFAULT_MODEL,
|
||||
mainLoopModelForSession: null,
|
||||
}))
|
||||
refreshProfiles()
|
||||
setAppState(prev => ({
|
||||
...prev,
|
||||
mainLoopModel: GITHUB_PROVIDER_DEFAULT_MODEL,
|
||||
}))
|
||||
setStatusMessage(`Active provider: ${GITHUB_PROVIDER_LABEL}`)
|
||||
setScreen('menu')
|
||||
returnToMenu()
|
||||
return
|
||||
}
|
||||
|
||||
const active = setActiveProviderProfile(profileId)
|
||||
if (!active) {
|
||||
setErrorMessage('Could not change active provider.')
|
||||
setScreen('menu')
|
||||
returnToMenu()
|
||||
return
|
||||
}
|
||||
|
||||
// Update the session model to the new provider's first model.
|
||||
// persistActiveProviderProfileModel (called by onChangeAppState) will
|
||||
// not overwrite the multi-model list because it checks if the model
|
||||
// is already in the profile's comma-separated model list.
|
||||
const newModel = getPrimaryModel(active.model)
|
||||
setAppState(prev => ({
|
||||
...prev,
|
||||
mainLoopModel: newModel,
|
||||
}))
|
||||
|
||||
providerLabel = active.name
|
||||
setAppState(prev => ({
|
||||
...prev,
|
||||
mainLoopModel: active.model,
|
||||
mainLoopModelForSession: null,
|
||||
}))
|
||||
const settingsOverrideError =
|
||||
clearStartupProviderOverrideFromUserSettings()
|
||||
const isActiveCodexOAuth = isCodexOAuthProfile(
|
||||
@@ -613,16 +647,21 @@ export function ProviderManager({ mode, onDone }: Props): React.ReactNode {
|
||||
? `Active provider: ${active.name}. Warning: could not clear startup provider override (${settingsOverrideError}).`
|
||||
: `Active provider: ${active.name}`,
|
||||
)
|
||||
setScreen('menu')
|
||||
returnToMenu()
|
||||
} catch (error) {
|
||||
refreshProfiles()
|
||||
setStatusMessage(undefined)
|
||||
const detail = error instanceof Error ? error.message : String(error)
|
||||
setErrorMessage(`Could not finish activating ${providerLabel}: ${detail}`)
|
||||
setScreen('menu')
|
||||
returnToMenu()
|
||||
}
|
||||
}
|
||||
|
||||
function returnToMenu(): void {
|
||||
setMenuFocusValue('done')
|
||||
setScreen('menu')
|
||||
}
|
||||
|
||||
function closeWithCancelled(message: string): void {
|
||||
onDone({ action: 'cancelled', message })
|
||||
}
|
||||
@@ -773,6 +812,13 @@ export function ProviderManager({ mode, onDone }: Props): React.ReactNode {
|
||||
}
|
||||
|
||||
const isActiveSavedProfile = getActiveProviderProfile()?.id === saved.id
|
||||
if (isActiveSavedProfile) {
|
||||
setAppState(prev => ({
|
||||
...prev,
|
||||
mainLoopModel: saved.model,
|
||||
mainLoopModelForSession: null,
|
||||
}))
|
||||
}
|
||||
const settingsOverrideError = isActiveSavedProfile
|
||||
? clearStartupProviderOverrideFromUserSettings()
|
||||
: null
|
||||
@@ -800,7 +846,7 @@ export function ProviderManager({ mode, onDone }: Props): React.ReactNode {
|
||||
setEditingProfileId(null)
|
||||
setFormStepIndex(0)
|
||||
setErrorMessage(undefined)
|
||||
setScreen('menu')
|
||||
returnToMenu()
|
||||
}
|
||||
|
||||
function renderOllamaSelection(): React.ReactNode {
|
||||
@@ -923,7 +969,7 @@ export function ProviderManager({ mode, onDone }: Props): React.ReactNode {
|
||||
return
|
||||
}
|
||||
|
||||
setScreen('menu')
|
||||
returnToMenu()
|
||||
}
|
||||
|
||||
useKeybinding('confirm:no', handleBackFromForm, {
|
||||
@@ -1004,11 +1050,31 @@ export function ProviderManager({ mode, onDone }: Props): React.ReactNode {
|
||||
label: 'LM Studio',
|
||||
description: 'Local LM Studio endpoint',
|
||||
},
|
||||
{
|
||||
value: 'dashscope-cn',
|
||||
label: 'Alibaba Coding Plan (China)',
|
||||
description: 'Alibaba DashScope China endpoint',
|
||||
},
|
||||
{
|
||||
value: 'dashscope-intl',
|
||||
label: 'Alibaba Coding Plan',
|
||||
description: 'Alibaba DashScope International endpoint',
|
||||
},
|
||||
{
|
||||
value: 'custom',
|
||||
label: 'Custom',
|
||||
description: 'Any OpenAI-compatible provider',
|
||||
},
|
||||
{
|
||||
value: 'nvidia-nim',
|
||||
label: 'NVIDIA NIM',
|
||||
description: 'NVIDIA NIM endpoint',
|
||||
},
|
||||
{
|
||||
value: 'minimax',
|
||||
label: 'MiniMax',
|
||||
description: 'MiniMax API endpoint',
|
||||
},
|
||||
...(mode === 'first-run'
|
||||
? [
|
||||
{
|
||||
@@ -1046,7 +1112,7 @@ export function ProviderManager({ mode, onDone }: Props): React.ReactNode {
|
||||
closeWithCancelled('Provider setup skipped')
|
||||
return
|
||||
}
|
||||
setScreen('menu')
|
||||
returnToMenu()
|
||||
}}
|
||||
visibleOptionCount={Math.min(13, options.length)}
|
||||
/>
|
||||
@@ -1084,6 +1150,7 @@ export function ProviderManager({ mode, onDone }: Props): React.ReactNode {
|
||||
focus={true}
|
||||
showCursor={true}
|
||||
placeholder={`${currentStep.placeholder}${figures.ellipsis}`}
|
||||
mask={currentStepKey === 'apiKey' ? '*' : undefined}
|
||||
columns={80}
|
||||
cursorOffset={cursorOffset}
|
||||
onChangeCursorOffset={setCursorOffset}
|
||||
@@ -1246,6 +1313,7 @@ export function ProviderManager({ mode, onDone }: Props): React.ReactNode {
|
||||
}
|
||||
}}
|
||||
onCancel={() => closeWithCancelled('Provider manager closed')}
|
||||
defaultFocusValue={menuFocusValue}
|
||||
visibleOptionCount={options.length}
|
||||
/>
|
||||
</Box>
|
||||
@@ -1293,8 +1361,8 @@ export function ProviderManager({ mode, onDone }: Props): React.ReactNode {
|
||||
description: 'Return to provider manager',
|
||||
},
|
||||
]}
|
||||
onChange={() => setScreen('menu')}
|
||||
onCancel={() => setScreen('menu')}
|
||||
onChange={() => returnToMenu()}
|
||||
onCancel={() => returnToMenu()}
|
||||
visibleOptionCount={1}
|
||||
/>
|
||||
</Box>
|
||||
@@ -1309,7 +1377,7 @@ export function ProviderManager({ mode, onDone }: Props): React.ReactNode {
|
||||
<Select
|
||||
options={selectOptions}
|
||||
onChange={onSelect}
|
||||
onCancel={() => setScreen('menu')}
|
||||
onCancel={() => returnToMenu()}
|
||||
visibleOptionCount={Math.min(10, Math.max(2, selectOptions.length))}
|
||||
/>
|
||||
</Box>
|
||||
@@ -1350,7 +1418,7 @@ export function ProviderManager({ mode, onDone }: Props): React.ReactNode {
|
||||
setErrorMessage(
|
||||
'Codex OAuth login finished, but the provider profile could not be saved.',
|
||||
)
|
||||
setScreen('menu')
|
||||
returnToMenu()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1362,7 +1430,7 @@ export function ProviderManager({ mode, onDone }: Props): React.ReactNode {
|
||||
setErrorMessage(
|
||||
'Codex OAuth login finished, but the provider could not be set as the startup provider.',
|
||||
)
|
||||
setScreen('menu')
|
||||
returnToMenu()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1396,7 +1464,7 @@ export function ProviderManager({ mode, onDone }: Props): React.ReactNode {
|
||||
|
||||
setStatusMessage(message)
|
||||
setErrorMessage(undefined)
|
||||
setScreen('menu')
|
||||
returnToMenu()
|
||||
}}
|
||||
/>
|
||||
)
|
||||
@@ -1436,7 +1504,7 @@ export function ProviderManager({ mode, onDone }: Props): React.ReactNode {
|
||||
refreshProfiles()
|
||||
setStatusMessage('GitHub provider deleted')
|
||||
}
|
||||
setScreen('menu')
|
||||
returnToMenu()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1471,7 +1539,7 @@ export function ProviderManager({ mode, onDone }: Props): React.ReactNode {
|
||||
: 'Provider deleted',
|
||||
)
|
||||
}
|
||||
setScreen('menu')
|
||||
returnToMenu()
|
||||
},
|
||||
{ includeGithub: true },
|
||||
)
|
||||
|
||||
@@ -117,17 +117,28 @@ function detectProvider(): { name: string; model: string; baseUrl: string; isLoc
|
||||
const baseUrl = resolvedRequest.baseUrl
|
||||
const isLocal = isLocalProviderUrl(baseUrl)
|
||||
let name = 'OpenAI'
|
||||
// Override to Codex when resolved endpoint is Codex
|
||||
if (resolvedRequest.transport === 'codex_responses' || baseUrl.includes('chatgpt.com/backend-api/codex')) {
|
||||
if (/nvidia/i.test(baseUrl) || /nvidia/i.test(rawModel) || process.env.NVIDIA_NIM)
|
||||
name = 'NVIDIA NIM'
|
||||
else if (/minimax/i.test(baseUrl) || /minimax/i.test(rawModel) || process.env.MINIMAX_API_KEY)
|
||||
name = 'MiniMax'
|
||||
else if (resolvedRequest.transport === 'codex_responses' || baseUrl.includes('chatgpt.com/backend-api/codex'))
|
||||
name = 'Codex'
|
||||
} else if (/deepseek/i.test(baseUrl) || /deepseek/i.test(rawModel)) name = 'DeepSeek'
|
||||
else if (/openrouter/i.test(baseUrl)) name = 'OpenRouter'
|
||||
else if (/together/i.test(baseUrl)) name = 'Together AI'
|
||||
else if (/groq/i.test(baseUrl)) name = 'Groq'
|
||||
else if (/mistral/i.test(baseUrl) || /mistral/i.test(rawModel)) name = 'Mistral'
|
||||
else if (/azure/i.test(baseUrl)) name = 'Azure OpenAI'
|
||||
else if (/llama/i.test(rawModel)) name = 'Meta Llama'
|
||||
else if (isLocal) name = getLocalOpenAICompatibleProviderLabel(baseUrl)
|
||||
else if (/deepseek/i.test(baseUrl) || /deepseek/i.test(rawModel))
|
||||
name = 'DeepSeek'
|
||||
else if (/openrouter/i.test(baseUrl))
|
||||
name = 'OpenRouter'
|
||||
else if (/together/i.test(baseUrl))
|
||||
name = 'Together AI'
|
||||
else if (/groq/i.test(baseUrl))
|
||||
name = 'Groq'
|
||||
else if (/mistral/i.test(baseUrl) || /mistral/i.test(rawModel))
|
||||
name = 'Mistral'
|
||||
else if (/azure/i.test(baseUrl))
|
||||
name = 'Azure OpenAI'
|
||||
else if (/llama/i.test(rawModel))
|
||||
name = 'Meta Llama'
|
||||
else if (isLocal)
|
||||
name = getLocalOpenAICompatibleProviderLabel(baseUrl)
|
||||
|
||||
// Resolve model alias to actual model name + reasoning effort
|
||||
let displayModel = resolvedRequest.resolvedModel
|
||||
@@ -142,7 +153,9 @@ function detectProvider(): { name: string; model: string; baseUrl: string; isLoc
|
||||
const settings = getSettings_DEPRECATED() || {}
|
||||
const modelSetting = settings.model || process.env.ANTHROPIC_MODEL || process.env.CLAUDE_MODEL || 'claude-sonnet-4-6'
|
||||
const resolvedModel = parseUserSpecifiedModel(modelSetting)
|
||||
return { name: 'Anthropic', model: resolvedModel, baseUrl: 'https://api.anthropic.com', isLocal: false }
|
||||
const baseUrl = process.env.ANTHROPIC_BASE_URL ?? 'https://api.anthropic.com'
|
||||
const isLocal = isLocalProviderUrl(baseUrl)
|
||||
return { name: 'Anthropic', model: resolvedModel, baseUrl, isLocal }
|
||||
}
|
||||
|
||||
// ─── Box drawing ──────────────────────────────────────────────────────────────
|
||||
|
||||
@@ -6,6 +6,7 @@ import stripAnsi from 'strip-ansi'
|
||||
|
||||
import { createRoot } from '../ink.js'
|
||||
import { AppStateProvider } from '../state/AppState.js'
|
||||
import { maskTextWithVisibleEdges } from '../utils/Cursor.js'
|
||||
import TextInput from './TextInput.js'
|
||||
import VimTextInput from './VimTextInput.js'
|
||||
|
||||
@@ -199,6 +200,13 @@ test('TextInput renders typed characters before delayed parent value commits', a
|
||||
expect(output).not.toContain('Type here...')
|
||||
})
|
||||
|
||||
test('maskTextWithVisibleEdges preserves only the first and last three chars', () => {
|
||||
expect(maskTextWithVisibleEdges('sk-secret-12345678', '*')).toBe(
|
||||
'sk-************678',
|
||||
)
|
||||
expect(maskTextWithVisibleEdges('abcdef', '*')).toBe('******')
|
||||
})
|
||||
|
||||
test('VimTextInput preserves rapid typed characters before delayed parent value commits', async () => {
|
||||
const { stdout, stdin, getOutput } = createTestStreams()
|
||||
const root = await createRoot({
|
||||
|
||||
@@ -1,5 +1,16 @@
|
||||
import { afterEach, expect, test } from 'bun:test'
|
||||
|
||||
// MACRO is replaced at build time by Bun.define but not in test mode.
|
||||
// Define it globally so tests that import modules using MACRO don't crash.
|
||||
;(globalThis as Record<string, unknown>).MACRO = {
|
||||
VERSION: '99.0.0',
|
||||
DISPLAY_VERSION: '0.0.0-test',
|
||||
BUILD_TIME: new Date().toISOString(),
|
||||
ISSUES_EXPLAINER: 'report the issue at https://github.com/anthropics/claude-code/issues',
|
||||
PACKAGE_URL: '@gitlawb/openclaude',
|
||||
NATIVE_PACKAGE_URL: undefined,
|
||||
}
|
||||
|
||||
import { getSystemPrompt, DEFAULT_AGENT_PROMPT } from './prompts.js'
|
||||
import { CLI_SYSPROMPT_PREFIXES, getCLISyspromptPrefix } from './system.js'
|
||||
import { CLAUDE_CODE_GUIDE_AGENT } from '../tools/AgentTool/built-in/claudeCodeGuideAgent.js'
|
||||
|
||||
@@ -5,7 +5,7 @@ import {
|
||||
} from '../utils/providerProfile.js'
|
||||
import {
|
||||
getProviderValidationError,
|
||||
validateProviderEnvOrExit,
|
||||
validateProviderEnvForStartupOrExit,
|
||||
} from '../utils/providerValidation.js'
|
||||
|
||||
// OpenClaude: polyfill globalThis.File for Node < 20.
|
||||
@@ -132,7 +132,7 @@ async function main(): Promise<void> {
|
||||
hydrateGithubModelsTokenFromSecureStorage()
|
||||
}
|
||||
|
||||
await validateProviderEnvOrExit()
|
||||
await validateProviderEnvForStartupOrExit()
|
||||
|
||||
// Print the gradient startup screen before the Ink UI loads
|
||||
const { printStartupScreen } = await import('../components/StartupScreen.js')
|
||||
|
||||
75
src/entrypoints/mcp.test.ts
Normal file
75
src/entrypoints/mcp.test.ts
Normal file
@@ -0,0 +1,75 @@
|
||||
import { describe, it, expect, mock } from 'bun:test'
|
||||
import { getCombinedTools, loadReexposedMcpTools } from './mcp.js'
|
||||
import type { Tool as InternalTool } from '../Tool.js'
|
||||
import type { MCPServerConnection } from '../services/mcp/types.js'
|
||||
import type { Tool } from '@modelcontextprotocol/sdk/types.js'
|
||||
|
||||
// Mock the MCP client service to control the tools and connections returned
|
||||
const mockGetMcpToolsCommandsAndResources = mock(async (onConnectionAttempt: any) => {})
|
||||
mock.module('../services/mcp/client.js', () => ({
|
||||
getMcpToolsCommandsAndResources: mockGetMcpToolsCommandsAndResources
|
||||
}))
|
||||
|
||||
describe('getCombinedTools', () => {
|
||||
it('deduplicates builtins when mcpTools have the same name, prioritizing mcpTools', () => {
|
||||
const builtinBash = { name: 'Bash', isMcp: false } as unknown as InternalTool
|
||||
const builtinRead = { name: 'Read', isMcp: false } as unknown as InternalTool
|
||||
const mcpBash = { name: 'Bash', isMcp: true } as unknown as InternalTool
|
||||
|
||||
const builtins = [builtinBash, builtinRead]
|
||||
const mcpTools = [mcpBash]
|
||||
|
||||
const result = getCombinedTools(builtins, mcpTools)
|
||||
|
||||
expect(result).toHaveLength(2)
|
||||
expect(result[0]).toBe(mcpBash)
|
||||
expect(result[1]).toBe(builtinRead)
|
||||
})
|
||||
})
|
||||
|
||||
describe('loadReexposedMcpTools', () => {
|
||||
it('loads tools and clients regardless of connection state (including needs-auth)', async () => {
|
||||
// Setup the mock to simulate yielding a needs-auth server and a connected server
|
||||
mockGetMcpToolsCommandsAndResources.mockImplementation(async (onConnectionAttempt) => {
|
||||
const needsAuthClient = {
|
||||
name: 'auth-server',
|
||||
type: 'needs-auth',
|
||||
config: {}
|
||||
} as MCPServerConnection
|
||||
|
||||
const authTool = {
|
||||
name: 'mcp__auth-server__authenticate',
|
||||
isMcp: true
|
||||
} as unknown as InternalTool
|
||||
|
||||
const connectedClient = {
|
||||
name: 'connected-server',
|
||||
type: 'connected',
|
||||
config: {},
|
||||
client: {}
|
||||
} as MCPServerConnection
|
||||
|
||||
const connectedTool = {
|
||||
name: 'mcp__connected-server__do_thing',
|
||||
isMcp: true
|
||||
} as unknown as InternalTool
|
||||
|
||||
// Simulate the callback behavior
|
||||
onConnectionAttempt({ client: needsAuthClient, tools: [authTool], commands: [] })
|
||||
onConnectionAttempt({ client: connectedClient, tools: [connectedTool], commands: [] })
|
||||
})
|
||||
|
||||
const { mcpClients, mcpTools } = await loadReexposedMcpTools()
|
||||
|
||||
expect(mcpClients).toHaveLength(2)
|
||||
expect(mcpClients[0].type).toBe('needs-auth')
|
||||
expect(mcpClients[1].type).toBe('connected')
|
||||
|
||||
expect(mcpTools).toHaveLength(2)
|
||||
expect(mcpTools[0].name).toBe('mcp__auth-server__authenticate')
|
||||
expect(mcpTools[1].name).toBe('mcp__connected-server__do_thing')
|
||||
|
||||
// Reset mock for other tests
|
||||
mockGetMcpToolsCommandsAndResources.mockReset()
|
||||
})
|
||||
})
|
||||
@@ -7,6 +7,7 @@ process.env.CLAUDE_CODE_DISABLE_EXPERIMENTAL_BETAS ??= 'true'
|
||||
|
||||
import { Server } from '@modelcontextprotocol/sdk/server/index.js'
|
||||
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'
|
||||
import { ZodError } from 'zod'
|
||||
import {
|
||||
CallToolRequestSchema,
|
||||
type CallToolResult,
|
||||
@@ -17,9 +18,12 @@ import {
|
||||
import { getDefaultAppState } from 'src/state/AppStateStore.js'
|
||||
import review from '../commands/review.js'
|
||||
import type { Command } from '../commands.js'
|
||||
import { getMcpToolsCommandsAndResources } from '../services/mcp/client.js'
|
||||
import type { MCPServerConnection } from '../services/mcp/types.js'
|
||||
import {
|
||||
findToolByName,
|
||||
getEmptyToolPermissionContext,
|
||||
type Tool as InternalTool,
|
||||
type ToolUseContext,
|
||||
} from '../Tool.js'
|
||||
import { getTools } from '../tools.js'
|
||||
@@ -39,6 +43,32 @@ type ToolOutput = Tool['outputSchema']
|
||||
|
||||
const MCP_COMMANDS: Command[] = [review]
|
||||
|
||||
export function getCombinedTools(
|
||||
builtins: InternalTool[],
|
||||
mcpTools: InternalTool[],
|
||||
): InternalTool[] {
|
||||
const mcpToolNames = new Set(mcpTools.map(t => t.name))
|
||||
const deduplicatedBuiltins = builtins.filter(t => !mcpToolNames.has(t.name))
|
||||
|
||||
return [...mcpTools, ...deduplicatedBuiltins]
|
||||
}
|
||||
|
||||
export async function loadReexposedMcpTools(): Promise<{
|
||||
mcpClients: MCPServerConnection[]
|
||||
mcpTools: InternalTool[]
|
||||
}> {
|
||||
const mcpClients: MCPServerConnection[] = []
|
||||
const mcpTools: InternalTool[] = []
|
||||
|
||||
// Load configured MCP clients and their tools
|
||||
await getMcpToolsCommandsAndResources(({ client, tools: clientTools }) => {
|
||||
mcpClients.push(client)
|
||||
mcpTools.push(...clientTools)
|
||||
})
|
||||
|
||||
return { mcpClients, mcpTools }
|
||||
}
|
||||
|
||||
export async function startMCPServer(
|
||||
cwd: string,
|
||||
debug: boolean,
|
||||
@@ -63,12 +93,13 @@ export async function startMCPServer(
|
||||
},
|
||||
)
|
||||
|
||||
const { mcpClients, mcpTools } = await loadReexposedMcpTools()
|
||||
|
||||
server.setRequestHandler(
|
||||
ListToolsRequestSchema,
|
||||
async (): Promise<ListToolsResult> => {
|
||||
// TODO: Also re-expose any MCP tools
|
||||
const toolPermissionContext = getEmptyToolPermissionContext()
|
||||
const tools = getTools(toolPermissionContext)
|
||||
const tools = getCombinedTools(getTools(toolPermissionContext), mcpTools)
|
||||
return {
|
||||
tools: await Promise.all(
|
||||
tools.map(async tool => {
|
||||
@@ -94,7 +125,7 @@ export async function startMCPServer(
|
||||
tools,
|
||||
agents: [],
|
||||
}),
|
||||
inputSchema: zodToJsonSchema(tool.inputSchema) as ToolInput,
|
||||
inputSchema: (tool.inputJSONSchema ?? zodToJsonSchema(tool.inputSchema)) as ToolInput,
|
||||
outputSchema,
|
||||
}
|
||||
}),
|
||||
@@ -107,8 +138,7 @@ export async function startMCPServer(
|
||||
CallToolRequestSchema,
|
||||
async ({ params: { name, arguments: args } }): Promise<CallToolResult> => {
|
||||
const toolPermissionContext = getEmptyToolPermissionContext()
|
||||
// TODO: Also re-expose any MCP tools
|
||||
const tools = getTools(toolPermissionContext)
|
||||
const tools = getCombinedTools(getTools(toolPermissionContext), mcpTools)
|
||||
const tool = findToolByName(tools, name)
|
||||
if (!tool) {
|
||||
throw new Error(`Tool ${name} not found`)
|
||||
@@ -123,7 +153,7 @@ export async function startMCPServer(
|
||||
tools,
|
||||
mainLoopModel: getMainLoopModel(),
|
||||
thinkingConfig: { type: 'disabled' },
|
||||
mcpClients: [],
|
||||
mcpClients,
|
||||
mcpResources: {},
|
||||
isNonInteractiveSession: true,
|
||||
debug,
|
||||
@@ -140,13 +170,16 @@ export async function startMCPServer(
|
||||
updateAttributionState: () => {},
|
||||
}
|
||||
|
||||
// TODO: validate input types with zod
|
||||
try {
|
||||
if (!tool.isEnabled()) {
|
||||
throw new Error(`Tool ${name} is not enabled`)
|
||||
}
|
||||
|
||||
// Validate input types with zod
|
||||
const parsedArgs = tool.inputSchema.parse(args ?? {})
|
||||
|
||||
const validationResult = await tool.validateInput?.(
|
||||
(args as never) ?? {},
|
||||
(parsedArgs as never) ?? {},
|
||||
toolUseContext,
|
||||
)
|
||||
if (validationResult && !validationResult.result) {
|
||||
@@ -155,7 +188,7 @@ export async function startMCPServer(
|
||||
)
|
||||
}
|
||||
const finalResult = await tool.call(
|
||||
(args ?? {}) as never,
|
||||
(parsedArgs ?? {}) as never,
|
||||
toolUseContext,
|
||||
hasPermissionsToUseTool,
|
||||
createAssistantMessage({
|
||||
@@ -163,20 +196,50 @@ export async function startMCPServer(
|
||||
}),
|
||||
)
|
||||
|
||||
let content: CallToolResult['content']
|
||||
const data = finalResult.data as string | { type: string; text?: string; source?: { type: string; media_type: string; data: string } }[] | unknown
|
||||
|
||||
if (typeof data === 'string') {
|
||||
content = [{ type: 'text', text: data }]
|
||||
} else if (Array.isArray(data)) {
|
||||
content = data.map((block: any) => {
|
||||
if (block.type === 'text') {
|
||||
return { type: 'text', text: block.text || '' }
|
||||
} else if (block.type === 'image' && block.source) {
|
||||
return {
|
||||
type: 'image',
|
||||
data: block.source.data,
|
||||
mimeType: block.source.media_type,
|
||||
}
|
||||
} else {
|
||||
// eslint-disable-next-line custom-rules/no-top-level-side-effects, no-console
|
||||
console.warn(`Unmapped content block type from tool ${name}: ${block.type || 'unknown'}`)
|
||||
return { type: 'text', text: jsonStringify(block) }
|
||||
}
|
||||
}) as CallToolResult['content']
|
||||
} else {
|
||||
content = [{ type: 'text', text: jsonStringify(data) }]
|
||||
}
|
||||
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text' as const,
|
||||
text:
|
||||
typeof finalResult === 'string'
|
||||
? finalResult
|
||||
: jsonStringify(finalResult.data),
|
||||
},
|
||||
],
|
||||
content,
|
||||
isError: !!(finalResult as any).isError,
|
||||
}
|
||||
} catch (error) {
|
||||
logError(error)
|
||||
|
||||
if (error instanceof ZodError) {
|
||||
return {
|
||||
isError: true,
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: `Tool ${name} input is invalid:\n${error.errors.map(e => `- ${e.path.join('.')}: ${e.message}`).join('\n')}`,
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
const parts =
|
||||
error instanceof Error ? getErrorParts(error) : [String(error)]
|
||||
const errorText = parts.filter(Boolean).join('\n').trim() || 'Error'
|
||||
@@ -201,3 +264,4 @@ export async function startMCPServer(
|
||||
|
||||
return await runServer()
|
||||
}
|
||||
|
||||
|
||||
@@ -114,8 +114,8 @@ export const SandboxSettingsSchema = lazySchema(() =>
|
||||
.boolean()
|
||||
.optional()
|
||||
.describe(
|
||||
'Allow commands to run outside the sandbox via the dangerouslyDisableSandbox parameter. ' +
|
||||
'When false, the dangerouslyDisableSandbox parameter is completely ignored and all commands must run sandboxed. ' +
|
||||
'Allow trusted, user-initiated commands to run outside the sandbox. ' +
|
||||
'When false, sandbox override requests are ignored and all commands must run sandboxed. ' +
|
||||
'Default: true.',
|
||||
),
|
||||
network: SandboxNetworkConfigSchema(),
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { APIError } from '@anthropic-ai/sdk'
|
||||
import { fetchWithProxyRetry } from './fetchWithProxyRetry.js'
|
||||
import type {
|
||||
ResolvedCodexCredentials,
|
||||
ResolvedProviderRequest,
|
||||
@@ -559,12 +560,15 @@ export async function performCodexRequest(options: {
|
||||
}
|
||||
headers.originator ??= 'openclaude'
|
||||
|
||||
const response = await fetch(`${options.request.baseUrl}/responses`, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: JSON.stringify(body),
|
||||
signal: options.signal,
|
||||
})
|
||||
const response = await fetchWithProxyRetry(
|
||||
`${options.request.baseUrl}/responses`,
|
||||
{
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: JSON.stringify(body),
|
||||
signal: options.signal,
|
||||
},
|
||||
)
|
||||
|
||||
if (!response.ok) {
|
||||
const errorBody = await response.text().catch(() => 'unknown error')
|
||||
|
||||
44
src/services/api/errors.openaiCompatibility.test.ts
Normal file
44
src/services/api/errors.openaiCompatibility.test.ts
Normal file
@@ -0,0 +1,44 @@
|
||||
import { APIError } from '@anthropic-ai/sdk'
|
||||
import { expect, test } from 'bun:test'
|
||||
|
||||
import { getAssistantMessageFromError } from './errors.js'
|
||||
|
||||
function getFirstText(message: ReturnType<typeof getAssistantMessageFromError>): string {
|
||||
const first = message.message.content[0]
|
||||
if (!first || typeof first !== 'object' || !('text' in first)) {
|
||||
return ''
|
||||
}
|
||||
return typeof first.text === 'string' ? first.text : ''
|
||||
}
|
||||
|
||||
test('maps endpoint_not_found category markers to actionable setup guidance', () => {
|
||||
const error = APIError.generate(
|
||||
404,
|
||||
undefined,
|
||||
'OpenAI API error 404: Not Found [openai_category=endpoint_not_found] Hint: Confirm OPENAI_BASE_URL includes /v1.',
|
||||
new Headers(),
|
||||
)
|
||||
|
||||
const message = getAssistantMessageFromError(error, 'qwen2.5-coder:7b')
|
||||
const text = getFirstText(message)
|
||||
|
||||
expect(message.isApiErrorMessage).toBe(true)
|
||||
expect(text).toContain('Provider endpoint was not found')
|
||||
expect(text).toContain('OPENAI_BASE_URL')
|
||||
expect(text).toContain('/v1')
|
||||
})
|
||||
|
||||
test('maps tool_call_incompatible category markers to model/tool guidance', () => {
|
||||
const error = APIError.generate(
|
||||
400,
|
||||
undefined,
|
||||
'OpenAI API error 400: tool_calls are not supported [openai_category=tool_call_incompatible]',
|
||||
new Headers(),
|
||||
)
|
||||
|
||||
const message = getAssistantMessageFromError(error, 'qwen2.5-coder:7b')
|
||||
const text = getFirstText(message)
|
||||
|
||||
expect(text).toContain('rejected tool-calling payloads')
|
||||
expect(text).toContain('/model')
|
||||
})
|
||||
@@ -50,9 +50,110 @@ import {
|
||||
} from '../claudeAiLimits.js'
|
||||
import { shouldProcessRateLimits } from '../rateLimitMocking.js' // Used for /mock-limits command
|
||||
import { extractConnectionErrorDetails, formatAPIError } from './errorUtils.js'
|
||||
import {
|
||||
extractOpenAICategoryMarker,
|
||||
type OpenAICompatibilityFailureCategory,
|
||||
} from './openaiErrorClassification.js'
|
||||
|
||||
export const API_ERROR_MESSAGE_PREFIX = 'API Error'
|
||||
|
||||
function stripOpenAICompatibilityMetadata(message: string): string {
|
||||
return message
|
||||
.replace(/\s*\[openai_category=[a-z_]+\]\s*/g, ' ')
|
||||
.replace(/\s{2,}/g, ' ')
|
||||
.trim()
|
||||
}
|
||||
|
||||
function mapOpenAICompatibilityFailureToAssistantMessage(options: {
|
||||
category: OpenAICompatibilityFailureCategory
|
||||
model: string
|
||||
rawMessage: string
|
||||
}): AssistantMessage {
|
||||
const switchCmd = getIsNonInteractiveSession() ? '--model' : '/model'
|
||||
const compactHint = getIsNonInteractiveSession()
|
||||
? 'Reduce prompt size or start a new session.'
|
||||
: 'Run /compact or start a new session with /new.'
|
||||
|
||||
switch (options.category) {
|
||||
case 'localhost_resolution_failed':
|
||||
case 'connection_refused':
|
||||
return createAssistantAPIErrorMessage({
|
||||
content:
|
||||
'Could not connect to the local OpenAI-compatible provider. Ensure the local server is running, then use OPENAI_BASE_URL=http://127.0.0.1:11434/v1 for Ollama.',
|
||||
error: 'unknown',
|
||||
})
|
||||
|
||||
case 'endpoint_not_found':
|
||||
return createAssistantAPIErrorMessage({
|
||||
content:
|
||||
'Provider endpoint was not found. Confirm OPENAI_BASE_URL targets an OpenAI-compatible /v1 endpoint (for Ollama: http://127.0.0.1:11434/v1).',
|
||||
error: 'invalid_request',
|
||||
})
|
||||
|
||||
case 'model_not_found':
|
||||
return createAssistantAPIErrorMessage({
|
||||
content: `The selected model (${options.model}) is not available on this provider. Run ${switchCmd} to choose another model, or verify installed local models (for Ollama: ollama list).`,
|
||||
error: 'invalid_request',
|
||||
})
|
||||
|
||||
case 'auth_invalid':
|
||||
return createAssistantAPIErrorMessage({
|
||||
content: `${API_ERROR_MESSAGE_PREFIX}: Authentication failed for your OpenAI-compatible provider. Verify OPENAI_API_KEY and endpoint-specific auth requirements.`,
|
||||
error: 'authentication_failed',
|
||||
})
|
||||
|
||||
case 'rate_limited':
|
||||
return createAssistantAPIErrorMessage({
|
||||
content: `${API_ERROR_MESSAGE_PREFIX}: Provider rate limit reached. Retry in a few seconds.`,
|
||||
error: 'rate_limit',
|
||||
})
|
||||
|
||||
case 'request_timeout':
|
||||
return createAssistantAPIErrorMessage({
|
||||
content: `${API_ERROR_MESSAGE_PREFIX}: Provider request timed out. Local models may be loading or overloaded; retry shortly or increase API_TIMEOUT_MS.`,
|
||||
error: 'unknown',
|
||||
})
|
||||
|
||||
case 'context_overflow':
|
||||
return createAssistantAPIErrorMessage({
|
||||
content: `The conversation exceeded the provider context limit. ${compactHint}`,
|
||||
error: 'invalid_request',
|
||||
})
|
||||
|
||||
case 'tool_call_incompatible':
|
||||
return createAssistantAPIErrorMessage({
|
||||
content: `The selected provider/model rejected tool-calling payloads. Try ${switchCmd} to pick a tool-capable model or continue without tools.`,
|
||||
error: 'invalid_request',
|
||||
})
|
||||
|
||||
case 'malformed_provider_response':
|
||||
return createAssistantAPIErrorMessage({
|
||||
content: `${API_ERROR_MESSAGE_PREFIX}: Provider returned a malformed response. Confirm endpoint compatibility and check local proxy/network middleware.`,
|
||||
error: 'unknown',
|
||||
errorDetails: stripOpenAICompatibilityMetadata(options.rawMessage),
|
||||
})
|
||||
|
||||
case 'provider_unavailable':
|
||||
return createAssistantAPIErrorMessage({
|
||||
content: `${API_ERROR_MESSAGE_PREFIX}: Provider is temporarily unavailable. Retry in a moment.`,
|
||||
error: 'unknown',
|
||||
})
|
||||
|
||||
case 'network_error':
|
||||
case 'unknown':
|
||||
return createAssistantAPIErrorMessage({
|
||||
content: `${API_ERROR_MESSAGE_PREFIX}: ${stripOpenAICompatibilityMetadata(options.rawMessage)}`,
|
||||
error: 'unknown',
|
||||
})
|
||||
|
||||
default:
|
||||
return createAssistantAPIErrorMessage({
|
||||
content: `${API_ERROR_MESSAGE_PREFIX}: ${stripOpenAICompatibilityMetadata(options.rawMessage)}`,
|
||||
error: 'unknown',
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
export function startsWithApiErrorPrefix(text: string): boolean {
|
||||
return (
|
||||
text.startsWith(API_ERROR_MESSAGE_PREFIX) ||
|
||||
@@ -457,6 +558,19 @@ export function getAssistantMessageFromError(
|
||||
})
|
||||
}
|
||||
|
||||
// OpenAI-compatible transport and HTTP failures include structured category
|
||||
// markers from openaiShim.ts for actionable end-user remediation.
|
||||
if (error instanceof APIError) {
|
||||
const openaiCategory = extractOpenAICategoryMarker(error.message)
|
||||
if (openaiCategory) {
|
||||
return mapOpenAICompatibilityFailureToAssistantMessage({
|
||||
category: openaiCategory,
|
||||
model,
|
||||
rawMessage: error.message,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Check for emergency capacity off switch for Opus PAYG users
|
||||
if (
|
||||
error instanceof Error &&
|
||||
|
||||
86
src/services/api/fetchWithProxyRetry.test.ts
Normal file
86
src/services/api/fetchWithProxyRetry.test.ts
Normal file
@@ -0,0 +1,86 @@
|
||||
import { afterEach, beforeEach, expect, test } from 'bun:test'
|
||||
|
||||
import { _resetKeepAliveForTesting } from '../../utils/proxy.js'
|
||||
import {
|
||||
fetchWithProxyRetry,
|
||||
isRetryableFetchError,
|
||||
} from './fetchWithProxyRetry.js'
|
||||
|
||||
type FetchType = typeof globalThis.fetch
|
||||
|
||||
const originalFetch = globalThis.fetch
|
||||
const originalEnv = {
|
||||
HTTP_PROXY: process.env.HTTP_PROXY,
|
||||
HTTPS_PROXY: process.env.HTTPS_PROXY,
|
||||
}
|
||||
|
||||
function restoreEnv(key: 'HTTP_PROXY' | 'HTTPS_PROXY', value: string | undefined): void {
|
||||
if (value === undefined) {
|
||||
delete process.env[key]
|
||||
} else {
|
||||
process.env[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
process.env.HTTP_PROXY = 'http://127.0.0.1:15236'
|
||||
delete process.env.HTTPS_PROXY
|
||||
_resetKeepAliveForTesting()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
globalThis.fetch = originalFetch
|
||||
restoreEnv('HTTP_PROXY', originalEnv.HTTP_PROXY)
|
||||
restoreEnv('HTTPS_PROXY', originalEnv.HTTPS_PROXY)
|
||||
_resetKeepAliveForTesting()
|
||||
})
|
||||
|
||||
test('isRetryableFetchError matches Bun socket-closed failures', () => {
|
||||
expect(
|
||||
isRetryableFetchError(
|
||||
new Error(
|
||||
'The socket connection was closed unexpectedly. For more information, pass `verbose: true` in the second argument to fetch()',
|
||||
),
|
||||
),
|
||||
).toBe(true)
|
||||
})
|
||||
|
||||
test('fetchWithProxyRetry retries once with keepalive disabled after socket closure', async () => {
|
||||
const calls: Array<RequestInit | undefined> = []
|
||||
|
||||
globalThis.fetch = (async (_input, init) => {
|
||||
calls.push(init)
|
||||
if (calls.length === 1) {
|
||||
throw new Error(
|
||||
'The socket connection was closed unexpectedly. For more information, pass `verbose: true` in the second argument to fetch()',
|
||||
)
|
||||
}
|
||||
return new Response('ok')
|
||||
}) as FetchType
|
||||
|
||||
const response = await fetchWithProxyRetry('https://example.com/search', {
|
||||
method: 'POST',
|
||||
})
|
||||
|
||||
expect(await response.text()).toBe('ok')
|
||||
expect(calls).toHaveLength(2)
|
||||
expect((calls[0] as RequestInit & { proxy?: string }).proxy).toBe(
|
||||
'http://127.0.0.1:15236',
|
||||
)
|
||||
expect((calls[0] as RequestInit).keepalive).toBeUndefined()
|
||||
expect((calls[1] as RequestInit).keepalive).toBe(false)
|
||||
})
|
||||
|
||||
test('fetchWithProxyRetry does not retry non-network errors', async () => {
|
||||
let attempts = 0
|
||||
|
||||
globalThis.fetch = (async () => {
|
||||
attempts += 1
|
||||
throw new Error('400 bad request')
|
||||
}) as FetchType
|
||||
|
||||
await expect(fetchWithProxyRetry('https://example.com')).rejects.toThrow(
|
||||
'400 bad request',
|
||||
)
|
||||
expect(attempts).toBe(1)
|
||||
})
|
||||
44
src/services/api/fetchWithProxyRetry.ts
Normal file
44
src/services/api/fetchWithProxyRetry.ts
Normal file
@@ -0,0 +1,44 @@
|
||||
import { disableKeepAlive, getProxyFetchOptions } from '../../utils/proxy.js'
|
||||
|
||||
const RETRYABLE_FETCH_ERROR_PATTERN =
|
||||
/socket connection was closed unexpectedly|ECONNRESET|EPIPE|socket hang up|Connection reset by peer|fetch failed/i
|
||||
|
||||
export function isRetryableFetchError(error: unknown): boolean {
|
||||
if (!(error instanceof Error)) {
|
||||
return false
|
||||
}
|
||||
if (error.name === 'AbortError') {
|
||||
return false
|
||||
}
|
||||
return RETRYABLE_FETCH_ERROR_PATTERN.test(error.message)
|
||||
}
|
||||
|
||||
export async function fetchWithProxyRetry(
|
||||
input: string | URL | Request,
|
||||
init?: RequestInit,
|
||||
options?: { forAnthropicAPI?: boolean; maxAttempts?: number },
|
||||
): Promise<Response> {
|
||||
const maxAttempts = Math.max(1, options?.maxAttempts ?? 2)
|
||||
let lastError: unknown
|
||||
|
||||
for (let attempt = 1; attempt <= maxAttempts; attempt++) {
|
||||
try {
|
||||
return await fetch(input, {
|
||||
...init,
|
||||
...getProxyFetchOptions({
|
||||
forAnthropicAPI: options?.forAnthropicAPI,
|
||||
}),
|
||||
})
|
||||
} catch (error) {
|
||||
lastError = error
|
||||
if (attempt >= maxAttempts || !isRetryableFetchError(error)) {
|
||||
throw error
|
||||
}
|
||||
disableKeepAlive()
|
||||
}
|
||||
}
|
||||
|
||||
throw lastError instanceof Error
|
||||
? lastError
|
||||
: new Error('Fetch failed without an error object')
|
||||
}
|
||||
97
src/services/api/openaiErrorClassification.test.ts
Normal file
97
src/services/api/openaiErrorClassification.test.ts
Normal file
@@ -0,0 +1,97 @@
|
||||
import { expect, test } from 'bun:test'
|
||||
|
||||
import {
|
||||
buildOpenAICompatibilityErrorMessage,
|
||||
classifyOpenAIHttpFailure,
|
||||
classifyOpenAINetworkFailure,
|
||||
extractOpenAICategoryMarker,
|
||||
formatOpenAICategoryMarker,
|
||||
} from './openaiErrorClassification.js'
|
||||
|
||||
test('classifies localhost ECONNREFUSED as connection_refused', () => {
|
||||
const error = Object.assign(new TypeError('fetch failed'), {
|
||||
code: 'ECONNREFUSED',
|
||||
})
|
||||
|
||||
const failure = classifyOpenAINetworkFailure(error, {
|
||||
url: 'http://localhost:11434/v1/chat/completions',
|
||||
})
|
||||
|
||||
expect(failure.category).toBe('connection_refused')
|
||||
expect(failure.retryable).toBe(true)
|
||||
expect(failure.code).toBe('ECONNREFUSED')
|
||||
expect(failure.hint).toContain('local server is running')
|
||||
})
|
||||
|
||||
test('classifies localhost ENOTFOUND as localhost_resolution_failed', () => {
|
||||
const error = Object.assign(new TypeError('getaddrinfo ENOTFOUND localhost'), {
|
||||
code: 'ENOTFOUND',
|
||||
})
|
||||
|
||||
const failure = classifyOpenAINetworkFailure(error, {
|
||||
url: 'http://localhost:11434/v1/chat/completions',
|
||||
})
|
||||
|
||||
expect(failure.category).toBe('localhost_resolution_failed')
|
||||
expect(failure.retryable).toBe(true)
|
||||
expect(failure.code).toBe('ENOTFOUND')
|
||||
expect(failure.hint).toContain('127.0.0.1')
|
||||
})
|
||||
|
||||
test('classifies model-not-found 404 responses', () => {
|
||||
const failure = classifyOpenAIHttpFailure({
|
||||
status: 404,
|
||||
body: 'The model qwen2.5-coder:7b was not found',
|
||||
})
|
||||
|
||||
expect(failure.category).toBe('model_not_found')
|
||||
expect(failure.retryable).toBe(false)
|
||||
})
|
||||
|
||||
test('classifies generic 404 responses as endpoint_not_found', () => {
|
||||
const failure = classifyOpenAIHttpFailure({
|
||||
status: 404,
|
||||
body: 'Not Found',
|
||||
})
|
||||
|
||||
expect(failure.category).toBe('endpoint_not_found')
|
||||
expect(failure.hint).toContain('/v1')
|
||||
})
|
||||
|
||||
test('classifies context-overflow responses', () => {
|
||||
const failure = classifyOpenAIHttpFailure({
|
||||
status: 500,
|
||||
body: 'request too large: maximum context length exceeded',
|
||||
})
|
||||
|
||||
expect(failure.category).toBe('context_overflow')
|
||||
expect(failure.retryable).toBe(false)
|
||||
})
|
||||
|
||||
test('classifies tool compatibility failures', () => {
|
||||
const failure = classifyOpenAIHttpFailure({
|
||||
status: 400,
|
||||
body: 'tool_calls are not supported by this model',
|
||||
})
|
||||
|
||||
expect(failure.category).toBe('tool_call_incompatible')
|
||||
})
|
||||
|
||||
test('embeds and extracts category markers in formatted messages', () => {
|
||||
const marker = formatOpenAICategoryMarker('endpoint_not_found')
|
||||
expect(marker).toBe('[openai_category=endpoint_not_found]')
|
||||
|
||||
const formatted = buildOpenAICompatibilityErrorMessage('OpenAI API error 404: Not Found', {
|
||||
category: 'endpoint_not_found',
|
||||
hint: 'Confirm OPENAI_BASE_URL includes /v1.',
|
||||
})
|
||||
|
||||
expect(formatted).toContain('[openai_category=endpoint_not_found]')
|
||||
expect(formatted).toContain('Hint: Confirm OPENAI_BASE_URL includes /v1.')
|
||||
expect(extractOpenAICategoryMarker(formatted)).toBe('endpoint_not_found')
|
||||
})
|
||||
|
||||
test('ignores unknown category markers during extraction', () => {
|
||||
const malformed = 'OpenAI API error 500 [openai_category=totally_fake_category]'
|
||||
expect(extractOpenAICategoryMarker(malformed)).toBeUndefined()
|
||||
})
|
||||
355
src/services/api/openaiErrorClassification.ts
Normal file
355
src/services/api/openaiErrorClassification.ts
Normal file
@@ -0,0 +1,355 @@
|
||||
export type OpenAICompatibilityFailureCategory =
|
||||
| 'connection_refused'
|
||||
| 'localhost_resolution_failed'
|
||||
| 'request_timeout'
|
||||
| 'network_error'
|
||||
| 'auth_invalid'
|
||||
| 'rate_limited'
|
||||
| 'model_not_found'
|
||||
| 'endpoint_not_found'
|
||||
| 'context_overflow'
|
||||
| 'tool_call_incompatible'
|
||||
| 'malformed_provider_response'
|
||||
| 'provider_unavailable'
|
||||
| 'unknown'
|
||||
|
||||
export type OpenAICompatibilityFailure = {
|
||||
source: 'network' | 'http'
|
||||
category: OpenAICompatibilityFailureCategory
|
||||
retryable: boolean
|
||||
message: string
|
||||
hint?: string
|
||||
code?: string
|
||||
status?: number
|
||||
}
|
||||
|
||||
const OPENAI_CATEGORY_MARKER_PREFIX = '[openai_category='
|
||||
|
||||
const LOCALHOST_HOSTNAMES = new Set(['localhost', '127.0.0.1', '::1'])
|
||||
|
||||
const OPENAI_COMPATIBILITY_FAILURE_CATEGORIES: ReadonlySet<OpenAICompatibilityFailureCategory> =
|
||||
new Set<OpenAICompatibilityFailureCategory>([
|
||||
'connection_refused',
|
||||
'localhost_resolution_failed',
|
||||
'request_timeout',
|
||||
'network_error',
|
||||
'auth_invalid',
|
||||
'rate_limited',
|
||||
'model_not_found',
|
||||
'endpoint_not_found',
|
||||
'context_overflow',
|
||||
'tool_call_incompatible',
|
||||
'malformed_provider_response',
|
||||
'provider_unavailable',
|
||||
'unknown',
|
||||
])
|
||||
|
||||
function isOpenAICompatibilityFailureCategory(
|
||||
value: string,
|
||||
): value is OpenAICompatibilityFailureCategory {
|
||||
return OPENAI_COMPATIBILITY_FAILURE_CATEGORIES.has(
|
||||
value as OpenAICompatibilityFailureCategory,
|
||||
)
|
||||
}
|
||||
|
||||
function getErrorCode(error: unknown): string | undefined {
|
||||
let current: unknown = error
|
||||
const maxDepth = 5
|
||||
|
||||
for (let depth = 0; depth < maxDepth; depth++) {
|
||||
if (
|
||||
current &&
|
||||
typeof current === 'object' &&
|
||||
'code' in current &&
|
||||
typeof (current as { code?: unknown }).code === 'string'
|
||||
) {
|
||||
return (current as { code: string }).code
|
||||
}
|
||||
|
||||
if (
|
||||
current &&
|
||||
typeof current === 'object' &&
|
||||
'cause' in current &&
|
||||
(current as { cause?: unknown }).cause !== current
|
||||
) {
|
||||
current = (current as { cause?: unknown }).cause
|
||||
continue
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
return undefined
|
||||
}
|
||||
|
||||
function getHostname(url: string): string | null {
|
||||
try {
|
||||
return new URL(url).hostname.toLowerCase()
|
||||
} catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
function isLocalhostLikeHostname(hostname: string | null): boolean {
|
||||
if (!hostname) return false
|
||||
if (LOCALHOST_HOSTNAMES.has(hostname)) return true
|
||||
return /^127\./.test(hostname)
|
||||
}
|
||||
|
||||
function isContextOverflowMessage(body: string): boolean {
|
||||
const lower = body.toLowerCase()
|
||||
return (
|
||||
lower.includes('too many tokens') ||
|
||||
lower.includes('request too large') ||
|
||||
lower.includes('context length') ||
|
||||
lower.includes('maximum context') ||
|
||||
lower.includes('input length') ||
|
||||
lower.includes('payload too large') ||
|
||||
lower.includes('prompt is too long')
|
||||
)
|
||||
}
|
||||
|
||||
function isToolCompatibilityMessage(body: string): boolean {
|
||||
const lower = body.toLowerCase()
|
||||
return (
|
||||
lower.includes('tool_calls') ||
|
||||
lower.includes('tool_call') ||
|
||||
lower.includes('tool_use') ||
|
||||
lower.includes('tool_result') ||
|
||||
lower.includes('function calling') ||
|
||||
lower.includes('function call')
|
||||
)
|
||||
}
|
||||
|
||||
function isMalformedProviderResponse(body: string): boolean {
|
||||
const lower = body.toLowerCase()
|
||||
return (
|
||||
lower.includes('<!doctype html') ||
|
||||
lower.includes('<html') ||
|
||||
lower.includes('invalid json') ||
|
||||
lower.includes('malformed') ||
|
||||
lower.includes('unexpected token') ||
|
||||
lower.includes('cannot parse') ||
|
||||
lower.includes('not valid json')
|
||||
)
|
||||
}
|
||||
|
||||
function isModelNotFoundMessage(body: string): boolean {
|
||||
const lower = body.toLowerCase()
|
||||
return (
|
||||
lower.includes('model') &&
|
||||
(
|
||||
lower.includes('not found') ||
|
||||
lower.includes('does not exist') ||
|
||||
lower.includes('unknown model') ||
|
||||
lower.includes('unavailable model')
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
export function formatOpenAICategoryMarker(
|
||||
category: OpenAICompatibilityFailureCategory,
|
||||
): string {
|
||||
return `${OPENAI_CATEGORY_MARKER_PREFIX}${category}]`
|
||||
}
|
||||
|
||||
export function extractOpenAICategoryMarker(
|
||||
message: string,
|
||||
): OpenAICompatibilityFailureCategory | undefined {
|
||||
const match = message.match(/\[openai_category=([a-z_]+)]/)
|
||||
const category = match?.[1]
|
||||
|
||||
if (!category || !isOpenAICompatibilityFailureCategory(category)) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
return category
|
||||
}
|
||||
|
||||
export function buildOpenAICompatibilityErrorMessage(
|
||||
baseMessage: string,
|
||||
failure: Pick<OpenAICompatibilityFailure, 'category' | 'hint'>,
|
||||
): string {
|
||||
const marker = formatOpenAICategoryMarker(failure.category)
|
||||
const hint = failure.hint ? ` Hint: ${failure.hint}` : ''
|
||||
return `${baseMessage} ${marker}${hint}`
|
||||
}
|
||||
|
||||
export function classifyOpenAINetworkFailure(
|
||||
error: unknown,
|
||||
options: { url: string },
|
||||
): OpenAICompatibilityFailure {
|
||||
const message = error instanceof Error ? error.message : String(error)
|
||||
const lowerMessage = message.toLowerCase()
|
||||
const code = getErrorCode(error)
|
||||
const hostname = getHostname(options.url)
|
||||
const isLocalHost = isLocalhostLikeHostname(hostname)
|
||||
|
||||
if (
|
||||
code === 'ETIMEDOUT' ||
|
||||
code === 'UND_ERR_CONNECT_TIMEOUT' ||
|
||||
lowerMessage.includes('timeout') ||
|
||||
lowerMessage.includes('timed out') ||
|
||||
lowerMessage.includes('aborterror')
|
||||
) {
|
||||
return {
|
||||
source: 'network',
|
||||
category: 'request_timeout',
|
||||
retryable: true,
|
||||
message,
|
||||
code,
|
||||
hint: 'The provider took too long to respond. Check local model load time or increase API timeout.',
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
isLocalHost &&
|
||||
(
|
||||
code === 'ENOTFOUND' ||
|
||||
code === 'EAI_AGAIN' ||
|
||||
lowerMessage.includes('getaddrinfo') ||
|
||||
(code === undefined && lowerMessage.includes('fetch failed'))
|
||||
)
|
||||
) {
|
||||
return {
|
||||
source: 'network',
|
||||
category: 'localhost_resolution_failed',
|
||||
retryable: true,
|
||||
message,
|
||||
code,
|
||||
hint: 'Localhost failed for this request. Retry with 127.0.0.1 and confirm Ollama is serving on the configured port.',
|
||||
}
|
||||
}
|
||||
|
||||
if (code === 'ECONNREFUSED') {
|
||||
return {
|
||||
source: 'network',
|
||||
category: 'connection_refused',
|
||||
retryable: true,
|
||||
message,
|
||||
code,
|
||||
hint: isLocalHost
|
||||
? 'Connection to the local provider was refused. Ensure the local server is running and listening on the configured port.'
|
||||
: 'Connection was refused by the provider endpoint. Ensure the server is running and the port is correct.',
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
source: 'network',
|
||||
category: 'network_error',
|
||||
retryable: true,
|
||||
message,
|
||||
code,
|
||||
hint: 'Network transport failed before a provider response was received.',
|
||||
}
|
||||
}
|
||||
|
||||
export function classifyOpenAIHttpFailure(options: {
|
||||
status: number
|
||||
body: string
|
||||
}): OpenAICompatibilityFailure {
|
||||
const body = options.body ?? ''
|
||||
|
||||
if (options.status === 401 || options.status === 403) {
|
||||
return {
|
||||
source: 'http',
|
||||
category: 'auth_invalid',
|
||||
retryable: false,
|
||||
status: options.status,
|
||||
message: body,
|
||||
hint: 'Authentication failed. Verify API key, token source, and endpoint-specific auth headers.',
|
||||
}
|
||||
}
|
||||
|
||||
if (options.status === 429) {
|
||||
return {
|
||||
source: 'http',
|
||||
category: 'rate_limited',
|
||||
retryable: true,
|
||||
status: options.status,
|
||||
message: body,
|
||||
hint: 'Provider rate-limited the request. Retry after backoff.',
|
||||
}
|
||||
}
|
||||
|
||||
if (options.status === 404 && isModelNotFoundMessage(body)) {
|
||||
return {
|
||||
source: 'http',
|
||||
category: 'model_not_found',
|
||||
retryable: false,
|
||||
status: options.status,
|
||||
message: body,
|
||||
hint: 'The selected model is not installed or not available on this endpoint.',
|
||||
}
|
||||
}
|
||||
|
||||
if (options.status === 404) {
|
||||
return {
|
||||
source: 'http',
|
||||
category: 'endpoint_not_found',
|
||||
retryable: false,
|
||||
status: options.status,
|
||||
message: body,
|
||||
hint: 'Endpoint was not found. Confirm OPENAI_BASE_URL includes /v1 for OpenAI-compatible local providers.',
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
options.status === 413 ||
|
||||
((options.status === 400 || options.status >= 500) &&
|
||||
isContextOverflowMessage(body))
|
||||
) {
|
||||
return {
|
||||
source: 'http',
|
||||
category: 'context_overflow',
|
||||
retryable: false,
|
||||
status: options.status,
|
||||
message: body,
|
||||
hint: 'Prompt context exceeded model/server limits. Reduce context or increase provider context length.',
|
||||
}
|
||||
}
|
||||
|
||||
if (options.status === 400 && isToolCompatibilityMessage(body)) {
|
||||
return {
|
||||
source: 'http',
|
||||
category: 'tool_call_incompatible',
|
||||
retryable: false,
|
||||
status: options.status,
|
||||
message: body,
|
||||
hint: 'Provider/model rejected tool-calling payload. Retry without tools or use a tool-capable model.',
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
(options.status >= 200 && options.status < 300 && isMalformedProviderResponse(body)) ||
|
||||
(options.status >= 400 && isMalformedProviderResponse(body))
|
||||
) {
|
||||
return {
|
||||
source: 'http',
|
||||
category: 'malformed_provider_response',
|
||||
retryable: false,
|
||||
status: options.status,
|
||||
message: body,
|
||||
hint: 'Provider returned malformed or non-JSON response where JSON was expected.',
|
||||
}
|
||||
}
|
||||
|
||||
if (options.status >= 500) {
|
||||
return {
|
||||
source: 'http',
|
||||
category: 'provider_unavailable',
|
||||
retryable: true,
|
||||
status: options.status,
|
||||
message: body,
|
||||
hint: 'Provider reported a server-side failure. Retry after a short delay.',
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
source: 'http',
|
||||
category: 'unknown',
|
||||
retryable: false,
|
||||
status: options.status,
|
||||
message: body,
|
||||
}
|
||||
}
|
||||
119
src/services/api/openaiShim.diagnostics.test.ts
Normal file
119
src/services/api/openaiShim.diagnostics.test.ts
Normal file
@@ -0,0 +1,119 @@
|
||||
import { afterEach, expect, mock, test } from 'bun:test'
|
||||
|
||||
const originalFetch = globalThis.fetch
|
||||
const originalEnv = {
|
||||
OPENAI_BASE_URL: process.env.OPENAI_BASE_URL,
|
||||
OPENAI_API_KEY: process.env.OPENAI_API_KEY,
|
||||
OPENAI_MODEL: process.env.OPENAI_MODEL,
|
||||
}
|
||||
|
||||
function restoreEnv(key: string, value: string | undefined): void {
|
||||
if (value === undefined) {
|
||||
delete process.env[key]
|
||||
} else {
|
||||
process.env[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
afterEach(() => {
|
||||
globalThis.fetch = originalFetch
|
||||
restoreEnv('OPENAI_BASE_URL', originalEnv.OPENAI_BASE_URL)
|
||||
restoreEnv('OPENAI_API_KEY', originalEnv.OPENAI_API_KEY)
|
||||
restoreEnv('OPENAI_MODEL', originalEnv.OPENAI_MODEL)
|
||||
mock.restore()
|
||||
})
|
||||
|
||||
test('logs classified transport diagnostics with category and code', async () => {
|
||||
const debugSpy = mock(() => {})
|
||||
mock.module('../../utils/debug.js', () => ({
|
||||
logForDebugging: debugSpy,
|
||||
}))
|
||||
|
||||
const nonce = `${Date.now()}-${Math.random()}`
|
||||
const { createOpenAIShimClient } = await import(`./openaiShim.ts?ts=${nonce}`)
|
||||
|
||||
process.env.OPENAI_BASE_URL = 'http://localhost:11434/v1'
|
||||
process.env.OPENAI_API_KEY = 'ollama'
|
||||
|
||||
const transportError = Object.assign(new TypeError('fetch failed'), {
|
||||
code: 'ECONNREFUSED',
|
||||
})
|
||||
|
||||
globalThis.fetch = mock(async () => {
|
||||
throw transportError
|
||||
}) as typeof globalThis.fetch
|
||||
|
||||
const client = createOpenAIShimClient({}) as {
|
||||
beta: {
|
||||
messages: {
|
||||
create: (params: Record<string, unknown>) => Promise<unknown>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
await expect(
|
||||
client.beta.messages.create({
|
||||
model: 'qwen2.5-coder:7b',
|
||||
messages: [{ role: 'user', content: 'hello' }],
|
||||
max_tokens: 64,
|
||||
stream: false,
|
||||
}),
|
||||
).rejects.toThrow('openai_category=connection_refused')
|
||||
|
||||
const transportLog = debugSpy.mock.calls.find(call =>
|
||||
typeof call?.[0] === 'string' && call[0].includes('transport failure'),
|
||||
)
|
||||
|
||||
expect(transportLog).toBeDefined()
|
||||
expect(String(transportLog?.[0])).toContain('category=connection_refused')
|
||||
expect(String(transportLog?.[0])).toContain('code=ECONNREFUSED')
|
||||
expect(transportLog?.[1]).toEqual({ level: 'warn' })
|
||||
})
|
||||
|
||||
test('redacts credentials in transport diagnostic URL logs', async () => {
|
||||
const debugSpy = mock(() => {})
|
||||
mock.module('../../utils/debug.js', () => ({
|
||||
logForDebugging: debugSpy,
|
||||
}))
|
||||
|
||||
const nonce = `${Date.now()}-${Math.random()}`
|
||||
const { createOpenAIShimClient } = await import(`./openaiShim.ts?ts=${nonce}`)
|
||||
|
||||
process.env.OPENAI_BASE_URL = 'http://user:supersecret@localhost:11434/v1'
|
||||
process.env.OPENAI_API_KEY = 'supersecret'
|
||||
|
||||
const transportError = Object.assign(new TypeError('fetch failed'), {
|
||||
code: 'ECONNREFUSED',
|
||||
})
|
||||
|
||||
globalThis.fetch = mock(async () => {
|
||||
throw transportError
|
||||
}) as typeof globalThis.fetch
|
||||
|
||||
const client = createOpenAIShimClient({}) as {
|
||||
beta: {
|
||||
messages: {
|
||||
create: (params: Record<string, unknown>) => Promise<unknown>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
await expect(
|
||||
client.beta.messages.create({
|
||||
model: 'qwen2.5-coder:7b',
|
||||
messages: [{ role: 'user', content: 'hello' }],
|
||||
max_tokens: 64,
|
||||
stream: false,
|
||||
}),
|
||||
).rejects.toThrow('openai_category=connection_refused')
|
||||
|
||||
const transportLog = debugSpy.mock.calls.find(call =>
|
||||
typeof call?.[0] === 'string' && call[0].includes('transport failure'),
|
||||
)
|
||||
|
||||
expect(transportLog).toBeDefined()
|
||||
const logLine = String(transportLog?.[0])
|
||||
expect(logLine).toContain('url=http://redacted:redacted@localhost:11434/v1/chat/completions')
|
||||
expect(logLine).not.toContain('user:supersecret')
|
||||
expect(logLine).not.toContain('supersecret@')
|
||||
})
|
||||
@@ -2775,3 +2775,172 @@ test('streaming: strips leaked reasoning preamble when split across multiple con
|
||||
|
||||
expect(textDeltas).toEqual(['Hey! How can I help you today?'])
|
||||
})
|
||||
|
||||
test('classifies localhost transport failures with actionable category marker', async () => {
|
||||
process.env.OPENAI_BASE_URL = 'http://localhost:11434/v1'
|
||||
|
||||
const transportError = Object.assign(new TypeError('fetch failed'), {
|
||||
code: 'ECONNREFUSED',
|
||||
})
|
||||
|
||||
globalThis.fetch = (async () => {
|
||||
throw transportError
|
||||
}) as FetchType
|
||||
|
||||
const client = createOpenAIShimClient({}) as OpenAIShimClient
|
||||
|
||||
await expect(
|
||||
client.beta.messages.create({
|
||||
model: 'qwen2.5-coder:7b',
|
||||
messages: [{ role: 'user', content: 'hello' }],
|
||||
max_tokens: 64,
|
||||
stream: false,
|
||||
}),
|
||||
).rejects.toThrow('openai_category=connection_refused')
|
||||
|
||||
await expect(
|
||||
client.beta.messages.create({
|
||||
model: 'qwen2.5-coder:7b',
|
||||
messages: [{ role: 'user', content: 'hello' }],
|
||||
max_tokens: 64,
|
||||
stream: false,
|
||||
}),
|
||||
).rejects.toThrow('local server is running')
|
||||
})
|
||||
|
||||
test('propagates AbortError without wrapping it as transport failure', async () => {
|
||||
process.env.OPENAI_BASE_URL = 'http://localhost:11434/v1'
|
||||
|
||||
const abortError = new DOMException('The operation was aborted.', 'AbortError')
|
||||
globalThis.fetch = (async () => {
|
||||
throw abortError
|
||||
}) as FetchType
|
||||
|
||||
const controller = new AbortController()
|
||||
controller.abort()
|
||||
|
||||
const client = createOpenAIShimClient({}) as OpenAIShimClient
|
||||
|
||||
await expect(
|
||||
client.beta.messages.create(
|
||||
{
|
||||
model: 'qwen2.5-coder:7b',
|
||||
messages: [{ role: 'user', content: 'hello' }],
|
||||
max_tokens: 64,
|
||||
stream: false,
|
||||
},
|
||||
{ signal: controller.signal },
|
||||
),
|
||||
).rejects.toBe(abortError)
|
||||
})
|
||||
|
||||
test('classifies chat-completions endpoint 404 failures with endpoint_not_found marker', async () => {
|
||||
process.env.OPENAI_BASE_URL = 'http://localhost:11434'
|
||||
|
||||
globalThis.fetch = (async () =>
|
||||
new Response('Not Found', {
|
||||
status: 404,
|
||||
headers: {
|
||||
'Content-Type': 'text/plain',
|
||||
},
|
||||
})) as FetchType
|
||||
|
||||
const client = createOpenAIShimClient({}) as OpenAIShimClient
|
||||
|
||||
await expect(
|
||||
client.beta.messages.create({
|
||||
model: 'qwen2.5-coder:7b',
|
||||
messages: [{ role: 'user', content: 'hello' }],
|
||||
max_tokens: 64,
|
||||
stream: false,
|
||||
}),
|
||||
).rejects.toThrow('openai_category=endpoint_not_found')
|
||||
})
|
||||
|
||||
test('preserves valid tool_result and drops orphan tool_result', async () => {
|
||||
let requestBody: Record<string, unknown> | undefined
|
||||
|
||||
globalThis.fetch = (async (_input, init) => {
|
||||
requestBody = JSON.parse(String(init?.body))
|
||||
|
||||
return new Response(
|
||||
JSON.stringify({
|
||||
id: 'chatcmpl-1',
|
||||
model: 'mistral-large-latest',
|
||||
choices: [
|
||||
{
|
||||
message: {
|
||||
role: 'assistant',
|
||||
content: 'done',
|
||||
},
|
||||
finish_reason: 'stop',
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 12,
|
||||
completion_tokens: 4,
|
||||
total_tokens: 16,
|
||||
},
|
||||
}),
|
||||
{
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
},
|
||||
)
|
||||
}) as FetchType
|
||||
|
||||
const client = createOpenAIShimClient({}) as OpenAIShimClient
|
||||
|
||||
await client.beta.messages.create({
|
||||
model: 'mistral-large-latest',
|
||||
system: 'test system',
|
||||
messages: [
|
||||
{ role: 'user', content: 'Search and then I will interrupt' },
|
||||
{
|
||||
role: 'assistant',
|
||||
content: [
|
||||
{
|
||||
type: 'tool_use',
|
||||
id: 'valid_call_1',
|
||||
name: 'Search',
|
||||
input: { query: 'openclaude' },
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: [
|
||||
{
|
||||
type: 'tool_result',
|
||||
tool_use_id: 'valid_call_1',
|
||||
content: 'Found it!',
|
||||
},
|
||||
{
|
||||
type: 'tool_result',
|
||||
tool_use_id: 'orphan_call_2',
|
||||
content: 'Interrupted result',
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: 'What happened?',
|
||||
}
|
||||
],
|
||||
},
|
||||
],
|
||||
max_tokens: 64,
|
||||
stream: false,
|
||||
})
|
||||
|
||||
const messages = requestBody?.messages as Array<Record<string, unknown>>
|
||||
|
||||
// Should have: system, user, assistant (tool_use), tool (valid_call_1), user
|
||||
// Should NOT have: tool (orphan_call_2)
|
||||
|
||||
const toolMessages = messages.filter(m => m.role === 'tool')
|
||||
expect(toolMessages.length).toBe(1)
|
||||
expect(toolMessages[0].tool_call_id).toBe('valid_call_1')
|
||||
|
||||
const orphanMessage = toolMessages.find(m => m.tool_call_id === 'orphan_call_2')
|
||||
expect(orphanMessage).toBeUndefined()
|
||||
})
|
||||
|
||||
@@ -47,12 +47,18 @@ import {
|
||||
type AnthropicUsage,
|
||||
type ShimCreateParams,
|
||||
} from './codexShim.js'
|
||||
import { fetchWithProxyRetry } from './fetchWithProxyRetry.js'
|
||||
import {
|
||||
isLocalProviderUrl,
|
||||
resolveRuntimeCodexCredentials,
|
||||
resolveProviderRequest,
|
||||
getGithubEndpointType,
|
||||
} from './providerConfig.js'
|
||||
import {
|
||||
buildOpenAICompatibilityErrorMessage,
|
||||
classifyOpenAIHttpFailure,
|
||||
classifyOpenAINetworkFailure,
|
||||
} from './openaiErrorClassification.js'
|
||||
import { sanitizeSchemaForOpenAICompat } from '../../utils/schemaSanitizer.js'
|
||||
import { redactSecretValueForDisplay } from '../../utils/providerProfile.js'
|
||||
import {
|
||||
@@ -82,6 +88,19 @@ const COPILOT_HEADERS: Record<string, string> = {
|
||||
'Copilot-Integration-Id': 'vscode-chat',
|
||||
}
|
||||
|
||||
const SENSITIVE_URL_QUERY_PARAM_NAMES = [
|
||||
'api_key',
|
||||
'key',
|
||||
'token',
|
||||
'access_token',
|
||||
'refresh_token',
|
||||
'signature',
|
||||
'sig',
|
||||
'secret',
|
||||
'password',
|
||||
'authorization',
|
||||
]
|
||||
|
||||
function isGithubModelsMode(): boolean {
|
||||
return isEnvTruthy(process.env.CLAUDE_CODE_USE_GITHUB)
|
||||
}
|
||||
@@ -131,6 +150,34 @@ function formatRetryAfterHint(response: Response): string {
|
||||
return ra ? ` (Retry-After: ${ra})` : ''
|
||||
}
|
||||
|
||||
function shouldRedactUrlQueryParam(name: string): boolean {
|
||||
const lower = name.toLowerCase()
|
||||
return SENSITIVE_URL_QUERY_PARAM_NAMES.some(token => lower.includes(token))
|
||||
}
|
||||
|
||||
function redactUrlForDiagnostics(url: string): string {
|
||||
try {
|
||||
const parsed = new URL(url)
|
||||
if (parsed.username) {
|
||||
parsed.username = 'redacted'
|
||||
}
|
||||
if (parsed.password) {
|
||||
parsed.password = 'redacted'
|
||||
}
|
||||
|
||||
for (const key of parsed.searchParams.keys()) {
|
||||
if (shouldRedactUrlQueryParam(key)) {
|
||||
parsed.searchParams.set(key, 'redacted')
|
||||
}
|
||||
}
|
||||
|
||||
const serialized = parsed.toString()
|
||||
return redactSecretValueForDisplay(serialized, process.env as SecretValueSource) ?? serialized
|
||||
} catch {
|
||||
return redactSecretValueForDisplay(url, process.env as SecretValueSource) ?? url
|
||||
}
|
||||
}
|
||||
|
||||
function sleepMs(ms: number): Promise<void> {
|
||||
return new Promise(resolve => setTimeout(resolve, ms))
|
||||
}
|
||||
@@ -302,6 +349,7 @@ function convertMessages(
|
||||
system: unknown,
|
||||
): OpenAIMessage[] {
|
||||
const result: OpenAIMessage[] = []
|
||||
const knownToolCallIds = new Set<string>()
|
||||
|
||||
// System message first
|
||||
const sysText = convertSystemPrompt(system)
|
||||
@@ -321,13 +369,21 @@ function convertMessages(
|
||||
const toolResults = content.filter((b: { type?: string }) => b.type === 'tool_result')
|
||||
const otherContent = content.filter((b: { type?: string }) => b.type !== 'tool_result')
|
||||
|
||||
// Emit tool results as tool messages
|
||||
// Emit tool results as tool messages, but ONLY if we have a matching tool_use ID.
|
||||
// Mistral/OpenAI strictly require tool messages to follow an assistant message with tool_calls.
|
||||
// If the user interrupted (ESC) and a synthetic tool_result was generated without a recorded tool_use,
|
||||
// emitting it here would cause a "role must alternate" or "unexpected role" error.
|
||||
for (const tr of toolResults) {
|
||||
result.push({
|
||||
role: 'tool',
|
||||
tool_call_id: tr.tool_use_id ?? 'unknown',
|
||||
content: convertToolResultContent(tr.content, tr.is_error),
|
||||
})
|
||||
const id = tr.tool_use_id ?? 'unknown'
|
||||
if (knownToolCallIds.has(id)) {
|
||||
result.push({
|
||||
role: 'tool',
|
||||
tool_call_id: id,
|
||||
content: convertToolResultContent(tr.content, tr.is_error),
|
||||
})
|
||||
} else {
|
||||
logForDebugging(`Dropping orphan tool_result for ID: ${id} to prevent API error`)
|
||||
}
|
||||
}
|
||||
|
||||
// Emit remaining user content
|
||||
@@ -368,9 +424,11 @@ function convertMessages(
|
||||
input?: unknown
|
||||
extra_content?: Record<string, unknown>
|
||||
signature?: string
|
||||
}, index) => {
|
||||
}) => {
|
||||
const id = tu.id ?? `call_${crypto.randomUUID().replace(/-/g, '')}`
|
||||
knownToolCallIds.add(id)
|
||||
const toolCall: NonNullable<OpenAIMessage['tool_calls']>[number] = {
|
||||
id: tu.id ?? `call_${crypto.randomUUID().replace(/-/g, '')}`,
|
||||
id,
|
||||
type: 'function' as const,
|
||||
function: {
|
||||
name: tu.name ?? 'unknown',
|
||||
@@ -395,7 +453,6 @@ function convertMessages(
|
||||
|
||||
// Merge into existing google-specific metadata if present
|
||||
const existingGoogle = (toolCall.extra_content?.google as Record<string, unknown>) ?? {}
|
||||
|
||||
toolCall.extra_content = {
|
||||
...toolCall.extra_content,
|
||||
google: {
|
||||
@@ -550,7 +607,10 @@ function convertTools(
|
||||
function: {
|
||||
name: t.name,
|
||||
description: t.description ?? '',
|
||||
parameters: normalizeSchemaForOpenAI(schema, !isGemini),
|
||||
parameters: normalizeSchemaForOpenAI(
|
||||
schema,
|
||||
!isGemini && !isEnvTruthy(process.env.OPENCLAUDE_DISABLE_STRICT_TOOLS),
|
||||
),
|
||||
},
|
||||
}
|
||||
})
|
||||
@@ -1360,8 +1420,12 @@ class OpenAIShimMessages {
|
||||
...filterAnthropicHeaders(options?.headers),
|
||||
}
|
||||
|
||||
const isGemini = isEnvTruthy(process.env.CLAUDE_CODE_USE_GEMINI)
|
||||
const apiKey = this.providerOverride?.apiKey ?? process.env.OPENAI_API_KEY ?? ''
|
||||
const isGemini = isGeminiMode()
|
||||
const isMiniMax = !!process.env.MINIMAX_API_KEY
|
||||
const apiKey =
|
||||
this.providerOverride?.apiKey ??
|
||||
process.env.OPENAI_API_KEY ??
|
||||
(isMiniMax ? process.env.MINIMAX_API_KEY : '')
|
||||
// Detect Azure endpoints by hostname (not raw URL) to prevent bypass via
|
||||
// path segments like https://evil.com/cognitiveservices.azure.com/
|
||||
let isAzure = false
|
||||
@@ -1425,12 +1489,97 @@ class OpenAIShimMessages {
|
||||
}
|
||||
|
||||
const maxAttempts = isGithub ? GITHUB_429_MAX_RETRIES : 1
|
||||
|
||||
const throwClassifiedTransportError = (
|
||||
error: unknown,
|
||||
requestUrl: string,
|
||||
): never => {
|
||||
if (options?.signal?.aborted) {
|
||||
throw error
|
||||
}
|
||||
|
||||
const failure = classifyOpenAINetworkFailure(error, {
|
||||
url: requestUrl,
|
||||
})
|
||||
const redactedUrl = redactUrlForDiagnostics(requestUrl)
|
||||
const safeMessage =
|
||||
redactSecretValueForDisplay(
|
||||
failure.message,
|
||||
process.env as SecretValueSource,
|
||||
) || 'Request failed'
|
||||
|
||||
logForDebugging(
|
||||
`[OpenAIShim] transport failure category=${failure.category} retryable=${failure.retryable} code=${failure.code ?? 'unknown'} method=POST url=${redactedUrl} model=${request.resolvedModel} message=${safeMessage}`,
|
||||
{ level: 'warn' },
|
||||
)
|
||||
|
||||
throw APIError.generate(
|
||||
503,
|
||||
undefined,
|
||||
buildOpenAICompatibilityErrorMessage(
|
||||
`OpenAI API transport error: ${safeMessage}${failure.code ? ` (code=${failure.code})` : ''}`,
|
||||
failure,
|
||||
),
|
||||
new Headers(),
|
||||
)
|
||||
}
|
||||
|
||||
const throwClassifiedHttpError = (
|
||||
status: number,
|
||||
errorBody: string,
|
||||
parsedBody: object | undefined,
|
||||
responseHeaders: Headers,
|
||||
requestUrl: string,
|
||||
rateHint = '',
|
||||
): never => {
|
||||
const failure = classifyOpenAIHttpFailure({
|
||||
status,
|
||||
body: errorBody,
|
||||
})
|
||||
const redactedUrl = redactUrlForDiagnostics(requestUrl)
|
||||
|
||||
logForDebugging(
|
||||
`[OpenAIShim] request failed category=${failure.category} retryable=${failure.retryable} status=${status} method=POST url=${redactedUrl} model=${request.resolvedModel}`,
|
||||
{ level: 'warn' },
|
||||
)
|
||||
|
||||
throw APIError.generate(
|
||||
status,
|
||||
parsedBody,
|
||||
buildOpenAICompatibilityErrorMessage(
|
||||
`OpenAI API error ${status}: ${errorBody}${rateHint}`,
|
||||
failure,
|
||||
),
|
||||
responseHeaders,
|
||||
)
|
||||
}
|
||||
|
||||
let response: Response | undefined
|
||||
for (let attempt = 0; attempt < maxAttempts; attempt++) {
|
||||
response = await fetch(chatCompletionsUrl, fetchInit)
|
||||
try {
|
||||
response = await fetchWithProxyRetry(chatCompletionsUrl, fetchInit)
|
||||
} catch (error) {
|
||||
const isAbortError =
|
||||
fetchInit.signal?.aborted === true ||
|
||||
(typeof DOMException !== 'undefined' &&
|
||||
error instanceof DOMException &&
|
||||
error.name === 'AbortError') ||
|
||||
(typeof error === 'object' &&
|
||||
error !== null &&
|
||||
'name' in error &&
|
||||
error.name === 'AbortError')
|
||||
|
||||
if (isAbortError) {
|
||||
throw error
|
||||
}
|
||||
|
||||
throwClassifiedTransportError(error, chatCompletionsUrl)
|
||||
}
|
||||
|
||||
if (response.ok) {
|
||||
return response
|
||||
}
|
||||
|
||||
if (
|
||||
isGithub &&
|
||||
response.status === 429 &&
|
||||
@@ -1500,34 +1649,43 @@ class OpenAIShimMessages {
|
||||
}
|
||||
}
|
||||
|
||||
const responsesResponse = await fetch(responsesUrl, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: JSON.stringify(responsesBody),
|
||||
signal: options?.signal,
|
||||
})
|
||||
let responsesResponse: Response
|
||||
try {
|
||||
responsesResponse = await fetchWithProxyRetry(responsesUrl, {
|
||||
method: 'POST',
|
||||
headers,
|
||||
body: JSON.stringify(responsesBody),
|
||||
signal: options?.signal,
|
||||
})
|
||||
} catch (error) {
|
||||
throwClassifiedTransportError(error, responsesUrl)
|
||||
}
|
||||
|
||||
if (responsesResponse.ok) {
|
||||
return responsesResponse
|
||||
}
|
||||
const responsesErrorBody = await responsesResponse.text().catch(() => 'unknown error')
|
||||
let responsesErrorResponse: object | undefined
|
||||
try { responsesErrorResponse = JSON.parse(responsesErrorBody) } catch { /* raw text */ }
|
||||
throw APIError.generate(
|
||||
throwClassifiedHttpError(
|
||||
responsesResponse.status,
|
||||
responsesErrorBody,
|
||||
responsesErrorResponse,
|
||||
`OpenAI API error ${responsesResponse.status}: ${responsesErrorBody}`,
|
||||
responsesResponse.headers,
|
||||
responsesUrl,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
let errorResponse: object | undefined
|
||||
try { errorResponse = JSON.parse(errorBody) } catch { /* raw text */ }
|
||||
throw APIError.generate(
|
||||
throwClassifiedHttpError(
|
||||
response.status,
|
||||
errorBody,
|
||||
errorResponse,
|
||||
`OpenAI API error ${response.status}: ${errorBody}${rateHint}`,
|
||||
response.headers as unknown as Headers,
|
||||
chatCompletionsUrl,
|
||||
rateHint,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
107
src/services/api/providerConfig.envDiagnostics.test.ts
Normal file
107
src/services/api/providerConfig.envDiagnostics.test.ts
Normal file
@@ -0,0 +1,107 @@
|
||||
import { afterEach, expect, mock, test } from 'bun:test'
|
||||
|
||||
const originalEnv = {
|
||||
CLAUDE_CODE_USE_OPENAI: process.env.CLAUDE_CODE_USE_OPENAI,
|
||||
CLAUDE_CODE_USE_MISTRAL: process.env.CLAUDE_CODE_USE_MISTRAL,
|
||||
OPENAI_BASE_URL: process.env.OPENAI_BASE_URL,
|
||||
OPENAI_MODEL: process.env.OPENAI_MODEL,
|
||||
OPENAI_API_BASE: process.env.OPENAI_API_BASE,
|
||||
MISTRAL_BASE_URL: process.env.MISTRAL_BASE_URL,
|
||||
MISTRAL_MODEL: process.env.MISTRAL_MODEL,
|
||||
}
|
||||
|
||||
function restoreEnv(key: string, value: string | undefined): void {
|
||||
if (value === undefined) {
|
||||
delete process.env[key]
|
||||
} else {
|
||||
process.env[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
afterEach(() => {
|
||||
restoreEnv('CLAUDE_CODE_USE_OPENAI', originalEnv.CLAUDE_CODE_USE_OPENAI)
|
||||
restoreEnv('CLAUDE_CODE_USE_MISTRAL', originalEnv.CLAUDE_CODE_USE_MISTRAL)
|
||||
restoreEnv('OPENAI_BASE_URL', originalEnv.OPENAI_BASE_URL)
|
||||
restoreEnv('OPENAI_MODEL', originalEnv.OPENAI_MODEL)
|
||||
restoreEnv('OPENAI_API_BASE', originalEnv.OPENAI_API_BASE)
|
||||
restoreEnv('MISTRAL_BASE_URL', originalEnv.MISTRAL_BASE_URL)
|
||||
restoreEnv('MISTRAL_MODEL', originalEnv.MISTRAL_MODEL)
|
||||
mock.restore()
|
||||
})
|
||||
|
||||
test('logs a warning when OPENAI_BASE_URL is literal undefined', async () => {
|
||||
const debugSpy = mock(() => {})
|
||||
mock.module('../../utils/debug.js', () => ({
|
||||
logForDebugging: debugSpy,
|
||||
}))
|
||||
|
||||
process.env.CLAUDE_CODE_USE_OPENAI = '1'
|
||||
process.env.OPENAI_BASE_URL = 'undefined'
|
||||
process.env.OPENAI_MODEL = 'gpt-4o'
|
||||
delete process.env.OPENAI_API_BASE
|
||||
|
||||
const nonce = `${Date.now()}-${Math.random()}`
|
||||
const { resolveProviderRequest } = await import(`./providerConfig.ts?ts=${nonce}`)
|
||||
|
||||
const resolved = resolveProviderRequest()
|
||||
|
||||
expect(resolved.baseUrl).toBe('https://api.openai.com/v1')
|
||||
|
||||
const warningCall = debugSpy.mock.calls.find(call =>
|
||||
typeof call?.[0] === 'string' &&
|
||||
call[0].includes('OPENAI_BASE_URL') &&
|
||||
call[0].includes('"undefined"'),
|
||||
)
|
||||
|
||||
expect(warningCall).toBeDefined()
|
||||
expect(warningCall?.[1]).toEqual({ level: 'warn' })
|
||||
})
|
||||
|
||||
test('does not warn for OPENAI_API_BASE when OPENAI_BASE_URL is active', async () => {
|
||||
const debugSpy = mock(() => {})
|
||||
mock.module('../../utils/debug.js', () => ({
|
||||
logForDebugging: debugSpy,
|
||||
}))
|
||||
|
||||
process.env.CLAUDE_CODE_USE_OPENAI = '1'
|
||||
delete process.env.CLAUDE_CODE_USE_MISTRAL
|
||||
process.env.OPENAI_BASE_URL = 'http://127.0.0.1:11434/v1'
|
||||
process.env.OPENAI_MODEL = 'qwen2.5-coder:7b'
|
||||
process.env.OPENAI_API_BASE = 'undefined'
|
||||
|
||||
const nonce = `${Date.now()}-${Math.random()}`
|
||||
const { resolveProviderRequest } = await import(`./providerConfig.ts?ts=${nonce}`)
|
||||
|
||||
const resolved = resolveProviderRequest()
|
||||
|
||||
expect(resolved.baseUrl).toBe('http://127.0.0.1:11434/v1')
|
||||
|
||||
const aliasWarning = debugSpy.mock.calls.find(call =>
|
||||
typeof call?.[0] === 'string' &&
|
||||
call[0].includes('OPENAI_API_BASE') &&
|
||||
call[0].includes('"undefined"'),
|
||||
)
|
||||
|
||||
expect(aliasWarning).toBeUndefined()
|
||||
})
|
||||
|
||||
test('uses OPENAI_API_BASE as fallback in mistral mode when MISTRAL_BASE_URL is unset', async () => {
|
||||
const debugSpy = mock(() => {})
|
||||
mock.module('../../utils/debug.js', () => ({
|
||||
logForDebugging: debugSpy,
|
||||
}))
|
||||
|
||||
delete process.env.CLAUDE_CODE_USE_OPENAI
|
||||
process.env.CLAUDE_CODE_USE_MISTRAL = '1'
|
||||
delete process.env.MISTRAL_BASE_URL
|
||||
process.env.MISTRAL_MODEL = 'mistral-medium-latest'
|
||||
process.env.OPENAI_API_BASE = 'http://127.0.0.1:11434/v1'
|
||||
|
||||
const nonce = `${Date.now()}-${Math.random()}`
|
||||
const { resolveProviderRequest } = await import(`./providerConfig.ts?ts=${nonce}`)
|
||||
|
||||
const resolved = resolveProviderRequest()
|
||||
|
||||
expect(resolved.baseUrl).toBe('http://127.0.0.1:11434/v1')
|
||||
expect(debugSpy.mock.calls).toHaveLength(0)
|
||||
})
|
||||
@@ -8,17 +8,20 @@ import {
|
||||
readCodexCredentials,
|
||||
type CodexCredentialBlob,
|
||||
} from '../../utils/codexCredentials.js'
|
||||
import { logForDebugging } from '../../utils/debug.js'
|
||||
import { isEnvTruthy } from '../../utils/envUtils.js'
|
||||
import {
|
||||
asTrimmedString,
|
||||
parseChatgptAccountId,
|
||||
} from './codexOAuthShared.js'
|
||||
import { DEFAULT_GEMINI_BASE_URL } from 'src/utils/providerProfile.js'
|
||||
|
||||
export const DEFAULT_OPENAI_BASE_URL = 'https://api.openai.com/v1'
|
||||
export const DEFAULT_CODEX_BASE_URL = 'https://chatgpt.com/backend-api/codex'
|
||||
export const DEFAULT_MISTRAL_BASE_URL = 'https://api.mistral.ai/v1'
|
||||
/** Default GitHub Copilot API model when user selects copilot / github:copilot */
|
||||
export const DEFAULT_GITHUB_MODELS_API_MODEL = 'gpt-4o'
|
||||
const warnedUndefinedEnvNames = new Set<string>()
|
||||
|
||||
const CODEX_ALIAS_MODELS: Record<
|
||||
string,
|
||||
@@ -129,7 +132,33 @@ function isPrivateIpv6Address(hostname: string): boolean {
|
||||
function asEnvUrl(value: string | undefined): string | undefined {
|
||||
if (!value) return undefined
|
||||
const trimmed = value.trim()
|
||||
if (!trimmed || trimmed === 'undefined') return undefined
|
||||
if (!trimmed) return undefined
|
||||
if (trimmed === 'undefined') {
|
||||
return undefined
|
||||
}
|
||||
return trimmed
|
||||
}
|
||||
|
||||
function asNamedEnvUrl(
|
||||
value: string | undefined,
|
||||
envName: string,
|
||||
): string | undefined {
|
||||
if (!value) return undefined
|
||||
|
||||
const trimmed = value.trim()
|
||||
if (!trimmed) return undefined
|
||||
|
||||
if (trimmed === 'undefined') {
|
||||
if (!warnedUndefinedEnvNames.has(envName)) {
|
||||
warnedUndefinedEnvNames.add(envName)
|
||||
logForDebugging(
|
||||
`[provider-config] Environment variable ${envName} is the literal string "undefined"; ignoring it.`,
|
||||
{ level: 'warn' },
|
||||
)
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
return trimmed
|
||||
}
|
||||
|
||||
@@ -353,23 +382,52 @@ export function resolveProviderRequest(options?: {
|
||||
}): ResolvedProviderRequest {
|
||||
const isGithubMode = isEnvTruthy(process.env.CLAUDE_CODE_USE_GITHUB)
|
||||
const isMistralMode = isEnvTruthy(process.env.CLAUDE_CODE_USE_MISTRAL)
|
||||
const isGeminiMode = isEnvTruthy(process.env.CLAUDE_CODE_USE_GEMINI)
|
||||
const requestedModel =
|
||||
options?.model?.trim() ||
|
||||
(isMistralMode
|
||||
? process.env.MISTRAL_MODEL?.trim()
|
||||
: process.env.OPENAI_MODEL?.trim()) ||
|
||||
(isGeminiMode
|
||||
? process.env.GEMINI_MODEL?.trim()
|
||||
: process.env.OPENAI_MODEL?.trim()) ||
|
||||
options?.fallbackModel?.trim() ||
|
||||
(isGithubMode ? 'github:copilot' : 'gpt-4o')
|
||||
const descriptor = parseModelDescriptor(requestedModel)
|
||||
const explicitBaseUrl = asEnvUrl(options?.baseUrl)
|
||||
|
||||
const normalizedMistralEnvBaseUrl = asNamedEnvUrl(
|
||||
process.env.MISTRAL_BASE_URL,
|
||||
'MISTRAL_BASE_URL',
|
||||
)
|
||||
|
||||
const normalizedGeminiEnvBaseUrl = asNamedEnvUrl(
|
||||
process.env.GEMINI_BASE_URL,
|
||||
'GEMINI_BASE_URL',
|
||||
)
|
||||
|
||||
const primaryEnvBaseUrl = isMistralMode
|
||||
? normalizedMistralEnvBaseUrl
|
||||
: isGeminiMode
|
||||
? normalizedGeminiEnvBaseUrl
|
||||
: asNamedEnvUrl(process.env.OPENAI_BASE_URL, 'OPENAI_BASE_URL')
|
||||
|
||||
const fallbackEnvBaseUrl = isMistralMode
|
||||
? (primaryEnvBaseUrl === undefined
|
||||
? asNamedEnvUrl(process.env.OPENAI_API_BASE, 'OPENAI_API_BASE') ?? DEFAULT_MISTRAL_BASE_URL
|
||||
: undefined)
|
||||
: isGeminiMode
|
||||
? (primaryEnvBaseUrl === undefined
|
||||
? asNamedEnvUrl(process.env.OPENAI_API_BASE, 'OPENAI_API_BASE') ?? DEFAULT_GEMINI_BASE_URL
|
||||
: undefined)
|
||||
: (primaryEnvBaseUrl === undefined
|
||||
? asNamedEnvUrl(process.env.OPENAI_API_BASE, 'OPENAI_API_BASE')
|
||||
: undefined)
|
||||
|
||||
const envBaseUrlRaw =
|
||||
explicitBaseUrl ??
|
||||
asEnvUrl(
|
||||
isMistralMode
|
||||
? (process.env.MISTRAL_BASE_URL ?? DEFAULT_MISTRAL_BASE_URL)
|
||||
: process.env.OPENAI_BASE_URL
|
||||
) ??
|
||||
asEnvUrl(process.env.OPENAI_API_BASE)
|
||||
primaryEnvBaseUrl ??
|
||||
fallbackEnvBaseUrl
|
||||
|
||||
const isCodexModelForGithub = isGithubMode && isCodexAlias(requestedModel)
|
||||
const envBaseUrl =
|
||||
|
||||
@@ -110,9 +110,14 @@ export function calculateTokenWarningState(
|
||||
? autoCompactThreshold
|
||||
: getEffectiveContextWindowSize(model)
|
||||
|
||||
// Use the raw context window (without output reservation) for the percentage
|
||||
// display, so users see remaining context relative to the model's full capacity.
|
||||
// The threshold (which subtracts buffer) should only affect when we warn/compact,
|
||||
// not what percentage we display.
|
||||
const rawContextWindow = getContextWindowForModel(model, getSdkBetas())
|
||||
const percentLeft = Math.max(
|
||||
0,
|
||||
Math.round(((threshold - tokenUsage) / threshold) * 100),
|
||||
Math.round(((rawContextWindow - tokenUsage) / rawContextWindow) * 100),
|
||||
)
|
||||
|
||||
const warningThreshold = threshold - WARNING_THRESHOLD_BUFFER_TOKENS
|
||||
|
||||
152
src/services/diagnosticTracking.test.ts
Normal file
152
src/services/diagnosticTracking.test.ts
Normal file
@@ -0,0 +1,152 @@
|
||||
import { describe, test, expect, beforeEach, afterEach } from 'bun:test'
|
||||
import { DiagnosticTrackingService } from './diagnosticTracking.js'
|
||||
import type { MCPServerConnection } from './mcp/types.js'
|
||||
|
||||
// Mock the IDE client utility
|
||||
const mockGetConnectedIdeClient = (clients: MCPServerConnection[]) =>
|
||||
clients.find(client => client.type === 'connected')
|
||||
|
||||
describe('DiagnosticTrackingService', () => {
|
||||
let service: DiagnosticTrackingService
|
||||
let mockClients: MCPServerConnection[]
|
||||
let mockIdeClient: MCPServerConnection
|
||||
|
||||
beforeEach(() => {
|
||||
// Get fresh instance for each test
|
||||
service = DiagnosticTrackingService.getInstance()
|
||||
|
||||
// Setup mock clients
|
||||
mockIdeClient = {
|
||||
type: 'connected',
|
||||
name: 'test-ide',
|
||||
capabilities: {},
|
||||
config: {},
|
||||
cleanup: async () => {},
|
||||
client: {
|
||||
request: async () => ({}),
|
||||
setNotificationHandler: () => {},
|
||||
close: async () => {},
|
||||
},
|
||||
} as unknown as MCPServerConnection
|
||||
|
||||
mockClients = [
|
||||
{ type: 'disconnected', name: 'test-disconnected', config: {} } as unknown as MCPServerConnection,
|
||||
mockIdeClient,
|
||||
]
|
||||
})
|
||||
|
||||
afterEach(async () => {
|
||||
await service.shutdown()
|
||||
})
|
||||
|
||||
describe('handleQueryStart', () => {
|
||||
test('should store MCP clients and initialize service', async () => {
|
||||
await service.handleQueryStart(mockClients)
|
||||
|
||||
// Service should be initialized
|
||||
expect(service).toBeDefined()
|
||||
|
||||
// Should be able to get IDE client from stored clients
|
||||
// We can't directly test private methods, but we can test the behavior
|
||||
const result = await service.getNewDiagnosticsCompat()
|
||||
expect(result).toEqual([]) // Should return empty when no diagnostics
|
||||
})
|
||||
|
||||
test('should reset service if already initialized', async () => {
|
||||
// Initialize first
|
||||
await service.handleQueryStart(mockClients)
|
||||
|
||||
// Call again - should reset without error
|
||||
await service.handleQueryStart(mockClients)
|
||||
|
||||
// Should still work
|
||||
const result = await service.getNewDiagnosticsCompat()
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('backward-compatible methods', () => {
|
||||
beforeEach(async () => {
|
||||
await service.handleQueryStart(mockClients)
|
||||
})
|
||||
|
||||
test('beforeFileEditedCompat should work without explicit client', async () => {
|
||||
// Should not throw error and should return undefined when no IDE client
|
||||
const result = await service.beforeFileEditedCompat('/test/file.ts')
|
||||
expect(result).toBeUndefined()
|
||||
})
|
||||
|
||||
test('getNewDiagnosticsCompat should work without explicit client', async () => {
|
||||
const result = await service.getNewDiagnosticsCompat()
|
||||
expect(Array.isArray(result)).toBe(true)
|
||||
})
|
||||
|
||||
test('ensureFileOpenedCompat should work without explicit client', async () => {
|
||||
const result = await service.ensureFileOpenedCompat('/test/file.ts')
|
||||
expect(result).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('new explicit client methods', () => {
|
||||
test('beforeFileEdited should require client parameter', async () => {
|
||||
// Should not work without client
|
||||
const result = await service.beforeFileEdited('/test/file.ts', undefined as any)
|
||||
expect(result).toBeUndefined()
|
||||
})
|
||||
|
||||
test('getNewDiagnostics should require client parameter', async () => {
|
||||
// Should not work without client
|
||||
const result = await service.getNewDiagnostics(undefined as any)
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
|
||||
test('ensureFileOpened should require client parameter', async () => {
|
||||
// Should not work without client
|
||||
const result = await service.ensureFileOpened('/test/file.ts', undefined as any)
|
||||
expect(result).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('shutdown', () => {
|
||||
test('should clear stored clients on shutdown', async () => {
|
||||
await service.handleQueryStart(mockClients)
|
||||
|
||||
// Verify service is working
|
||||
const beforeResult = await service.getNewDiagnosticsCompat()
|
||||
expect(Array.isArray(beforeResult)).toBe(true)
|
||||
|
||||
// Shutdown
|
||||
await service.shutdown()
|
||||
|
||||
// After shutdown, compat methods should return empty results
|
||||
const afterResult = await service.getNewDiagnosticsCompat()
|
||||
expect(afterResult).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('integration with existing functionality', () => {
|
||||
test('should maintain existing diagnostic tracking behavior', async () => {
|
||||
await service.handleQueryStart(mockClients)
|
||||
|
||||
// Test baseline tracking
|
||||
await service.beforeFileEditedCompat('/test/file.ts')
|
||||
|
||||
// Test getting new diagnostics (should be empty since no IDE client is actually connected)
|
||||
const newDiagnostics = await service.getNewDiagnosticsCompat()
|
||||
expect(Array.isArray(newDiagnostics)).toBe(true)
|
||||
})
|
||||
|
||||
test('should handle missing IDE client gracefully', async () => {
|
||||
// Test with no connected clients
|
||||
const noIdeClients = [
|
||||
{ type: 'disconnected', name: 'test-disconnected-2', config: {} } as unknown as MCPServerConnection,
|
||||
]
|
||||
|
||||
await service.handleQueryStart(noIdeClients)
|
||||
|
||||
// Should handle gracefully
|
||||
const result = await service.getNewDiagnosticsCompat()
|
||||
expect(result).toEqual([])
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -32,7 +32,7 @@ export class DiagnosticTrackingService {
|
||||
private baseline: Map<string, Diagnostic[]> = new Map()
|
||||
|
||||
private initialized = false
|
||||
private mcpClient: MCPServerConnection | undefined
|
||||
private currentMcpClients: MCPServerConnection[] = []
|
||||
|
||||
// Track when files were last processed/fetched
|
||||
private lastProcessedTimestamps: Map<string, number> = new Map()
|
||||
@@ -48,18 +48,17 @@ export class DiagnosticTrackingService {
|
||||
return DiagnosticTrackingService.instance
|
||||
}
|
||||
|
||||
initialize(mcpClient: MCPServerConnection) {
|
||||
initialize() {
|
||||
if (this.initialized) {
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: Do not cache the connected mcpClient since it can change.
|
||||
this.mcpClient = mcpClient
|
||||
this.initialized = true
|
||||
}
|
||||
|
||||
async shutdown(): Promise<void> {
|
||||
this.initialized = false
|
||||
this.currentMcpClients = []
|
||||
this.baseline.clear()
|
||||
this.rightFileDiagnosticsState.clear()
|
||||
this.lastProcessedTimestamps.clear()
|
||||
@@ -75,6 +74,46 @@ export class DiagnosticTrackingService {
|
||||
this.lastProcessedTimestamps.clear()
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current IDE client from stored MCP clients
|
||||
*/
|
||||
private getCurrentIdeClient(): MCPServerConnection | undefined {
|
||||
return getConnectedIdeClient(this.currentMcpClients)
|
||||
}
|
||||
|
||||
/**
|
||||
* Backward-compatible method that uses stored IDE client
|
||||
*/
|
||||
async beforeFileEditedCompat(filePath: string): Promise<void> {
|
||||
const ideClient = this.getCurrentIdeClient()
|
||||
if (!ideClient) {
|
||||
return
|
||||
}
|
||||
return await this.beforeFileEdited(filePath, ideClient)
|
||||
}
|
||||
|
||||
/**
|
||||
* Backward-compatible method that uses stored IDE client
|
||||
*/
|
||||
async getNewDiagnosticsCompat(): Promise<DiagnosticFile[]> {
|
||||
const ideClient = this.getCurrentIdeClient()
|
||||
if (!ideClient) {
|
||||
return []
|
||||
}
|
||||
return await this.getNewDiagnostics(ideClient)
|
||||
}
|
||||
|
||||
/**
|
||||
* Backward-compatible method that uses stored IDE client
|
||||
*/
|
||||
async ensureFileOpenedCompat(fileUri: string): Promise<void> {
|
||||
const ideClient = this.getCurrentIdeClient()
|
||||
if (!ideClient) {
|
||||
return
|
||||
}
|
||||
return await this.ensureFileOpened(fileUri, ideClient)
|
||||
}
|
||||
|
||||
private normalizeFileUri(fileUri: string): string {
|
||||
// Remove our protocol prefixes
|
||||
const protocolPrefixes = [
|
||||
@@ -100,11 +139,11 @@ export class DiagnosticTrackingService {
|
||||
* Ensure a file is opened in the IDE before processing.
|
||||
* This is important for language services like diagnostics to work properly.
|
||||
*/
|
||||
async ensureFileOpened(fileUri: string): Promise<void> {
|
||||
async ensureFileOpened(fileUri: string, mcpClient: MCPServerConnection): Promise<void> {
|
||||
if (
|
||||
!this.initialized ||
|
||||
!this.mcpClient ||
|
||||
this.mcpClient.type !== 'connected'
|
||||
!mcpClient ||
|
||||
mcpClient.type !== 'connected'
|
||||
) {
|
||||
return
|
||||
}
|
||||
@@ -121,7 +160,7 @@ export class DiagnosticTrackingService {
|
||||
selectToEndOfLine: false,
|
||||
makeFrontmost: false,
|
||||
},
|
||||
this.mcpClient,
|
||||
mcpClient,
|
||||
)
|
||||
} catch (error) {
|
||||
logError(error as Error)
|
||||
@@ -132,11 +171,11 @@ export class DiagnosticTrackingService {
|
||||
* Capture baseline diagnostics for a specific file before editing.
|
||||
* This is called before editing a file to ensure we have a baseline to compare against.
|
||||
*/
|
||||
async beforeFileEdited(filePath: string): Promise<void> {
|
||||
async beforeFileEdited(filePath: string, mcpClient: MCPServerConnection): Promise<void> {
|
||||
if (
|
||||
!this.initialized ||
|
||||
!this.mcpClient ||
|
||||
this.mcpClient.type !== 'connected'
|
||||
!mcpClient ||
|
||||
mcpClient.type !== 'connected'
|
||||
) {
|
||||
return
|
||||
}
|
||||
@@ -147,7 +186,7 @@ export class DiagnosticTrackingService {
|
||||
const result = await callIdeRpc(
|
||||
'getDiagnostics',
|
||||
{ uri: `file://${filePath}` },
|
||||
this.mcpClient,
|
||||
mcpClient,
|
||||
)
|
||||
const diagnosticFile = this.parseDiagnosticResult(result)[0]
|
||||
if (diagnosticFile) {
|
||||
@@ -185,11 +224,11 @@ export class DiagnosticTrackingService {
|
||||
* Get new diagnostics from file://, _claude_fs_right, and _claude_fs_ URIs that aren't in the baseline.
|
||||
* Only processes diagnostics for files that have been edited.
|
||||
*/
|
||||
async getNewDiagnostics(): Promise<DiagnosticFile[]> {
|
||||
async getNewDiagnostics(mcpClient: MCPServerConnection): Promise<DiagnosticFile[]> {
|
||||
if (
|
||||
!this.initialized ||
|
||||
!this.mcpClient ||
|
||||
this.mcpClient.type !== 'connected'
|
||||
!mcpClient ||
|
||||
mcpClient.type !== 'connected'
|
||||
) {
|
||||
return []
|
||||
}
|
||||
@@ -200,7 +239,7 @@ export class DiagnosticTrackingService {
|
||||
const result = await callIdeRpc(
|
||||
'getDiagnostics',
|
||||
{}, // Empty params fetches all diagnostics
|
||||
this.mcpClient,
|
||||
mcpClient,
|
||||
)
|
||||
allDiagnosticFiles = this.parseDiagnosticResult(result)
|
||||
} catch (_error) {
|
||||
@@ -328,13 +367,16 @@ export class DiagnosticTrackingService {
|
||||
* @param shouldQuery Whether a query is actually being made (not just a command)
|
||||
*/
|
||||
async handleQueryStart(clients: MCPServerConnection[]): Promise<void> {
|
||||
// Store the current MCP clients for later use
|
||||
this.currentMcpClients = clients
|
||||
|
||||
// Only proceed if we should query and have clients
|
||||
if (!this.initialized) {
|
||||
// Find the connected IDE client
|
||||
const connectedIdeClient = getConnectedIdeClient(clients)
|
||||
|
||||
if (connectedIdeClient) {
|
||||
this.initialize(connectedIdeClient)
|
||||
this.initialize()
|
||||
}
|
||||
} else {
|
||||
// Reset diagnostic tracking for new query loops
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { afterEach, beforeEach, describe, expect, mock, test } from 'bun:test'
|
||||
import { afterEach, describe, expect, mock, test } from 'bun:test'
|
||||
|
||||
import {
|
||||
DEFAULT_GITHUB_DEVICE_SCOPE,
|
||||
@@ -12,22 +12,15 @@ async function importFreshModule() {
|
||||
return import(`./deviceFlow.ts?ts=${Date.now()}-${Math.random()}`)
|
||||
}
|
||||
|
||||
afterEach(() => {
|
||||
mock.restore()
|
||||
})
|
||||
|
||||
describe('requestDeviceCode', () => {
|
||||
const originalFetch = globalThis.fetch
|
||||
|
||||
beforeEach(() => {
|
||||
mock.restore()
|
||||
globalThis.fetch = originalFetch
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
globalThis.fetch = originalFetch
|
||||
})
|
||||
|
||||
test('parses successful device code response', async () => {
|
||||
const { requestDeviceCode } = await importFreshModule()
|
||||
|
||||
globalThis.fetch = mock(() =>
|
||||
const fetchImpl = mock(() =>
|
||||
Promise.resolve(
|
||||
new Response(
|
||||
JSON.stringify({
|
||||
@@ -44,7 +37,7 @@ describe('requestDeviceCode', () => {
|
||||
|
||||
const r = await requestDeviceCode({
|
||||
clientId: 'test-client',
|
||||
fetchImpl: globalThis.fetch,
|
||||
fetchImpl,
|
||||
})
|
||||
expect(r.device_code).toBe('abc')
|
||||
expect(r.user_code).toBe('ABCD-1234')
|
||||
@@ -57,17 +50,17 @@ describe('requestDeviceCode', () => {
|
||||
const { requestDeviceCode, GitHubDeviceFlowError } =
|
||||
await importFreshModule()
|
||||
|
||||
globalThis.fetch = mock(() =>
|
||||
const fetchImpl = mock(() =>
|
||||
Promise.resolve(new Response('bad', { status: 500 })),
|
||||
)
|
||||
await expect(
|
||||
requestDeviceCode({ clientId: 'x', fetchImpl: globalThis.fetch }),
|
||||
requestDeviceCode({ clientId: 'x', fetchImpl }),
|
||||
).rejects.toThrow(GitHubDeviceFlowError)
|
||||
})
|
||||
|
||||
test('uses OAuth-safe default scope', async () => {
|
||||
let capturedScope = ''
|
||||
globalThis.fetch = mock((_url: RequestInfo | URL, init?: RequestInit) => {
|
||||
const fetchImpl = mock((_url: RequestInfo | URL, init?: RequestInit) => {
|
||||
const body = init?.body
|
||||
if (body instanceof URLSearchParams) {
|
||||
capturedScope = body.get('scope') ?? ''
|
||||
@@ -87,7 +80,7 @@ describe('requestDeviceCode', () => {
|
||||
)
|
||||
})
|
||||
|
||||
await requestDeviceCode({ clientId: 'test-client', fetchImpl: globalThis.fetch })
|
||||
await requestDeviceCode({ clientId: 'test-client', fetchImpl })
|
||||
expect(capturedScope).toBe(DEFAULT_GITHUB_DEVICE_SCOPE)
|
||||
expect(capturedScope).toBe('read:user')
|
||||
})
|
||||
@@ -96,7 +89,7 @@ describe('requestDeviceCode', () => {
|
||||
const scopesSeen: string[] = []
|
||||
let callCount = 0
|
||||
|
||||
globalThis.fetch = mock((_url: RequestInfo | URL, init?: RequestInit) => {
|
||||
const fetchImpl = mock((_url: RequestInfo | URL, init?: RequestInit) => {
|
||||
const body = init?.body
|
||||
const scope =
|
||||
body instanceof URLSearchParams
|
||||
@@ -132,7 +125,7 @@ describe('requestDeviceCode', () => {
|
||||
const result = await requestDeviceCode({
|
||||
clientId: 'test-client',
|
||||
scope: 'read:user,models:read',
|
||||
fetchImpl: globalThis.fetch,
|
||||
fetchImpl,
|
||||
})
|
||||
|
||||
expect(result.device_code).toBe('abc')
|
||||
@@ -142,17 +135,11 @@ describe('requestDeviceCode', () => {
|
||||
})
|
||||
|
||||
describe('pollAccessToken', () => {
|
||||
const originalFetch = globalThis.fetch
|
||||
|
||||
afterEach(() => {
|
||||
globalThis.fetch = originalFetch
|
||||
})
|
||||
|
||||
test('returns token when GitHub responds with access_token immediately', async () => {
|
||||
const { pollAccessToken } = await importFreshModule()
|
||||
|
||||
let calls = 0
|
||||
globalThis.fetch = mock(() => {
|
||||
const fetchImpl = mock(() => {
|
||||
calls++
|
||||
return Promise.resolve(
|
||||
new Response(JSON.stringify({ access_token: 'tok-xyz' }), {
|
||||
@@ -163,7 +150,7 @@ describe('pollAccessToken', () => {
|
||||
|
||||
const token = await pollAccessToken('dev-code', {
|
||||
clientId: 'cid',
|
||||
fetchImpl: globalThis.fetch,
|
||||
fetchImpl,
|
||||
})
|
||||
expect(token).toBe('tok-xyz')
|
||||
expect(calls).toBe(1)
|
||||
@@ -172,7 +159,7 @@ describe('pollAccessToken', () => {
|
||||
test('throws on access_denied', async () => {
|
||||
const { pollAccessToken } = await importFreshModule()
|
||||
|
||||
globalThis.fetch = mock(() =>
|
||||
const fetchImpl = mock(() =>
|
||||
Promise.resolve(
|
||||
new Response(JSON.stringify({ error: 'access_denied' }), {
|
||||
status: 200,
|
||||
@@ -182,23 +169,17 @@ describe('pollAccessToken', () => {
|
||||
await expect(
|
||||
pollAccessToken('dc', {
|
||||
clientId: 'c',
|
||||
fetchImpl: globalThis.fetch,
|
||||
fetchImpl,
|
||||
}),
|
||||
).rejects.toThrow(/denied/)
|
||||
})
|
||||
})
|
||||
|
||||
describe('exchangeForCopilotToken', () => {
|
||||
const originalFetch = globalThis.fetch
|
||||
|
||||
afterEach(() => {
|
||||
globalThis.fetch = originalFetch
|
||||
})
|
||||
|
||||
test('parses successful Copilot token response', async () => {
|
||||
const { exchangeForCopilotToken } = await importFreshModule()
|
||||
|
||||
globalThis.fetch = mock(() =>
|
||||
const fetchImpl = mock(() =>
|
||||
Promise.resolve(
|
||||
new Response(
|
||||
JSON.stringify({
|
||||
@@ -214,7 +195,7 @@ describe('exchangeForCopilotToken', () => {
|
||||
),
|
||||
)
|
||||
|
||||
const result = await exchangeForCopilotToken('oauth-token', globalThis.fetch)
|
||||
const result = await exchangeForCopilotToken('oauth-token', fetchImpl)
|
||||
expect(result.token).toBe('copilot-token-xyz')
|
||||
expect(result.expires_at).toBe(1700000000)
|
||||
expect(result.refresh_in).toBe(3600)
|
||||
@@ -225,24 +206,24 @@ describe('exchangeForCopilotToken', () => {
|
||||
const { exchangeForCopilotToken, GitHubDeviceFlowError } =
|
||||
await importFreshModule()
|
||||
|
||||
globalThis.fetch = mock(() =>
|
||||
const fetchImpl = mock(() =>
|
||||
Promise.resolve(new Response('unauthorized', { status: 401 })),
|
||||
)
|
||||
await expect(
|
||||
exchangeForCopilotToken('bad-token', globalThis.fetch),
|
||||
exchangeForCopilotToken('bad-token', fetchImpl),
|
||||
).rejects.toThrow(GitHubDeviceFlowError)
|
||||
})
|
||||
|
||||
test('throws on malformed response', async () => {
|
||||
const { exchangeForCopilotToken } = await importFreshModule()
|
||||
|
||||
globalThis.fetch = mock(() =>
|
||||
const fetchImpl = mock(() =>
|
||||
Promise.resolve(
|
||||
new Response(JSON.stringify({ invalid: 'data' }), { status: 200 }),
|
||||
),
|
||||
)
|
||||
await expect(
|
||||
exchangeForCopilotToken('oauth-token', globalThis.fetch),
|
||||
exchangeForCopilotToken('oauth-token', fetchImpl),
|
||||
).rejects.toThrow(/Malformed/)
|
||||
})
|
||||
})
|
||||
|
||||
61
src/services/mcp/auth.test.ts
Normal file
61
src/services/mcp/auth.test.ts
Normal file
@@ -0,0 +1,61 @@
|
||||
import assert from 'node:assert/strict'
|
||||
import test from 'node:test'
|
||||
|
||||
import { validateOAuthCallbackParams } from './auth.js'
|
||||
|
||||
test('OAuth callback rejects error parameters before state validation can be bypassed', () => {
|
||||
const result = validateOAuthCallbackParams(
|
||||
{
|
||||
error: 'access_denied',
|
||||
error_description: 'denied by provider',
|
||||
},
|
||||
'expected-state',
|
||||
)
|
||||
|
||||
assert.deepEqual(result, { type: 'state_mismatch' })
|
||||
})
|
||||
|
||||
test('OAuth callback accepts provider errors only when state matches', () => {
|
||||
const result = validateOAuthCallbackParams(
|
||||
{
|
||||
state: 'expected-state',
|
||||
error: 'access_denied',
|
||||
error_description: 'denied by provider',
|
||||
error_uri: 'https://example.test/error',
|
||||
},
|
||||
'expected-state',
|
||||
)
|
||||
|
||||
assert.deepEqual(result, {
|
||||
type: 'error',
|
||||
error: 'access_denied',
|
||||
errorDescription: 'denied by provider',
|
||||
errorUri: 'https://example.test/error',
|
||||
message:
|
||||
'OAuth error: access_denied - denied by provider (See: https://example.test/error)',
|
||||
})
|
||||
})
|
||||
|
||||
test('OAuth callback accepts authorization codes only when state matches', () => {
|
||||
assert.deepEqual(
|
||||
validateOAuthCallbackParams(
|
||||
{
|
||||
state: 'expected-state',
|
||||
code: 'auth-code',
|
||||
},
|
||||
'expected-state',
|
||||
),
|
||||
{ type: 'code', code: 'auth-code' },
|
||||
)
|
||||
|
||||
assert.deepEqual(
|
||||
validateOAuthCallbackParams(
|
||||
{
|
||||
state: 'wrong-state',
|
||||
code: 'auth-code',
|
||||
},
|
||||
'expected-state',
|
||||
),
|
||||
{ type: 'state_mismatch' },
|
||||
)
|
||||
})
|
||||
@@ -124,6 +124,74 @@ function redactSensitiveUrlParams(url: string): string {
|
||||
}
|
||||
}
|
||||
|
||||
type OAuthCallbackParamValue = string | string[] | null | undefined
|
||||
|
||||
type OAuthCallbackValidationResult =
|
||||
| { type: 'code'; code: string }
|
||||
| {
|
||||
type: 'error'
|
||||
error: string
|
||||
errorDescription: string
|
||||
errorUri: string
|
||||
message: string
|
||||
}
|
||||
| { type: 'missing_result' }
|
||||
| { type: 'state_mismatch' }
|
||||
|
||||
function getFirstOAuthCallbackParam(
|
||||
value: OAuthCallbackParamValue,
|
||||
): string | undefined {
|
||||
if (Array.isArray(value)) {
|
||||
return value.find(item => item.length > 0)
|
||||
}
|
||||
return value && value.length > 0 ? value : undefined
|
||||
}
|
||||
|
||||
export function validateOAuthCallbackParams(
|
||||
params: {
|
||||
code?: OAuthCallbackParamValue
|
||||
state?: OAuthCallbackParamValue
|
||||
error?: OAuthCallbackParamValue
|
||||
error_description?: OAuthCallbackParamValue
|
||||
error_uri?: OAuthCallbackParamValue
|
||||
},
|
||||
oauthState: string,
|
||||
): OAuthCallbackValidationResult {
|
||||
const code = getFirstOAuthCallbackParam(params.code)
|
||||
const state = getFirstOAuthCallbackParam(params.state)
|
||||
const error = getFirstOAuthCallbackParam(params.error)
|
||||
const errorDescription =
|
||||
getFirstOAuthCallbackParam(params.error_description) ?? ''
|
||||
const errorUri = getFirstOAuthCallbackParam(params.error_uri) ?? ''
|
||||
|
||||
if (state !== oauthState) {
|
||||
return { type: 'state_mismatch' }
|
||||
}
|
||||
|
||||
if (error) {
|
||||
let message = `OAuth error: ${error}`
|
||||
if (errorDescription) {
|
||||
message += ` - ${errorDescription}`
|
||||
}
|
||||
if (errorUri) {
|
||||
message += ` (See: ${errorUri})`
|
||||
}
|
||||
return {
|
||||
type: 'error',
|
||||
error,
|
||||
errorDescription,
|
||||
errorUri,
|
||||
message,
|
||||
}
|
||||
}
|
||||
|
||||
if (code) {
|
||||
return { type: 'code', code }
|
||||
}
|
||||
|
||||
return { type: 'missing_result' }
|
||||
}
|
||||
|
||||
/**
|
||||
* Some OAuth servers (notably Slack) return HTTP 200 for all responses,
|
||||
* signaling errors via the JSON body instead. The SDK's executeTokenRequest
|
||||
@@ -1058,30 +1126,31 @@ export async function performMCPOAuthFlow(
|
||||
options.onWaitingForCallback((callbackUrl: string) => {
|
||||
try {
|
||||
const parsed = new URL(callbackUrl)
|
||||
const code = parsed.searchParams.get('code')
|
||||
const state = parsed.searchParams.get('state')
|
||||
const error = parsed.searchParams.get('error')
|
||||
const result = validateOAuthCallbackParams(
|
||||
{
|
||||
code: parsed.searchParams.get('code'),
|
||||
state: parsed.searchParams.get('state'),
|
||||
error: parsed.searchParams.get('error'),
|
||||
error_description:
|
||||
parsed.searchParams.get('error_description'),
|
||||
error_uri: parsed.searchParams.get('error_uri'),
|
||||
},
|
||||
oauthState,
|
||||
)
|
||||
|
||||
if (error) {
|
||||
const errorDescription =
|
||||
parsed.searchParams.get('error_description') || ''
|
||||
cleanup()
|
||||
rejectOnce(
|
||||
new Error(`OAuth error: ${error} - ${errorDescription}`),
|
||||
)
|
||||
if (result.type === 'state_mismatch') {
|
||||
// Ignore so a stray or malicious URL cannot cancel an active flow.
|
||||
return
|
||||
}
|
||||
|
||||
if (!code) {
|
||||
// Not a valid callback URL, ignore so the user can try again
|
||||
if (result.type === 'missing_result') {
|
||||
// Not a valid callback URL, ignore so the user can try again.
|
||||
return
|
||||
}
|
||||
|
||||
if (state !== oauthState) {
|
||||
if (result.type === 'error') {
|
||||
cleanup()
|
||||
rejectOnce(
|
||||
new Error('OAuth state mismatch - possible CSRF attack'),
|
||||
)
|
||||
rejectOnce(new Error(result.message))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1090,7 +1159,7 @@ export async function performMCPOAuthFlow(
|
||||
`Received auth code via manual callback URL`,
|
||||
)
|
||||
cleanup()
|
||||
resolveOnce(code)
|
||||
resolveOnce(result.code)
|
||||
} catch {
|
||||
// Invalid URL, ignore so the user can try again
|
||||
}
|
||||
@@ -1101,53 +1170,49 @@ export async function performMCPOAuthFlow(
|
||||
const parsedUrl = parse(req.url || '', true)
|
||||
|
||||
if (parsedUrl.pathname === '/callback') {
|
||||
const code = parsedUrl.query.code as string
|
||||
const state = parsedUrl.query.state as string
|
||||
const error = parsedUrl.query.error
|
||||
const errorDescription = parsedUrl.query.error_description as string
|
||||
const errorUri = parsedUrl.query.error_uri as string
|
||||
const result = validateOAuthCallbackParams(
|
||||
parsedUrl.query,
|
||||
oauthState,
|
||||
)
|
||||
|
||||
// Validate OAuth state to prevent CSRF attacks
|
||||
if (!error && state !== oauthState) {
|
||||
if (result.type === 'state_mismatch') {
|
||||
res.writeHead(400, { 'Content-Type': 'text/html' })
|
||||
res.end(
|
||||
`<h1>Authentication Error</h1><p>Invalid state parameter. Please try again.</p><p>You can close this window.</p>`,
|
||||
)
|
||||
cleanup()
|
||||
rejectOnce(new Error('OAuth state mismatch - possible CSRF attack'))
|
||||
return
|
||||
}
|
||||
|
||||
if (error) {
|
||||
if (result.type === 'missing_result') {
|
||||
res.writeHead(400, { 'Content-Type': 'text/html' })
|
||||
res.end(
|
||||
`<h1>Authentication Error</h1><p>Missing OAuth result. Please try again.</p><p>You can close this window.</p>`,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if (result.type === 'error') {
|
||||
res.writeHead(200, { 'Content-Type': 'text/html' })
|
||||
// Sanitize error messages to prevent XSS
|
||||
const sanitizedError = xss(String(error))
|
||||
const sanitizedErrorDescription = errorDescription
|
||||
? xss(String(errorDescription))
|
||||
const sanitizedError = xss(result.error)
|
||||
const sanitizedErrorDescription = result.errorDescription
|
||||
? xss(result.errorDescription)
|
||||
: ''
|
||||
res.end(
|
||||
`<h1>Authentication Error</h1><p>${sanitizedError}: ${sanitizedErrorDescription}</p><p>You can close this window.</p>`,
|
||||
)
|
||||
cleanup()
|
||||
let errorMessage = `OAuth error: ${error}`
|
||||
if (errorDescription) {
|
||||
errorMessage += ` - ${errorDescription}`
|
||||
}
|
||||
if (errorUri) {
|
||||
errorMessage += ` (See: ${errorUri})`
|
||||
}
|
||||
rejectOnce(new Error(errorMessage))
|
||||
rejectOnce(new Error(result.message))
|
||||
return
|
||||
}
|
||||
|
||||
if (code) {
|
||||
res.writeHead(200, { 'Content-Type': 'text/html' })
|
||||
res.end(
|
||||
`<h1>Authentication Successful</h1><p>You can close this window. Return to Claude Code.</p>`,
|
||||
)
|
||||
cleanup()
|
||||
resolveOnce(code)
|
||||
}
|
||||
res.writeHead(200, { 'Content-Type': 'text/html' })
|
||||
res.end(
|
||||
`<h1>Authentication Successful</h1><p>You can close this window. Return to Claude Code.</p>`,
|
||||
)
|
||||
cleanup()
|
||||
resolveOnce(result.code)
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@@ -240,21 +240,28 @@ For commands that are harder to parse at a glance (piped commands, obscure flags
|
||||
- curl -s url | jq '.data[]' → "Fetch JSON from URL and extract data array elements"`),
|
||||
run_in_background: semanticBoolean(z.boolean().optional()).describe(`Set to true to run this command in the background. Use Read to read the output later.`),
|
||||
dangerouslyDisableSandbox: semanticBoolean(z.boolean().optional()).describe('Set this to true to dangerously override sandbox mode and run commands without sandboxing.'),
|
||||
_dangerouslyDisableSandboxApproved: z.boolean().optional().describe('Internal: user-approved sandbox override'),
|
||||
_simulatedSedEdit: z.object({
|
||||
filePath: z.string(),
|
||||
newContent: z.string()
|
||||
}).optional().describe('Internal: pre-computed sed edit result from preview')
|
||||
}));
|
||||
|
||||
// Always omit _simulatedSedEdit from the model-facing schema. It is an internal-only
|
||||
// field set by SedEditPermissionRequest after the user approves a sed edit preview.
|
||||
// Exposing it in the schema would let the model bypass permission checks and the
|
||||
// sandbox by pairing an innocuous command with an arbitrary file write.
|
||||
// Always omit internal-only fields from the model-facing schema.
|
||||
// _simulatedSedEdit is set by SedEditPermissionRequest after the user approves a
|
||||
// sed edit preview; exposing it would let the model bypass permission checks and
|
||||
// the sandbox by pairing an innocuous command with an arbitrary file write.
|
||||
// dangerouslyDisableSandbox is also omitted because sandbox escape must be tied
|
||||
// to trusted user/internal provenance, not model-controlled tool input.
|
||||
// Also conditionally remove run_in_background when background tasks are disabled.
|
||||
const inputSchema = lazySchema(() => isBackgroundTasksDisabled ? fullInputSchema().omit({
|
||||
run_in_background: true,
|
||||
dangerouslyDisableSandbox: true,
|
||||
_dangerouslyDisableSandboxApproved: true,
|
||||
_simulatedSedEdit: true
|
||||
}) : fullInputSchema().omit({
|
||||
dangerouslyDisableSandbox: true,
|
||||
_dangerouslyDisableSandboxApproved: true,
|
||||
_simulatedSedEdit: true
|
||||
}));
|
||||
type InputSchema = ReturnType<typeof inputSchema>;
|
||||
|
||||
59
src/tools/BashTool/bashPermissions.test.ts
Normal file
59
src/tools/BashTool/bashPermissions.test.ts
Normal file
@@ -0,0 +1,59 @@
|
||||
import { afterEach, expect, test } from 'bun:test'
|
||||
|
||||
import { getEmptyToolPermissionContext } from '../../Tool.js'
|
||||
import { SandboxManager } from '../../utils/sandbox/sandbox-adapter.js'
|
||||
import { bashToolHasPermission } from './bashPermissions.js'
|
||||
|
||||
const originalSandboxMethods = {
|
||||
isSandboxingEnabled: SandboxManager.isSandboxingEnabled,
|
||||
isAutoAllowBashIfSandboxedEnabled:
|
||||
SandboxManager.isAutoAllowBashIfSandboxedEnabled,
|
||||
areUnsandboxedCommandsAllowed: SandboxManager.areUnsandboxedCommandsAllowed,
|
||||
getExcludedCommands: SandboxManager.getExcludedCommands,
|
||||
}
|
||||
|
||||
afterEach(() => {
|
||||
SandboxManager.isSandboxingEnabled =
|
||||
originalSandboxMethods.isSandboxingEnabled
|
||||
SandboxManager.isAutoAllowBashIfSandboxedEnabled =
|
||||
originalSandboxMethods.isAutoAllowBashIfSandboxedEnabled
|
||||
SandboxManager.areUnsandboxedCommandsAllowed =
|
||||
originalSandboxMethods.areUnsandboxedCommandsAllowed
|
||||
SandboxManager.getExcludedCommands = originalSandboxMethods.getExcludedCommands
|
||||
})
|
||||
|
||||
function makeToolUseContext() {
|
||||
const toolPermissionContext = getEmptyToolPermissionContext()
|
||||
|
||||
return {
|
||||
abortController: new AbortController(),
|
||||
options: {
|
||||
isNonInteractiveSession: false,
|
||||
},
|
||||
getAppState() {
|
||||
return {
|
||||
toolPermissionContext,
|
||||
}
|
||||
},
|
||||
} as never
|
||||
}
|
||||
|
||||
test('sandbox auto-allow still enforces Bash path constraints', async () => {
|
||||
;(globalThis as unknown as { MACRO: { VERSION: string } }).MACRO = {
|
||||
VERSION: 'test',
|
||||
}
|
||||
|
||||
SandboxManager.isSandboxingEnabled = () => true
|
||||
SandboxManager.isAutoAllowBashIfSandboxedEnabled = () => true
|
||||
SandboxManager.areUnsandboxedCommandsAllowed = () => true
|
||||
SandboxManager.getExcludedCommands = () => []
|
||||
|
||||
const result = await bashToolHasPermission(
|
||||
{ command: 'cat ../../../../../etc/passwd' },
|
||||
makeToolUseContext(),
|
||||
)
|
||||
|
||||
expect(result.behavior).toBe('ask')
|
||||
expect(result.message).toContain('was blocked')
|
||||
expect(result.message).toContain('/etc/passwd')
|
||||
})
|
||||
@@ -1814,7 +1814,10 @@ export async function bashToolHasPermission(
|
||||
input,
|
||||
appState.toolPermissionContext,
|
||||
)
|
||||
if (sandboxAutoAllowResult.behavior !== 'passthrough') {
|
||||
if (
|
||||
sandboxAutoAllowResult.behavior === 'deny' ||
|
||||
sandboxAutoAllowResult.behavior === 'ask'
|
||||
) {
|
||||
return sandboxAutoAllowResult
|
||||
}
|
||||
}
|
||||
|
||||
@@ -179,9 +179,6 @@ function getSimpleSandboxSection(): string {
|
||||
const networkRestrictionConfig = SandboxManager.getNetworkRestrictionConfig()
|
||||
const allowUnixSockets = SandboxManager.getAllowUnixSockets()
|
||||
const ignoreViolations = SandboxManager.getIgnoreViolations()
|
||||
const allowUnsandboxedCommands =
|
||||
SandboxManager.areUnsandboxedCommandsAllowed()
|
||||
|
||||
// Replace the per-UID temp dir literal (e.g. /private/tmp/claude-1001/) with
|
||||
// "$TMPDIR" so the prompt is identical across users — avoids busting the
|
||||
// cross-user global prompt cache. The sandbox already sets $TMPDIR at runtime.
|
||||
@@ -225,38 +222,9 @@ function getSimpleSandboxSection(): string {
|
||||
)
|
||||
}
|
||||
|
||||
const sandboxOverrideItems: Array<string | string[]> =
|
||||
allowUnsandboxedCommands
|
||||
? [
|
||||
'You should always default to running commands within the sandbox. Do NOT attempt to set `dangerouslyDisableSandbox: true` unless:',
|
||||
[
|
||||
'The user *explicitly* asks you to bypass sandbox',
|
||||
'A specific command just failed and you see evidence of sandbox restrictions causing the failure. Note that commands can fail for many reasons unrelated to the sandbox (missing files, wrong arguments, network issues, etc.).',
|
||||
],
|
||||
'Evidence of sandbox-caused failures includes:',
|
||||
[
|
||||
'"Operation not permitted" errors for file/network operations',
|
||||
'Access denied to specific paths outside allowed directories',
|
||||
'Network connection failures to non-whitelisted hosts',
|
||||
'Unix socket connection errors',
|
||||
],
|
||||
'When you see evidence of sandbox-caused failure:',
|
||||
[
|
||||
"Immediately retry with `dangerouslyDisableSandbox: true` (don't ask, just do it)",
|
||||
'Briefly explain what sandbox restriction likely caused the failure. Be sure to mention that the user can use the `/sandbox` command to manage restrictions.',
|
||||
'This will prompt the user for permission',
|
||||
],
|
||||
'Treat each command you execute with `dangerouslyDisableSandbox: true` individually. Even if you have recently run a command with this setting, you should default to running future commands within the sandbox.',
|
||||
'Do not suggest adding sensitive paths like ~/.bashrc, ~/.zshrc, ~/.ssh/*, or credential files to the sandbox allowlist.',
|
||||
]
|
||||
: [
|
||||
'All commands MUST run in sandbox mode - the `dangerouslyDisableSandbox` parameter is disabled by policy.',
|
||||
'Commands cannot run outside the sandbox under any circumstances.',
|
||||
'If a command fails due to sandbox restrictions, work with the user to adjust sandbox settings instead.',
|
||||
]
|
||||
|
||||
const items: Array<string | string[]> = [
|
||||
...sandboxOverrideItems,
|
||||
'Commands MUST run in sandbox mode. If a command fails due to sandbox restrictions, explain the likely restriction and work with the user to adjust sandbox settings or run an explicit user-initiated shell command.',
|
||||
'Do not suggest adding sensitive paths like ~/.bashrc, ~/.zshrc, ~/.ssh/*, or credential files to the sandbox allowlist.',
|
||||
'For temporary files, always use the `$TMPDIR` environment variable. TMPDIR is automatically set to the correct sandbox-writable directory in sandbox mode. Do NOT use `/tmp` directly - use `$TMPDIR` instead.',
|
||||
]
|
||||
|
||||
|
||||
74
src/tools/BashTool/shouldUseSandbox.test.ts
Normal file
74
src/tools/BashTool/shouldUseSandbox.test.ts
Normal file
@@ -0,0 +1,74 @@
|
||||
import { afterEach, expect, test } from 'bun:test'
|
||||
|
||||
import { SandboxManager } from '../../utils/sandbox/sandbox-adapter.js'
|
||||
import { BashTool } from './BashTool.js'
|
||||
import { PowerShellTool } from '../PowerShellTool/PowerShellTool.js'
|
||||
import { shouldUseSandbox } from './shouldUseSandbox.js'
|
||||
|
||||
const originalSandboxMethods = {
|
||||
isSandboxingEnabled: SandboxManager.isSandboxingEnabled,
|
||||
areUnsandboxedCommandsAllowed: SandboxManager.areUnsandboxedCommandsAllowed,
|
||||
}
|
||||
|
||||
afterEach(() => {
|
||||
SandboxManager.isSandboxingEnabled =
|
||||
originalSandboxMethods.isSandboxingEnabled
|
||||
SandboxManager.areUnsandboxedCommandsAllowed =
|
||||
originalSandboxMethods.areUnsandboxedCommandsAllowed
|
||||
})
|
||||
|
||||
test('model-facing Bash schema rejects dangerouslyDisableSandbox', () => {
|
||||
const result = BashTool.inputSchema.safeParse({
|
||||
command: 'cat /etc/passwd',
|
||||
dangerouslyDisableSandbox: true,
|
||||
})
|
||||
|
||||
expect(result.success).toBe(false)
|
||||
})
|
||||
|
||||
test('model-facing PowerShell schema rejects dangerouslyDisableSandbox', () => {
|
||||
const result = PowerShellTool.inputSchema.safeParse({
|
||||
command: 'Get-Content C:\\Windows\\System32\\drivers\\etc\\hosts',
|
||||
dangerouslyDisableSandbox: true,
|
||||
})
|
||||
|
||||
expect(result.success).toBe(false)
|
||||
})
|
||||
|
||||
test('model-controlled dangerouslyDisableSandbox does not bypass sandbox', () => {
|
||||
SandboxManager.isSandboxingEnabled = () => true
|
||||
SandboxManager.areUnsandboxedCommandsAllowed = () => true
|
||||
|
||||
expect(
|
||||
shouldUseSandbox({
|
||||
command: 'cat /etc/passwd',
|
||||
dangerouslyDisableSandbox: true,
|
||||
}),
|
||||
).toBe(true)
|
||||
})
|
||||
|
||||
test('trusted internal approval can disable sandbox when policy allows it', () => {
|
||||
SandboxManager.isSandboxingEnabled = () => true
|
||||
SandboxManager.areUnsandboxedCommandsAllowed = () => true
|
||||
|
||||
expect(
|
||||
shouldUseSandbox({
|
||||
command: 'cat /etc/passwd',
|
||||
dangerouslyDisableSandbox: true,
|
||||
_dangerouslyDisableSandboxApproved: true,
|
||||
}),
|
||||
).toBe(false)
|
||||
})
|
||||
|
||||
test('trusted internal approval cannot disable sandbox when policy forbids it', () => {
|
||||
SandboxManager.isSandboxingEnabled = () => true
|
||||
SandboxManager.areUnsandboxedCommandsAllowed = () => false
|
||||
|
||||
expect(
|
||||
shouldUseSandbox({
|
||||
command: 'cat /etc/passwd',
|
||||
dangerouslyDisableSandbox: true,
|
||||
_dangerouslyDisableSandboxApproved: true,
|
||||
}),
|
||||
).toBe(true)
|
||||
})
|
||||
@@ -13,6 +13,7 @@ import {
|
||||
type SandboxInput = {
|
||||
command?: string
|
||||
dangerouslyDisableSandbox?: boolean
|
||||
_dangerouslyDisableSandboxApproved?: boolean
|
||||
}
|
||||
|
||||
// NOTE: excludedCommands is a user-facing convenience feature, not a security boundary.
|
||||
@@ -141,9 +142,13 @@ export function shouldUseSandbox(input: Partial<SandboxInput>): boolean {
|
||||
return false
|
||||
}
|
||||
|
||||
// Don't sandbox if explicitly overridden AND unsandboxed commands are allowed by policy
|
||||
// Only trusted internal callers may request an unsandboxed command. The
|
||||
// model-facing Bash schema omits _dangerouslyDisableSandboxApproved, so a
|
||||
// tool_use payload cannot disable the sandbox by setting
|
||||
// dangerouslyDisableSandbox directly.
|
||||
if (
|
||||
input.dangerouslyDisableSandbox &&
|
||||
input._dangerouslyDisableSandboxApproved &&
|
||||
SandboxManager.areUnsandboxedCommandsAllowed()
|
||||
) {
|
||||
return false
|
||||
|
||||
@@ -422,7 +422,7 @@ export const FileEditTool = buildTool({
|
||||
activateConditionalSkillsForPaths([absoluteFilePath], cwd)
|
||||
}
|
||||
|
||||
await diagnosticTracker.beforeFileEdited(absoluteFilePath)
|
||||
await diagnosticTracker.beforeFileEditedCompat(absoluteFilePath)
|
||||
|
||||
// Ensure parent directory exists before the atomic read-modify-write section.
|
||||
// These awaits must stay OUTSIDE the critical section below — a yield between
|
||||
|
||||
@@ -244,7 +244,7 @@ export const FileWriteTool = buildTool({
|
||||
// Activate conditional skills whose path patterns match this file
|
||||
activateConditionalSkillsForPaths([fullFilePath], cwd)
|
||||
|
||||
await diagnosticTracker.beforeFileEdited(fullFilePath)
|
||||
await diagnosticTracker.beforeFileEditedCompat(fullFilePath)
|
||||
|
||||
// Ensure parent directory exists before the atomic read-modify-write section.
|
||||
// Must stay OUTSIDE the critical section below (a yield between the staleness
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import { Ajv } from 'ajv'
|
||||
import { z } from 'zod/v4'
|
||||
import { buildTool, type ToolDef } from '../../Tool.js'
|
||||
import { buildTool, type ToolDef, type ValidationResult } from '../../Tool.js'
|
||||
import { lazySchema } from '../../utils/lazySchema.js'
|
||||
import type { PermissionResult } from '../../utils/permissions/PermissionResult.js'
|
||||
import type { PermissionResult } from '../../types/permissions.js'
|
||||
import { isOutputLineTruncated } from '../../utils/terminal.js'
|
||||
import { DESCRIPTION, PROMPT } from './prompt.js'
|
||||
import {
|
||||
@@ -37,6 +38,8 @@ export type Output = z.infer<OutputSchema>
|
||||
// Re-export MCPProgress from centralized types to break import cycles
|
||||
export type { MCPProgress } from '../../types/tools.js'
|
||||
|
||||
const ajv = new Ajv({ strict: false })
|
||||
|
||||
export const MCPTool = buildTool({
|
||||
isMcp: true,
|
||||
// Overridden in mcpClient.ts with the real MCP tool name + args
|
||||
@@ -72,6 +75,27 @@ export const MCPTool = buildTool({
|
||||
message: 'MCPTool requires permission.',
|
||||
}
|
||||
},
|
||||
async validateInput(input, context): Promise<ValidationResult> {
|
||||
if (this.inputJSONSchema) {
|
||||
try {
|
||||
const validate = ajv.compile(this.inputJSONSchema)
|
||||
if (!validate(input)) {
|
||||
return {
|
||||
result: false,
|
||||
message: ajv.errorsText(validate.errors),
|
||||
errorCode: 400,
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
return {
|
||||
result: false,
|
||||
message: `Failed to compile JSON schema for validation: ${error}`,
|
||||
errorCode: 500,
|
||||
}
|
||||
}
|
||||
}
|
||||
return { result: true }
|
||||
},
|
||||
renderToolUseMessage,
|
||||
// Overridden in mcpClient.ts
|
||||
userFacingName: () => 'mcp',
|
||||
@@ -100,3 +124,4 @@ export const MCPTool = buildTool({
|
||||
}
|
||||
},
|
||||
} satisfies ToolDef<InputSchema, Output>)
|
||||
|
||||
|
||||
@@ -230,13 +230,20 @@ const fullInputSchema = lazySchema(() => z.strictObject({
|
||||
timeout: semanticNumber(z.number().optional()).describe(`Optional timeout in milliseconds (max ${getMaxTimeoutMs()})`),
|
||||
description: z.string().optional().describe('Clear, concise description of what this command does in active voice.'),
|
||||
run_in_background: semanticBoolean(z.boolean().optional()).describe(`Set to true to run this command in the background. Use Read to read the output later.`),
|
||||
dangerouslyDisableSandbox: semanticBoolean(z.boolean().optional()).describe('Set this to true to dangerously override sandbox mode and run commands without sandboxing.')
|
||||
dangerouslyDisableSandbox: semanticBoolean(z.boolean().optional()).describe('Set this to true to dangerously override sandbox mode and run commands without sandboxing.'),
|
||||
_dangerouslyDisableSandboxApproved: z.boolean().optional().describe('Internal: user-approved sandbox override')
|
||||
}));
|
||||
|
||||
// Conditionally remove run_in_background from schema when background tasks are disabled
|
||||
// Omit internal-only sandbox override fields from the model-facing schema.
|
||||
// Conditionally remove run_in_background from schema when background tasks are disabled.
|
||||
const inputSchema = lazySchema(() => isBackgroundTasksDisabled ? fullInputSchema().omit({
|
||||
run_in_background: true
|
||||
}) : fullInputSchema());
|
||||
run_in_background: true,
|
||||
dangerouslyDisableSandbox: true,
|
||||
_dangerouslyDisableSandboxApproved: true
|
||||
}) : fullInputSchema().omit({
|
||||
dangerouslyDisableSandbox: true,
|
||||
_dangerouslyDisableSandboxApproved: true
|
||||
}));
|
||||
type InputSchema = ReturnType<typeof inputSchema>;
|
||||
|
||||
// Use fullInputSchema for the type to always include run_in_background
|
||||
@@ -697,7 +704,8 @@ async function* runPowerShellCommand({
|
||||
description,
|
||||
timeout,
|
||||
run_in_background,
|
||||
dangerouslyDisableSandbox
|
||||
dangerouslyDisableSandbox,
|
||||
_dangerouslyDisableSandboxApproved
|
||||
} = input;
|
||||
const timeoutMs = Math.min(timeout || getDefaultTimeoutMs(), getMaxTimeoutMs());
|
||||
let fullOutput = '';
|
||||
@@ -749,7 +757,8 @@ async function* runPowerShellCommand({
|
||||
// The explicit platform check is redundant-but-obvious.
|
||||
shouldUseSandbox: getPlatform() === 'windows' ? false : shouldUseSandbox({
|
||||
command,
|
||||
dangerouslyDisableSandbox
|
||||
dangerouslyDisableSandbox,
|
||||
_dangerouslyDisableSandboxApproved
|
||||
}),
|
||||
shouldAutoBackground
|
||||
});
|
||||
|
||||
@@ -9,6 +9,7 @@ import { z } from 'zod/v4'
|
||||
import { getFeatureValue_CACHED_MAY_BE_STALE } from '../../services/analytics/growthbook.js'
|
||||
import { queryModelWithStreaming } from '../../services/api/claude.js'
|
||||
import { collectCodexCompletedResponse } from '../../services/api/codexShim.js'
|
||||
import { fetchWithProxyRetry } from '../../services/api/fetchWithProxyRetry.js'
|
||||
import {
|
||||
resolveCodexApiCredentials,
|
||||
resolveProviderRequest,
|
||||
@@ -314,7 +315,7 @@ async function runCodexWebSearch(
|
||||
body.reasoning = request.reasoning
|
||||
}
|
||||
|
||||
const response = await fetch(`${request.baseUrl}/responses`, {
|
||||
const response = await fetchWithProxyRetry(`${request.baseUrl}/responses`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
|
||||
@@ -148,6 +148,42 @@ type Position = {
|
||||
column: number
|
||||
}
|
||||
|
||||
export function maskTextWithVisibleEdges(
|
||||
value: string,
|
||||
mask: string,
|
||||
visiblePrefix = 3,
|
||||
visibleSuffix = 3,
|
||||
): string {
|
||||
if (!mask || !value) return value
|
||||
|
||||
const graphemes = Array.from(getGraphemeSegmenter().segment(value))
|
||||
const secretGraphemeCount = graphemes.filter(
|
||||
({ segment }) => segment !== '\n',
|
||||
).length
|
||||
const visibleCount = visiblePrefix + visibleSuffix
|
||||
|
||||
if (secretGraphemeCount <= visibleCount) {
|
||||
return graphemes
|
||||
.map(({ segment }) => (segment === '\n' ? segment : mask))
|
||||
.join('')
|
||||
}
|
||||
|
||||
let secretIndex = 0
|
||||
return graphemes
|
||||
.map(({ segment }) => {
|
||||
if (segment === '\n') return segment
|
||||
|
||||
const nextSegment =
|
||||
secretIndex < visiblePrefix ||
|
||||
secretIndex >= secretGraphemeCount - visibleSuffix
|
||||
? segment
|
||||
: mask
|
||||
secretIndex += 1
|
||||
return nextSegment
|
||||
})
|
||||
.join('')
|
||||
}
|
||||
|
||||
export class Cursor {
|
||||
readonly offset: number
|
||||
constructor(
|
||||
@@ -208,7 +244,12 @@ export class Cursor {
|
||||
maxVisibleLines?: number,
|
||||
) {
|
||||
const { line, column } = this.getPosition()
|
||||
const allLines = this.measuredText.getWrappedText()
|
||||
const allLines = mask
|
||||
? new MeasuredText(
|
||||
maskTextWithVisibleEdges(this.text, mask),
|
||||
this.measuredText.columns,
|
||||
).getWrappedText()
|
||||
: this.measuredText.getWrappedText()
|
||||
|
||||
const startLine = this.getViewportStartLine(maxVisibleLines)
|
||||
const endLine =
|
||||
@@ -221,23 +262,6 @@ export class Cursor {
|
||||
.map((text, i) => {
|
||||
const currentLine = i + startLine
|
||||
let displayText = text
|
||||
if (mask) {
|
||||
const graphemes = Array.from(getGraphemeSegmenter().segment(text))
|
||||
if (currentLine === allLines.length - 1) {
|
||||
// Last line: mask all but the trailing 6 chars so the user can
|
||||
// confirm they pasted the right thing without exposing the full token
|
||||
const visibleCount = Math.min(6, graphemes.length)
|
||||
const maskCount = graphemes.length - visibleCount
|
||||
const splitOffset =
|
||||
graphemes.length > visibleCount ? graphemes[maskCount]!.index : 0
|
||||
displayText = mask.repeat(maskCount) + text.slice(splitOffset)
|
||||
} else {
|
||||
// Earlier wrapped lines: fully mask. Previously only the last line
|
||||
// was masked, leaking the start of the token on narrow terminals
|
||||
// where the pasted OAuth code wraps across multiple lines.
|
||||
displayText = mask.repeat(graphemes.length)
|
||||
}
|
||||
}
|
||||
// looking for the line with the cursor
|
||||
if (line !== currentLine) return displayText.trimEnd()
|
||||
|
||||
|
||||
@@ -78,3 +78,28 @@ test('toolToAPISchema keeps skill required for SkillTool', async () => {
|
||||
required: ['skill'],
|
||||
})
|
||||
})
|
||||
|
||||
test('toolToAPISchema removes extra required keys not in properties (MCP schema sanitization)', async () => {
|
||||
const schema = await toolToAPISchema(
|
||||
{
|
||||
name: 'mcp__test__create_object',
|
||||
inputSchema: z.strictObject({}),
|
||||
inputJSONSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
name: { type: 'string' },
|
||||
},
|
||||
required: ['name', 'attributes'],
|
||||
},
|
||||
prompt: async () => 'Create an object',
|
||||
} as unknown as Tool,
|
||||
{
|
||||
getToolPermissionContext: async () => getEmptyToolPermissionContext(),
|
||||
tools: [] as unknown as Tools,
|
||||
agents: [],
|
||||
},
|
||||
)
|
||||
|
||||
const inputSchema = (schema as { input_schema: { required?: string[] } }).input_schema
|
||||
expect(inputSchema.required).toEqual(['name'])
|
||||
})
|
||||
|
||||
@@ -111,11 +111,60 @@ function filterSwarmFieldsFromSchema(
|
||||
delete filteredProps[field]
|
||||
}
|
||||
filtered.properties = filteredProps
|
||||
|
||||
// Keep `required` in sync after removing properties
|
||||
if (Array.isArray(filtered.required)) {
|
||||
filtered.required = filtered.required.filter(
|
||||
(key: string) => key in filteredProps,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return filtered
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensure `required` only lists keys present in `properties`.
|
||||
* MCP servers may emit schemas where these are out of sync, causing
|
||||
* API 400 errors ("Extra required key supplied").
|
||||
* Recurses into nested object schemas.
|
||||
*/
|
||||
function sanitizeSchemaRequired(
|
||||
schema: Anthropic.Tool.InputSchema,
|
||||
): Anthropic.Tool.InputSchema {
|
||||
if (!schema || typeof schema !== 'object') {
|
||||
return schema
|
||||
}
|
||||
|
||||
const result = { ...schema }
|
||||
const props = result.properties as Record<string, unknown> | undefined
|
||||
|
||||
if (props && Array.isArray(result.required)) {
|
||||
result.required = result.required.filter(
|
||||
(key: string) => key in props,
|
||||
)
|
||||
}
|
||||
|
||||
// Recurse into nested object properties
|
||||
if (props) {
|
||||
const sanitizedProps = { ...props }
|
||||
for (const [key, value] of Object.entries(sanitizedProps)) {
|
||||
if (
|
||||
value &&
|
||||
typeof value === 'object' &&
|
||||
(value as Record<string, unknown>).type === 'object'
|
||||
) {
|
||||
sanitizedProps[key] = sanitizeSchemaRequired(
|
||||
value as Anthropic.Tool.InputSchema,
|
||||
)
|
||||
}
|
||||
}
|
||||
result.properties = sanitizedProps
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
export async function toolToAPISchema(
|
||||
tool: Tool,
|
||||
options: {
|
||||
@@ -156,7 +205,7 @@ export async function toolToAPISchema(
|
||||
// Use tool's JSON schema directly if provided, otherwise convert Zod schema
|
||||
let input_schema = (
|
||||
'inputJSONSchema' in tool && tool.inputJSONSchema
|
||||
? tool.inputJSONSchema
|
||||
? sanitizeSchemaRequired(tool.inputJSONSchema as Anthropic.Tool.InputSchema)
|
||||
: zodToJsonSchema(tool.inputSchema)
|
||||
) as Anthropic.Tool.InputSchema
|
||||
|
||||
@@ -613,10 +662,6 @@ export function normalizeToolInput<T extends Tool>(
|
||||
...(timeout !== undefined && { timeout }),
|
||||
...(description !== undefined && { description }),
|
||||
...(run_in_background !== undefined && { run_in_background }),
|
||||
...('dangerouslyDisableSandbox' in parsed &&
|
||||
parsed.dangerouslyDisableSandbox !== undefined && {
|
||||
dangerouslyDisableSandbox: parsed.dangerouslyDisableSandbox,
|
||||
}),
|
||||
} as z.infer<T['inputSchema']>
|
||||
}
|
||||
case FileEditTool.name: {
|
||||
|
||||
@@ -2882,7 +2882,7 @@ async function getDiagnosticAttachments(
|
||||
}
|
||||
|
||||
// Get new diagnostics from the tracker (IDE diagnostics via MCP)
|
||||
const newDiagnostics = await diagnosticTracker.getNewDiagnostics()
|
||||
const newDiagnostics = await diagnosticTracker.getNewDiagnosticsCompat()
|
||||
if (newDiagnostics.length === 0) {
|
||||
return []
|
||||
}
|
||||
|
||||
@@ -155,7 +155,7 @@ export {
|
||||
NOTIFICATION_CHANNELS,
|
||||
} from './configConstants.js'
|
||||
|
||||
import type { EDITOR_MODES, NOTIFICATION_CHANNELS } from './configConstants.js'
|
||||
import type { EDITOR_MODES, NOTIFICATION_CHANNELS, PROVIDERS } from './configConstants.js'
|
||||
|
||||
export type NotificationChannel = (typeof NOTIFICATION_CHANNELS)[number]
|
||||
|
||||
@@ -181,10 +181,12 @@ export type DiffTool = 'terminal' | 'auto'
|
||||
|
||||
export type OutputStyle = string
|
||||
|
||||
export type Providers = typeof PROVIDERS[number]
|
||||
|
||||
export type ProviderProfile = {
|
||||
id: string
|
||||
name: string
|
||||
provider: 'openai' | 'anthropic'
|
||||
provider: Providers
|
||||
baseUrl: string
|
||||
model: string
|
||||
apiKey?: string
|
||||
|
||||
@@ -19,3 +19,5 @@ export const EDITOR_MODES = ['normal', 'vim'] as const
|
||||
// 'in-process' = in-process teammates running in same process
|
||||
// 'auto' = automatically choose based on context (default)
|
||||
export const TEAMMATE_MODES = ['auto', 'tmux', 'in-process'] as const
|
||||
|
||||
export const PROVIDERS = ['openai', 'anthropic', 'mistral', 'gemini'] as const
|
||||
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
const originalEnv = {
|
||||
CLAUDE_CODE_USE_OPENAI: process.env.CLAUDE_CODE_USE_OPENAI,
|
||||
CLAUDE_CODE_MAX_OUTPUT_TOKENS: process.env.CLAUDE_CODE_MAX_OUTPUT_TOKENS,
|
||||
OPENAI_MODEL: process.env.OPENAI_MODEL,
|
||||
}
|
||||
|
||||
afterEach(() => {
|
||||
@@ -23,11 +24,17 @@ afterEach(() => {
|
||||
process.env.CLAUDE_CODE_MAX_OUTPUT_TOKENS =
|
||||
originalEnv.CLAUDE_CODE_MAX_OUTPUT_TOKENS
|
||||
}
|
||||
if (originalEnv.OPENAI_MODEL === undefined) {
|
||||
delete process.env.OPENAI_MODEL
|
||||
} else {
|
||||
process.env.OPENAI_MODEL = originalEnv.OPENAI_MODEL
|
||||
}
|
||||
})
|
||||
|
||||
test('deepseek-chat uses provider-specific context and output caps', () => {
|
||||
process.env.CLAUDE_CODE_USE_OPENAI = '1'
|
||||
delete process.env.CLAUDE_CODE_MAX_OUTPUT_TOKENS
|
||||
delete process.env.OPENAI_MODEL
|
||||
|
||||
expect(getContextWindowForModel('deepseek-chat')).toBe(128_000)
|
||||
expect(getModelMaxOutputTokens('deepseek-chat')).toEqual({
|
||||
@@ -40,6 +47,7 @@ test('deepseek-chat uses provider-specific context and output caps', () => {
|
||||
test('deepseek-chat clamps oversized max output overrides to the provider limit', () => {
|
||||
process.env.CLAUDE_CODE_USE_OPENAI = '1'
|
||||
process.env.CLAUDE_CODE_MAX_OUTPUT_TOKENS = '32000'
|
||||
delete process.env.OPENAI_MODEL
|
||||
|
||||
expect(getMaxOutputTokensForModel('deepseek-chat')).toBe(8_192)
|
||||
})
|
||||
@@ -47,6 +55,7 @@ test('deepseek-chat clamps oversized max output overrides to the provider limit'
|
||||
test('gpt-4o uses provider-specific context and output caps', () => {
|
||||
process.env.CLAUDE_CODE_USE_OPENAI = '1'
|
||||
delete process.env.CLAUDE_CODE_MAX_OUTPUT_TOKENS
|
||||
delete process.env.OPENAI_MODEL
|
||||
|
||||
expect(getContextWindowForModel('gpt-4o')).toBe(128_000)
|
||||
expect(getModelMaxOutputTokens('gpt-4o')).toEqual({
|
||||
@@ -59,6 +68,7 @@ test('gpt-4o uses provider-specific context and output caps', () => {
|
||||
test('gpt-4o clamps oversized max output overrides to the provider limit', () => {
|
||||
process.env.CLAUDE_CODE_USE_OPENAI = '1'
|
||||
process.env.CLAUDE_CODE_MAX_OUTPUT_TOKENS = '32000'
|
||||
delete process.env.OPENAI_MODEL
|
||||
|
||||
expect(getMaxOutputTokensForModel('gpt-4o')).toBe(16_384)
|
||||
})
|
||||
@@ -66,6 +76,7 @@ test('gpt-4o clamps oversized max output overrides to the provider limit', () =>
|
||||
test('gpt-5.4 family uses provider-specific context and output caps', () => {
|
||||
process.env.CLAUDE_CODE_USE_OPENAI = '1'
|
||||
delete process.env.CLAUDE_CODE_MAX_OUTPUT_TOKENS
|
||||
delete process.env.OPENAI_MODEL
|
||||
|
||||
expect(getContextWindowForModel('gpt-5.4')).toBe(1_050_000)
|
||||
expect(getModelMaxOutputTokens('gpt-5.4')).toEqual({
|
||||
@@ -98,6 +109,7 @@ test('gpt-5.4 family keeps large max output overrides within provider limits', (
|
||||
test('MiniMax-M2.7 uses explicit provider-specific context and output caps', () => {
|
||||
process.env.CLAUDE_CODE_USE_OPENAI = '1'
|
||||
delete process.env.CLAUDE_CODE_MAX_OUTPUT_TOKENS
|
||||
delete process.env.OPENAI_MODEL
|
||||
|
||||
expect(getContextWindowForModel('MiniMax-M2.7')).toBe(204_800)
|
||||
expect(getModelMaxOutputTokens('MiniMax-M2.7')).toEqual({
|
||||
@@ -110,6 +122,7 @@ test('MiniMax-M2.7 uses explicit provider-specific context and output caps', ()
|
||||
test('unknown openai-compatible models use the 128k fallback window (not 8k, see #635)', () => {
|
||||
process.env.CLAUDE_CODE_USE_OPENAI = '1'
|
||||
delete process.env.CLAUDE_CODE_MAX_OUTPUT_TOKENS
|
||||
delete process.env.OPENAI_MODEL
|
||||
|
||||
expect(getContextWindowForModel('some-unknown-3p-model')).toBe(128_000)
|
||||
})
|
||||
@@ -117,6 +130,7 @@ test('unknown openai-compatible models use the 128k fallback window (not 8k, see
|
||||
test('MiniMax-M2.5 and M2.1 use explicit provider-specific context and output caps', () => {
|
||||
process.env.CLAUDE_CODE_USE_OPENAI = '1'
|
||||
delete process.env.CLAUDE_CODE_MAX_OUTPUT_TOKENS
|
||||
delete process.env.OPENAI_MODEL
|
||||
|
||||
expect(getContextWindowForModel('MiniMax-M2.5')).toBe(204_800)
|
||||
expect(getContextWindowForModel('MiniMax-M2.5-highspeed')).toBe(204_800)
|
||||
@@ -127,3 +141,116 @@ test('MiniMax-M2.5 and M2.1 use explicit provider-specific context and output ca
|
||||
upperLimit: 131_072,
|
||||
})
|
||||
})
|
||||
|
||||
test('DashScope qwen3.6-plus uses provider-specific context and output caps', () => {
|
||||
process.env.CLAUDE_CODE_USE_OPENAI = '1'
|
||||
delete process.env.CLAUDE_CODE_MAX_OUTPUT_TOKENS
|
||||
|
||||
expect(getContextWindowForModel('qwen3.6-plus')).toBe(1_000_000)
|
||||
expect(getModelMaxOutputTokens('qwen3.6-plus')).toEqual({
|
||||
default: 65_536,
|
||||
upperLimit: 65_536,
|
||||
})
|
||||
expect(getMaxOutputTokensForModel('qwen3.6-plus')).toBe(65_536)
|
||||
})
|
||||
|
||||
test('DashScope qwen3.5-plus uses provider-specific context and output caps', () => {
|
||||
process.env.CLAUDE_CODE_USE_OPENAI = '1'
|
||||
delete process.env.CLAUDE_CODE_MAX_OUTPUT_TOKENS
|
||||
|
||||
expect(getContextWindowForModel('qwen3.5-plus')).toBe(1_000_000)
|
||||
expect(getModelMaxOutputTokens('qwen3.5-plus')).toEqual({
|
||||
default: 65_536,
|
||||
upperLimit: 65_536,
|
||||
})
|
||||
expect(getMaxOutputTokensForModel('qwen3.5-plus')).toBe(65_536)
|
||||
})
|
||||
|
||||
test('DashScope qwen3-coder-plus uses provider-specific context and output caps', () => {
|
||||
process.env.CLAUDE_CODE_USE_OPENAI = '1'
|
||||
delete process.env.CLAUDE_CODE_MAX_OUTPUT_TOKENS
|
||||
|
||||
expect(getContextWindowForModel('qwen3-coder-plus')).toBe(1_000_000)
|
||||
expect(getModelMaxOutputTokens('qwen3-coder-plus')).toEqual({
|
||||
default: 65_536,
|
||||
upperLimit: 65_536,
|
||||
})
|
||||
})
|
||||
|
||||
test('DashScope qwen3-coder-next uses provider-specific context and output caps', () => {
|
||||
process.env.CLAUDE_CODE_USE_OPENAI = '1'
|
||||
delete process.env.CLAUDE_CODE_MAX_OUTPUT_TOKENS
|
||||
|
||||
expect(getContextWindowForModel('qwen3-coder-next')).toBe(262_144)
|
||||
expect(getModelMaxOutputTokens('qwen3-coder-next')).toEqual({
|
||||
default: 65_536,
|
||||
upperLimit: 65_536,
|
||||
})
|
||||
})
|
||||
|
||||
test('DashScope qwen3-max uses provider-specific context and output caps', () => {
|
||||
process.env.CLAUDE_CODE_USE_OPENAI = '1'
|
||||
delete process.env.CLAUDE_CODE_MAX_OUTPUT_TOKENS
|
||||
|
||||
expect(getContextWindowForModel('qwen3-max')).toBe(262_144)
|
||||
expect(getModelMaxOutputTokens('qwen3-max')).toEqual({
|
||||
default: 32_768,
|
||||
upperLimit: 32_768,
|
||||
})
|
||||
})
|
||||
|
||||
test('DashScope qwen3-max dated variant resolves to base entry via prefix match', () => {
|
||||
process.env.CLAUDE_CODE_USE_OPENAI = '1'
|
||||
delete process.env.CLAUDE_CODE_MAX_OUTPUT_TOKENS
|
||||
|
||||
expect(getContextWindowForModel('qwen3-max-2026-01-23')).toBe(262_144)
|
||||
expect(getModelMaxOutputTokens('qwen3-max-2026-01-23')).toEqual({
|
||||
default: 32_768,
|
||||
upperLimit: 32_768,
|
||||
})
|
||||
})
|
||||
|
||||
test('DashScope kimi-k2.5 uses provider-specific context and output caps', () => {
|
||||
process.env.CLAUDE_CODE_USE_OPENAI = '1'
|
||||
delete process.env.CLAUDE_CODE_MAX_OUTPUT_TOKENS
|
||||
|
||||
expect(getContextWindowForModel('kimi-k2.5')).toBe(262_144)
|
||||
expect(getModelMaxOutputTokens('kimi-k2.5')).toEqual({
|
||||
default: 32_768,
|
||||
upperLimit: 32_768,
|
||||
})
|
||||
})
|
||||
|
||||
test('DashScope glm-5 uses provider-specific context and output caps', () => {
|
||||
process.env.CLAUDE_CODE_USE_OPENAI = '1'
|
||||
delete process.env.CLAUDE_CODE_MAX_OUTPUT_TOKENS
|
||||
|
||||
expect(getContextWindowForModel('glm-5')).toBe(202_752)
|
||||
expect(getModelMaxOutputTokens('glm-5')).toEqual({
|
||||
default: 16_384,
|
||||
upperLimit: 16_384,
|
||||
})
|
||||
})
|
||||
|
||||
test('DashScope glm-4.7 uses provider-specific context and output caps', () => {
|
||||
process.env.CLAUDE_CODE_USE_OPENAI = '1'
|
||||
delete process.env.CLAUDE_CODE_MAX_OUTPUT_TOKENS
|
||||
|
||||
expect(getContextWindowForModel('glm-4.7')).toBe(202_752)
|
||||
expect(getModelMaxOutputTokens('glm-4.7')).toEqual({
|
||||
default: 16_384,
|
||||
upperLimit: 16_384,
|
||||
})
|
||||
})
|
||||
|
||||
test('DashScope models clamp oversized max output overrides to the provider limit', () => {
|
||||
process.env.CLAUDE_CODE_USE_OPENAI = '1'
|
||||
process.env.CLAUDE_CODE_MAX_OUTPUT_TOKENS = '100000'
|
||||
|
||||
expect(getMaxOutputTokensForModel('qwen3.6-plus')).toBe(65_536)
|
||||
expect(getMaxOutputTokensForModel('qwen3.5-plus')).toBe(65_536)
|
||||
expect(getMaxOutputTokensForModel('qwen3-coder-next')).toBe(65_536)
|
||||
expect(getMaxOutputTokensForModel('qwen3-max')).toBe(32_768)
|
||||
expect(getMaxOutputTokensForModel('kimi-k2.5')).toBe(32_768)
|
||||
expect(getMaxOutputTokensForModel('glm-5')).toBe(16_384)
|
||||
})
|
||||
|
||||
@@ -37,6 +37,8 @@ export const CLAUDE_3_7_SONNET_CONFIG = {
|
||||
gemini: 'gemini-2.0-flash',
|
||||
github: 'github:copilot',
|
||||
codex: 'gpt-5.4',
|
||||
'nvidia-nim': 'nvidia/llama-3.1-nemotron-70b-instruct',
|
||||
minimax: 'MiniMax-M2.5',
|
||||
} as const satisfies ModelConfig
|
||||
|
||||
export const CLAUDE_3_5_V2_SONNET_CONFIG = {
|
||||
@@ -48,6 +50,8 @@ export const CLAUDE_3_5_V2_SONNET_CONFIG = {
|
||||
gemini: 'gemini-2.0-flash',
|
||||
github: 'github:copilot',
|
||||
codex: 'gpt-5.4',
|
||||
'nvidia-nim': 'nvidia/llama-3.1-nemotron-70b-instruct',
|
||||
minimax: 'MiniMax-M2.5',
|
||||
} as const satisfies ModelConfig
|
||||
|
||||
export const CLAUDE_3_5_HAIKU_CONFIG = {
|
||||
@@ -59,6 +63,8 @@ export const CLAUDE_3_5_HAIKU_CONFIG = {
|
||||
gemini: 'gemini-2.0-flash-lite',
|
||||
github: 'github:copilot',
|
||||
codex: 'gpt-5.4',
|
||||
'nvidia-nim': 'nvidia/llama-3.1-nemotron-70b-instruct',
|
||||
minimax: 'MiniMax-M2.5',
|
||||
} as const satisfies ModelConfig
|
||||
|
||||
export const CLAUDE_HAIKU_4_5_CONFIG = {
|
||||
@@ -70,6 +76,8 @@ export const CLAUDE_HAIKU_4_5_CONFIG = {
|
||||
gemini: 'gemini-2.0-flash-lite',
|
||||
github: 'github:copilot',
|
||||
codex: 'gpt-5.4',
|
||||
'nvidia-nim': 'nvidia/llama-3.1-nemotron-70b-instruct',
|
||||
minimax: 'MiniMax-M2.5',
|
||||
} as const satisfies ModelConfig
|
||||
|
||||
export const CLAUDE_SONNET_4_CONFIG = {
|
||||
@@ -81,6 +89,8 @@ export const CLAUDE_SONNET_4_CONFIG = {
|
||||
gemini: 'gemini-2.0-flash',
|
||||
github: 'github:copilot',
|
||||
codex: 'gpt-5.4',
|
||||
'nvidia-nim': 'nvidia/llama-3.1-nemotron-70b-instruct',
|
||||
minimax: 'MiniMax-M2.5',
|
||||
} as const satisfies ModelConfig
|
||||
|
||||
export const CLAUDE_SONNET_4_5_CONFIG = {
|
||||
@@ -92,6 +102,8 @@ export const CLAUDE_SONNET_4_5_CONFIG = {
|
||||
gemini: 'gemini-2.0-flash',
|
||||
github: 'github:copilot',
|
||||
codex: 'gpt-5.4',
|
||||
'nvidia-nim': 'nvidia/llama-3.1-nemotron-70b-instruct',
|
||||
minimax: 'MiniMax-M2.5',
|
||||
} as const satisfies ModelConfig
|
||||
|
||||
export const CLAUDE_OPUS_4_CONFIG = {
|
||||
@@ -103,6 +115,8 @@ export const CLAUDE_OPUS_4_CONFIG = {
|
||||
gemini: 'gemini-2.5-pro-preview-03-25',
|
||||
github: 'github:copilot',
|
||||
codex: 'gpt-5.4',
|
||||
'nvidia-nim': 'nvidia/llama-3.1-nemotron-70b-instruct',
|
||||
minimax: 'MiniMax-M2.5',
|
||||
} as const satisfies ModelConfig
|
||||
|
||||
export const CLAUDE_OPUS_4_1_CONFIG = {
|
||||
@@ -114,6 +128,8 @@ export const CLAUDE_OPUS_4_1_CONFIG = {
|
||||
gemini: 'gemini-2.5-pro-preview-03-25',
|
||||
github: 'github:copilot',
|
||||
codex: 'gpt-5.4',
|
||||
'nvidia-nim': 'nvidia/llama-3.1-nemotron-70b-instruct',
|
||||
minimax: 'MiniMax-M2.5',
|
||||
} as const satisfies ModelConfig
|
||||
|
||||
export const CLAUDE_OPUS_4_5_CONFIG = {
|
||||
@@ -125,6 +141,8 @@ export const CLAUDE_OPUS_4_5_CONFIG = {
|
||||
gemini: 'gemini-2.5-pro-preview-03-25',
|
||||
github: 'github:copilot',
|
||||
codex: 'gpt-5.4',
|
||||
'nvidia-nim': 'nvidia/llama-3.1-nemotron-70b-instruct',
|
||||
minimax: 'MiniMax-M2.5',
|
||||
} as const satisfies ModelConfig
|
||||
|
||||
export const CLAUDE_OPUS_4_6_CONFIG = {
|
||||
@@ -136,6 +154,8 @@ export const CLAUDE_OPUS_4_6_CONFIG = {
|
||||
gemini: 'gemini-2.5-pro-preview-03-25',
|
||||
github: 'github:copilot',
|
||||
codex: 'gpt-5.4',
|
||||
'nvidia-nim': 'nvidia/llama-3.1-nemotron-70b-instruct',
|
||||
minimax: 'MiniMax-M2.5',
|
||||
} as const satisfies ModelConfig
|
||||
|
||||
export const CLAUDE_SONNET_4_6_CONFIG = {
|
||||
@@ -147,6 +167,8 @@ export const CLAUDE_SONNET_4_6_CONFIG = {
|
||||
gemini: 'gemini-2.0-flash',
|
||||
github: 'github:copilot',
|
||||
codex: 'gpt-5.4',
|
||||
'nvidia-nim': 'nvidia/llama-3.1-nemotron-70b-instruct',
|
||||
minimax: 'MiniMax-M2.5',
|
||||
} as const satisfies ModelConfig
|
||||
|
||||
// @[MODEL LAUNCH]: Register the new config here.
|
||||
@@ -181,4 +203,4 @@ export const CANONICAL_ID_TO_KEY: Record<CanonicalModelId, ModelKey> =
|
||||
(Object.entries(ALL_MODEL_CONFIGS) as [ModelKey, ModelConfig][]).map(
|
||||
([key, cfg]) => [cfg.firstParty, key],
|
||||
),
|
||||
) as Record<CanonicalModelId, ModelKey>
|
||||
) as Record<CanonicalModelId, ModelKey>
|
||||
46
src/utils/model/minimaxModels.ts
Normal file
46
src/utils/model/minimaxModels.ts
Normal file
@@ -0,0 +1,46 @@
|
||||
/**
|
||||
* MiniMax model list for the /model picker.
|
||||
* Full model catalog from MiniMax API.
|
||||
*/
|
||||
|
||||
import type { ModelOption } from './modelOptions.js'
|
||||
import { getAPIProvider } from './providers.js'
|
||||
import { isEnvTruthy } from '../envUtils.js'
|
||||
|
||||
export function isMiniMaxProvider(): boolean {
|
||||
if (isEnvTruthy(process.env.MINIMAX_API_KEY)) {
|
||||
return true
|
||||
}
|
||||
const baseUrl = process.env.OPENAI_BASE_URL ?? ''
|
||||
if (baseUrl.includes('minimax')) {
|
||||
return true
|
||||
}
|
||||
return getAPIProvider() === 'minimax'
|
||||
}
|
||||
|
||||
function getMiniMaxModels(): ModelOption[] {
|
||||
return [
|
||||
// Latest Generation Models - use correct MiniMax naming with M prefix
|
||||
{ value: 'MiniMax-M2', label: 'MiniMax M2', description: 'MoE model - 131K context - Chat/Code/Reasoning' },
|
||||
{ value: 'MiniMax-M2.1', label: 'MiniMax M2.1', description: 'Enhanced - 200K context - Vision' },
|
||||
{ value: 'MiniMax-M2.5', label: 'MiniMax M2.5', description: 'Flagship - 256K context - Vision/Function-calling' },
|
||||
{ value: 'MiniMax-Text-01', label: 'MiniMax Text 01', description: 'Text-focused - 512K context - FREE' },
|
||||
{ value: 'MiniMax-Text-01-Preview', label: 'MiniMax Text 01 Preview', description: 'Preview - 256K context - FREE' },
|
||||
{ value: 'MiniMax-Vision-01', label: 'MiniMax Vision 01', description: 'Vision model - 32K context' },
|
||||
{ value: 'MiniMax-Vision-01-Fast', label: 'MiniMax Vision 01 Fast', description: 'Fast vision - 16K context - FREE' },
|
||||
// Legacy free tier models
|
||||
{ value: 'abab6.5s-chat', label: 'ABAB 6.5S Chat', description: 'Legacy free - 16K context' },
|
||||
{ value: 'abab6.5-chat', label: 'ABAB 6.5 Chat', description: 'Legacy free - 32K context' },
|
||||
{ value: 'abab6.5g-chat', label: 'ABAB 6.5G Chat', description: 'Generation 6.5 - 32K context' },
|
||||
{ value: 'abab6-chat', label: 'ABAB 6 Chat', description: 'Legacy - 8K context' },
|
||||
]
|
||||
}
|
||||
|
||||
let cachedMiniMaxOptions: ModelOption[] | null = null
|
||||
|
||||
export function getCachedMiniMaxModelOptions(): ModelOption[] {
|
||||
if (!cachedMiniMaxOptions) {
|
||||
cachedMiniMaxOptions = getMiniMaxModels()
|
||||
}
|
||||
return cachedMiniMaxOptions
|
||||
}
|
||||
57
src/utils/model/model.github.test.ts
Normal file
57
src/utils/model/model.github.test.ts
Normal file
@@ -0,0 +1,57 @@
|
||||
import { afterEach, beforeEach, expect, test } from 'bun:test'
|
||||
|
||||
import { saveGlobalConfig } from '../config.js'
|
||||
import { getDefaultMainLoopModelSetting, getUserSpecifiedModelSetting } from './model.js'
|
||||
|
||||
const env = {
|
||||
CLAUDE_CODE_USE_GITHUB: process.env.CLAUDE_CODE_USE_GITHUB,
|
||||
CLAUDE_CODE_USE_OPENAI: process.env.CLAUDE_CODE_USE_OPENAI,
|
||||
CLAUDE_CODE_USE_GEMINI: process.env.CLAUDE_CODE_USE_GEMINI,
|
||||
CLAUDE_CODE_USE_BEDROCK: process.env.CLAUDE_CODE_USE_BEDROCK,
|
||||
CLAUDE_CODE_USE_VERTEX: process.env.CLAUDE_CODE_USE_VERTEX,
|
||||
CLAUDE_CODE_USE_FOUNDRY: process.env.CLAUDE_CODE_USE_FOUNDRY,
|
||||
OPENAI_MODEL: process.env.OPENAI_MODEL,
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
process.env.CLAUDE_CODE_USE_GITHUB = '1'
|
||||
delete process.env.CLAUDE_CODE_USE_OPENAI
|
||||
delete process.env.CLAUDE_CODE_USE_GEMINI
|
||||
delete process.env.CLAUDE_CODE_USE_BEDROCK
|
||||
delete process.env.CLAUDE_CODE_USE_VERTEX
|
||||
delete process.env.CLAUDE_CODE_USE_FOUNDRY
|
||||
delete process.env.OPENAI_MODEL
|
||||
saveGlobalConfig(current => ({
|
||||
...current,
|
||||
model: ({ bad: true } as unknown) as string,
|
||||
}))
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
process.env.CLAUDE_CODE_USE_GITHUB = env.CLAUDE_CODE_USE_GITHUB
|
||||
process.env.CLAUDE_CODE_USE_OPENAI = env.CLAUDE_CODE_USE_OPENAI
|
||||
process.env.CLAUDE_CODE_USE_GEMINI = env.CLAUDE_CODE_USE_GEMINI
|
||||
process.env.CLAUDE_CODE_USE_BEDROCK = env.CLAUDE_CODE_USE_BEDROCK
|
||||
process.env.CLAUDE_CODE_USE_VERTEX = env.CLAUDE_CODE_USE_VERTEX
|
||||
process.env.CLAUDE_CODE_USE_FOUNDRY = env.CLAUDE_CODE_USE_FOUNDRY
|
||||
process.env.OPENAI_MODEL = env.OPENAI_MODEL
|
||||
saveGlobalConfig(current => ({
|
||||
...current,
|
||||
model: undefined,
|
||||
}))
|
||||
})
|
||||
|
||||
test('github default model setting ignores non-string saved model', () => {
|
||||
const model = getDefaultMainLoopModelSetting()
|
||||
expect(typeof model).toBe('string')
|
||||
expect(model).not.toBe('[object Object]')
|
||||
expect(model.length).toBeGreaterThan(0)
|
||||
})
|
||||
|
||||
test('user specified model ignores non-string saved model', () => {
|
||||
const model = getUserSpecifiedModelSetting()
|
||||
if (model !== undefined && model !== null) {
|
||||
expect(typeof model).toBe('string')
|
||||
expect(model).not.toBe('[object Object]')
|
||||
}
|
||||
})
|
||||
@@ -33,6 +33,12 @@ export type ModelShortName = string
|
||||
export type ModelName = string
|
||||
export type ModelSetting = ModelName | ModelAlias | null
|
||||
|
||||
function normalizeModelSetting(value: unknown): ModelName | ModelAlias | undefined {
|
||||
if (typeof value !== 'string') return undefined
|
||||
const trimmed = value.trim()
|
||||
return trimmed.length > 0 ? trimmed : undefined
|
||||
}
|
||||
|
||||
export function getSmallFastModel(): ModelName {
|
||||
if (process.env.ANTHROPIC_SMALL_FAST_MODEL) return process.env.ANTHROPIC_SMALL_FAST_MODEL
|
||||
// For Gemini provider, use a fast model
|
||||
@@ -82,6 +88,7 @@ export function getUserSpecifiedModelSetting(): ModelSetting | undefined {
|
||||
specifiedModel = modelOverride
|
||||
} else {
|
||||
const settings = getSettings_DEPRECATED() || {}
|
||||
const setting = normalizeModelSetting(settings.model)
|
||||
// Read the model env var that matches the active provider to prevent
|
||||
// cross-provider leaks (e.g. ANTHROPIC_MODEL sent to the OpenAI API).
|
||||
const provider = getAPIProvider()
|
||||
@@ -90,7 +97,7 @@ export function getUserSpecifiedModelSetting(): ModelSetting | undefined {
|
||||
(provider === 'mistral' ? process.env.MISTRAL_MODEL : undefined) ||
|
||||
(provider === 'openai' || provider === 'gemini' || provider === 'mistral' || provider === 'github' ? process.env.OPENAI_MODEL : undefined) ||
|
||||
(provider === 'firstParty' ? process.env.ANTHROPIC_MODEL : undefined) ||
|
||||
settings.model ||
|
||||
setting ||
|
||||
undefined
|
||||
}
|
||||
|
||||
@@ -264,7 +271,11 @@ export function getDefaultMainLoopModelSetting(): ModelName | ModelAlias {
|
||||
// GitHub Copilot provider: check settings.model first, then env, then default
|
||||
if (getAPIProvider() === 'github') {
|
||||
const settings = getSettings_DEPRECATED() || {}
|
||||
return settings.model || process.env.OPENAI_MODEL || 'github:copilot'
|
||||
return (
|
||||
normalizeModelSetting(settings.model) ||
|
||||
normalizeModelSetting(process.env.OPENAI_MODEL) ||
|
||||
'github:copilot'
|
||||
)
|
||||
}
|
||||
// Gemini provider: always use the configured Gemini model
|
||||
if (getAPIProvider() === 'gemini') {
|
||||
@@ -595,7 +606,10 @@ export function getPublicModelName(model: ModelName): string {
|
||||
export function parseUserSpecifiedModel(
|
||||
modelInput: ModelName | ModelAlias,
|
||||
): ModelName {
|
||||
const modelInputTrimmed = modelInput.trim()
|
||||
const modelInputTrimmed = normalizeModelSetting(modelInput)
|
||||
if (!modelInputTrimmed) {
|
||||
return getDefaultSonnetModel()
|
||||
}
|
||||
const normalizedModel = modelInputTrimmed.toLowerCase()
|
||||
|
||||
const has1mTag = has1mContext(normalizedModel)
|
||||
|
||||
@@ -33,8 +33,14 @@ import {
|
||||
} from './model.js'
|
||||
import { has1mContext } from '../context.js'
|
||||
import { getGlobalConfig } from '../config.js'
|
||||
import { getActiveOpenAIModelOptionsCache } from '../providerProfiles.js'
|
||||
import {
|
||||
getActiveOpenAIModelOptionsCache,
|
||||
getActiveProviderProfile,
|
||||
getProfileModelOptions,
|
||||
} from '../providerProfiles.js'
|
||||
import { getCachedOllamaModelOptions, isOllamaProvider } from './ollamaModels.js'
|
||||
import { getCachedNvidiaNimModelOptions, isNvidiaNimProvider } from './nvidiaNimModels.js'
|
||||
import { getCachedMiniMaxModelOptions, isMiniMaxProvider } from './minimaxModels.js'
|
||||
import { getAntModels } from './antModels.js'
|
||||
|
||||
// @[MODEL LAUNCH]: Update all the available and default model option strings below.
|
||||
@@ -390,6 +396,26 @@ function getModelOptionsBase(fastMode = false): ModelOption[] {
|
||||
return [defaultOption]
|
||||
}
|
||||
|
||||
// When using NVIDIA NIM, show models from the NVIDIA catalog
|
||||
if (isNvidiaNimProvider()) {
|
||||
const defaultOption = getDefaultOptionForUser(fastMode)
|
||||
const nvidiaModels = getCachedNvidiaNimModelOptions()
|
||||
if (nvidiaModels.length > 0) {
|
||||
return [defaultOption, ...nvidiaModels]
|
||||
}
|
||||
return [defaultOption]
|
||||
}
|
||||
|
||||
// When using MiniMax, show models from the MiniMax catalog
|
||||
if (isMiniMaxProvider()) {
|
||||
const defaultOption = getDefaultOptionForUser(fastMode)
|
||||
const minimaxModels = getCachedMiniMaxModelOptions()
|
||||
if (minimaxModels.length > 0) {
|
||||
return [defaultOption, ...minimaxModels]
|
||||
}
|
||||
return [defaultOption]
|
||||
}
|
||||
|
||||
if (process.env.USER_TYPE === 'ant') {
|
||||
// Build options from antModels config
|
||||
const antModelOptions: ModelOption[] = getAntModels().map(m => ({
|
||||
@@ -454,6 +480,20 @@ function getModelOptionsBase(fastMode = false): ModelOption[] {
|
||||
]
|
||||
}
|
||||
|
||||
// When a provider profile's env is applied, collect its models so they
|
||||
// can be appended to the standard picker options below.
|
||||
// We check PROFILE_ENV_APPLIED to avoid the ?? profiles[0] fallback in
|
||||
// getActiveProviderProfile which would affect users with inactive profiles.
|
||||
const profileEnvApplied = process.env.CLAUDE_CODE_PROVIDER_PROFILE_ENV_APPLIED === '1'
|
||||
const profileModelOptions: ModelOption[] = []
|
||||
if (profileEnvApplied) {
|
||||
const activeProfile = getActiveProviderProfile()
|
||||
if (activeProfile) {
|
||||
const models = getProfileModelOptions(activeProfile)
|
||||
profileModelOptions.push(...models)
|
||||
}
|
||||
}
|
||||
|
||||
// PAYG 1P API: Default (Sonnet) + Sonnet 1M + Opus 4.6 + Opus 1M + Haiku
|
||||
if (getAPIProvider() === 'firstParty') {
|
||||
const payg1POptions = [getDefaultOptionForUser(fastMode)]
|
||||
@@ -469,6 +509,7 @@ function getModelOptionsBase(fastMode = false): ModelOption[] {
|
||||
}
|
||||
}
|
||||
payg1POptions.push(getHaiku45Option())
|
||||
payg1POptions.push(...profileModelOptions)
|
||||
return payg1POptions
|
||||
}
|
||||
|
||||
@@ -508,6 +549,7 @@ function getModelOptionsBase(fastMode = false): ModelOption[] {
|
||||
} else {
|
||||
payg3pOptions.push(getHaikuOption())
|
||||
}
|
||||
payg3pOptions.push(...profileModelOptions)
|
||||
return payg3pOptions
|
||||
}
|
||||
|
||||
|
||||
161
src/utils/model/nvidiaNimModels.ts
Normal file
161
src/utils/model/nvidiaNimModels.ts
Normal file
@@ -0,0 +1,161 @@
|
||||
/**
|
||||
* NVIDIA NIM model list for the /model picker.
|
||||
* Filtered to chat/instruct models only - embedding, reward, safety, vision, etc. excluded.
|
||||
*/
|
||||
|
||||
import type { ModelOption } from './modelOptions.js'
|
||||
import { getAPIProvider } from './providers.js'
|
||||
import { isEnvTruthy } from '../envUtils.js'
|
||||
|
||||
export function isNvidiaNimProvider(): boolean {
|
||||
// Check if explicitly set via NVIDIA_NIM or via provider flag
|
||||
if (isEnvTruthy(process.env.NVIDIA_NIM)) {
|
||||
return true
|
||||
}
|
||||
// Also check if using NVIDIA NIM endpoint
|
||||
const baseUrl = process.env.OPENAI_BASE_URL ?? ''
|
||||
if (baseUrl.includes('nvidia') || baseUrl.includes('integrate.api.nvidia')) {
|
||||
return true
|
||||
}
|
||||
return getAPIProvider() === 'nvidia-nim'
|
||||
}
|
||||
|
||||
function getNvidiaNimModels(): ModelOption[] {
|
||||
return [
|
||||
// AGENTIC REASONING MODELS
|
||||
{ value: 'nvidia/cosmos-reason2-8b', label: 'Cosmos Reason 2 8B', description: 'Reasoning' },
|
||||
{ value: 'microsoft/phi-4-mini-flash-reasoning', label: 'Phi 4 Mini Flash Reasoning', description: 'Reasoning' },
|
||||
{ value: 'qwen/qwen3-next-80b-a3b-thinking', label: 'Qwen 3 Next 80B Thinking', description: 'Reasoning' },
|
||||
{ value: 'deepseek-ai/deepseek-r1-distill-qwen-32b', label: 'DeepSeek R1 Qwen 32B', description: 'Reasoning' },
|
||||
{ value: 'deepseek-ai/deepseek-r1-distill-qwen-14b', label: 'DeepSeek R1 Qwen 14B', description: 'Reasoning' },
|
||||
{ value: 'deepseek-ai/deepseek-r1-distill-qwen-7b', label: 'DeepSeek R1 Qwen 7B', description: 'Reasoning' },
|
||||
{ value: 'deepseek-ai/deepseek-r1-distill-llama-8b', label: 'DeepSeek R1 Llama 8B', description: 'Reasoning' },
|
||||
{ value: 'qwen/qwq-32b', label: 'QwQ 32B Reasoning', description: 'Reasoning' },
|
||||
// CODE MODELS
|
||||
{ value: 'meta/codellama-70b', label: 'CodeLlama 70B', description: 'Code' },
|
||||
{ value: 'bigcode/starcoder2-15b', label: 'StarCoder2 15B', description: 'Code' },
|
||||
{ value: 'bigcode/starcoder2-7b', label: 'StarCoder2 7B', description: 'Code' },
|
||||
{ value: 'mistralai/codestral-22b-instruct-v0.1', label: 'Codestral 22B', description: 'Code' },
|
||||
{ value: 'mistralai/mamba-codestral-7b-v0.1', label: 'Mamba Codestral 7B', description: 'Code' },
|
||||
{ value: 'deepseek-ai/deepseek-coder-6.7b-instruct', label: 'DeepSeek Coder 6.7B', description: 'Code' },
|
||||
{ value: 'google/codegemma-7b', label: 'CodeGemma 7B', description: 'Code' },
|
||||
{ value: 'google/codegemma-1.1-7b', label: 'CodeGemma 1.1 7B', description: 'Code' },
|
||||
{ value: 'qwen/qwen2.5-coder-32b-instruct', label: 'Qwen 2.5 Coder 32B', description: 'Code' },
|
||||
{ value: 'qwen/qwen2.5-coder-7b-instruct', label: 'Qwen 2.5 Coder 7B', description: 'Code' },
|
||||
{ value: 'qwen/qwen3-coder-480b-a35b-instruct', label: 'Qwen 3 Coder 480B', description: 'Code' },
|
||||
{ value: 'ibm/granite-34b-code-instruct', label: 'Granite 34B Code', description: 'Code' },
|
||||
{ value: 'ibm/granite-8b-code-instruct', label: 'Granite 8B Code', description: 'Code' },
|
||||
// NEMOTRON MODELS - NVIDIA Flagship
|
||||
{ value: 'nvidia/llama-3.1-nemotron-70b-instruct', label: 'Nemotron 70B Instruct', description: 'NVIDIA Flagship' },
|
||||
{ value: 'nvidia/llama-3.1-nemotron-51b-instruct', label: 'Nemotron 51B Instruct', description: 'NVIDIA Flagship' },
|
||||
{ value: 'nvidia/llama-3.1-nemotron-ultra-253b-v1', label: 'Nemotron Ultra 253B', description: 'NVIDIA Flagship' },
|
||||
{ value: 'nvidia/llama-3.3-nemotron-super-49b-v1', label: 'Nemotron Super 49B v1', description: 'NVIDIA Flagship' },
|
||||
{ value: 'nvidia/llama-3.3-nemotron-super-49b-v1.5', label: 'Nemotron Super 49B v1.5', description: 'NVIDIA Flagship' },
|
||||
{ value: 'nvidia/nemotron-4-340b-instruct', label: 'Nemotron 4 340B', description: 'NVIDIA Flagship' },
|
||||
{ value: 'nvidia/nemotron-3-super-120b-a12b', label: 'Nemotron 3 Super 120B', description: 'NVIDIA Flagship' },
|
||||
{ value: 'nvidia/nemotron-3-nano-30b-a3b', label: 'Nemotron 3 Nano 30B', description: 'NVIDIA Flagship' },
|
||||
{ value: 'nvidia/nemotron-mini-4b-instruct', label: 'Nemotron Mini 4B', description: 'NVIDIA Flagship' },
|
||||
{ value: 'nvidia/llama-3.1-nemotron-nano-8b-v1', label: 'Nemotron Nano 8B', description: 'NVIDIA Flagship' },
|
||||
{ value: 'nvidia/llama-3.1-nemotron-nano-4b-v1.1', label: 'Nemotron Nano 4B v1.1', description: 'NVIDIA Flagship' },
|
||||
// CHATQA MODELS
|
||||
{ value: 'nvidia/llama3-chatqa-1.5-70b', label: 'Llama3 ChatQA 1.5 70B', description: 'Chat' },
|
||||
{ value: 'nvidia/llama3-chatqa-1.5-8b', label: 'Llama3 ChatQA 1.5 8B', description: 'Chat' },
|
||||
// META LLAMA MODELS
|
||||
{ value: 'meta/llama-3.1-405b-instruct', label: 'Llama 3.1 405B', description: 'Meta Llama' },
|
||||
{ value: 'meta/llama-3.1-70b-instruct', label: 'Llama 3.1 70B', description: 'Meta Llama' },
|
||||
{ value: 'meta/llama-3.1-8b-instruct', label: 'Llama 3.1 8B', description: 'Meta Llama' },
|
||||
{ value: 'meta/llama-3.2-90b-vision-instruct', label: 'Llama 3.2 90B Vision', description: 'Meta Llama' },
|
||||
{ value: 'meta/llama-3.2-11b-vision-instruct', label: 'Llama 3.2 11B Vision', description: 'Meta Llama' },
|
||||
{ value: 'meta/llama-3.2-3b-instruct', label: 'Llama 3.2 3B', description: 'Meta Llama' },
|
||||
{ value: 'meta/llama-3.2-1b-instruct', label: 'Llama 3.2 1B', description: 'Meta Llama' },
|
||||
{ value: 'meta/llama-3.3-70b-instruct', label: 'Llama 3.3 70B', description: 'Meta Llama' },
|
||||
{ value: 'meta/llama-4-maverick-17b-128e-instruct', label: 'Llama 4 Maverick 17B', description: 'Meta Llama' },
|
||||
{ value: 'meta/llama-4-scout-17b-16e-instruct', label: 'Llama 4 Scout 17B', description: 'Meta Llama' },
|
||||
// GOOGLE GEMMA MODELS (text only - no vision)
|
||||
{ value: 'google/gemma-4-31b-it', label: 'Gemma 4 31B', description: 'Google Gemma' },
|
||||
{ value: 'google/gemma-3-27b-it', label: 'Gemma 3 27B', description: 'Google Gemma' },
|
||||
{ value: 'google/gemma-3-12b-it', label: 'Gemma 3 12B', description: 'Google Gemma' },
|
||||
{ value: 'google/gemma-3-4b-it', label: 'Gemma 3 4B', description: 'Google Gemma' },
|
||||
{ value: 'google/gemma-3-1b-it', label: 'Gemma 3 1B', description: 'Google Gemma' },
|
||||
{ value: 'google/gemma-3n-e4b-it', label: 'Gemma 3N E4B', description: 'Google Gemma' },
|
||||
{ value: 'google/gemma-3n-e2b-it', label: 'Gemma 3N E2B', description: 'Google Gemma' },
|
||||
{ value: 'google/gemma-2-27b-it', label: 'Gemma 2 27B', description: 'Google Gemma' },
|
||||
{ value: 'google/gemma-2-9b-it', label: 'Gemma 2 9B', description: 'Google Gemma' },
|
||||
{ value: 'google/gemma-2-2b-it', label: 'Gemma 2 2B', description: 'Google Gemma' },
|
||||
// MISTRAL MODELS
|
||||
{ value: 'mistralai/mistral-large-3-675b-instruct-2512', label: 'Mistral Large 3 675B', description: 'Mistral' },
|
||||
{ value: 'mistralai/mistral-large-2-instruct', label: 'Mistral Large 2', description: 'Mistral' },
|
||||
{ value: 'mistralai/mistral-large', label: 'Mistral Large', description: 'Mistral' },
|
||||
{ value: 'mistralai/mistral-medium-3-instruct', label: 'Mistral Medium 3', description: 'Mistral' },
|
||||
{ value: 'mistralai/mistral-small-4-119b-2603', label: 'Mistral Small 4 119B', description: 'Mistral' },
|
||||
{ value: 'mistralai/mistral-small-3.1-24b-instruct-2503', label: 'Mistral Small 3.1 24B', description: 'Mistral' },
|
||||
{ value: 'mistralai/mistral-small-24b-instruct', label: 'Mistral Small 24B', description: 'Mistral' },
|
||||
{ value: 'mistralai/mistral-7b-instruct-v0.3', label: 'Mistral 7B v0.3', description: 'Mistral' },
|
||||
{ value: 'mistralai/mistral-7b-instruct-v0.2', label: 'Mistral 7B v0.2', description: 'Mistral' },
|
||||
{ value: 'mistralai/mixtral-8x22b-instruct-v0.1', label: 'Mixtral 8x22B', description: 'Mistral' },
|
||||
{ value: 'mistralai/mixtral-8x22b-instruct-v0.1', label: 'Mixtral 8x22B Instruct', description: 'Mistral' },
|
||||
{ value: 'mistralai/mixtral-8x7b-instruct-v0.1', label: 'Mixtral 8x7B', description: 'Mistral' },
|
||||
{ value: 'mistralai/mistral-nemotron', label: 'Mistral Nemotron', description: 'Mistral' },
|
||||
{ value: 'mistralai/mathstral-7b-v0.1', label: 'Mathstral 7B', description: 'Math' },
|
||||
{ value: 'mistralai/ministral-14b-instruct-2512', label: 'Ministral 14B', description: 'Mistral' },
|
||||
{ value: 'mistralai/devstral-2-123b-instruct-2512', label: 'Devstral 2 123B', description: 'Code' },
|
||||
{ value: 'mistralai/magistral-small-2506', label: 'Magistral Small', description: 'Mistral' },
|
||||
// MICROSOFT PHI MODELS (text only - no vision)
|
||||
{ value: 'microsoft/phi-4-multimodal-instruct', label: 'Phi 4 Multimodal', description: 'Multimodal' },
|
||||
{ value: 'microsoft/phi-4-mini-instruct', label: 'Phi 4 Mini', description: 'Phi' },
|
||||
{ value: 'microsoft/phi-3.5-mini-instruct', label: 'Phi 3.5 Mini', description: 'Phi' },
|
||||
{ value: 'microsoft/phi-3-small-128k-instruct', label: 'Phi 3 Small 128K', description: 'Phi' },
|
||||
{ value: 'microsoft/phi-3-small-8k-instruct', label: 'Phi 3 Small 8K', description: 'Phi' },
|
||||
{ value: 'microsoft/phi-3-medium-128k-instruct', label: 'Phi 3 Medium 128K', description: 'Phi' },
|
||||
{ value: 'microsoft/phi-3-medium-4k-instruct', label: 'Phi 3 Medium 4K', description: 'Phi' },
|
||||
{ value: 'microsoft/phi-3-mini-128k-instruct', label: 'Phi 3 Mini 128K', description: 'Phi' },
|
||||
{ value: 'microsoft/phi-3-mini-4k-instruct', label: 'Phi 3 Mini 4K', description: 'Phi' },
|
||||
// QWEN MODELS
|
||||
{ value: 'qwen/qwen3.5-397b-a17b', label: 'Qwen 3.5 397B', description: 'Qwen' },
|
||||
{ value: 'qwen/qwen3.5-122b-a10b', label: 'Qwen 3.5 122B', description: 'Qwen' },
|
||||
{ value: 'qwen/qwen3-next-80b-a3b-instruct', label: 'Qwen 3 Next 80B', description: 'Qwen' },
|
||||
{ value: 'qwen/qwen2.5-7b-instruct', label: 'Qwen 2.5 7B', description: 'Qwen' },
|
||||
{ value: 'qwen/qwen2-7b-instruct', label: 'Qwen 2 7B', description: 'Qwen' },
|
||||
{ value: 'qwen/qwen3-32b', label: 'Qwen 3 32B', description: 'Qwen' },
|
||||
{ value: 'qwen/qwen3-8b', label: 'Qwen 3 8B', description: 'Qwen' },
|
||||
// DEEPSEEK MODELS
|
||||
{ value: 'deepseek-ai/deepseek-r1', label: 'DeepSeek R1', description: 'DeepSeek' },
|
||||
{ value: 'deepseek-ai/deepseek-v3', label: 'DeepSeek V3', description: 'DeepSeek' },
|
||||
{ value: 'deepseek-ai/deepseek-v3.2', label: 'DeepSeek V3.2', description: 'DeepSeek' },
|
||||
{ value: 'deepseek-ai/deepseek-v3.1-terminus', label: 'DeepSeek V3.1 Terminus', description: 'DeepSeek' },
|
||||
{ value: 'deepseek-ai/deepseek-v3.1', label: 'DeepSeek V3.1', description: 'DeepSeek' },
|
||||
// IBM GRANITE MODELS
|
||||
{ value: 'ibm/granite-3.3-8b-instruct', label: 'Granite 3.3 8B', description: 'IBM Granite' },
|
||||
{ value: 'ibm/granite-3.0-8b-instruct', label: 'Granite 3.0 8B', description: 'IBM Granite' },
|
||||
{ value: 'ibm/granite-3.0-3b-a800m-instruct', label: 'Granite 3.0 3B', description: 'IBM Granite' },
|
||||
// OTHER MODELS
|
||||
{ value: 'databricks/dbrx-instruct', label: 'DBRX Instruct', description: 'Other' },
|
||||
{ value: '01-ai/yi-large', label: 'Yi Large', description: 'Other' },
|
||||
{ value: 'ai21labs/jamba-1.5-large-instruct', label: 'Jamba 1.5 Large', description: 'Other' },
|
||||
{ value: 'ai21labs/jamba-1.5-mini-instruct', label: 'Jamba 1.5 Mini', description: 'Other' },
|
||||
{ value: 'writer/palmyra-creative-122b', label: 'Palmyra Creative 122B', description: 'Other' },
|
||||
{ value: 'writer/palmyra-fin-70b-32k', label: 'Palmyra Fin 70B 32K', description: 'Other' },
|
||||
{ value: 'writer/palmyra-med-70b', label: 'Palmyra Med 70B', description: 'Other' },
|
||||
{ value: 'writer/palmyra-med-70b-32k', label: 'Palmyra Med 70B 32K', description: 'Other' },
|
||||
// Z-AI GLM MODELS
|
||||
{ value: 'z-ai/glm5', label: 'GLM-5', description: 'Z-AI' },
|
||||
{ value: 'z-ai/glm4.7', label: 'GLM-4.7', description: 'Z-AI' },
|
||||
// MINIMAX MODELS
|
||||
{ value: 'minimaxai/minimax-m2.5', label: 'MiniMax M2.5', description: 'MiniMax' },
|
||||
// MOONSHOT KIMI MODELS
|
||||
{ value: 'moonshotai/kimi-k2.5', label: 'Kimi K2.5', description: 'Moonshot' },
|
||||
{ value: 'moonshotai/kimi-k2-instruct', label: 'Kimi K2 Instruct', description: 'Moonshot' },
|
||||
{ value: 'moonshotai/kimi-k2-thinking', label: 'Kimi K2 Thinking', description: 'Moonshot' },
|
||||
{ value: 'moonshotai/kimi-k2.5-thinking', label: 'Kimi K2.5 Thinking', description: 'Moonshot' },
|
||||
{ value: 'moonshotai/kimi-k2-instruct-0905', label: 'Kimi K2 Instruct 0905', description: 'Moonshot' },
|
||||
]
|
||||
}
|
||||
|
||||
let cachedNvidiaNimOptions: ModelOption[] | null = null
|
||||
|
||||
export function getCachedNvidiaNimModelOptions(): ModelOption[] {
|
||||
if (!cachedNvidiaNimOptions) {
|
||||
cachedNvidiaNimOptions = getNvidiaNimModels()
|
||||
}
|
||||
return cachedNvidiaNimOptions
|
||||
}
|
||||
@@ -104,6 +104,57 @@ const OPENAI_CONTEXT_WINDOWS: Record<string, number> = {
|
||||
'devstral-latest': 256_000,
|
||||
'ministral-3b-latest': 256_000,
|
||||
|
||||
// NVIDIA NIM - popular models
|
||||
'nvidia/llama-3.1-nemotron-70b-instruct': 128_000,
|
||||
'nvidia/llama-3.1-nemotron-ultra-253b-v1': 128_000,
|
||||
'nvidia/nemotron-mini-4b-instruct': 32_768,
|
||||
'meta/llama-3.1-405b-instruct': 128_000,
|
||||
'meta/llama-3.1-70b-instruct': 128_000,
|
||||
'meta/llama-3.1-8b-instruct': 128_000,
|
||||
'meta/llama-3.2-90b-instruct': 128_000,
|
||||
'meta/llama-3.2-1b-instruct': 128_000,
|
||||
'meta/llama-3.2-3b-instruct': 128_000,
|
||||
'meta/llama-3.3-70b-instruct': 128_000,
|
||||
// Google Gemma via NVIDIA NIM
|
||||
'google/gemma-2-27b-it': 8_192,
|
||||
'google/gemma-2-9b-it': 8_192,
|
||||
'google/gemma-3-27b-it': 131_072,
|
||||
'google/gemma-3-12b-it': 131_072,
|
||||
'google/gemma-3-4b-it': 131_072,
|
||||
// DeepSeek via NVIDIA NIM
|
||||
'deepseek-ai/deepseek-r1': 128_000,
|
||||
'deepseek-ai/deepseek-v3': 128_000,
|
||||
'deepseek-ai/deepseek-v3.2': 128_000,
|
||||
// Qwen via NVIDIA NIM
|
||||
'qwen/qwen3-32b': 128_000,
|
||||
'qwen/qwen3-8b': 128_000,
|
||||
'qwen/qwen2.5-7b-instruct': 32_768,
|
||||
// Mistral via NVIDIA NIM
|
||||
'mistralai/mistral-large-3-675b-instruct-2512': 256_000,
|
||||
'mistralai/mistral-large-2-instruct': 256_000,
|
||||
'mistralai/mistral-small-3.1-24b-instruct-2503': 32_768,
|
||||
'mistralai/mixtral-8x7b-instruct-v0.1': 32_768,
|
||||
// Microsoft Phi via NVIDIA NIM
|
||||
'microsoft/phi-4-mini-instruct': 16_384,
|
||||
'microsoft/phi-3.5-mini-instruct': 16_384,
|
||||
'microsoft/phi-3-mini-128k-instruct': 128_000,
|
||||
// IBM Granite via NVIDIA NIM
|
||||
'ibm/granite-3.3-8b-instruct': 8_192,
|
||||
'ibm/granite-8b-code-instruct': 8_192,
|
||||
// GLM models via NVIDIA NIM
|
||||
'z-ai/glm5': 200_000,
|
||||
'z-ai/glm4.7': 128_000,
|
||||
// Kimi models via NVIDIA NIM
|
||||
'moonshotai/kimi-k2.5': 200_000,
|
||||
'moonshotai/kimi-k2-instruct': 128_000,
|
||||
// DBRX via NVIDIA NIM
|
||||
'databricks/dbrx-instruct': 131_072,
|
||||
// Jamba via NVIDIA NIM
|
||||
'ai21labs/jamba-1.5-large-instruct': 256_000,
|
||||
'ai21labs/jamba-1.5-mini-instruct': 256_000,
|
||||
// Yi via NVIDIA NIM
|
||||
'01-ai/yi-large': 32_768,
|
||||
|
||||
// MiniMax (all M2.x variants share 204,800 context, 131,072 max output)
|
||||
'MiniMax-M2.7': 204_800,
|
||||
'MiniMax-M2.7-highspeed': 204_800,
|
||||
@@ -118,14 +169,23 @@ const OPENAI_CONTEXT_WINDOWS: Record<string, number> = {
|
||||
'minimax-m2.1': 204_800,
|
||||
'minimax-m2.1-highspeed': 204_800,
|
||||
|
||||
// MiniMax new models
|
||||
'MiniMax-Text-01': 524_288,
|
||||
'MiniMax-Text-01-Preview': 262_144,
|
||||
'MiniMax-Vision-01': 32_768,
|
||||
'MiniMax-Vision-01-Fast': 16_384,
|
||||
'MiniMax-M2': 204_800,
|
||||
|
||||
// Google (via OpenRouter)
|
||||
'google/gemini-2.0-flash':1_048_576,
|
||||
'google/gemini-2.5-pro': 1_048_576,
|
||||
|
||||
// Google (native via CLAUDE_CODE_USE_GEMINI)
|
||||
'gemini-2.0-flash': 1_048_576,
|
||||
'gemini-2.5-pro': 1_048_576,
|
||||
'gemini-2.5-flash': 1_048_576,
|
||||
'gemini-2.0-flash': 1_048_576,
|
||||
'gemini-2.5-pro': 1_048_576,
|
||||
'gemini-2.5-flash': 1_048_576,
|
||||
'gemini-3.1-pro': 1_048_576,
|
||||
'gemini-3.1-flash-lite-preview': 1_048_576,
|
||||
|
||||
// Ollama local models
|
||||
// Llama 3.1+ models support 128k context natively (Meta official specs).
|
||||
@@ -144,6 +204,21 @@ const OPENAI_CONTEXT_WINDOWS: Record<string, number> = {
|
||||
'llama3.2:1b': 128_000,
|
||||
'qwen3:8b': 128_000,
|
||||
'codestral': 32_768,
|
||||
|
||||
// Alibaba DashScope (Coding Plan)
|
||||
// Model context windows from DashScope API /models endpoint (April 2026).
|
||||
// Values sourced from: qwen3.5-plus/qwen3-coder-plus (1M), qwen3-coder-next/max (256K),
|
||||
// kimi-k2.5 (256K), glm-5/glm-4.7 (198K).
|
||||
// Max output tokens: Qwen variants (64K/32K), GLM (16K).
|
||||
'qwen3.6-plus': 1_000_000,
|
||||
'qwen3.5-plus': 1_000_000,
|
||||
'qwen3-coder-plus': 1_000_000,
|
||||
'qwen3-coder-next': 262_144,
|
||||
'qwen3-max': 262_144,
|
||||
'qwen3-max-2026-01-23': 262_144,
|
||||
'kimi-k2.5': 262_144,
|
||||
'glm-5': 202_752,
|
||||
'glm-4.7': 202_752,
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -246,15 +321,23 @@ const OPENAI_MAX_OUTPUT_TOKENS: Record<string, number> = {
|
||||
'minimax-m2.5-highspeed': 131_072,
|
||||
'minimax-m2.1': 131_072,
|
||||
'minimax-m2.1-highspeed': 131_072,
|
||||
// New MiniMax models
|
||||
'MiniMax-M2': 131_072,
|
||||
'MiniMax-Text-01': 65_536,
|
||||
'MiniMax-Text-01-Preview': 65_536,
|
||||
'MiniMax-Vision-01': 16_384,
|
||||
'MiniMax-Vision-01-Fast': 16_384,
|
||||
|
||||
// Google (via OpenRouter)
|
||||
'google/gemini-2.0-flash': 8_192,
|
||||
'google/gemini-2.5-pro': 65_536,
|
||||
|
||||
// Google (native via CLAUDE_CODE_USE_GEMINI)
|
||||
'gemini-2.0-flash': 8_192,
|
||||
'gemini-2.5-pro': 65_536,
|
||||
'gemini-2.5-flash': 65_536,
|
||||
'gemini-2.0-flash': 8_192,
|
||||
'gemini-2.5-pro': 65_536,
|
||||
'gemini-2.5-flash': 65_536,
|
||||
'gemini-3.1-pro': 65_536,
|
||||
'gemini-3.1-flash-lite-preview': 65_536,
|
||||
|
||||
// Ollama local models (conservative safe defaults)
|
||||
'llama3.3:70b': 4_096,
|
||||
@@ -271,6 +354,43 @@ const OPENAI_MAX_OUTPUT_TOKENS: Record<string, number> = {
|
||||
'llama3.2:1b': 4_096,
|
||||
'qwen3:8b': 8_192,
|
||||
'codestral': 8_192,
|
||||
|
||||
// NVIDIA NIM models
|
||||
'nvidia/llama-3.1-nemotron-70b-instruct': 32_768,
|
||||
'nvidia/nemotron-mini-4b-instruct': 8_192,
|
||||
'meta/llama-3.1-405b-instruct': 32_768,
|
||||
'meta/llama-3.1-70b-instruct': 32_768,
|
||||
'meta/llama-3.2-90b-instruct': 32_768,
|
||||
'meta/llama-3.3-70b-instruct': 32_768,
|
||||
'google/gemma-2-27b-it': 4_096,
|
||||
'google/gemma-3-27b-it': 16_384,
|
||||
'google/gemma-3-12b-it': 16_384,
|
||||
'deepseek-ai/deepseek-r1': 32_768,
|
||||
'deepseek-ai/deepseek-v3': 32_768,
|
||||
'deepseek-ai/deepseek-v3.2': 32_768,
|
||||
'qwen/qwen3-32b': 32_768,
|
||||
'qwen/qwen2.5-7b-instruct': 8_192,
|
||||
'mistralai/mistral-large-3-675b-instruct-2512': 32_768,
|
||||
'mistralai/mixtral-8x7b-instruct-v0.1': 8_192,
|
||||
'microsoft/phi-4-mini-instruct': 4_096,
|
||||
'microsoft/phi-3.5-mini-instruct': 4_096,
|
||||
'ibm/granite-3.3-8b-instruct': 4_096,
|
||||
'z-ai/glm5': 32_768,
|
||||
'moonshotai/kimi-k2.5': 32_768,
|
||||
'databricks/dbrx-instruct': 32_768,
|
||||
'ai21labs/jamba-1.5-large-instruct': 32_768,
|
||||
'01-ai/yi-large': 8_192,
|
||||
|
||||
// Alibaba DashScope (Coding Plan)
|
||||
'qwen3.6-plus': 65_536,
|
||||
'qwen3.5-plus': 65_536,
|
||||
'qwen3-coder-plus': 65_536,
|
||||
'qwen3-coder-next': 65_536,
|
||||
'qwen3-max': 32_768,
|
||||
'qwen3-max-2026-01-23': 32_768,
|
||||
'kimi-k2.5': 32_768,
|
||||
'glm-5': 16_384,
|
||||
'glm-4.7': 16_384,
|
||||
}
|
||||
|
||||
function lookupByModel<T>(table: Record<string, T>, model: string): T | undefined {
|
||||
|
||||
@@ -11,9 +11,17 @@ export type APIProvider =
|
||||
| 'gemini'
|
||||
| 'github'
|
||||
| 'codex'
|
||||
| 'nvidia-nim'
|
||||
| 'minimax'
|
||||
| 'mistral'
|
||||
|
||||
export function getAPIProvider(): APIProvider {
|
||||
if (isEnvTruthy(process.env.NVIDIA_NIM)) {
|
||||
return 'nvidia-nim'
|
||||
}
|
||||
if (isEnvTruthy(process.env.MINIMAX_API_KEY)) {
|
||||
return 'minimax'
|
||||
}
|
||||
return isEnvTruthy(process.env.CLAUDE_CODE_USE_GEMINI)
|
||||
? 'gemini'
|
||||
:
|
||||
|
||||
@@ -11,6 +11,8 @@ import {
|
||||
} from '@anthropic-ai/sdk'
|
||||
import { getModelStrings } from './modelStrings.js'
|
||||
import { getCachedOllamaModelOptions, isOllamaProvider } from './ollamaModels.js'
|
||||
import { getCachedNvidiaNimModelOptions, isNvidiaNimProvider } from './nvidiaNimModels.js'
|
||||
import { getCachedMiniMaxModelOptions, isMiniMaxProvider } from './minimaxModels.js'
|
||||
|
||||
// Cache valid models to avoid repeated API calls
|
||||
const validModelCache = new Map<string, boolean>()
|
||||
@@ -47,6 +49,40 @@ export async function validateModel(
|
||||
// If cache is empty, fall through to API validation
|
||||
}
|
||||
|
||||
// For NVIDIA NIM provider, validate against cached model list
|
||||
if (isNvidiaNimProvider()) {
|
||||
const nvidiaModels = getCachedNvidiaNimModelOptions()
|
||||
const found = nvidiaModels.some(m => m.value === normalizedModel)
|
||||
if (found) {
|
||||
validModelCache.set(normalizedModel, true)
|
||||
return { valid: true }
|
||||
}
|
||||
if (nvidiaModels.length > 0) {
|
||||
const MAX_SHOWN = 5
|
||||
const names = nvidiaModels.map(m => m.value)
|
||||
const shown = names.slice(0, MAX_SHOWN).join(', ')
|
||||
const suffix = names.length > MAX_SHOWN ? ` and ${names.length - MAX_SHOWN} more` : ''
|
||||
return { valid: false, error: `Model '${normalizedModel}' not found in NVIDIA NIM catalog. Available: ${shown}${suffix}` }
|
||||
}
|
||||
}
|
||||
|
||||
// For MiniMax provider, validate against cached model list
|
||||
if (isMiniMaxProvider()) {
|
||||
const minimaxModels = getCachedMiniMaxModelOptions()
|
||||
const found = minimaxModels.some(m => m.value === normalizedModel)
|
||||
if (found) {
|
||||
validModelCache.set(normalizedModel, true)
|
||||
return { valid: true }
|
||||
}
|
||||
if (minimaxModels.length > 0) {
|
||||
const MAX_SHOWN = 5
|
||||
const names = minimaxModels.map(m => m.value)
|
||||
const shown = names.slice(0, MAX_SHOWN).join(', ')
|
||||
const suffix = names.length > MAX_SHOWN ? ` and ${names.length - MAX_SHOWN} more` : ''
|
||||
return { valid: false, error: `Model '${normalizedModel}' not found in MiniMax catalog. Available: ${shown}${suffix}` }
|
||||
}
|
||||
}
|
||||
|
||||
// Check against availableModels allowlist before any API call
|
||||
if (!isModelAllowed(normalizedModel)) {
|
||||
return {
|
||||
|
||||
@@ -76,7 +76,9 @@ describe('OpenClaude paths', () => {
|
||||
})
|
||||
|
||||
test('local installer uses openclaude wrapper path', async () => {
|
||||
delete process.env.CLAUDE_CONFIG_DIR
|
||||
// Force .openclaude config home so the test doesn't fall back to
|
||||
// ~/.claude when ~/.openclaude doesn't exist on this machine.
|
||||
process.env.CLAUDE_CONFIG_DIR = join(homedir(), '.openclaude')
|
||||
const { getLocalClaudePath } = await importFreshLocalInstaller()
|
||||
|
||||
expect(getLocalClaudePath()).toBe(
|
||||
|
||||
@@ -65,10 +65,11 @@ export async function processBashCommand(inputString: string, precedingInputBloc
|
||||
});
|
||||
};
|
||||
|
||||
// User-initiated `!` commands run outside sandbox. Both shell tools honor
|
||||
// dangerouslyDisableSandbox (checked against areUnsandboxedCommandsAllowed()
|
||||
// in shouldUseSandbox.ts). PS sandbox is Linux/macOS/WSL2 only — on Windows
|
||||
// native, shouldUseSandbox() returns false regardless (unsupported platform).
|
||||
// User-initiated `!` commands run outside sandbox when policy allows it.
|
||||
// Bash requires an internal approval marker so model-controlled tool input
|
||||
// cannot disable sandboxing by setting dangerouslyDisableSandbox directly.
|
||||
// PS sandbox is Linux/macOS/WSL2 only — on Windows native, shouldUseSandbox()
|
||||
// returns false regardless (unsupported platform).
|
||||
// Lazy-require PowerShellTool so its ~300KB chunk only loads when the
|
||||
// user has actually selected the powershell default shell.
|
||||
type PSMod = typeof import('src/tools/PowerShellTool/PowerShellTool.js');
|
||||
@@ -81,10 +82,12 @@ export async function processBashCommand(inputString: string, precedingInputBloc
|
||||
const shellTool = PowerShellTool ?? BashTool;
|
||||
const response = PowerShellTool ? await PowerShellTool.call({
|
||||
command: inputString,
|
||||
dangerouslyDisableSandbox: true
|
||||
dangerouslyDisableSandbox: true,
|
||||
_dangerouslyDisableSandboxApproved: true
|
||||
}, bashModeContext, undefined, undefined, onProgress) : await BashTool.call({
|
||||
command: inputString,
|
||||
dangerouslyDisableSandbox: true
|
||||
dangerouslyDisableSandbox: true,
|
||||
_dangerouslyDisableSandboxApproved: true
|
||||
}, bashModeContext, undefined, undefined, onProgress);
|
||||
const data = response.data;
|
||||
if (!data) {
|
||||
|
||||
@@ -105,6 +105,14 @@ export function getLocalOpenAICompatibleProviderLabel(baseUrl?: string): string
|
||||
) {
|
||||
return 'text-generation-webui'
|
||||
}
|
||||
// Check for NVIDIA NIM
|
||||
if (host.includes('nvidia') || haystack.includes('nvidia') || host.includes('integrate.api.nvidia')) {
|
||||
return 'NVIDIA NIM'
|
||||
}
|
||||
// Check for MiniMax (both api.minimax.io and api.minimax.chat)
|
||||
if (host.includes('minimax') || haystack.includes('minimax')) {
|
||||
return 'MiniMax'
|
||||
}
|
||||
} catch {
|
||||
// Fall back to the generic label when the base URL is malformed.
|
||||
}
|
||||
|
||||
@@ -21,6 +21,8 @@ export const VALID_PROVIDERS = [
|
||||
'bedrock',
|
||||
'vertex',
|
||||
'ollama',
|
||||
'nvidia-nim',
|
||||
'minimax',
|
||||
] as const
|
||||
|
||||
export type ProviderFlagName = (typeof VALID_PROVIDERS)[number]
|
||||
@@ -131,6 +133,21 @@ export function applyProviderFlag(
|
||||
}
|
||||
if (model) process.env.OPENAI_MODEL = model
|
||||
break
|
||||
|
||||
case 'nvidia-nim':
|
||||
process.env.CLAUDE_CODE_USE_OPENAI = '1'
|
||||
process.env.OPENAI_BASE_URL ??= 'https://integrate.api.nvidia.com/v1'
|
||||
process.env.NVIDIA_NIM = '1'
|
||||
process.env.OPENAI_MODEL ??= 'nvidia/llama-3.1-nemotron-70b-instruct'
|
||||
if (model) process.env.OPENAI_MODEL = model
|
||||
break
|
||||
|
||||
case 'minimax':
|
||||
process.env.CLAUDE_CODE_USE_OPENAI = '1'
|
||||
process.env.OPENAI_BASE_URL ??= 'https://api.minimax.io/v1'
|
||||
process.env.OPENAI_MODEL ??= 'MiniMax-M2.5'
|
||||
if (model) process.env.OPENAI_MODEL = model
|
||||
break
|
||||
}
|
||||
|
||||
return {}
|
||||
|
||||
108
src/utils/providerModels.test.ts
Normal file
108
src/utils/providerModels.test.ts
Normal file
@@ -0,0 +1,108 @@
|
||||
import { describe, expect, test } from 'bun:test'
|
||||
|
||||
import {
|
||||
getPrimaryModel,
|
||||
hasMultipleModels,
|
||||
parseModelList,
|
||||
} from './providerModels.ts'
|
||||
|
||||
// ── parseModelList ────────────────────────────────────────────────────────────
|
||||
|
||||
describe('parseModelList', () => {
|
||||
test('splits comma-separated models', () => {
|
||||
expect(parseModelList('glm-4.7, glm-4.7-flash')).toEqual([
|
||||
'glm-4.7',
|
||||
'glm-4.7-flash',
|
||||
])
|
||||
})
|
||||
|
||||
test('returns single model in an array', () => {
|
||||
expect(parseModelList('llama3.1:8b')).toEqual(['llama3.1:8b'])
|
||||
})
|
||||
|
||||
test('trims whitespace around each model', () => {
|
||||
expect(parseModelList(' gpt-4o , gpt-4o-mini , o3-mini ')).toEqual([
|
||||
'gpt-4o',
|
||||
'gpt-4o-mini',
|
||||
'o3-mini',
|
||||
])
|
||||
})
|
||||
|
||||
test('filters out empty entries from trailing commas', () => {
|
||||
expect(parseModelList('gpt-4o,,gpt-4o-mini,')).toEqual([
|
||||
'gpt-4o',
|
||||
'gpt-4o-mini',
|
||||
])
|
||||
})
|
||||
|
||||
test('returns empty array for empty string', () => {
|
||||
expect(parseModelList('')).toEqual([])
|
||||
})
|
||||
|
||||
test('returns empty array for whitespace-only string', () => {
|
||||
expect(parseModelList(' ')).toEqual([])
|
||||
})
|
||||
|
||||
test('returns empty array for comma-only string', () => {
|
||||
expect(parseModelList(',,,')).toEqual([])
|
||||
})
|
||||
|
||||
test('handles models with colons', () => {
|
||||
expect(parseModelList('qwen2.5-coder:7b, llama3.1:8b')).toEqual([
|
||||
'qwen2.5-coder:7b',
|
||||
'llama3.1:8b',
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
// ── getPrimaryModel ───────────────────────────────────────────────────────────
|
||||
|
||||
describe('getPrimaryModel', () => {
|
||||
test('returns first model from comma-separated list', () => {
|
||||
expect(getPrimaryModel('glm-4.7, glm-4.7-flash')).toBe('glm-4.7')
|
||||
})
|
||||
|
||||
test('returns the only model when single model is provided', () => {
|
||||
expect(getPrimaryModel('llama3.1:8b')).toBe('llama3.1:8b')
|
||||
})
|
||||
|
||||
test('returns the original string when input is empty', () => {
|
||||
expect(getPrimaryModel('')).toBe('')
|
||||
})
|
||||
|
||||
test('returns first model after trimming', () => {
|
||||
expect(getPrimaryModel(' gpt-4o , gpt-4o-mini')).toBe('gpt-4o')
|
||||
})
|
||||
|
||||
test('returns first model when others are empty from trailing commas', () => {
|
||||
expect(getPrimaryModel('claude-sonnet-4-6,,')).toBe('claude-sonnet-4-6')
|
||||
})
|
||||
})
|
||||
|
||||
// ── hasMultipleModels ─────────────────────────────────────────────────────────
|
||||
|
||||
describe('hasMultipleModels', () => {
|
||||
test('returns true when multiple models are present', () => {
|
||||
expect(hasMultipleModels('glm-4.7, glm-4.7-flash')).toBe(true)
|
||||
})
|
||||
|
||||
test('returns false for a single model', () => {
|
||||
expect(hasMultipleModels('llama3.1:8b')).toBe(false)
|
||||
})
|
||||
|
||||
test('returns false for empty string', () => {
|
||||
expect(hasMultipleModels('')).toBe(false)
|
||||
})
|
||||
|
||||
test('returns false for whitespace-only string', () => {
|
||||
expect(hasMultipleModels(' ')).toBe(false)
|
||||
})
|
||||
|
||||
test('returns false when extra commas produce no extra models', () => {
|
||||
expect(hasMultipleModels('gpt-4o,,')).toBe(false)
|
||||
})
|
||||
|
||||
test('returns true for three models', () => {
|
||||
expect(hasMultipleModels('a, b, c')).toBe(true)
|
||||
})
|
||||
})
|
||||
33
src/utils/providerModels.ts
Normal file
33
src/utils/providerModels.ts
Normal file
@@ -0,0 +1,33 @@
|
||||
/**
|
||||
* Utility functions for parsing comma-separated model names in provider profiles.
|
||||
*
|
||||
* Example: "glm-4.7, glm-4.7-flash" -> ["glm-4.7", "glm-4.7-flash"]
|
||||
* Single model: "llama3.1:8b" -> ["llama3.1:8b"]
|
||||
*/
|
||||
|
||||
/**
|
||||
* Splits a comma-separated model field into an array of trimmed model names,
|
||||
* filtering out any empty entries.
|
||||
*/
|
||||
export function parseModelList(modelField: string): string[] {
|
||||
return modelField
|
||||
.split(',')
|
||||
.map((part) => part.trim())
|
||||
.filter((part) => part.length > 0)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the first (primary) model from a comma-separated model field.
|
||||
* Falls back to the original string if parsing yields no results.
|
||||
*/
|
||||
export function getPrimaryModel(modelField: string): string {
|
||||
const models = parseModelList(modelField)
|
||||
return models.length > 0 ? models[0] : modelField
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if the model field contains more than one model.
|
||||
*/
|
||||
export function hasMultipleModels(modelField: string): boolean {
|
||||
return parseModelList(modelField).length > 1
|
||||
}
|
||||
@@ -166,7 +166,7 @@ test('matching persisted gemini env is reused for gemini launch', async () => {
|
||||
assert.equal(env.GEMINI_BASE_URL, 'https://example.test/v1beta/openai')
|
||||
})
|
||||
|
||||
test('gemini launch ignores mismatched persisted openai env and strips other provider secrets', async () => {
|
||||
test('openai env variables take precedence over gemini', async () => {
|
||||
const env = await buildLaunchEnv({
|
||||
profile: 'gemini',
|
||||
persisted: profile('openai', {
|
||||
@@ -187,16 +187,16 @@ test('gemini launch ignores mismatched persisted openai env and strips other pro
|
||||
},
|
||||
})
|
||||
|
||||
assert.equal(env.CLAUDE_CODE_USE_GEMINI, '1')
|
||||
assert.equal(env.CLAUDE_CODE_USE_OPENAI, undefined)
|
||||
assert.equal(env.GEMINI_MODEL, 'gemini-2.0-flash')
|
||||
assert.equal(env.GEMINI_API_KEY, 'gem-live')
|
||||
assert.equal(env.CLAUDE_CODE_USE_GEMINI, undefined)
|
||||
assert.equal(env.CLAUDE_CODE_USE_OPENAI, '1')
|
||||
assert.equal(env.GEMINI_MODEL, undefined)
|
||||
assert.equal(env.GEMINI_API_KEY, undefined)
|
||||
assert.equal(
|
||||
env.GEMINI_BASE_URL,
|
||||
'https://generativelanguage.googleapis.com/v1beta/openai',
|
||||
undefined,
|
||||
)
|
||||
assert.equal(env.GOOGLE_API_KEY, undefined)
|
||||
assert.equal(env.OPENAI_API_KEY, undefined)
|
||||
assert.equal(env.OPENAI_API_KEY, 'sk-live')
|
||||
assert.equal(env.CODEX_API_KEY, undefined)
|
||||
assert.equal(env.CHATGPT_ACCOUNT_ID, undefined)
|
||||
})
|
||||
@@ -562,8 +562,13 @@ test('buildStartupEnvFromProfile leaves explicit provider selections untouched',
|
||||
processEnv,
|
||||
})
|
||||
|
||||
assert.equal(env, processEnv)
|
||||
// Remove the strict object equality check: assert.equal(env, processEnv)
|
||||
assert.equal(env.CLAUDE_CODE_USE_GEMINI, '1')
|
||||
assert.equal(env.GEMINI_API_KEY, 'gem-live')
|
||||
assert.equal(env.GEMINI_MODEL, 'gemini-2.0-flash')
|
||||
// Add the new default fields injected by the function
|
||||
assert.equal(env.GEMINI_BASE_URL, 'https://generativelanguage.googleapis.com/v1beta/openai')
|
||||
assert.equal(env.GEMINI_AUTH_MODE, 'api-key')
|
||||
assert.equal(env.OPENAI_API_KEY, undefined)
|
||||
})
|
||||
|
||||
@@ -607,14 +612,17 @@ test('buildStartupEnvFromProfile treats explicit falsey provider flags as user i
|
||||
processEnv,
|
||||
})
|
||||
|
||||
assert.equal(env, processEnv)
|
||||
assert.equal(env.CLAUDE_CODE_USE_OPENAI, '0')
|
||||
assert.equal(env.GEMINI_API_KEY, undefined)
|
||||
assert.equal(env.CLAUDE_CODE_USE_OPENAI, undefined)
|
||||
assert.equal(env.CLAUDE_CODE_USE_GEMINI, '1')
|
||||
assert.equal(env.GEMINI_API_KEY, 'gem-persisted')
|
||||
assert.equal(env.GEMINI_MODEL, 'gemini-2.5-flash')
|
||||
assert.equal(env.GEMINI_BASE_URL, 'https://generativelanguage.googleapis.com/v1beta/openai')
|
||||
assert.equal(env.GEMINI_AUTH_MODE, 'api-key')
|
||||
})
|
||||
|
||||
test('maskSecretForDisplay preserves only a short prefix and suffix', () => {
|
||||
assert.equal(maskSecretForDisplay('sk-secret-12345678'), 'sk-...5678')
|
||||
assert.equal(maskSecretForDisplay('AIzaSecret12345678'), 'AIza...5678')
|
||||
assert.equal(maskSecretForDisplay('sk-secret-12345678'), 'sk-...678')
|
||||
assert.equal(maskSecretForDisplay('AIzaSecret12345678'), 'AIz...678')
|
||||
})
|
||||
|
||||
test('redactSecretValueForDisplay masks poisoned display fields that equal configured secrets', () => {
|
||||
@@ -622,7 +630,7 @@ test('redactSecretValueForDisplay masks poisoned display fields that equal confi
|
||||
|
||||
assert.equal(
|
||||
redactSecretValueForDisplay(apiKey, { OPENAI_API_KEY: apiKey }),
|
||||
'sk-...5678',
|
||||
'sk-...678',
|
||||
)
|
||||
assert.equal(
|
||||
redactSecretValueForDisplay('gpt-4o', { OPENAI_API_KEY: apiKey }),
|
||||
|
||||
@@ -6,29 +6,32 @@ import {
|
||||
isCodexBaseUrl,
|
||||
resolveCodexApiCredentials,
|
||||
resolveProviderRequest,
|
||||
} from '../services/api/providerConfig.ts'
|
||||
} from '../services/api/providerConfig.js'
|
||||
import { parseChatgptAccountId } from '../services/api/codexOAuthShared.js'
|
||||
import {
|
||||
getGoalDefaultOpenAIModel,
|
||||
normalizeRecommendationGoal,
|
||||
type RecommendationGoal,
|
||||
} from './providerRecommendation.ts'
|
||||
import { readGeminiAccessToken } from './geminiCredentials.ts'
|
||||
import { getOllamaChatBaseUrl } from './providerDiscovery.ts'
|
||||
import { getProviderValidationError } from './providerValidation.ts'
|
||||
} from './providerRecommendation.js'
|
||||
import { readGeminiAccessToken } from './geminiCredentials.js'
|
||||
import { getOllamaChatBaseUrl } from './providerDiscovery.js'
|
||||
import { getProviderValidationError } from './providerValidation.js'
|
||||
import {
|
||||
maskSecretForDisplay,
|
||||
redactSecretValueForDisplay,
|
||||
sanitizeApiKey,
|
||||
sanitizeProviderConfigValue,
|
||||
} from './providerSecrets.ts'
|
||||
} from './providerSecrets.js'
|
||||
|
||||
export {
|
||||
maskSecretForDisplay,
|
||||
redactSecretValueForDisplay,
|
||||
sanitizeApiKey,
|
||||
sanitizeProviderConfigValue,
|
||||
} from './providerSecrets.ts'
|
||||
} from './providerSecrets.js'
|
||||
import { isEnvTruthy } from './envUtils.ts'
|
||||
|
||||
import { PROVIDERS } from './configConstants.js'
|
||||
|
||||
export const PROFILE_FILE_NAME = '.openclaude-profile.json'
|
||||
export const DEFAULT_GEMINI_BASE_URL =
|
||||
@@ -57,18 +60,28 @@ const PROFILE_ENV_KEYS = [
|
||||
'GEMINI_MODEL',
|
||||
'GEMINI_BASE_URL',
|
||||
'GOOGLE_API_KEY',
|
||||
'NVIDIA_NIM',
|
||||
'NVIDIA_API_KEY',
|
||||
'NVIDIA_MODEL',
|
||||
'MINIMAX_API_KEY',
|
||||
'MINIMAX_BASE_URL',
|
||||
'MINIMAX_MODEL',
|
||||
'MISTRAL_BASE_URL',
|
||||
'MISTRAL_API_KEY',
|
||||
'MISTRAL_MODEL',
|
||||
] as const
|
||||
|
||||
export type ProviderProfile =
|
||||
| 'openai'
|
||||
| 'ollama'
|
||||
| 'codex'
|
||||
| 'gemini'
|
||||
| 'atomic-chat'
|
||||
| 'mistral'
|
||||
const SECRET_ENV_KEYS = [
|
||||
'OPENAI_API_KEY',
|
||||
'CODEX_API_KEY',
|
||||
'GEMINI_API_KEY',
|
||||
'GOOGLE_API_KEY',
|
||||
'NVIDIA_API_KEY',
|
||||
'MINIMAX_API_KEY',
|
||||
'MISTRAL_API_KEY',
|
||||
] as const
|
||||
|
||||
export type ProviderProfile = 'openai' | 'ollama' | 'codex' | 'gemini' | 'atomic-chat' | 'nvidia-nim' | 'minimax' | 'mistral'
|
||||
|
||||
export type ProfileEnv = {
|
||||
OPENAI_BASE_URL?: string
|
||||
@@ -82,6 +95,12 @@ export type ProfileEnv = {
|
||||
GEMINI_AUTH_MODE?: 'api-key' | 'access-token' | 'adc'
|
||||
GEMINI_MODEL?: string
|
||||
GEMINI_BASE_URL?: string
|
||||
GOOGLE_API_KEY?: string
|
||||
NVIDIA_NIM?: string
|
||||
NVIDIA_API_KEY?: string
|
||||
MINIMAX_API_KEY?: string
|
||||
MINIMAX_BASE_URL?: string
|
||||
MINIMAX_MODEL?: string
|
||||
MISTRAL_BASE_URL?: string
|
||||
MISTRAL_API_KEY?: string
|
||||
MISTRAL_MODEL?: string
|
||||
@@ -93,6 +112,19 @@ export type ProfileFile = {
|
||||
createdAt: string
|
||||
}
|
||||
|
||||
type SecretValueSource = Partial<
|
||||
Record<
|
||||
| 'OPENAI_API_KEY'
|
||||
| 'CODEX_API_KEY'
|
||||
| 'GEMINI_API_KEY'
|
||||
| 'GOOGLE_API_KEY'
|
||||
| 'NVIDIA_API_KEY'
|
||||
| 'MINIMAX_API_KEY'
|
||||
| 'MISTRAL_API_KEY',
|
||||
string | undefined
|
||||
>
|
||||
>
|
||||
|
||||
type ProfileFileLocation = {
|
||||
cwd?: string
|
||||
filePath?: string
|
||||
@@ -113,6 +145,8 @@ export function isProviderProfile(value: unknown): value is ProviderProfile {
|
||||
value === 'codex' ||
|
||||
value === 'gemini' ||
|
||||
value === 'atomic-chat' ||
|
||||
value === 'nvidia-nim' ||
|
||||
value === 'minimax' ||
|
||||
value === 'mistral'
|
||||
)
|
||||
}
|
||||
@@ -143,6 +177,67 @@ export function buildAtomicChatProfileEnv(
|
||||
}
|
||||
}
|
||||
|
||||
export function buildNvidiaNimProfileEnv(options: {
|
||||
model?: string | null
|
||||
baseUrl?: string | null
|
||||
apiKey?: string | null
|
||||
processEnv?: NodeJS.ProcessEnv
|
||||
}): ProfileEnv | null {
|
||||
const processEnv = options.processEnv ?? process.env
|
||||
const key = sanitizeApiKey(options.apiKey ?? processEnv.NVIDIA_API_KEY)
|
||||
if (!key) {
|
||||
return null
|
||||
}
|
||||
|
||||
const defaultBaseUrl = 'https://integrate.api.nvidia.com/v1'
|
||||
const secretSource: SecretValueSource = { OPENAI_API_KEY: key }
|
||||
|
||||
return {
|
||||
OPENAI_BASE_URL:
|
||||
sanitizeProviderConfigValue(options.baseUrl, secretSource) ||
|
||||
sanitizeProviderConfigValue(processEnv.OPENAI_BASE_URL, secretSource) ||
|
||||
defaultBaseUrl,
|
||||
OPENAI_MODEL:
|
||||
sanitizeProviderConfigValue(options.model, secretSource) ||
|
||||
sanitizeProviderConfigValue(processEnv.OPENAI_MODEL, secretSource) ||
|
||||
'nvidia/llama-3.1-nemotron-70b-instruct',
|
||||
OPENAI_API_KEY: key,
|
||||
NVIDIA_NIM: '1',
|
||||
}
|
||||
}
|
||||
|
||||
export function buildMiniMaxProfileEnv(options: {
|
||||
model?: string | null
|
||||
baseUrl?: string | null
|
||||
apiKey?: string | null
|
||||
processEnv?: NodeJS.ProcessEnv
|
||||
}): ProfileEnv | null {
|
||||
const processEnv = options.processEnv ?? process.env
|
||||
const key = sanitizeApiKey(options.apiKey ?? processEnv.MINIMAX_API_KEY)
|
||||
if (!key) {
|
||||
return null
|
||||
}
|
||||
|
||||
const defaultBaseUrl = 'https://api.minimax.io/v1'
|
||||
const defaultModel = 'MiniMax-M2.5'
|
||||
const secretSource: SecretValueSource = { OPENAI_API_KEY: key }
|
||||
|
||||
return {
|
||||
OPENAI_BASE_URL:
|
||||
sanitizeProviderConfigValue(options.baseUrl, secretSource) ||
|
||||
sanitizeProviderConfigValue(processEnv.OPENAI_BASE_URL, secretSource) ||
|
||||
defaultBaseUrl,
|
||||
OPENAI_MODEL:
|
||||
sanitizeProviderConfigValue(options.model, secretSource) ||
|
||||
sanitizeProviderConfigValue(processEnv.OPENAI_MODEL, secretSource) ||
|
||||
defaultModel,
|
||||
OPENAI_API_KEY: key,
|
||||
MINIMAX_API_KEY: key,
|
||||
MINIMAX_BASE_URL: defaultBaseUrl,
|
||||
MINIMAX_MODEL: defaultModel,
|
||||
}
|
||||
}
|
||||
|
||||
export function buildGeminiProfileEnv(options: {
|
||||
model?: string | null
|
||||
baseUrl?: string | null
|
||||
@@ -161,15 +256,13 @@ export function buildGeminiProfileEnv(options: {
|
||||
return null
|
||||
}
|
||||
|
||||
const secretSource: SecretValueSource = key ? { GEMINI_API_KEY: key } : {}
|
||||
|
||||
const env: ProfileEnv = {
|
||||
GEMINI_AUTH_MODE: authMode,
|
||||
GEMINI_MODEL:
|
||||
sanitizeProviderConfigValue(options.model, { GEMINI_API_KEY: key }, processEnv) ||
|
||||
sanitizeProviderConfigValue(
|
||||
processEnv.GEMINI_MODEL,
|
||||
{ GEMINI_API_KEY: key },
|
||||
processEnv,
|
||||
) ||
|
||||
sanitizeProviderConfigValue(options.model, secretSource) ||
|
||||
sanitizeProviderConfigValue(processEnv.GEMINI_MODEL, secretSource) ||
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
}
|
||||
|
||||
@@ -178,12 +271,8 @@ export function buildGeminiProfileEnv(options: {
|
||||
}
|
||||
|
||||
const baseUrl =
|
||||
sanitizeProviderConfigValue(options.baseUrl, { GEMINI_API_KEY: key }, processEnv) ||
|
||||
sanitizeProviderConfigValue(
|
||||
processEnv.GEMINI_BASE_URL,
|
||||
{ GEMINI_API_KEY: key },
|
||||
processEnv,
|
||||
)
|
||||
sanitizeProviderConfigValue(options.baseUrl, secretSource) ||
|
||||
sanitizeProviderConfigValue(processEnv.GEMINI_BASE_URL, secretSource)
|
||||
if (baseUrl) {
|
||||
env.GEMINI_BASE_URL = baseUrl
|
||||
}
|
||||
@@ -205,15 +294,14 @@ export function buildOpenAIProfileEnv(options: {
|
||||
}
|
||||
|
||||
const defaultModel = getGoalDefaultOpenAIModel(options.goal)
|
||||
const secretSource: SecretValueSource = { OPENAI_API_KEY: key }
|
||||
const shellOpenAIModel = sanitizeProviderConfigValue(
|
||||
processEnv.OPENAI_MODEL,
|
||||
{ OPENAI_API_KEY: key },
|
||||
processEnv,
|
||||
secretSource,
|
||||
)
|
||||
const shellOpenAIBaseUrl = sanitizeProviderConfigValue(
|
||||
processEnv.OPENAI_BASE_URL,
|
||||
{ OPENAI_API_KEY: key },
|
||||
processEnv,
|
||||
secretSource,
|
||||
)
|
||||
const shellOpenAIRequest = resolveProviderRequest({
|
||||
model: shellOpenAIModel,
|
||||
@@ -224,19 +312,11 @@ export function buildOpenAIProfileEnv(options: {
|
||||
|
||||
return {
|
||||
OPENAI_BASE_URL:
|
||||
sanitizeProviderConfigValue(
|
||||
options.baseUrl,
|
||||
{ OPENAI_API_KEY: key },
|
||||
processEnv,
|
||||
) ||
|
||||
sanitizeProviderConfigValue(options.baseUrl, secretSource) ||
|
||||
(useShellOpenAIConfig ? shellOpenAIBaseUrl : undefined) ||
|
||||
DEFAULT_OPENAI_BASE_URL,
|
||||
OPENAI_MODEL:
|
||||
sanitizeProviderConfigValue(
|
||||
options.model,
|
||||
{ OPENAI_API_KEY: key },
|
||||
processEnv,
|
||||
) ||
|
||||
sanitizeProviderConfigValue(options.model, secretSource) ||
|
||||
(useShellOpenAIConfig ? shellOpenAIModel : undefined) ||
|
||||
defaultModel,
|
||||
OPENAI_API_KEY: key,
|
||||
@@ -293,21 +373,19 @@ export function buildMistralProfileEnv(options: {
|
||||
const env: ProfileEnv = {
|
||||
MISTRAL_API_KEY: key,
|
||||
MISTRAL_MODEL:
|
||||
sanitizeProviderConfigValue(options.model, { MISTRAL_API_KEY: key }, processEnv) ||
|
||||
sanitizeProviderConfigValue(options.model, { MISTRAL_API_KEY: key }) ||
|
||||
sanitizeProviderConfigValue(
|
||||
processEnv.MISTRAL_MODEL,
|
||||
{ MISTRAL_API_KEY: key },
|
||||
processEnv,
|
||||
) ||
|
||||
DEFAULT_MISTRAL_MODEL,
|
||||
}
|
||||
|
||||
const baseUrl =
|
||||
sanitizeProviderConfigValue(options.baseUrl, { MISTRAL_API_KEY: key }, processEnv) ||
|
||||
sanitizeProviderConfigValue(options.baseUrl, { MISTRAL_API_KEY: key }) ||
|
||||
sanitizeProviderConfigValue(
|
||||
processEnv.MISTRAL_BASE_URL,
|
||||
{ MISTRAL_API_KEY: key },
|
||||
processEnv,
|
||||
)
|
||||
if (baseUrl) {
|
||||
env.MISTRAL_BASE_URL = baseUrl
|
||||
@@ -423,13 +501,13 @@ export function hasExplicitProviderSelection(
|
||||
}
|
||||
|
||||
return (
|
||||
processEnv.CLAUDE_CODE_USE_OPENAI !== undefined ||
|
||||
processEnv.CLAUDE_CODE_USE_GITHUB !== undefined ||
|
||||
processEnv.CLAUDE_CODE_USE_GEMINI !== undefined ||
|
||||
processEnv.CLAUDE_CODE_USE_MISTRAL !== undefined ||
|
||||
processEnv.CLAUDE_CODE_USE_BEDROCK !== undefined ||
|
||||
processEnv.CLAUDE_CODE_USE_VERTEX !== undefined ||
|
||||
processEnv.CLAUDE_CODE_USE_FOUNDRY !== undefined
|
||||
isEnvTruthy(processEnv.CLAUDE_CODE_USE_OPENAI) ||
|
||||
isEnvTruthy(processEnv.CLAUDE_CODE_USE_GITHUB) ||
|
||||
isEnvTruthy(processEnv.CLAUDE_CODE_USE_GEMINI) ||
|
||||
isEnvTruthy(processEnv.CLAUDE_CODE_USE_MISTRAL) ||
|
||||
isEnvTruthy(processEnv.CLAUDE_CODE_USE_BEDROCK) ||
|
||||
isEnvTruthy(processEnv.CLAUDE_CODE_USE_VERTEX) ||
|
||||
isEnvTruthy(processEnv.CLAUDE_CODE_USE_FOUNDRY)
|
||||
)
|
||||
}
|
||||
|
||||
@@ -465,11 +543,11 @@ export async function buildLaunchEnv(options: {
|
||||
)
|
||||
const shellOpenAIModel = sanitizeProviderConfigValue(
|
||||
processEnv.OPENAI_MODEL,
|
||||
processEnv,
|
||||
processEnv as SecretValueSource,
|
||||
)
|
||||
const shellOpenAIBaseUrl = sanitizeProviderConfigValue(
|
||||
processEnv.OPENAI_BASE_URL,
|
||||
processEnv,
|
||||
processEnv as SecretValueSource,
|
||||
)
|
||||
const persistedGeminiModel = sanitizeProviderConfigValue(
|
||||
persistedEnv.GEMINI_MODEL,
|
||||
@@ -481,11 +559,11 @@ export async function buildLaunchEnv(options: {
|
||||
)
|
||||
const shellGeminiModel = sanitizeProviderConfigValue(
|
||||
processEnv.GEMINI_MODEL,
|
||||
processEnv,
|
||||
processEnv as SecretValueSource,
|
||||
)
|
||||
const shellGeminiBaseUrl = sanitizeProviderConfigValue(
|
||||
processEnv.GEMINI_BASE_URL,
|
||||
processEnv,
|
||||
processEnv as SecretValueSource,
|
||||
)
|
||||
const shellGeminiAccessToken =
|
||||
processEnv.GEMINI_ACCESS_TOKEN?.trim() || undefined
|
||||
@@ -498,6 +576,20 @@ export async function buildLaunchEnv(options: {
|
||||
const persistedGeminiKey = sanitizeApiKey(persistedEnv.GEMINI_API_KEY)
|
||||
const persistedGeminiAuthMode = persistedEnv.GEMINI_AUTH_MODE
|
||||
|
||||
if (hasExplicitProviderSelection(processEnv)) {
|
||||
for (let provider of PROVIDERS) {
|
||||
if (provider === "anthropic") {
|
||||
continue;
|
||||
}
|
||||
|
||||
const env_key_name = `CLAUDE_CODE_USE_${provider.toUpperCase()}`
|
||||
|
||||
if (env_key_name in processEnv && isEnvTruthy(processEnv[env_key_name])) {
|
||||
options.profile = provider;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (options.profile === 'gemini') {
|
||||
const env: NodeJS.ProcessEnv = {
|
||||
...processEnv,
|
||||
@@ -567,19 +659,15 @@ export async function buildLaunchEnv(options: {
|
||||
|
||||
const shellMistralModel = sanitizeProviderConfigValue(
|
||||
processEnv.MISTRAL_MODEL,
|
||||
processEnv,
|
||||
)
|
||||
const persistedMistralModel = sanitizeProviderConfigValue(
|
||||
persistedEnv.MISTRAL_MODEL,
|
||||
persistedEnv,
|
||||
)
|
||||
const shellMistralBaseUrl = sanitizeProviderConfigValue(
|
||||
processEnv.MISTRAL_BASE_URL,
|
||||
processEnv,
|
||||
)
|
||||
const persistedMistralBaseUrl = sanitizeProviderConfigValue(
|
||||
persistedEnv.MISTRAL_BASE_URL,
|
||||
persistedEnv,
|
||||
)
|
||||
|
||||
env.MISTRAL_MODEL =
|
||||
@@ -754,12 +842,18 @@ export async function buildStartupEnvFromProfile(options?: {
|
||||
const persisted = options?.persisted ?? loadProfileFile()
|
||||
|
||||
// Saved /provider profiles should still win over provider-manager env that was
|
||||
// auto-applied during startup. Only explicit shell/flag provider selection
|
||||
// auto-applied during startup. Only an explicit shell/flag provider selection
|
||||
// should bypass the persisted startup profile.
|
||||
//
|
||||
const profileManagedEnv = processEnv.CLAUDE_CODE_PROVIDER_PROFILE_ENV_APPLIED === '1'
|
||||
if (hasExplicitProviderSelection(processEnv) && !profileManagedEnv) {
|
||||
return processEnv
|
||||
}
|
||||
|
||||
// If the user explicitly selected a provider via env, allow it to bypass
|
||||
// the persisted profile only when we can prove it was managed by the
|
||||
// persisted profile env itself.
|
||||
//
|
||||
// Practically: on initial startup, provider routing env vars can already
|
||||
// be present due to earlier auto-application steps. We should still apply
|
||||
// the persisted profile rather than returning early.
|
||||
|
||||
if (!persisted) {
|
||||
return processEnv
|
||||
|
||||
@@ -13,6 +13,7 @@ const RESTORED_KEYS = [
|
||||
'CLAUDE_CODE_PROVIDER_PROFILE_ENV_APPLIED_ID',
|
||||
'CLAUDE_CODE_USE_OPENAI',
|
||||
'CLAUDE_CODE_USE_GEMINI',
|
||||
'CLAUDE_CODE_USE_MISTRAL',
|
||||
'CLAUDE_CODE_USE_GITHUB',
|
||||
'CLAUDE_CODE_USE_BEDROCK',
|
||||
'CLAUDE_CODE_USE_VERTEX',
|
||||
@@ -24,6 +25,15 @@ const RESTORED_KEYS = [
|
||||
'ANTHROPIC_BASE_URL',
|
||||
'ANTHROPIC_MODEL',
|
||||
'ANTHROPIC_API_KEY',
|
||||
'GEMINI_BASE_URL',
|
||||
'GEMINI_MODEL',
|
||||
'GEMINI_API_KEY',
|
||||
'GEMINI_AUTH_MODE',
|
||||
'GEMINI_ACCESS_TOKEN',
|
||||
'GOOGLE_API_KEY',
|
||||
'MISTRAL_BASE_URL',
|
||||
'MISTRAL_MODEL',
|
||||
'MISTRAL_API_KEY',
|
||||
] as const
|
||||
|
||||
type MockConfigState = {
|
||||
@@ -98,6 +108,24 @@ function buildProfile(overrides: Partial<ProviderProfile> = {}): ProviderProfile
|
||||
}
|
||||
}
|
||||
|
||||
function buildMistralProfile(overrides: Partial<ProviderProfile> = {}): ProviderProfile {
|
||||
return buildProfile({
|
||||
provider: 'mistral',
|
||||
baseUrl: 'https://api.mistral.ai/v1',
|
||||
model: 'devstral-latest',
|
||||
...overrides,
|
||||
})
|
||||
}
|
||||
|
||||
function buildGeminiProfile(overrides: Partial<ProviderProfile> = {}): ProviderProfile {
|
||||
return buildProfile({
|
||||
provider: 'gemini',
|
||||
baseUrl: 'https://generativelanguage.googleapis.com/v1beta/openai',
|
||||
model: 'gemini-3-flash-preview',
|
||||
...overrides,
|
||||
})
|
||||
}
|
||||
|
||||
describe('applyProviderProfileToProcessEnv', () => {
|
||||
test('openai profile clears competing gemini/github flags', async () => {
|
||||
const { applyProviderProfileToProcessEnv } =
|
||||
@@ -118,6 +146,36 @@ describe('applyProviderProfileToProcessEnv', () => {
|
||||
expect(getFreshAPIProvider()).toBe('openai')
|
||||
})
|
||||
|
||||
test('mistral profile sets CLAUDE_CODE_USE_MISTRAL and clears openai flags', async () => {
|
||||
const { applyProviderProfileToProcessEnv } =
|
||||
await importFreshProviderProfileModules()
|
||||
process.env.CLAUDE_CODE_USE_OPENAI = '1'
|
||||
|
||||
applyProviderProfileToProcessEnv(buildMistralProfile())
|
||||
const { getAPIProvider: getFreshAPIProvider } =
|
||||
await importFreshProvidersModule()
|
||||
|
||||
expect(process.env.CLAUDE_CODE_USE_MISTRAL).toBe('1')
|
||||
expect(process.env.CLAUDE_CODE_USE_OPENAI).toBeUndefined()
|
||||
expect(process.env.MISTRAL_MODEL).toBe('devstral-latest')
|
||||
expect(getFreshAPIProvider()).toBe('mistral')
|
||||
})
|
||||
|
||||
test('gemini profile sets CLAUDE_CODE_USE_GEMINI and clears openai flags', async () => {
|
||||
const { applyProviderProfileToProcessEnv } =
|
||||
await importFreshProviderProfileModules()
|
||||
process.env.CLAUDE_CODE_USE_OPENAI = '1'
|
||||
|
||||
applyProviderProfileToProcessEnv(buildGeminiProfile())
|
||||
const { getAPIProvider: getFreshAPIProvider } =
|
||||
await importFreshProvidersModule()
|
||||
|
||||
expect(process.env.CLAUDE_CODE_USE_GEMINI).toBe('1')
|
||||
expect(process.env.CLAUDE_CODE_USE_OPENAI).toBeUndefined()
|
||||
expect(process.env.GEMINI_MODEL).toBe('gemini-3-flash-preview')
|
||||
expect(getFreshAPIProvider()).toBe('gemini')
|
||||
})
|
||||
|
||||
test('anthropic profile clears competing gemini/github flags', async () => {
|
||||
const { applyProviderProfileToProcessEnv } =
|
||||
await importFreshProviderProfileModules()
|
||||
@@ -139,6 +197,39 @@ describe('applyProviderProfileToProcessEnv', () => {
|
||||
expect(process.env.CLAUDE_CODE_USE_OPENAI).toBeUndefined()
|
||||
expect(getFreshAPIProvider()).toBe('firstParty')
|
||||
})
|
||||
|
||||
test('openai profile with multi-model string sets only first model in OPENAI_MODEL', async () => {
|
||||
const { applyProviderProfileToProcessEnv } =
|
||||
await importFreshProviderProfileModules()
|
||||
|
||||
applyProviderProfileToProcessEnv(
|
||||
buildProfile({
|
||||
provider: 'openai',
|
||||
baseUrl: 'https://api.openai.com/v1',
|
||||
model: 'glm-4.7, glm-4.7-flash, glm-4.7-plus',
|
||||
}),
|
||||
)
|
||||
|
||||
expect(process.env.OPENAI_MODEL).toBe('glm-4.7')
|
||||
expect(String(process.env.CLAUDE_CODE_USE_OPENAI)).toBe('1')
|
||||
expect(process.env.OPENAI_BASE_URL).toBe('https://api.openai.com/v1')
|
||||
})
|
||||
|
||||
test('anthropic profile with multi-model string sets only first model in ANTHROPIC_MODEL', async () => {
|
||||
const { applyProviderProfileToProcessEnv } =
|
||||
await importFreshProviderProfileModules()
|
||||
|
||||
applyProviderProfileToProcessEnv(
|
||||
buildProfile({
|
||||
provider: 'anthropic',
|
||||
baseUrl: 'https://api.anthropic.com',
|
||||
model: 'claude-sonnet-4-6, claude-opus-4-6',
|
||||
}),
|
||||
)
|
||||
|
||||
expect(process.env.ANTHROPIC_MODEL).toBe('claude-sonnet-4-6')
|
||||
expect(process.env.ANTHROPIC_BASE_URL).toBe('https://api.anthropic.com')
|
||||
})
|
||||
})
|
||||
|
||||
describe('applyActiveProviderProfileFromConfig', () => {
|
||||
@@ -361,6 +452,169 @@ describe('getProviderPresetDefaults', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('setActiveProviderProfile', () => {
|
||||
test('sets OPENAI_MODEL env var when switching to an openai-type provider', async () => {
|
||||
const { setActiveProviderProfile } =
|
||||
await importFreshProviderProfileModules()
|
||||
const openaiProfile = buildProfile({
|
||||
id: 'openai_prof',
|
||||
name: 'OpenAI Provider',
|
||||
provider: 'openai',
|
||||
baseUrl: 'https://api.openai.com/v1',
|
||||
model: 'gpt-4o',
|
||||
})
|
||||
|
||||
saveMockGlobalConfig(current => ({
|
||||
...current,
|
||||
providerProfiles: [openaiProfile],
|
||||
}))
|
||||
|
||||
const result = setActiveProviderProfile('openai_prof')
|
||||
|
||||
expect(result?.id).toBe('openai_prof')
|
||||
expect(String(process.env.CLAUDE_CODE_USE_OPENAI)).toBe('1')
|
||||
expect(process.env.OPENAI_MODEL).toBe('gpt-4o')
|
||||
expect(process.env.OPENAI_BASE_URL).toBe('https://api.openai.com/v1')
|
||||
expect(process.env.CLAUDE_CODE_PROVIDER_PROFILE_ENV_APPLIED_ID).toBe(
|
||||
'openai_prof',
|
||||
)
|
||||
})
|
||||
|
||||
test('sets ANTHROPIC_MODEL env var when switching to an anthropic-type provider', async () => {
|
||||
const { setActiveProviderProfile } =
|
||||
await importFreshProviderProfileModules()
|
||||
const anthropicProfile = buildProfile({
|
||||
id: 'anthro_prof',
|
||||
name: 'Anthropic Provider',
|
||||
provider: 'anthropic',
|
||||
baseUrl: 'https://api.anthropic.com',
|
||||
model: 'claude-sonnet-4-6',
|
||||
})
|
||||
|
||||
saveMockGlobalConfig(current => ({
|
||||
...current,
|
||||
providerProfiles: [anthropicProfile],
|
||||
}))
|
||||
|
||||
const result = setActiveProviderProfile('anthro_prof')
|
||||
|
||||
expect(result?.id).toBe('anthro_prof')
|
||||
expect(process.env.ANTHROPIC_MODEL).toBe('claude-sonnet-4-6')
|
||||
expect(process.env.ANTHROPIC_BASE_URL).toBe('https://api.anthropic.com')
|
||||
expect(process.env.CLAUDE_CODE_USE_OPENAI).toBeUndefined()
|
||||
expect(process.env.OPENAI_MODEL).toBeUndefined()
|
||||
expect(process.env.CLAUDE_CODE_PROVIDER_PROFILE_ENV_APPLIED_ID).toBe(
|
||||
'anthro_prof',
|
||||
)
|
||||
})
|
||||
|
||||
test('clears openai model env and sets anthropic model env when switching from openai to anthropic provider', async () => {
|
||||
const { setActiveProviderProfile } =
|
||||
await importFreshProviderProfileModules()
|
||||
const openaiProfile = buildProfile({
|
||||
id: 'openai_prof',
|
||||
name: 'OpenAI Provider',
|
||||
provider: 'openai',
|
||||
baseUrl: 'https://api.openai.com/v1',
|
||||
model: 'gpt-4o',
|
||||
apiKey: 'sk-openai-key',
|
||||
})
|
||||
const anthropicProfile = buildProfile({
|
||||
id: 'anthro_prof',
|
||||
name: 'Anthropic Provider',
|
||||
provider: 'anthropic',
|
||||
baseUrl: 'https://api.anthropic.com',
|
||||
model: 'claude-sonnet-4-6',
|
||||
apiKey: 'sk-ant-key',
|
||||
})
|
||||
|
||||
saveMockGlobalConfig(current => ({
|
||||
...current,
|
||||
providerProfiles: [openaiProfile, anthropicProfile],
|
||||
}))
|
||||
|
||||
// First activate the openai profile
|
||||
setActiveProviderProfile('openai_prof')
|
||||
expect(process.env.OPENAI_MODEL).toBe('gpt-4o')
|
||||
expect(String(process.env.CLAUDE_CODE_USE_OPENAI)).toBe('1')
|
||||
|
||||
// Now switch to the anthropic profile
|
||||
const result = setActiveProviderProfile('anthro_prof')
|
||||
|
||||
expect(result?.id).toBe('anthro_prof')
|
||||
expect(process.env.ANTHROPIC_MODEL).toBe('claude-sonnet-4-6')
|
||||
expect(process.env.ANTHROPIC_BASE_URL).toBe('https://api.anthropic.com')
|
||||
expect(process.env.CLAUDE_CODE_USE_OPENAI).toBeUndefined()
|
||||
expect(process.env.OPENAI_MODEL).toBeUndefined()
|
||||
expect(process.env.OPENAI_BASE_URL).toBeUndefined()
|
||||
expect(process.env.OPENAI_API_KEY).toBeUndefined()
|
||||
expect(process.env.CLAUDE_CODE_PROVIDER_PROFILE_ENV_APPLIED_ID).toBe(
|
||||
'anthro_prof',
|
||||
)
|
||||
})
|
||||
|
||||
test('clears anthropic model env and sets openai model env when switching from anthropic to openai provider', async () => {
|
||||
const { setActiveProviderProfile } =
|
||||
await importFreshProviderProfileModules()
|
||||
const anthropicProfile = buildProfile({
|
||||
id: 'anthro_prof',
|
||||
name: 'Anthropic Provider',
|
||||
provider: 'anthropic',
|
||||
baseUrl: 'https://api.anthropic.com',
|
||||
model: 'claude-sonnet-4-6',
|
||||
apiKey: 'sk-ant-key',
|
||||
})
|
||||
const openaiProfile = buildProfile({
|
||||
id: 'openai_prof',
|
||||
name: 'OpenAI Provider',
|
||||
provider: 'openai',
|
||||
baseUrl: 'https://api.openai.com/v1',
|
||||
model: 'gpt-4o',
|
||||
apiKey: 'sk-openai-key',
|
||||
})
|
||||
|
||||
saveMockGlobalConfig(current => ({
|
||||
...current,
|
||||
providerProfiles: [anthropicProfile, openaiProfile],
|
||||
}))
|
||||
|
||||
// First activate the anthropic profile
|
||||
setActiveProviderProfile('anthro_prof')
|
||||
expect(process.env.ANTHROPIC_MODEL).toBe('claude-sonnet-4-6')
|
||||
expect(process.env.ANTHROPIC_BASE_URL).toBe('https://api.anthropic.com')
|
||||
|
||||
// Now switch to the openai profile
|
||||
const result = setActiveProviderProfile('openai_prof')
|
||||
|
||||
expect(result?.id).toBe('openai_prof')
|
||||
expect(String(process.env.CLAUDE_CODE_USE_OPENAI)).toBe('1')
|
||||
expect(process.env.OPENAI_MODEL).toBe('gpt-4o')
|
||||
expect(process.env.OPENAI_BASE_URL).toBe('https://api.openai.com/v1')
|
||||
// ANTHROPIC_MODEL is set to the profile model for all provider types
|
||||
expect(process.env.ANTHROPIC_MODEL).toBe('gpt-4o')
|
||||
expect(process.env.ANTHROPIC_BASE_URL).toBeUndefined()
|
||||
expect(process.env.ANTHROPIC_API_KEY).toBeUndefined()
|
||||
expect(process.env.CLAUDE_CODE_PROVIDER_PROFILE_ENV_APPLIED_ID).toBe(
|
||||
'openai_prof',
|
||||
)
|
||||
})
|
||||
|
||||
test('returns null for non-existent profile id', async () => {
|
||||
const { setActiveProviderProfile } =
|
||||
await importFreshProviderProfileModules()
|
||||
const openaiProfile = buildProfile({ id: 'existing_prof' })
|
||||
|
||||
saveMockGlobalConfig(current => ({
|
||||
...current,
|
||||
providerProfiles: [openaiProfile],
|
||||
}))
|
||||
|
||||
const result = setActiveProviderProfile('nonexistent_prof')
|
||||
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('deleteProviderProfile', () => {
|
||||
test('deleting final profile clears provider env when active profile applied it', async () => {
|
||||
const {
|
||||
@@ -429,3 +683,82 @@ describe('deleteProviderProfile', () => {
|
||||
expect(process.env.OPENAI_MODEL).toBe('qwen2.5:3b')
|
||||
})
|
||||
})
|
||||
|
||||
describe('getProfileModelOptions', () => {
|
||||
test('generates options for multi-model profile', async () => {
|
||||
const { getProfileModelOptions } =
|
||||
await importFreshProviderProfileModules()
|
||||
|
||||
const options = getProfileModelOptions(
|
||||
buildProfile({
|
||||
name: 'Test Provider',
|
||||
model: 'glm-4.7, glm-4.7-flash, glm-4.7-plus',
|
||||
}),
|
||||
)
|
||||
|
||||
expect(options).toEqual([
|
||||
{ value: 'glm-4.7', label: 'glm-4.7', description: 'Provider: Test Provider' },
|
||||
{ value: 'glm-4.7-flash', label: 'glm-4.7-flash', description: 'Provider: Test Provider' },
|
||||
{ value: 'glm-4.7-plus', label: 'glm-4.7-plus', description: 'Provider: Test Provider' },
|
||||
])
|
||||
})
|
||||
|
||||
test('returns single option for single-model profile', async () => {
|
||||
const { getProfileModelOptions } =
|
||||
await importFreshProviderProfileModules()
|
||||
|
||||
const options = getProfileModelOptions(
|
||||
buildProfile({
|
||||
name: 'Single Model',
|
||||
model: 'llama3.1:8b',
|
||||
}),
|
||||
)
|
||||
|
||||
expect(options).toEqual([
|
||||
{ value: 'llama3.1:8b', label: 'llama3.1:8b', description: 'Provider: Single Model' },
|
||||
])
|
||||
})
|
||||
|
||||
test('returns empty array for empty model field', async () => {
|
||||
const { getProfileModelOptions } =
|
||||
await importFreshProviderProfileModules()
|
||||
|
||||
const options = getProfileModelOptions(
|
||||
buildProfile({
|
||||
name: 'Empty',
|
||||
model: '',
|
||||
}),
|
||||
)
|
||||
|
||||
expect(options).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('setActiveProviderProfile model cache', () => {
|
||||
test('populates model cache with all models from multi-model profile on activation', async () => {
|
||||
const {
|
||||
setActiveProviderProfile,
|
||||
getActiveOpenAIModelOptionsCache,
|
||||
} = await importFreshProviderProfileModules()
|
||||
|
||||
mockConfigState = {
|
||||
...createMockConfigState(),
|
||||
providerProfiles: [
|
||||
buildProfile({
|
||||
id: 'multi_provider',
|
||||
name: 'Multi Provider',
|
||||
model: 'glm-4.7, glm-4.7-flash, glm-4.7-plus',
|
||||
baseUrl: 'https://api.example.com/v1',
|
||||
}),
|
||||
],
|
||||
}
|
||||
|
||||
setActiveProviderProfile('multi_provider')
|
||||
|
||||
const cache = getActiveOpenAIModelOptionsCache()
|
||||
const cacheValues = cache.map(opt => opt.value)
|
||||
expect(cacheValues).toContain('glm-4.7')
|
||||
expect(cacheValues).toContain('glm-4.7-flash')
|
||||
expect(cacheValues).toContain('glm-4.7-plus')
|
||||
})
|
||||
})
|
||||
|
||||
@@ -5,6 +5,15 @@ import {
|
||||
type ProviderProfile,
|
||||
} from './config.js'
|
||||
import type { ModelOption } from './model/modelOptions.js'
|
||||
import { getPrimaryModel, parseModelList } from './providerModels.js'
|
||||
import {
|
||||
createProfileFile,
|
||||
saveProfileFile,
|
||||
buildGeminiProfileEnv,
|
||||
buildMistralProfileEnv,
|
||||
buildOpenAIProfileEnv,
|
||||
type ProviderProfile as ProviderProfileStartup,
|
||||
} from './providerProfile.js'
|
||||
|
||||
export type ProviderPreset =
|
||||
| 'anthropic'
|
||||
@@ -19,7 +28,11 @@ export type ProviderPreset =
|
||||
| 'azure-openai'
|
||||
| 'openrouter'
|
||||
| 'lmstudio'
|
||||
| 'dashscope-cn'
|
||||
| 'dashscope-intl'
|
||||
| 'custom'
|
||||
| 'nvidia-nim'
|
||||
| 'minimax'
|
||||
|
||||
export type ProviderProfileInput = {
|
||||
provider?: ProviderProfile['provider']
|
||||
@@ -55,7 +68,14 @@ function normalizeBaseUrl(value: string): string {
|
||||
function sanitizeProfile(profile: ProviderProfile): ProviderProfile | null {
|
||||
const id = trimValue(profile.id)
|
||||
const name = trimValue(profile.name)
|
||||
const provider = profile.provider === 'anthropic' ? 'anthropic' : 'openai'
|
||||
const provider =
|
||||
profile.provider === 'anthropic'
|
||||
? 'anthropic'
|
||||
: profile.provider === 'mistral'
|
||||
? 'mistral'
|
||||
: profile.provider === 'gemini'
|
||||
? 'gemini'
|
||||
: 'openai'
|
||||
const baseUrl = normalizeBaseUrl(profile.baseUrl)
|
||||
const model = trimValue(profile.model)
|
||||
|
||||
@@ -156,7 +176,7 @@ export function getProviderPresetDefaults(
|
||||
}
|
||||
case 'gemini':
|
||||
return {
|
||||
provider: 'openai',
|
||||
provider: 'gemini',
|
||||
name: 'Google Gemini',
|
||||
baseUrl: 'https://generativelanguage.googleapis.com/v1beta/openai',
|
||||
model: 'gemini-3-flash-preview',
|
||||
@@ -165,7 +185,7 @@ export function getProviderPresetDefaults(
|
||||
}
|
||||
case 'mistral':
|
||||
return {
|
||||
provider: 'openai',
|
||||
provider: 'mistral',
|
||||
name: 'Mistral',
|
||||
baseUrl: 'https://api.mistral.ai/v1',
|
||||
model: 'devstral-latest',
|
||||
@@ -217,6 +237,24 @@ export function getProviderPresetDefaults(
|
||||
apiKey: '',
|
||||
requiresApiKey: false,
|
||||
}
|
||||
case 'dashscope-cn':
|
||||
return {
|
||||
provider: 'openai',
|
||||
name: 'Alibaba Coding Plan (China)',
|
||||
baseUrl: 'https://coding.dashscope.aliyuncs.com/v1',
|
||||
model: 'qwen3.6-plus',
|
||||
apiKey: process.env.DASHSCOPE_API_KEY ?? '',
|
||||
requiresApiKey: true,
|
||||
}
|
||||
case 'dashscope-intl':
|
||||
return {
|
||||
provider: 'openai',
|
||||
name: 'Alibaba Coding Plan',
|
||||
baseUrl: 'https://coding-intl.dashscope.aliyuncs.com/v1',
|
||||
model: 'qwen3.6-plus',
|
||||
apiKey: process.env.DASHSCOPE_API_KEY ?? '',
|
||||
requiresApiKey: true,
|
||||
}
|
||||
case 'custom':
|
||||
return {
|
||||
provider: 'openai',
|
||||
@@ -229,6 +267,24 @@ export function getProviderPresetDefaults(
|
||||
apiKey: process.env.OPENAI_API_KEY ?? '',
|
||||
requiresApiKey: false,
|
||||
}
|
||||
case 'nvidia-nim':
|
||||
return {
|
||||
provider: 'openai',
|
||||
name: 'NVIDIA NIM',
|
||||
baseUrl: 'https://integrate.api.nvidia.com/v1',
|
||||
model: 'nvidia/llama-3.1-nemotron-70b-instruct',
|
||||
apiKey: process.env.NVIDIA_API_KEY ?? '',
|
||||
requiresApiKey: true,
|
||||
}
|
||||
case 'minimax':
|
||||
return {
|
||||
provider: 'openai',
|
||||
name: 'MiniMax',
|
||||
baseUrl: 'https://api.minimax.io/v1',
|
||||
model: 'MiniMax-M2.5',
|
||||
apiKey: process.env.MINIMAX_API_KEY ?? '',
|
||||
requiresApiKey: true,
|
||||
}
|
||||
case 'ollama':
|
||||
default:
|
||||
return {
|
||||
@@ -276,6 +332,7 @@ function hasConflictingProviderFlagsForProfile(
|
||||
|
||||
return (
|
||||
processEnv.CLAUDE_CODE_USE_GEMINI !== undefined ||
|
||||
processEnv.CLAUDE_CODE_USE_MISTRAL !== undefined ||
|
||||
processEnv.CLAUDE_CODE_USE_GITHUB !== undefined ||
|
||||
processEnv.CLAUDE_CODE_USE_BEDROCK !== undefined ||
|
||||
processEnv.CLAUDE_CODE_USE_VERTEX !== undefined ||
|
||||
@@ -311,12 +368,44 @@ function isProcessEnvAlignedWithProfile(
|
||||
return (
|
||||
!hasProviderSelectionFlags(processEnv) &&
|
||||
sameOptionalEnvValue(processEnv.ANTHROPIC_BASE_URL, profile.baseUrl) &&
|
||||
sameOptionalEnvValue(processEnv.ANTHROPIC_MODEL, profile.model) &&
|
||||
sameOptionalEnvValue(processEnv.ANTHROPIC_MODEL, getPrimaryModel(profile.model)) &&
|
||||
(!includeApiKey ||
|
||||
sameOptionalEnvValue(processEnv.ANTHROPIC_API_KEY, profile.apiKey))
|
||||
)
|
||||
}
|
||||
|
||||
if (profile.provider === 'mistral') {
|
||||
return (
|
||||
processEnv.CLAUDE_CODE_USE_MISTRAL !== undefined &&
|
||||
processEnv.CLAUDE_CODE_USE_GEMINI === undefined &&
|
||||
processEnv.CLAUDE_CODE_USE_OPENAI === undefined &&
|
||||
processEnv.CLAUDE_CODE_USE_GITHUB === undefined &&
|
||||
processEnv.CLAUDE_CODE_USE_BEDROCK === undefined &&
|
||||
processEnv.CLAUDE_CODE_USE_VERTEX === undefined &&
|
||||
processEnv.CLAUDE_CODE_USE_FOUNDRY === undefined &&
|
||||
sameOptionalEnvValue(processEnv.MISTRAL_BASE_URL, profile.baseUrl) &&
|
||||
sameOptionalEnvValue(processEnv.MISTRAL_MODEL, profile.model) &&
|
||||
(!includeApiKey ||
|
||||
sameOptionalEnvValue(processEnv.MISTRAL_API_KEY, profile.apiKey))
|
||||
)
|
||||
}
|
||||
|
||||
if (profile.provider === 'gemini') {
|
||||
return (
|
||||
processEnv.CLAUDE_CODE_USE_GEMINI !== undefined &&
|
||||
processEnv.CLAUDE_CODE_USE_MISTRAL === undefined &&
|
||||
processEnv.CLAUDE_CODE_USE_OPENAI === undefined &&
|
||||
processEnv.CLAUDE_CODE_USE_GITHUB === undefined &&
|
||||
processEnv.CLAUDE_CODE_USE_BEDROCK === undefined &&
|
||||
processEnv.CLAUDE_CODE_USE_VERTEX === undefined &&
|
||||
processEnv.CLAUDE_CODE_USE_FOUNDRY === undefined &&
|
||||
sameOptionalEnvValue(processEnv.GEMINI_BASE_URL, profile.baseUrl) &&
|
||||
sameOptionalEnvValue(processEnv.GEMINI_MODEL, profile.model) &&
|
||||
(!includeApiKey ||
|
||||
sameOptionalEnvValue(processEnv.GEMINI_API_KEY, profile.apiKey))
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
processEnv.CLAUDE_CODE_USE_OPENAI !== undefined &&
|
||||
processEnv.CLAUDE_CODE_USE_GEMINI === undefined &&
|
||||
@@ -326,7 +415,7 @@ function isProcessEnvAlignedWithProfile(
|
||||
processEnv.CLAUDE_CODE_USE_VERTEX === undefined &&
|
||||
processEnv.CLAUDE_CODE_USE_FOUNDRY === undefined &&
|
||||
sameOptionalEnvValue(processEnv.OPENAI_BASE_URL, profile.baseUrl) &&
|
||||
sameOptionalEnvValue(processEnv.OPENAI_MODEL, profile.model) &&
|
||||
sameOptionalEnvValue(processEnv.OPENAI_MODEL, getPrimaryModel(profile.model)) &&
|
||||
(!includeApiKey ||
|
||||
sameOptionalEnvValue(processEnv.OPENAI_API_KEY, profile.apiKey))
|
||||
)
|
||||
@@ -365,6 +454,22 @@ export function clearProviderProfileEnvFromProcessEnv(
|
||||
delete processEnv.ANTHROPIC_API_KEY
|
||||
delete processEnv[PROFILE_ENV_APPLIED_FLAG]
|
||||
delete processEnv[PROFILE_ENV_APPLIED_ID]
|
||||
|
||||
delete processEnv.GEMINI_MODEL
|
||||
delete processEnv.GEMINI_BASE_URL
|
||||
delete processEnv.GEMINI_API_KEY
|
||||
delete processEnv.GEMINI_AUTH_MODE
|
||||
delete processEnv.GEMINI_ACCESS_TOKEN
|
||||
delete processEnv.GOOGLE_API_KEY
|
||||
|
||||
delete processEnv.MISTRAL_MODEL
|
||||
delete processEnv.MISTRAL_BASE_URL
|
||||
delete processEnv.MISTRAL_API_KEY
|
||||
|
||||
// Clear provider-specific API keys
|
||||
delete processEnv.MINIMAX_API_KEY
|
||||
delete processEnv.NVIDIA_API_KEY
|
||||
delete processEnv.NVIDIA_NIM
|
||||
}
|
||||
|
||||
export function applyProviderProfileToProcessEnv(profile: ProviderProfile): void {
|
||||
@@ -372,7 +477,7 @@ export function applyProviderProfileToProcessEnv(profile: ProviderProfile): void
|
||||
process.env[PROFILE_ENV_APPLIED_FLAG] = '1'
|
||||
process.env[PROFILE_ENV_APPLIED_ID] = profile.id
|
||||
|
||||
process.env.ANTHROPIC_MODEL = profile.model
|
||||
process.env.ANTHROPIC_MODEL = getPrimaryModel(profile.model)
|
||||
if (profile.provider === 'anthropic') {
|
||||
process.env.ANTHROPIC_BASE_URL = profile.baseUrl
|
||||
|
||||
@@ -389,12 +494,54 @@ export function applyProviderProfileToProcessEnv(profile: ProviderProfile): void
|
||||
return
|
||||
}
|
||||
|
||||
if (profile.provider === 'mistral') {
|
||||
process.env.CLAUDE_CODE_USE_MISTRAL = '1'
|
||||
process.env.MISTRAL_BASE_URL = profile.baseUrl
|
||||
process.env.MISTRAL_MODEL = profile.model
|
||||
|
||||
if (profile.apiKey) {
|
||||
process.env.MISTRAL_API_KEY = profile.apiKey
|
||||
} else {
|
||||
delete process.env.MISTRAL_API_KEY
|
||||
}
|
||||
|
||||
delete process.env.OPENAI_BASE_URL
|
||||
delete process.env.OPENAI_API_KEY
|
||||
delete process.env.OPENAI_MODEL
|
||||
return
|
||||
}
|
||||
|
||||
if (profile.provider === 'gemini') {
|
||||
process.env.CLAUDE_CODE_USE_GEMINI = '1'
|
||||
process.env.GEMINI_BASE_URL = profile.baseUrl
|
||||
process.env.GEMINI_MODEL = profile.model
|
||||
|
||||
if (profile.apiKey) {
|
||||
process.env.GEMINI_API_KEY = profile.apiKey
|
||||
} else {
|
||||
delete process.env.GEMINI_API_KEY
|
||||
}
|
||||
|
||||
delete process.env.OPENAI_BASE_URL
|
||||
delete process.env.OPENAI_API_KEY
|
||||
delete process.env.OPENAI_MODEL
|
||||
return
|
||||
}
|
||||
|
||||
process.env.CLAUDE_CODE_USE_OPENAI = '1'
|
||||
process.env.OPENAI_BASE_URL = profile.baseUrl
|
||||
process.env.OPENAI_MODEL = profile.model
|
||||
process.env.OPENAI_MODEL = getPrimaryModel(profile.model)
|
||||
|
||||
if (profile.apiKey) {
|
||||
process.env.OPENAI_API_KEY = profile.apiKey
|
||||
// Also set provider-specific API keys for detection
|
||||
const baseUrl = profile.baseUrl.toLowerCase()
|
||||
if (baseUrl.includes('minimax')) {
|
||||
process.env.MINIMAX_API_KEY = profile.apiKey
|
||||
}
|
||||
if (baseUrl.includes('nvidia') || baseUrl.includes('integrate.api.nvidia')) {
|
||||
process.env.NVIDIA_API_KEY = profile.apiKey
|
||||
}
|
||||
} else {
|
||||
delete process.env.OPENAI_API_KEY
|
||||
}
|
||||
@@ -466,7 +613,7 @@ export function addProviderProfile(
|
||||
|
||||
const activeProfile = getActiveProviderProfile()
|
||||
if (activeProfile?.id === profile.id) {
|
||||
applyProviderProfileToProcessEnv(profile)
|
||||
setActiveProviderProfile(profile.id)
|
||||
clearActiveOpenAIModelOptionsCache()
|
||||
}
|
||||
|
||||
@@ -548,6 +695,16 @@ export function persistActiveProviderProfileModel(
|
||||
return null
|
||||
}
|
||||
|
||||
// If the model is already part of the profile's model list, don't
|
||||
// overwrite the field. This preserves comma-separated model lists like
|
||||
// "glm-4.5, glm-4.7". Switching between models in the list is a
|
||||
// session-level choice handled by mainLoopModelOverride, not a profile
|
||||
// edit — the profile's model list should only change via explicit edit.
|
||||
const existingModels = parseModelList(activeProfile.model)
|
||||
if (existingModels.includes(nextModel)) {
|
||||
return activeProfile
|
||||
}
|
||||
|
||||
saveGlobalConfig(current => {
|
||||
const currentProfiles = getProviderProfiles(current)
|
||||
const profileIndex = currentProfiles.findIndex(
|
||||
@@ -590,6 +747,23 @@ export function persistActiveProviderProfileModel(
|
||||
return resolvedProfile
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate model options from a provider profile's model field.
|
||||
* Each comma-separated model becomes a separate option in the picker.
|
||||
*/
|
||||
export function getProfileModelOptions(profile: ProviderProfile): ModelOption[] {
|
||||
const models = parseModelList(profile.model)
|
||||
if (models.length === 0) {
|
||||
return []
|
||||
}
|
||||
|
||||
return models.map(model => ({
|
||||
value: model,
|
||||
label: model,
|
||||
description: `Provider: ${profile.name}`,
|
||||
}))
|
||||
}
|
||||
|
||||
export function setActiveProviderProfile(
|
||||
profileId: string,
|
||||
): ProviderProfile | null {
|
||||
@@ -601,13 +775,85 @@ export function setActiveProviderProfile(
|
||||
return null
|
||||
}
|
||||
|
||||
const profileModelOptions = getProfileModelOptions(activeProfile)
|
||||
|
||||
saveGlobalConfig(config => ({
|
||||
...config,
|
||||
activeProviderProfileId: profileId,
|
||||
openaiAdditionalModelOptionsCache: getModelCacheByProfile(profileId, config),
|
||||
openaiAdditionalModelOptionsCache: profileModelOptions.length > 0
|
||||
? profileModelOptions
|
||||
: getModelCacheByProfile(profileId, config),
|
||||
openaiAdditionalModelOptionsCacheByProfile: {
|
||||
...(config.openaiAdditionalModelOptionsCacheByProfile ?? {}),
|
||||
[profileId]: profileModelOptions.length > 0
|
||||
? profileModelOptions
|
||||
: (config.openaiAdditionalModelOptionsCacheByProfile?.[profileId] ?? []),
|
||||
},
|
||||
}))
|
||||
|
||||
applyProviderProfileToProcessEnv(activeProfile)
|
||||
|
||||
// Keep startup persisted provider profile in sync so initial startup
|
||||
// uses the selected provider/model.
|
||||
const persistedProfile = (() => {
|
||||
if (activeProfile.provider === 'anthropic') return 'openai' as const
|
||||
return activeProfile.provider
|
||||
})()
|
||||
|
||||
const profileEnv = (() => {
|
||||
switch (activeProfile.provider) {
|
||||
case 'gemini':
|
||||
return (
|
||||
buildGeminiProfileEnv({
|
||||
model: activeProfile.model,
|
||||
baseUrl: activeProfile.baseUrl,
|
||||
apiKey: activeProfile.apiKey,
|
||||
authMode: 'api-key',
|
||||
processEnv: process.env,
|
||||
}) ?? null
|
||||
)
|
||||
case 'mistral':
|
||||
return (
|
||||
buildMistralProfileEnv({
|
||||
model: activeProfile.model,
|
||||
baseUrl: activeProfile.baseUrl,
|
||||
apiKey: activeProfile.apiKey,
|
||||
processEnv: process.env,
|
||||
}) ?? null
|
||||
)
|
||||
default:
|
||||
// anthropic and all openai-compatible providers
|
||||
return (
|
||||
buildOpenAIProfileEnv({
|
||||
model: activeProfile.model,
|
||||
baseUrl: activeProfile.baseUrl,
|
||||
apiKey: activeProfile.apiKey,
|
||||
processEnv: process.env,
|
||||
}) ?? null
|
||||
)
|
||||
}
|
||||
})()
|
||||
|
||||
if (profileEnv) {
|
||||
const startupProfile =
|
||||
activeProfile.provider === 'anthropic'
|
||||
? ({
|
||||
profile: 'openai' as ProviderProfileStartup,
|
||||
env: {
|
||||
OPENAI_BASE_URL: activeProfile.baseUrl,
|
||||
OPENAI_MODEL: activeProfile.model,
|
||||
OPENAI_API_KEY: activeProfile.apiKey,
|
||||
},
|
||||
} as const)
|
||||
: ({
|
||||
profile: activeProfile.provider as ProviderProfileStartup,
|
||||
env: profileEnv,
|
||||
} as const)
|
||||
|
||||
const file = createProfileFile(startupProfile.profile, startupProfile.env)
|
||||
saveProfileFile(file)
|
||||
}
|
||||
|
||||
return activeProfile
|
||||
}
|
||||
|
||||
|
||||
@@ -61,15 +61,7 @@ export function maskSecretForDisplay(
|
||||
return 'configured'
|
||||
}
|
||||
|
||||
if (sanitized.startsWith('sk-')) {
|
||||
return `${sanitized.slice(0, 3)}...${sanitized.slice(-4)}`
|
||||
}
|
||||
|
||||
if (sanitized.startsWith('AIza')) {
|
||||
return `${sanitized.slice(0, 4)}...${sanitized.slice(-4)}`
|
||||
}
|
||||
|
||||
return `${sanitized.slice(0, 2)}...${sanitized.slice(-4)}`
|
||||
return `${sanitized.slice(0, 3)}...${sanitized.slice(-3)}`
|
||||
}
|
||||
|
||||
export function redactSecretValueForDisplay(
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import { afterEach, expect, test } from 'bun:test'
|
||||
|
||||
import { getProviderValidationError } from './providerValidation.ts'
|
||||
import {
|
||||
getProviderValidationError,
|
||||
shouldExitForStartupProviderValidationError,
|
||||
} from './providerValidation.ts'
|
||||
|
||||
const originalEnv = {
|
||||
CLAUDE_CODE_USE_OPENAI: process.env.CLAUDE_CODE_USE_OPENAI,
|
||||
@@ -93,3 +96,45 @@ test('openai missing key error includes recovery guidance and config locations',
|
||||
expect(message).toContain('Saved startup settings can come from')
|
||||
expect(message).toContain('.openclaude-profile.json')
|
||||
})
|
||||
|
||||
test('startup provider validation allows interactive recovery', () => {
|
||||
expect(
|
||||
shouldExitForStartupProviderValidationError({
|
||||
args: [],
|
||||
stdoutIsTTY: true,
|
||||
}),
|
||||
).toBe(false)
|
||||
})
|
||||
|
||||
test('startup provider validation stays strict for non-interactive launches', () => {
|
||||
expect(
|
||||
shouldExitForStartupProviderValidationError({
|
||||
args: ['-p', 'hello'],
|
||||
stdoutIsTTY: true,
|
||||
}),
|
||||
).toBe(true)
|
||||
expect(
|
||||
shouldExitForStartupProviderValidationError({
|
||||
args: ['--print', 'hello'],
|
||||
stdoutIsTTY: true,
|
||||
}),
|
||||
).toBe(true)
|
||||
expect(
|
||||
shouldExitForStartupProviderValidationError({
|
||||
args: [],
|
||||
stdoutIsTTY: false,
|
||||
}),
|
||||
).toBe(true)
|
||||
expect(
|
||||
shouldExitForStartupProviderValidationError({
|
||||
args: ['--sdk-url', 'ws://127.0.0.1:3000'],
|
||||
stdoutIsTTY: true,
|
||||
}),
|
||||
).toBe(true)
|
||||
expect(
|
||||
shouldExitForStartupProviderValidationError({
|
||||
args: ['--sdk-url=ws://127.0.0.1:3000'],
|
||||
stdoutIsTTY: true,
|
||||
}),
|
||||
).toBe(true)
|
||||
})
|
||||
|
||||
@@ -169,3 +169,44 @@ export async function validateProviderEnvOrExit(
|
||||
process.exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
export function shouldExitForStartupProviderValidationError(options: {
|
||||
args?: string[]
|
||||
stdoutIsTTY?: boolean
|
||||
} = {}): boolean {
|
||||
const args = options.args ?? process.argv.slice(2)
|
||||
const stdoutIsTTY = options.stdoutIsTTY ?? process.stdout.isTTY
|
||||
|
||||
if (!stdoutIsTTY) {
|
||||
return true
|
||||
}
|
||||
|
||||
return (
|
||||
args.includes('-p') ||
|
||||
args.includes('--print') ||
|
||||
args.includes('--init-only') ||
|
||||
args.some(arg => arg.startsWith('--sdk-url'))
|
||||
)
|
||||
}
|
||||
|
||||
export async function validateProviderEnvForStartupOrExit(
|
||||
env: NodeJS.ProcessEnv = process.env,
|
||||
options?: {
|
||||
args?: string[]
|
||||
stdoutIsTTY?: boolean
|
||||
},
|
||||
): Promise<void> {
|
||||
const error = await getProviderValidationError(env)
|
||||
if (!error) {
|
||||
return
|
||||
}
|
||||
|
||||
if (shouldExitForStartupProviderValidationError(options)) {
|
||||
console.error(error)
|
||||
process.exit(1)
|
||||
}
|
||||
|
||||
console.error(
|
||||
`Warning: provider configuration is incomplete.\n${error}\nOpenClaude will continue starting so you can run /provider and repair the saved provider settings.`,
|
||||
)
|
||||
}
|
||||
|
||||
33
src/utils/swarm/spawnUtils.test.ts
Normal file
33
src/utils/swarm/spawnUtils.test.ts
Normal file
@@ -0,0 +1,33 @@
|
||||
import { afterEach, beforeEach, expect, test } from 'bun:test'
|
||||
|
||||
import { buildInheritedEnvVars } from './spawnUtils.js'
|
||||
|
||||
const ORIGINAL_ENV = { ...process.env }
|
||||
|
||||
beforeEach(() => {
|
||||
for (const key of Object.keys(process.env)) {
|
||||
delete process.env[key]
|
||||
}
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
for (const key of Object.keys(process.env)) {
|
||||
delete process.env[key]
|
||||
}
|
||||
Object.assign(process.env, ORIGINAL_ENV)
|
||||
})
|
||||
|
||||
test('buildInheritedEnvVars marks spawned teammates as host-managed for provider routing', () => {
|
||||
const envVars = buildInheritedEnvVars()
|
||||
|
||||
expect(envVars).toContain('CLAUDE_CODE_PROVIDER_MANAGED_BY_HOST=1')
|
||||
})
|
||||
|
||||
test('buildInheritedEnvVars forwards PATH for source-built teammate tool lookups', () => {
|
||||
process.env.PATH = '/custom/bin:/usr/bin'
|
||||
|
||||
const envVars = buildInheritedEnvVars()
|
||||
|
||||
expect(envVars).toContain('PATH=')
|
||||
expect(envVars).toContain('/custom/bin\\:/usr/bin')
|
||||
})
|
||||
@@ -141,6 +141,9 @@ const TEAMMATE_ENV_VARS = [
|
||||
'NODE_EXTRA_CA_CERTS',
|
||||
'REQUESTS_CA_BUNDLE',
|
||||
'CURL_CA_BUNDLE',
|
||||
// Source builds may rely on user shell PATH for rg/node/bun and other tools.
|
||||
// Forward it so teammates resolve the same toolchain as the parent session.
|
||||
'PATH',
|
||||
] as const
|
||||
|
||||
/**
|
||||
@@ -149,7 +152,13 @@ const TEAMMATE_ENV_VARS = [
|
||||
* plus any provider/config env vars that are set in the current process.
|
||||
*/
|
||||
export function buildInheritedEnvVars(): string {
|
||||
const envVars = ['CLAUDECODE=1', 'CLAUDE_CODE_EXPERIMENTAL_AGENT_TEAMS=1']
|
||||
const envVars = [
|
||||
'CLAUDECODE=1',
|
||||
'CLAUDE_CODE_EXPERIMENTAL_AGENT_TEAMS=1',
|
||||
// Teammates should inherit the leader-selected provider route instead of
|
||||
// replaying persisted ~/.claude or settings.env provider defaults.
|
||||
'CLAUDE_CODE_PROVIDER_MANAGED_BY_HOST=1',
|
||||
]
|
||||
|
||||
for (const key of TEAMMATE_ENV_VARS) {
|
||||
const value = process.env[key]
|
||||
|
||||
15
src/utils/truncate.test.ts
Normal file
15
src/utils/truncate.test.ts
Normal file
@@ -0,0 +1,15 @@
|
||||
import { truncate, truncateToWidth, truncatePathMiddle } from './truncate.js'
|
||||
|
||||
describe('truncate utilities', () => {
|
||||
test('truncate returns empty string for undefined input', () => {
|
||||
expect(truncate(undefined, 10)).toBe('')
|
||||
})
|
||||
|
||||
test('truncateToWidth returns empty string for undefined input', () => {
|
||||
expect(truncateToWidth(undefined, 5)).toBe('')
|
||||
})
|
||||
|
||||
test('truncatePathMiddle returns empty string for undefined path', () => {
|
||||
expect(truncatePathMiddle(undefined, 20)).toBe('')
|
||||
})
|
||||
})
|
||||
@@ -13,10 +13,11 @@ import { getGraphemeSegmenter } from './intl.js'
|
||||
* @param maxLength Maximum display width of the result in terminal columns (must be > 0)
|
||||
* @returns The truncated path, or original if it fits within maxLength
|
||||
*/
|
||||
export function truncatePathMiddle(path: string, maxLength: number): string {
|
||||
export function truncatePathMiddle(path: string | undefined, maxLength: number): string {
|
||||
const safePath = path ?? ''
|
||||
// No truncation needed
|
||||
if (stringWidth(path) <= maxLength) {
|
||||
return path
|
||||
if (stringWidth(safePath) <= maxLength) {
|
||||
return safePath
|
||||
}
|
||||
|
||||
// Handle edge case of very small or non-positive maxLength
|
||||
@@ -26,14 +27,14 @@ export function truncatePathMiddle(path: string, maxLength: number): string {
|
||||
|
||||
// Need at least room for "…" + something meaningful
|
||||
if (maxLength < 5) {
|
||||
return truncateToWidth(path, maxLength)
|
||||
return truncateToWidth(safePath, maxLength)
|
||||
}
|
||||
|
||||
// Find the filename (last path segment)
|
||||
const lastSlash = path.lastIndexOf('/')
|
||||
const lastSlash = safePath.lastIndexOf('/')
|
||||
// Include the leading slash in filename for display
|
||||
const filename = lastSlash >= 0 ? path.slice(lastSlash) : path
|
||||
const directory = lastSlash >= 0 ? path.slice(0, lastSlash) : ''
|
||||
const filename = lastSlash >= 0 ? safePath.slice(lastSlash) : safePath
|
||||
const directory = lastSlash >= 0 ? safePath.slice(0, lastSlash) : ''
|
||||
const filenameWidth = stringWidth(filename)
|
||||
|
||||
// If filename alone is too long, truncate from start
|
||||
@@ -60,12 +61,13 @@ export function truncatePathMiddle(path: string, maxLength: number): string {
|
||||
* Splits on grapheme boundaries to avoid breaking emoji or surrogate pairs.
|
||||
* Appends '…' when truncation occurs.
|
||||
*/
|
||||
export function truncateToWidth(text: string, maxWidth: number): string {
|
||||
if (stringWidth(text) <= maxWidth) return text
|
||||
export function truncateToWidth(text: string | undefined, maxWidth: number): string {
|
||||
const safeText = text ?? ''
|
||||
if (stringWidth(safeText) <= maxWidth) return safeText
|
||||
if (maxWidth <= 1) return '…'
|
||||
let width = 0
|
||||
let result = ''
|
||||
for (const { segment } of getGraphemeSegmenter().segment(text)) {
|
||||
for (const { segment } of getGraphemeSegmenter().segment(safeText)) {
|
||||
const segWidth = stringWidth(segment)
|
||||
if (width + segWidth > maxWidth - 1) break
|
||||
result += segment
|
||||
@@ -79,10 +81,11 @@ export function truncateToWidth(text: string, maxWidth: number): string {
|
||||
* Prepends '…' when truncation occurs.
|
||||
* Width-aware and grapheme-safe.
|
||||
*/
|
||||
export function truncateStartToWidth(text: string, maxWidth: number): string {
|
||||
if (stringWidth(text) <= maxWidth) return text
|
||||
export function truncateStartToWidth(text: string | undefined, maxWidth: number): string {
|
||||
const safeText = text ?? ''
|
||||
if (stringWidth(safeText) <= maxWidth) return safeText
|
||||
if (maxWidth <= 1) return '…'
|
||||
const segments = [...getGraphemeSegmenter().segment(text)]
|
||||
const segments = [...getGraphemeSegmenter().segment(safeText)]
|
||||
let width = 0
|
||||
let startIdx = segments.length
|
||||
for (let i = segments.length - 1; i >= 0; i--) {
|
||||
@@ -106,14 +109,15 @@ export function truncateStartToWidth(text: string, maxWidth: number): string {
|
||||
* Width-aware and grapheme-safe.
|
||||
*/
|
||||
export function truncateToWidthNoEllipsis(
|
||||
text: string,
|
||||
text: string | undefined,
|
||||
maxWidth: number,
|
||||
): string {
|
||||
if (stringWidth(text) <= maxWidth) return text
|
||||
const safeText = text ?? ''
|
||||
if (stringWidth(safeText) <= maxWidth) return safeText
|
||||
if (maxWidth <= 0) return ''
|
||||
let width = 0
|
||||
let result = ''
|
||||
for (const { segment } of getGraphemeSegmenter().segment(text)) {
|
||||
for (const { segment } of getGraphemeSegmenter().segment(safeText)) {
|
||||
const segWidth = stringWidth(segment)
|
||||
if (width + segWidth > maxWidth) break
|
||||
result += segment
|
||||
@@ -131,18 +135,21 @@ export function truncateToWidthNoEllipsis(
|
||||
* @param singleLine If true, also truncates at the first newline
|
||||
* @returns The truncated string with ellipsis if needed
|
||||
*/
|
||||
|
||||
export function truncate(
|
||||
str: string,
|
||||
str: string | undefined,
|
||||
maxWidth: number,
|
||||
singleLine: boolean = false,
|
||||
): string {
|
||||
let result = str
|
||||
const safeStr = str ?? ''
|
||||
if (safeStr === '') return ''
|
||||
let result = safeStr
|
||||
|
||||
// If singleLine is true, truncate at first newline
|
||||
if (singleLine) {
|
||||
const firstNewline = str.indexOf('\n')
|
||||
const firstNewline = safeStr.indexOf('\n')
|
||||
if (firstNewline !== -1) {
|
||||
result = str.substring(0, firstNewline)
|
||||
result = safeStr.substring(0, firstNewline)
|
||||
// Ensure total width including ellipsis doesn't exceed maxWidth
|
||||
if (stringWidth(result) + 1 > maxWidth) {
|
||||
return truncateToWidth(result, maxWidth)
|
||||
|
||||
69
src/utils/worktree.test.ts
Normal file
69
src/utils/worktree.test.ts
Normal file
@@ -0,0 +1,69 @@
|
||||
import { afterEach, expect, test } from 'bun:test'
|
||||
|
||||
import {
|
||||
_resetGitWorktreeMutationLocksForTesting,
|
||||
withGitWorktreeMutationLock,
|
||||
} from './worktree.js'
|
||||
|
||||
afterEach(() => {
|
||||
_resetGitWorktreeMutationLocksForTesting()
|
||||
})
|
||||
|
||||
test('withGitWorktreeMutationLock serializes mutations for the same repo', async () => {
|
||||
const order: string[] = []
|
||||
let releaseFirst!: () => void
|
||||
const firstGate = new Promise<void>(resolve => {
|
||||
releaseFirst = resolve
|
||||
})
|
||||
|
||||
const first = withGitWorktreeMutationLock('/repo', async () => {
|
||||
order.push('first:start')
|
||||
await firstGate
|
||||
order.push('first:end')
|
||||
})
|
||||
|
||||
const second = withGitWorktreeMutationLock('/repo', async () => {
|
||||
order.push('second:start')
|
||||
order.push('second:end')
|
||||
})
|
||||
|
||||
await Promise.resolve()
|
||||
await Promise.resolve()
|
||||
expect(order).toEqual(['first:start'])
|
||||
|
||||
releaseFirst()
|
||||
await Promise.all([first, second])
|
||||
|
||||
expect(order).toEqual([
|
||||
'first:start',
|
||||
'first:end',
|
||||
'second:start',
|
||||
'second:end',
|
||||
])
|
||||
})
|
||||
|
||||
test('withGitWorktreeMutationLock does not serialize different repos', async () => {
|
||||
const order: string[] = []
|
||||
let releaseFirst!: () => void
|
||||
const firstGate = new Promise<void>(resolve => {
|
||||
releaseFirst = resolve
|
||||
})
|
||||
|
||||
const first = withGitWorktreeMutationLock('/repo-a', async () => {
|
||||
order.push('a:start')
|
||||
await firstGate
|
||||
order.push('a:end')
|
||||
})
|
||||
|
||||
const second = withGitWorktreeMutationLock('/repo-b', async () => {
|
||||
order.push('b:start')
|
||||
order.push('b:end')
|
||||
})
|
||||
|
||||
await Promise.resolve()
|
||||
await Promise.resolve()
|
||||
expect(order).toEqual(['a:start', 'b:start', 'b:end'])
|
||||
|
||||
releaseFirst()
|
||||
await Promise.all([first, second])
|
||||
})
|
||||
@@ -192,6 +192,36 @@ type WorktreeCreateResult =
|
||||
existed: false
|
||||
}
|
||||
|
||||
const gitWorktreeMutationLocks = new Map<string, Promise<void>>()
|
||||
|
||||
export async function withGitWorktreeMutationLock<T>(
|
||||
repoRoot: string,
|
||||
fn: () => Promise<T>,
|
||||
): Promise<T> {
|
||||
const previous = gitWorktreeMutationLocks.get(repoRoot) ?? Promise.resolve()
|
||||
let releaseCurrent!: () => void
|
||||
const current = new Promise<void>(resolve => {
|
||||
releaseCurrent = resolve
|
||||
})
|
||||
const next = previous.catch(() => {}).then(() => current)
|
||||
gitWorktreeMutationLocks.set(repoRoot, next)
|
||||
|
||||
await previous.catch(() => {})
|
||||
|
||||
try {
|
||||
return await fn()
|
||||
} finally {
|
||||
releaseCurrent()
|
||||
if (gitWorktreeMutationLocks.get(repoRoot) === next) {
|
||||
gitWorktreeMutationLocks.delete(repoRoot)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function _resetGitWorktreeMutationLocksForTesting(): void {
|
||||
gitWorktreeMutationLocks.clear()
|
||||
}
|
||||
|
||||
// Env vars to prevent git/SSH from prompting for credentials (which hangs the CLI).
|
||||
// GIT_TERMINAL_PROMPT=0 prevents git from opening /dev/tty for credential prompts.
|
||||
// GIT_ASKPASS='' disables askpass GUI programs.
|
||||
@@ -254,124 +284,136 @@ async function getOrCreateWorktree(
|
||||
}
|
||||
}
|
||||
|
||||
// New worktree: fetch base branch then add
|
||||
await mkdir(worktreesDir(repoRoot), { recursive: true })
|
||||
|
||||
const fetchEnv = { ...process.env, ...GIT_NO_PROMPT_ENV }
|
||||
|
||||
let baseBranch: string
|
||||
let baseSha: string | null = null
|
||||
if (options?.prNumber) {
|
||||
const { code: prFetchCode, stderr: prFetchStderr } =
|
||||
await execFileNoThrowWithCwd(
|
||||
gitExe(),
|
||||
['fetch', 'origin', `pull/${options.prNumber}/head`],
|
||||
{ cwd: repoRoot, stdin: 'ignore', env: fetchEnv },
|
||||
)
|
||||
if (prFetchCode !== 0) {
|
||||
throw new Error(
|
||||
`Failed to fetch PR #${options.prNumber}: ${prFetchStderr.trim() || 'PR may not exist or the repository may not have a remote named "origin"'}`,
|
||||
)
|
||||
return withGitWorktreeMutationLock(repoRoot, async () => {
|
||||
const lockedExistingHead = await readWorktreeHeadSha(worktreePath)
|
||||
if (lockedExistingHead) {
|
||||
return {
|
||||
worktreePath,
|
||||
worktreeBranch,
|
||||
headCommit: lockedExistingHead,
|
||||
existed: true,
|
||||
}
|
||||
}
|
||||
baseBranch = 'FETCH_HEAD'
|
||||
} else {
|
||||
// If origin/<branch> already exists locally, skip fetch. In large repos
|
||||
// (210k files, 16M objects) fetch burns ~6-8s on a local commit-graph
|
||||
// scan before even hitting the network. A slightly stale base is fine —
|
||||
// the user can pull in the worktree if they want latest.
|
||||
// resolveRef reads the loose/packed ref directly; when it succeeds we
|
||||
// already have the SHA, so the later rev-parse is skipped entirely.
|
||||
const [defaultBranch, gitDir] = await Promise.all([
|
||||
getDefaultBranch(),
|
||||
resolveGitDir(repoRoot),
|
||||
])
|
||||
const originRef = `origin/${defaultBranch}`
|
||||
const originSha = gitDir
|
||||
? await resolveRef(gitDir, `refs/remotes/origin/${defaultBranch}`)
|
||||
: null
|
||||
if (originSha) {
|
||||
baseBranch = originRef
|
||||
baseSha = originSha
|
||||
|
||||
// New worktree: fetch base branch then add
|
||||
await mkdir(worktreesDir(repoRoot), { recursive: true })
|
||||
|
||||
const fetchEnv = { ...process.env, ...GIT_NO_PROMPT_ENV }
|
||||
|
||||
let baseBranch: string
|
||||
let baseSha: string | null = null
|
||||
if (options?.prNumber) {
|
||||
const { code: prFetchCode, stderr: prFetchStderr } =
|
||||
await execFileNoThrowWithCwd(
|
||||
gitExe(),
|
||||
['fetch', 'origin', `pull/${options.prNumber}/head`],
|
||||
{ cwd: repoRoot, stdin: 'ignore', env: fetchEnv },
|
||||
)
|
||||
if (prFetchCode !== 0) {
|
||||
throw new Error(
|
||||
`Failed to fetch PR #${options.prNumber}: ${prFetchStderr.trim() || 'PR may not exist or the repository may not have a remote named "origin"'}`,
|
||||
)
|
||||
}
|
||||
baseBranch = 'FETCH_HEAD'
|
||||
} else {
|
||||
const { code: fetchCode } = await execFileNoThrowWithCwd(
|
||||
gitExe(),
|
||||
['fetch', 'origin', defaultBranch],
|
||||
{ cwd: repoRoot, stdin: 'ignore', env: fetchEnv },
|
||||
)
|
||||
baseBranch = fetchCode === 0 ? originRef : 'HEAD'
|
||||
// If origin/<branch> already exists locally, skip fetch. In large repos
|
||||
// (210k files, 16M objects) fetch burns ~6-8s on a local commit-graph
|
||||
// scan before even hitting the network. A slightly stale base is fine —
|
||||
// the user can pull in the worktree if they want latest.
|
||||
// resolveRef reads the loose/packed ref directly; when it succeeds we
|
||||
// already have the SHA, so the later rev-parse is skipped entirely.
|
||||
const [defaultBranch, gitDir] = await Promise.all([
|
||||
getDefaultBranch(),
|
||||
resolveGitDir(repoRoot),
|
||||
])
|
||||
const originRef = `origin/${defaultBranch}`
|
||||
const originSha = gitDir
|
||||
? await resolveRef(gitDir, `refs/remotes/origin/${defaultBranch}`)
|
||||
: null
|
||||
if (originSha) {
|
||||
baseBranch = originRef
|
||||
baseSha = originSha
|
||||
} else {
|
||||
const { code: fetchCode } = await execFileNoThrowWithCwd(
|
||||
gitExe(),
|
||||
['fetch', 'origin', defaultBranch],
|
||||
{ cwd: repoRoot, stdin: 'ignore', env: fetchEnv },
|
||||
)
|
||||
baseBranch = fetchCode === 0 ? originRef : 'HEAD'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For the fetch/PR-fetch paths we still need the SHA — the fs-only resolveRef
|
||||
// above only covers the "origin/<branch> already exists locally" case.
|
||||
if (!baseSha) {
|
||||
const { stdout, code: shaCode } = await execFileNoThrowWithCwd(
|
||||
gitExe(),
|
||||
['rev-parse', baseBranch],
|
||||
{ cwd: repoRoot },
|
||||
)
|
||||
if (shaCode !== 0) {
|
||||
throw new Error(
|
||||
`Failed to resolve base branch "${baseBranch}": git rev-parse failed`,
|
||||
)
|
||||
}
|
||||
baseSha = stdout.trim()
|
||||
}
|
||||
|
||||
const sparsePaths = getInitialSettings().worktree?.sparsePaths
|
||||
const addArgs = ['worktree', 'add']
|
||||
if (sparsePaths?.length) {
|
||||
addArgs.push('--no-checkout')
|
||||
}
|
||||
// -B (not -b): reset any orphan branch left behind by a removed worktree dir.
|
||||
// Saves a `git branch -D` subprocess (~15ms spawn overhead) on every create.
|
||||
addArgs.push('-B', worktreeBranch, worktreePath, baseBranch)
|
||||
|
||||
const { code: createCode, stderr: createStderr } =
|
||||
await execFileNoThrowWithCwd(gitExe(), addArgs, { cwd: repoRoot })
|
||||
if (createCode !== 0) {
|
||||
throw new Error(`Failed to create worktree: ${createStderr}`)
|
||||
}
|
||||
|
||||
if (sparsePaths?.length) {
|
||||
// If sparse-checkout or checkout fail after --no-checkout, the worktree
|
||||
// is registered and HEAD is set but the working tree is empty. Next run's
|
||||
// fast-resume (rev-parse HEAD) would succeed and present a broken worktree
|
||||
// as "resumed". Tear it down before propagating the error.
|
||||
const tearDown = async (msg: string): Promise<never> => {
|
||||
await execFileNoThrowWithCwd(
|
||||
// For the fetch/PR-fetch paths we still need the SHA — the fs-only resolveRef
|
||||
// above only covers the "origin/<branch> already exists locally" case.
|
||||
if (!baseSha) {
|
||||
const { stdout, code: shaCode } = await execFileNoThrowWithCwd(
|
||||
gitExe(),
|
||||
['worktree', 'remove', '--force', worktreePath],
|
||||
['rev-parse', baseBranch],
|
||||
{ cwd: repoRoot },
|
||||
)
|
||||
throw new Error(msg)
|
||||
if (shaCode !== 0) {
|
||||
throw new Error(
|
||||
`Failed to resolve base branch "${baseBranch}": git rev-parse failed`,
|
||||
)
|
||||
}
|
||||
baseSha = stdout.trim()
|
||||
}
|
||||
const { code: sparseCode, stderr: sparseErr } =
|
||||
await execFileNoThrowWithCwd(
|
||||
|
||||
const sparsePaths = getInitialSettings().worktree?.sparsePaths
|
||||
const addArgs = ['worktree', 'add']
|
||||
if (sparsePaths?.length) {
|
||||
addArgs.push('--no-checkout')
|
||||
}
|
||||
// -B (not -b): reset any orphan branch left behind by a removed worktree dir.
|
||||
// Saves a `git branch -D` subprocess (~15ms spawn overhead) on every create.
|
||||
addArgs.push('-B', worktreeBranch, worktreePath, baseBranch)
|
||||
|
||||
const { code: createCode, stderr: createStderr } =
|
||||
await execFileNoThrowWithCwd(gitExe(), addArgs, { cwd: repoRoot })
|
||||
if (createCode !== 0) {
|
||||
throw new Error(`Failed to create worktree: ${createStderr}`)
|
||||
}
|
||||
|
||||
if (sparsePaths?.length) {
|
||||
// If sparse-checkout or checkout fail after --no-checkout, the worktree
|
||||
// is registered and HEAD is set but the working tree is empty. Next run's
|
||||
// fast-resume (rev-parse HEAD) would succeed and present a broken worktree
|
||||
// as "resumed". Tear it down before propagating the error.
|
||||
const tearDown = async (msg: string): Promise<never> => {
|
||||
await execFileNoThrowWithCwd(
|
||||
gitExe(),
|
||||
['worktree', 'remove', '--force', worktreePath],
|
||||
{ cwd: repoRoot },
|
||||
)
|
||||
throw new Error(msg)
|
||||
}
|
||||
const { code: sparseCode, stderr: sparseErr } =
|
||||
await execFileNoThrowWithCwd(
|
||||
gitExe(),
|
||||
['sparse-checkout', 'set', '--cone', '--', ...sparsePaths],
|
||||
{ cwd: worktreePath },
|
||||
)
|
||||
if (sparseCode !== 0) {
|
||||
await tearDown(`Failed to configure sparse-checkout: ${sparseErr}`)
|
||||
}
|
||||
const { code: coCode, stderr: coErr } = await execFileNoThrowWithCwd(
|
||||
gitExe(),
|
||||
['sparse-checkout', 'set', '--cone', '--', ...sparsePaths],
|
||||
['checkout', 'HEAD'],
|
||||
{ cwd: worktreePath },
|
||||
)
|
||||
if (sparseCode !== 0) {
|
||||
await tearDown(`Failed to configure sparse-checkout: ${sparseErr}`)
|
||||
if (coCode !== 0) {
|
||||
await tearDown(`Failed to checkout sparse worktree: ${coErr}`)
|
||||
}
|
||||
}
|
||||
const { code: coCode, stderr: coErr } = await execFileNoThrowWithCwd(
|
||||
gitExe(),
|
||||
['checkout', 'HEAD'],
|
||||
{ cwd: worktreePath },
|
||||
)
|
||||
if (coCode !== 0) {
|
||||
await tearDown(`Failed to checkout sparse worktree: ${coErr}`)
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
worktreePath,
|
||||
worktreeBranch,
|
||||
headCommit: baseSha,
|
||||
baseBranch,
|
||||
existed: false,
|
||||
}
|
||||
return {
|
||||
worktreePath,
|
||||
worktreeBranch,
|
||||
headCommit: baseSha,
|
||||
baseBranch,
|
||||
existed: false,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -984,39 +1026,41 @@ export async function removeAgentWorktree(
|
||||
return false
|
||||
}
|
||||
|
||||
// Run from the main repo root, not the worktree (which we're about to delete)
|
||||
const { code: removeCode, stderr: removeError } =
|
||||
await execFileNoThrowWithCwd(
|
||||
gitExe(),
|
||||
['worktree', 'remove', '--force', worktreePath],
|
||||
{ cwd: gitRoot },
|
||||
)
|
||||
return withGitWorktreeMutationLock(gitRoot, async () => {
|
||||
// Run from the main repo root, not the worktree (which we're about to delete)
|
||||
const { code: removeCode, stderr: removeError } =
|
||||
await execFileNoThrowWithCwd(
|
||||
gitExe(),
|
||||
['worktree', 'remove', '--force', worktreePath],
|
||||
{ cwd: gitRoot },
|
||||
)
|
||||
|
||||
if (removeCode !== 0) {
|
||||
logForDebugging(`Failed to remove agent worktree: ${removeError}`, {
|
||||
level: 'error',
|
||||
})
|
||||
return false
|
||||
}
|
||||
logForDebugging(`Removed agent worktree at: ${worktreePath}`)
|
||||
if (removeCode !== 0) {
|
||||
logForDebugging(`Failed to remove agent worktree: ${removeError}`, {
|
||||
level: 'error',
|
||||
})
|
||||
return false
|
||||
}
|
||||
logForDebugging(`Removed agent worktree at: ${worktreePath}`)
|
||||
|
||||
if (!worktreeBranch) {
|
||||
if (!worktreeBranch) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Delete the temporary worktree branch from the main repo
|
||||
const { code: deleteBranchCode, stderr: deleteBranchError } =
|
||||
await execFileNoThrowWithCwd(gitExe(), ['branch', '-D', worktreeBranch], {
|
||||
cwd: gitRoot,
|
||||
})
|
||||
|
||||
if (deleteBranchCode !== 0) {
|
||||
logForDebugging(
|
||||
`Could not delete agent worktree branch: ${deleteBranchError}`,
|
||||
{ level: 'error' },
|
||||
)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Delete the temporary worktree branch from the main repo
|
||||
const { code: deleteBranchCode, stderr: deleteBranchError } =
|
||||
await execFileNoThrowWithCwd(gitExe(), ['branch', '-D', worktreeBranch], {
|
||||
cwd: gitRoot,
|
||||
})
|
||||
|
||||
if (deleteBranchCode !== 0) {
|
||||
logForDebugging(
|
||||
`Could not delete agent worktree branch: ${deleteBranchError}`,
|
||||
{ level: 'error' },
|
||||
)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
"name": "openclaude-vscode",
|
||||
"displayName": "OpenClaude",
|
||||
"description": "Practical VS Code companion for OpenClaude with project-aware launch behavior and a real Control Center.",
|
||||
"version": "0.1.1",
|
||||
"version": "0.2.0",
|
||||
"publisher": "devnull-bootloader",
|
||||
"engines": {
|
||||
"vscode": "^1.95.0"
|
||||
@@ -19,7 +19,12 @@
|
||||
"onCommand:openclaude.openSetupDocs",
|
||||
"onCommand:openclaude.openWorkspaceProfile",
|
||||
"onCommand:openclaude.openControlCenter",
|
||||
"onView:openclaude.controlCenter"
|
||||
"onCommand:openclaude.newChat",
|
||||
"onCommand:openclaude.openChat",
|
||||
"onCommand:openclaude.resumeSession",
|
||||
"onCommand:openclaude.abortChat",
|
||||
"onView:openclaude.controlCenter",
|
||||
"onView:openclaude.chat"
|
||||
],
|
||||
"main": "./src/extension.js",
|
||||
"files": [
|
||||
@@ -28,6 +33,7 @@
|
||||
"src/extension.js",
|
||||
"src/presentation.js",
|
||||
"src/state.js",
|
||||
"src/chat/**",
|
||||
"themes/**"
|
||||
],
|
||||
"contributes": {
|
||||
@@ -61,6 +67,26 @@
|
||||
"command": "openclaude.openControlCenter",
|
||||
"title": "OpenClaude: Open Control Center",
|
||||
"category": "OpenClaude"
|
||||
},
|
||||
{
|
||||
"command": "openclaude.newChat",
|
||||
"title": "OpenClaude: New Chat",
|
||||
"category": "OpenClaude"
|
||||
},
|
||||
{
|
||||
"command": "openclaude.openChat",
|
||||
"title": "OpenClaude: Open Chat Panel",
|
||||
"category": "OpenClaude"
|
||||
},
|
||||
{
|
||||
"command": "openclaude.resumeSession",
|
||||
"title": "OpenClaude: Resume Session",
|
||||
"category": "OpenClaude"
|
||||
},
|
||||
{
|
||||
"command": "openclaude.abortChat",
|
||||
"title": "OpenClaude: Abort Generation",
|
||||
"category": "OpenClaude"
|
||||
}
|
||||
],
|
||||
"viewsContainers": {
|
||||
@@ -74,6 +100,11 @@
|
||||
},
|
||||
"views": {
|
||||
"openclaude": [
|
||||
{
|
||||
"id": "openclaude.chat",
|
||||
"name": "Chat",
|
||||
"type": "webview"
|
||||
},
|
||||
{
|
||||
"id": "openclaude.controlCenter",
|
||||
"name": "Control Center",
|
||||
@@ -81,6 +112,13 @@
|
||||
}
|
||||
]
|
||||
},
|
||||
"keybindings": [
|
||||
{
|
||||
"command": "openclaude.openChat",
|
||||
"key": "ctrl+shift+l",
|
||||
"mac": "cmd+shift+l"
|
||||
}
|
||||
],
|
||||
"configuration": {
|
||||
"title": "OpenClaude",
|
||||
"properties": {
|
||||
@@ -98,6 +136,18 @@
|
||||
"type": "boolean",
|
||||
"default": false,
|
||||
"description": "Optionally set CLAUDE_CODE_USE_OPENAI=1 in launched OpenClaude terminals."
|
||||
},
|
||||
"openclaude.permissionMode": {
|
||||
"type": "string",
|
||||
"default": "acceptEdits",
|
||||
"enum": ["default", "acceptEdits", "bypassPermissions", "plan"],
|
||||
"enumDescriptions": [
|
||||
"Prompt for permission on each tool use (requires manual approval)",
|
||||
"Auto-approve file edits, prompt for other operations (recommended)",
|
||||
"Auto-approve all operations without prompting",
|
||||
"Read-only mode — no file modifications allowed"
|
||||
],
|
||||
"description": "Permission mode for chat sessions. Controls which tool operations are auto-approved."
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -111,7 +161,7 @@
|
||||
},
|
||||
"scripts": {
|
||||
"test": "node --test ./src/*.test.js",
|
||||
"lint": "node -e \"for (const file of require('node:fs').readdirSync('./src')) { if (file.endsWith('.js')) { require('node:child_process').execFileSync(process.execPath, ['--check', require('node:path').join('src', file)], { stdio: 'inherit' }); } }\"",
|
||||
"lint": "node scripts/lint.js",
|
||||
"package": "npx @vscode/vsce package --no-dependencies"
|
||||
},
|
||||
"keywords": [
|
||||
|
||||
17
vscode-extension/openclaude-vscode/scripts/lint.js
Normal file
17
vscode-extension/openclaude-vscode/scripts/lint.js
Normal file
@@ -0,0 +1,17 @@
|
||||
const { readdirSync } = require('node:fs');
|
||||
const { execFileSync } = require('node:child_process');
|
||||
const { join } = require('node:path');
|
||||
|
||||
function check(dir) {
|
||||
for (const f of readdirSync(dir, { withFileTypes: true })) {
|
||||
if (f.isDirectory()) {
|
||||
check(join(dir, f.name));
|
||||
} else if (f.name.endsWith('.js') && !f.name.endsWith('.test.js')) {
|
||||
execFileSync(process.execPath, ['--check', join(dir, f.name)], {
|
||||
stdio: 'inherit',
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
check('./src');
|
||||
676
vscode-extension/openclaude-vscode/src/chat/chatProvider.js
Normal file
676
vscode-extension/openclaude-vscode/src/chat/chatProvider.js
Normal file
@@ -0,0 +1,676 @@
|
||||
/**
|
||||
* chatProvider — WebviewViewProvider (sidebar) and WebviewPanel manager
|
||||
* (editor tab) that wire ProcessManager events to the chat UI.
|
||||
*/
|
||||
|
||||
const vscode = require('vscode');
|
||||
const crypto = require('crypto');
|
||||
const { ProcessManager } = require('./processManager');
|
||||
const { toViewModel } = require('./messageParser');
|
||||
const { renderChatHtml } = require('./chatRenderer');
|
||||
const { isAssistantMessage, isPartialMessage, isStreamEvent,
|
||||
isContentBlockDelta, isContentBlockStart, isMessageStart,
|
||||
isResultMessage, isControlRequest, isToolProgressMessage,
|
||||
isStatusMessage, isRateLimitEvent, getTextContent,
|
||||
getToolUseBlocks } = require('./protocol');
|
||||
|
||||
async function openFileInEditor(filePath) {
|
||||
try {
|
||||
const uri = vscode.Uri.file(filePath);
|
||||
const doc = await vscode.workspace.openTextDocument(uri);
|
||||
await vscode.window.showTextDocument(doc, { preview: false });
|
||||
} catch {
|
||||
vscode.window.showWarningMessage(`Could not open file: ${filePath}`);
|
||||
}
|
||||
}
|
||||
|
||||
function getLaunchConfig() {
|
||||
const cfg = vscode.workspace.getConfiguration('openclaude');
|
||||
const command = cfg.get('launchCommand', 'openclaude');
|
||||
const shimEnabled = cfg.get('useOpenAIShim', false);
|
||||
const permissionMode = cfg.get('permissionMode', 'acceptEdits');
|
||||
const env = {};
|
||||
if (shimEnabled) env.CLAUDE_CODE_USE_OPENAI = '1';
|
||||
const folders = vscode.workspace.workspaceFolders;
|
||||
const cwd = folders && folders.length > 0 ? folders[0].uri.fsPath : undefined;
|
||||
return { command, cwd, env, permissionMode };
|
||||
}
|
||||
|
||||
class ChatController {
|
||||
constructor(sessionManager) {
|
||||
this._sessionManager = sessionManager;
|
||||
this._process = null;
|
||||
this._webviews = new Set();
|
||||
this._accumulatedText = '';
|
||||
this._toolUses = [];
|
||||
this._messages = [];
|
||||
this._currentSessionId = null;
|
||||
this._streaming = false;
|
||||
this._lastResult = null;
|
||||
this._thinkingTokens = 0;
|
||||
this._thinkingStartTime = null;
|
||||
this._currentBlockType = null;
|
||||
|
||||
this._onDidChangeState = new vscode.EventEmitter();
|
||||
this.onDidChangeState = this._onDidChangeState.event;
|
||||
}
|
||||
|
||||
get sessionId() { return this._currentSessionId; }
|
||||
get isStreaming() { return this._process && this._process.running; }
|
||||
get sessionManager() { return this._sessionManager; }
|
||||
|
||||
registerWebview(webview) {
|
||||
this._webviews.add(webview);
|
||||
return { dispose: () => this._webviews.delete(webview) };
|
||||
}
|
||||
|
||||
broadcast(msg) {
|
||||
for (const wv of this._webviews) {
|
||||
try { wv.postMessage(msg); } catch { /* webview might be disposed */ }
|
||||
}
|
||||
}
|
||||
|
||||
_broadcast(msg) {
|
||||
this.broadcast(msg);
|
||||
}
|
||||
|
||||
async startSession(opts = {}) {
|
||||
this.stopSession();
|
||||
this._accumulatedText = '';
|
||||
this._toolUses = [];
|
||||
// Only clear messages if this is a brand new session (not continuing)
|
||||
if (!opts.continueSession && !opts.sessionId) {
|
||||
this._messages = [];
|
||||
}
|
||||
this._currentSessionId = opts.sessionId || this._currentSessionId || null;
|
||||
|
||||
const { command, cwd, env, permissionMode } = getLaunchConfig();
|
||||
|
||||
this._process = new ProcessManager({
|
||||
command,
|
||||
cwd,
|
||||
env,
|
||||
sessionId: opts.sessionId,
|
||||
continueSession: opts.continueSession || false,
|
||||
model: opts.model,
|
||||
permissionMode,
|
||||
extraArgs: opts.extraArgs || [],
|
||||
});
|
||||
|
||||
this._readyResolve = null;
|
||||
this._readyPromise = new Promise(resolve => { this._readyResolve = resolve; });
|
||||
|
||||
this._process.onMessage((msg) => {
|
||||
if (msg.type === 'system' && this._readyResolve) {
|
||||
this._readyResolve();
|
||||
this._readyResolve = null;
|
||||
}
|
||||
this._handleMessage(msg);
|
||||
});
|
||||
this._process.onError((err) => {
|
||||
this._broadcast({ type: 'error', message: err.message || String(err) });
|
||||
});
|
||||
this._process.onExit(({ code }) => {
|
||||
// Flush any remaining streamed text
|
||||
if (this._streaming && this._accumulatedText) {
|
||||
this._broadcast({ type: 'stream_end', text: this._accumulatedText, usage: null, final: true });
|
||||
} else if (this._streaming) {
|
||||
this._broadcast({ type: 'stream_end', text: '', usage: (this._lastResult || {}).usage || null, final: true });
|
||||
}
|
||||
this._streaming = false;
|
||||
this._accumulatedText = '';
|
||||
this._toolUses = [];
|
||||
this._lastResult = null;
|
||||
this._broadcast({
|
||||
type: 'connected',
|
||||
message: code === 0 ? 'Ready' : `Process exited (code ${code})`,
|
||||
});
|
||||
this._onDidChangeState.fire('idle');
|
||||
});
|
||||
|
||||
try {
|
||||
this._process.start();
|
||||
this._broadcast({ type: 'connected', message: 'Connected' });
|
||||
this._onDidChangeState.fire('connected');
|
||||
} catch (err) {
|
||||
this._broadcast({ type: 'error', message: `Failed to start: ${err.message}` });
|
||||
}
|
||||
}
|
||||
|
||||
stopSession() {
|
||||
if (this._process) {
|
||||
this._process.dispose();
|
||||
this._process = null;
|
||||
}
|
||||
}
|
||||
|
||||
async sendMessage(text) {
|
||||
// Keep the process alive for multi-turn — just send directly.
|
||||
// The CLI maintains full session state (tools, history) across turns.
|
||||
// Only start a new process if none exists or it died.
|
||||
if (!this._process || !this._process.running) {
|
||||
await this.startSession({
|
||||
sessionId: this._currentSessionId || undefined,
|
||||
});
|
||||
}
|
||||
await this._doSend(text);
|
||||
}
|
||||
|
||||
async _doSend(text) {
|
||||
if (!this._process) return;
|
||||
// On first message after process start, wait for CLI to be ready.
|
||||
// On subsequent messages, the process is already running and accepting input.
|
||||
if (this._readyPromise) {
|
||||
const grace = new Promise(resolve => setTimeout(resolve, 8000));
|
||||
await Promise.race([this._readyPromise, grace]);
|
||||
this._readyPromise = null;
|
||||
}
|
||||
this._accumulatedText = '';
|
||||
this._toolUses = [];
|
||||
try {
|
||||
this._process.sendUserMessage(text);
|
||||
this._messages.push({ role: 'user', text });
|
||||
} catch (err) {
|
||||
this._broadcast({ type: 'error', message: err.message });
|
||||
}
|
||||
}
|
||||
|
||||
abort() {
|
||||
if (this._process) {
|
||||
this._process.abort();
|
||||
this._broadcast({ type: 'stream_end', text: this._accumulatedText, usage: null });
|
||||
this._onDidChangeState.fire('idle');
|
||||
}
|
||||
}
|
||||
|
||||
sendPermissionResponse(requestId, action, toolUseId) {
|
||||
if (!this._process) return;
|
||||
if (action === 'deny') {
|
||||
try {
|
||||
this._process.write({
|
||||
type: 'control_response',
|
||||
response: {
|
||||
subtype: 'error',
|
||||
request_id: requestId,
|
||||
error: 'User denied permission',
|
||||
},
|
||||
});
|
||||
} catch (err) {
|
||||
this._broadcast({ type: 'error', message: err.message });
|
||||
}
|
||||
return;
|
||||
}
|
||||
try {
|
||||
this._process.sendControlResponse(requestId, {
|
||||
toolUseID: toolUseId || undefined,
|
||||
...(action === 'allow-session' ? { remember: true } : {}),
|
||||
});
|
||||
} catch (err) {
|
||||
this._broadcast({ type: 'error', message: err.message });
|
||||
}
|
||||
}
|
||||
|
||||
getMessages() { return this._messages; }
|
||||
|
||||
_handleMessage(msg) {
|
||||
if (msg.session_id && !this._currentSessionId) {
|
||||
this._currentSessionId = msg.session_id;
|
||||
}
|
||||
|
||||
// System message — extract model and session info
|
||||
if (msg.type === 'system') {
|
||||
this._broadcast({
|
||||
type: 'system_info',
|
||||
model: msg.model || null,
|
||||
sessionId: msg.session_id || msg.sessionId || null,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Control request (permission prompt) — check EARLY before other handlers
|
||||
if (msg.type === 'control_request' || isControlRequest(msg)) {
|
||||
const req = msg.request || {};
|
||||
const { toolDisplayName, parseToolInput } = require('./messageParser');
|
||||
this._broadcast({
|
||||
type: 'permission_request',
|
||||
requestId: msg.request_id,
|
||||
toolName: req.tool_name || 'Unknown',
|
||||
displayName: req.display_name || req.title || toolDisplayName(req.tool_name),
|
||||
description: req.description || '',
|
||||
inputPreview: parseToolInput(req.input),
|
||||
toolUseId: req.tool_use_id || null,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Control cancel request
|
||||
if (msg.type === 'control_cancel_request') {
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle Anthropic raw stream events (the primary streaming mechanism)
|
||||
if (isStreamEvent(msg)) {
|
||||
this._handleStreamEvent(msg);
|
||||
return;
|
||||
}
|
||||
|
||||
// Assistant message — always mid-turn; true completion comes from 'result'
|
||||
if (isAssistantMessage(msg)) {
|
||||
const inner = msg.message || msg;
|
||||
const text = getTextContent(inner);
|
||||
const toolBlocks = getToolUseBlocks(inner);
|
||||
const { toolDisplayName, toolIcon } = require('./messageParser');
|
||||
const toolUseVms = toolBlocks.map(tu => ({
|
||||
id: tu.id,
|
||||
name: tu.name,
|
||||
displayName: toolDisplayName(tu.name),
|
||||
icon: toolIcon(tu.name),
|
||||
inputPreview: typeof tu.input === 'string' ? tu.input : JSON.stringify(tu.input || ''),
|
||||
input: tu.input,
|
||||
status: 'running',
|
||||
}));
|
||||
this._messages.push({ role: 'assistant', text, toolUses: toolUseVms });
|
||||
const usage = inner.usage || msg.usage || null;
|
||||
|
||||
// Finalize current text bubble but stay streaming — true completion
|
||||
// is signaled by the 'result' message, not by the assistant message.
|
||||
this._broadcast({ type: 'stream_end', text, usage, final: false });
|
||||
this._accumulatedText = '';
|
||||
|
||||
if (toolBlocks.length > 0) {
|
||||
for (const tu of toolBlocks) {
|
||||
this._broadcast({
|
||||
type: 'tool_input_ready',
|
||||
toolUseId: tu.id,
|
||||
input: tu.input,
|
||||
name: tu.name,
|
||||
});
|
||||
}
|
||||
this._broadcast({ type: 'status', content: 'Using tools...' });
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// User message with tool_use_result — this is the tool output
|
||||
if (msg.type === 'user' && msg.message) {
|
||||
const content = msg.message.content;
|
||||
if (Array.isArray(content)) {
|
||||
for (const block of content) {
|
||||
if (block.type === 'tool_result' && block.tool_use_id) {
|
||||
const resultText = typeof block.content === 'string'
|
||||
? block.content
|
||||
: Array.isArray(block.content)
|
||||
? block.content.map(b => b.text || '').join('')
|
||||
: '';
|
||||
this._broadcast({
|
||||
type: 'tool_result',
|
||||
toolUseId: block.tool_use_id,
|
||||
content: resultText.slice(0, 2000) || '(done)',
|
||||
isError: block.is_error || false,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
this._broadcast({ type: 'status', content: 'Thinking...' });
|
||||
return;
|
||||
}
|
||||
|
||||
// Session result — turn is complete. Go idle. The process stays alive
|
||||
// in stream-json mode for multi-turn conversation.
|
||||
if (msg.type === 'result' && msg.subtype) {
|
||||
this._lastResult = msg;
|
||||
// Only use result text if nothing was shown via streaming yet
|
||||
const text = this._accumulatedText || '';
|
||||
this._broadcast({ type: 'stream_end', text, usage: msg.usage || null, final: true });
|
||||
// Show turn info: if the model stopped without using tools (num_turns=1),
|
||||
// the user knows the model chose not to edit
|
||||
if (msg.num_turns !== undefined) {
|
||||
const reason = msg.stop_reason || 'done';
|
||||
this._broadcast({
|
||||
type: 'status',
|
||||
content: msg.num_turns > 1
|
||||
? 'Completed (' + msg.num_turns + ' turns)'
|
||||
: 'Ready',
|
||||
});
|
||||
}
|
||||
this._accumulatedText = '';
|
||||
this._toolUses = [];
|
||||
this._streaming = false;
|
||||
this._onDidChangeState.fire('idle');
|
||||
return;
|
||||
}
|
||||
|
||||
if (isToolProgressMessage(msg)) {
|
||||
const vm = toViewModel(msg)[0];
|
||||
this._broadcast({
|
||||
type: 'tool_progress',
|
||||
toolUseId: vm.toolUseId,
|
||||
content: vm.content,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if (isStatusMessage(msg)) {
|
||||
const vm = toViewModel(msg)[0];
|
||||
this._broadcast({ type: 'status', content: vm.content });
|
||||
return;
|
||||
}
|
||||
|
||||
if (isRateLimitEvent(msg)) {
|
||||
const vm = toViewModel(msg)[0];
|
||||
this._broadcast({ type: 'rate_limit', message: vm.message });
|
||||
return;
|
||||
}
|
||||
|
||||
// Log unhandled message types for debugging
|
||||
if (msg.type && msg.type !== 'stream_event') {
|
||||
this._broadcast({ type: 'status', content: '[debug] unhandled: ' + msg.type });
|
||||
}
|
||||
}
|
||||
|
||||
_handleStreamEvent(msg) {
|
||||
const event = msg.event;
|
||||
if (!event) return;
|
||||
|
||||
switch (event.type) {
|
||||
case 'message_start':
|
||||
this._accumulatedText = '';
|
||||
this._thinkingTokens = 0;
|
||||
this._currentBlockType = null;
|
||||
if (!this._streaming) {
|
||||
this._streaming = true;
|
||||
this._toolUses = [];
|
||||
this._onDidChangeState.fire('streaming');
|
||||
}
|
||||
this._broadcast({ type: 'stream_start' });
|
||||
break;
|
||||
|
||||
case 'content_block_start':
|
||||
if (event.content_block) {
|
||||
this._currentBlockType = event.content_block.type;
|
||||
if (event.content_block.type === 'tool_use') {
|
||||
const tu = event.content_block;
|
||||
this._toolUses.push({ id: tu.id, name: tu.name, input: '' });
|
||||
const { toolDisplayName, toolIcon } = require('./messageParser');
|
||||
this._broadcast({
|
||||
type: 'tool_use',
|
||||
toolUse: {
|
||||
id: tu.id,
|
||||
name: tu.name,
|
||||
displayName: toolDisplayName(tu.name),
|
||||
icon: toolIcon(tu.name),
|
||||
inputPreview: '',
|
||||
input: tu.input || null,
|
||||
status: 'running',
|
||||
},
|
||||
});
|
||||
} else if (event.content_block.type === 'thinking') {
|
||||
this._thinkingTokens = 0;
|
||||
this._thinkingStartTime = Date.now();
|
||||
this._broadcast({ type: 'thinking_start' });
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
case 'content_block_delta':
|
||||
if (event.delta) {
|
||||
if (event.delta.type === 'text_delta' && event.delta.text) {
|
||||
this._accumulatedText += event.delta.text;
|
||||
this._broadcast({ type: 'stream_delta', text: this._accumulatedText });
|
||||
} else if (event.delta.type === 'thinking_delta') {
|
||||
this._thinkingTokens += (event.delta.thinking || '').length;
|
||||
const elapsed = Math.round((Date.now() - (this._thinkingStartTime || Date.now())) / 1000);
|
||||
this._broadcast({
|
||||
type: 'thinking_delta',
|
||||
tokens: this._thinkingTokens,
|
||||
elapsed,
|
||||
});
|
||||
} else if (event.delta.type === 'input_json_delta' && event.delta.partial_json) {
|
||||
const lastTool = this._toolUses[this._toolUses.length - 1];
|
||||
if (lastTool) {
|
||||
lastTool.input = (lastTool.input || '') + event.delta.partial_json;
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
case 'content_block_stop':
|
||||
if (this._currentBlockType === 'thinking') {
|
||||
this._broadcast({ type: 'thinking_end' });
|
||||
}
|
||||
this._currentBlockType = null;
|
||||
break;
|
||||
|
||||
case 'message_delta':
|
||||
break;
|
||||
|
||||
case 'message_stop':
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
dispose() {
|
||||
this.stopSession();
|
||||
this._onDidChangeState.dispose();
|
||||
}
|
||||
}
|
||||
|
||||
class OpenClaudeChatViewProvider {
|
||||
constructor(chatController) {
|
||||
this._chatController = chatController;
|
||||
this._webviewView = null;
|
||||
}
|
||||
|
||||
resolveWebviewView(webviewView, _context, _token) {
|
||||
this._webviewView = webviewView;
|
||||
const webview = webviewView.webview;
|
||||
webview.options = { enableScripts: true };
|
||||
|
||||
const registration = this._chatController.registerWebview(webview);
|
||||
webviewView.onDidDispose(() => {
|
||||
registration.dispose();
|
||||
if (this._webviewView === webviewView) this._webviewView = null;
|
||||
});
|
||||
|
||||
webview.html = this._getHtml(webview);
|
||||
this._attachMessageHandler(webview);
|
||||
}
|
||||
|
||||
_getHtml() {
|
||||
const nonce = crypto.randomBytes(16).toString('hex');
|
||||
return renderChatHtml({ nonce, platform: process.platform });
|
||||
}
|
||||
|
||||
_attachMessageHandler(webview) {
|
||||
webview.onDidReceiveMessage(async (msg) => {
|
||||
switch (msg.type) {
|
||||
case 'send_message':
|
||||
this._chatController.sendMessage(msg.text);
|
||||
break;
|
||||
case 'abort':
|
||||
this._chatController.abort();
|
||||
break;
|
||||
case 'new_session':
|
||||
this._chatController.stopSession();
|
||||
webview.postMessage({ type: 'session_cleared' });
|
||||
break;
|
||||
case 'resume_session':
|
||||
this._chatController.stopSession();
|
||||
webview.postMessage({ type: 'session_cleared' });
|
||||
await this._loadAndDisplaySession(webview, msg.sessionId);
|
||||
await this._chatController.startSession({ sessionId: msg.sessionId });
|
||||
break;
|
||||
case 'permission_response':
|
||||
this._chatController.sendPermissionResponse(msg.requestId, msg.action, msg.toolUseId);
|
||||
break;
|
||||
case 'copy_code':
|
||||
if (msg.text) await vscode.env.clipboard.writeText(msg.text);
|
||||
break;
|
||||
case 'open_file':
|
||||
if (msg.path) await openFileInEditor(msg.path);
|
||||
break;
|
||||
case 'request_sessions':
|
||||
await this._sendSessionList(webview);
|
||||
break;
|
||||
case 'restore_request':
|
||||
this._restoreMessages(webview);
|
||||
break;
|
||||
case 'webview_ready':
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async _sendSessionList(webview) {
|
||||
if (!this._chatController.sessionManager) return;
|
||||
try {
|
||||
const sessions = await this._chatController.sessionManager.listSessions();
|
||||
webview.postMessage({ type: 'session_list', sessions });
|
||||
} catch {
|
||||
webview.postMessage({ type: 'session_list', sessions: [] });
|
||||
}
|
||||
}
|
||||
|
||||
_restoreMessages(webview) {
|
||||
const messages = this._chatController.getMessages();
|
||||
if (messages.length > 0) {
|
||||
webview.postMessage({ type: 'restore_messages', messages });
|
||||
}
|
||||
}
|
||||
|
||||
async _loadAndDisplaySession(webview, sessionId) {
|
||||
if (!this._chatController.sessionManager) return;
|
||||
try {
|
||||
const messages = await this._chatController.sessionManager.loadSession(sessionId);
|
||||
if (messages && messages.length > 0) {
|
||||
this._chatController._messages = messages;
|
||||
webview.postMessage({ type: 'restore_messages', messages });
|
||||
}
|
||||
} catch { /* session may not be loadable */ }
|
||||
}
|
||||
}
|
||||
|
||||
class OpenClaudeChatPanelManager {
|
||||
constructor(chatController) {
|
||||
this._chatController = chatController;
|
||||
this._panel = null;
|
||||
}
|
||||
|
||||
openPanel() {
|
||||
if (this._panel) {
|
||||
this._panel.reveal();
|
||||
return;
|
||||
}
|
||||
|
||||
this._panel = vscode.window.createWebviewPanel(
|
||||
'openclaude.chatPanel',
|
||||
'OpenClaude Chat',
|
||||
vscode.ViewColumn.Beside,
|
||||
{
|
||||
enableScripts: true,
|
||||
retainContextWhenHidden: true,
|
||||
},
|
||||
);
|
||||
|
||||
const webview = this._panel.webview;
|
||||
const registration = this._chatController.registerWebview(webview);
|
||||
|
||||
this._panel.onDidDispose(() => {
|
||||
registration.dispose();
|
||||
this._panel = null;
|
||||
});
|
||||
|
||||
const nonce = crypto.randomBytes(16).toString('hex');
|
||||
webview.html = renderChatHtml({ nonce, platform: process.platform });
|
||||
this._attachMessageHandler(webview);
|
||||
|
||||
const messages = this._chatController.getMessages();
|
||||
if (messages.length > 0) {
|
||||
webview.postMessage({ type: 'restore_messages', messages });
|
||||
}
|
||||
}
|
||||
|
||||
_attachMessageHandler(webview) {
|
||||
webview.onDidReceiveMessage(async (msg) => {
|
||||
switch (msg.type) {
|
||||
case 'send_message':
|
||||
this._chatController.sendMessage(msg.text);
|
||||
break;
|
||||
case 'abort':
|
||||
this._chatController.abort();
|
||||
break;
|
||||
case 'new_session':
|
||||
this._chatController.stopSession();
|
||||
webview.postMessage({ type: 'session_cleared' });
|
||||
break;
|
||||
case 'resume_session':
|
||||
this._chatController.stopSession();
|
||||
webview.postMessage({ type: 'session_cleared' });
|
||||
await this._loadAndDisplaySession(webview, msg.sessionId);
|
||||
await this._chatController.startSession({ sessionId: msg.sessionId });
|
||||
break;
|
||||
case 'permission_response':
|
||||
this._chatController.sendPermissionResponse(msg.requestId, msg.action, msg.toolUseId);
|
||||
break;
|
||||
case 'copy_code':
|
||||
if (msg.text) await vscode.env.clipboard.writeText(msg.text);
|
||||
break;
|
||||
case 'open_file':
|
||||
if (msg.path) await openFileInEditor(msg.path);
|
||||
break;
|
||||
case 'request_sessions':
|
||||
await this._sendSessionList(webview);
|
||||
break;
|
||||
case 'restore_request':
|
||||
this._restoreMessages(webview);
|
||||
break;
|
||||
case 'webview_ready':
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async _sendSessionList(webview) {
|
||||
if (!this._chatController.sessionManager) return;
|
||||
try {
|
||||
const sessions = await this._chatController.sessionManager.listSessions();
|
||||
webview.postMessage({ type: 'session_list', sessions });
|
||||
} catch {
|
||||
webview.postMessage({ type: 'session_list', sessions: [] });
|
||||
}
|
||||
}
|
||||
|
||||
_restoreMessages(webview) {
|
||||
const messages = this._chatController.getMessages();
|
||||
if (messages.length > 0) {
|
||||
webview.postMessage({ type: 'restore_messages', messages });
|
||||
}
|
||||
}
|
||||
|
||||
async _loadAndDisplaySession(webview, sessionId) {
|
||||
if (!this._chatController.sessionManager) return;
|
||||
try {
|
||||
const messages = await this._chatController.sessionManager.loadSession(sessionId);
|
||||
if (messages && messages.length > 0) {
|
||||
this._chatController._messages = messages;
|
||||
webview.postMessage({ type: 'restore_messages', messages });
|
||||
}
|
||||
} catch { /* session may not be loadable */ }
|
||||
}
|
||||
|
||||
dispose() {
|
||||
if (this._panel) {
|
||||
this._panel.dispose();
|
||||
this._panel = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
ChatController,
|
||||
OpenClaudeChatViewProvider,
|
||||
OpenClaudeChatPanelManager,
|
||||
};
|
||||
1354
vscode-extension/openclaude-vscode/src/chat/chatRenderer.js
Normal file
1354
vscode-extension/openclaude-vscode/src/chat/chatRenderer.js
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,90 @@
|
||||
/**
|
||||
* diffController — provides a TextDocumentContentProvider for virtual
|
||||
* diff documents and helpers to open VS Code's native diff editor when
|
||||
* tool use involves file edits.
|
||||
*/
|
||||
|
||||
const vscode = require('vscode');
|
||||
|
||||
const SCHEME = 'openclaude-diff';
|
||||
let contentStore = new Map();
|
||||
|
||||
class DiffContentProvider {
|
||||
constructor() {
|
||||
this._onDidChange = new vscode.EventEmitter();
|
||||
this.onDidChange = this._onDidChange.event;
|
||||
}
|
||||
|
||||
provideTextDocumentContent(uri) {
|
||||
return contentStore.get(uri.toString()) || '';
|
||||
}
|
||||
|
||||
update(uri) {
|
||||
this._onDidChange.fire(uri);
|
||||
}
|
||||
|
||||
dispose() {
|
||||
this._onDidChange.dispose();
|
||||
}
|
||||
}
|
||||
|
||||
function storeContent(id, content) {
|
||||
const uri = vscode.Uri.parse(`${SCHEME}:/${id}`);
|
||||
contentStore.set(uri.toString(), content);
|
||||
return uri;
|
||||
}
|
||||
|
||||
function clearContent(id) {
|
||||
const uri = vscode.Uri.parse(`${SCHEME}:/${id}`);
|
||||
contentStore.delete(uri.toString());
|
||||
}
|
||||
|
||||
function clearAll() {
|
||||
contentStore.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* Opens a diff view between original and modified content.
|
||||
* @param {object} opts
|
||||
* @param {string} opts.filePath - Display path (for the title)
|
||||
* @param {string} opts.original - Original file content
|
||||
* @param {string} opts.modified - Modified file content
|
||||
* @param {string} [opts.toolUseId] - Unique ID for this diff
|
||||
*/
|
||||
async function openDiff({ filePath, original, modified, toolUseId }) {
|
||||
const id = toolUseId || Math.random().toString(36).slice(2, 10);
|
||||
const originalUri = storeContent(`original-${id}`, original || '');
|
||||
const modifiedUri = storeContent(`modified-${id}`, modified || '');
|
||||
const shortName = filePath ? filePath.split(/[\\/]/).pop() : 'file';
|
||||
const title = `${shortName} (OpenClaude Diff)`;
|
||||
|
||||
await vscode.commands.executeCommand('vscode.diff', originalUri, modifiedUri, title);
|
||||
}
|
||||
|
||||
/**
|
||||
* Opens a diff between a real file on disk and modified content from
|
||||
* a tool use result.
|
||||
* @param {object} opts
|
||||
* @param {string} opts.filePath - Absolute path to the real file
|
||||
* @param {string} opts.modified - Modified content
|
||||
* @param {string} [opts.toolUseId]
|
||||
*/
|
||||
async function openFileDiff({ filePath, modified, toolUseId }) {
|
||||
const id = toolUseId || Math.random().toString(36).slice(2, 10);
|
||||
const fileUri = vscode.Uri.file(filePath);
|
||||
const modifiedUri = storeContent(`modified-${id}`, modified || '');
|
||||
const shortName = filePath.split(/[\\/]/).pop() || 'file';
|
||||
const title = `${shortName} (OpenClaude Edit)`;
|
||||
|
||||
await vscode.commands.executeCommand('vscode.diff', fileUri, modifiedUri, title);
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
DiffContentProvider,
|
||||
SCHEME,
|
||||
openDiff,
|
||||
openFileDiff,
|
||||
storeContent,
|
||||
clearContent,
|
||||
clearAll,
|
||||
};
|
||||
177
vscode-extension/openclaude-vscode/src/chat/messageParser.js
Normal file
177
vscode-extension/openclaude-vscode/src/chat/messageParser.js
Normal file
@@ -0,0 +1,177 @@
|
||||
/**
|
||||
* messageParser — transforms raw SDK messages from the CLI into view-model
|
||||
* objects that the chat renderer can display.
|
||||
*/
|
||||
|
||||
const {
|
||||
isAssistantMessage,
|
||||
isPartialMessage,
|
||||
isResultMessage,
|
||||
isControlRequest,
|
||||
isStatusMessage,
|
||||
isToolProgressMessage,
|
||||
isSessionStateChanged,
|
||||
isRateLimitEvent,
|
||||
getTextContent,
|
||||
getToolUseBlocks,
|
||||
} = require('./protocol');
|
||||
|
||||
function parseToolInput(input) {
|
||||
if (!input || typeof input !== 'object') return String(input ?? '');
|
||||
if (input.command) return input.command;
|
||||
if (input.file_path || input.path) return input.file_path || input.path;
|
||||
if (input.query) return input.query;
|
||||
try { return JSON.stringify(input, null, 2); } catch { return String(input); }
|
||||
}
|
||||
|
||||
function toolDisplayName(name) {
|
||||
const map = {
|
||||
Bash: 'Terminal',
|
||||
Read: 'Read File',
|
||||
Write: 'Write File',
|
||||
Edit: 'Edit File',
|
||||
MultiEdit: 'Multi Edit',
|
||||
Glob: 'Find Files',
|
||||
Grep: 'Search',
|
||||
LS: 'List Directory',
|
||||
WebFetch: 'Web Fetch',
|
||||
WebSearch: 'Web Search',
|
||||
TodoRead: 'Read Todos',
|
||||
TodoWrite: 'Write Todos',
|
||||
Task: 'Sub-agent',
|
||||
};
|
||||
return map[name] || name || 'Tool';
|
||||
}
|
||||
|
||||
function toolIcon(name) {
|
||||
const map = {
|
||||
Bash: '\u{1F4BB}',
|
||||
Read: '\u{1F4C4}',
|
||||
Write: '\u{270F}\uFE0F',
|
||||
Edit: '\u{270F}\uFE0F',
|
||||
MultiEdit: '\u{270F}\uFE0F',
|
||||
Glob: '\u{1F50D}',
|
||||
Grep: '\u{1F50E}',
|
||||
LS: '\u{1F4C2}',
|
||||
WebFetch: '\u{1F310}',
|
||||
WebSearch: '\u{1F310}',
|
||||
Task: '\u{1F916}',
|
||||
};
|
||||
return map[name] || '\u{1F527}';
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts an SDK message into one or more view-model entries for the chat UI.
|
||||
* Returns an array so partial messages can update in-place while final messages
|
||||
* produce a finalized entry.
|
||||
*/
|
||||
function toViewModel(msg) {
|
||||
if (isAssistantMessage(msg)) {
|
||||
return [{
|
||||
kind: 'assistant',
|
||||
id: msg.id || msg.message?.id || null,
|
||||
text: getTextContent(msg.message || msg),
|
||||
toolUses: getToolUseBlocks(msg.message || msg).map(tu => ({
|
||||
id: tu.id,
|
||||
name: tu.name,
|
||||
displayName: toolDisplayName(tu.name),
|
||||
icon: toolIcon(tu.name),
|
||||
inputPreview: parseToolInput(tu.input),
|
||||
input: tu.input,
|
||||
status: 'complete',
|
||||
})),
|
||||
model: msg.model || null,
|
||||
stopReason: msg.stop_reason || null,
|
||||
usage: msg.usage || null,
|
||||
final: true,
|
||||
}];
|
||||
}
|
||||
|
||||
if (isPartialMessage(msg)) {
|
||||
const inner = msg.message || msg;
|
||||
return [{
|
||||
kind: 'assistant_partial',
|
||||
id: inner.id || null,
|
||||
text: getTextContent(inner),
|
||||
toolUses: getToolUseBlocks(inner).map(tu => ({
|
||||
id: tu.id,
|
||||
name: tu.name,
|
||||
displayName: toolDisplayName(tu.name),
|
||||
icon: toolIcon(tu.name),
|
||||
inputPreview: parseToolInput(tu.input),
|
||||
input: tu.input,
|
||||
status: 'running',
|
||||
})),
|
||||
final: false,
|
||||
}];
|
||||
}
|
||||
|
||||
if (isResultMessage(msg)) {
|
||||
return [{
|
||||
kind: 'tool_result',
|
||||
toolUseId: msg.tool_use_id,
|
||||
content: typeof msg.content === 'string'
|
||||
? msg.content
|
||||
: Array.isArray(msg.content)
|
||||
? msg.content.map(b => b.text || '').join('')
|
||||
: '',
|
||||
isError: msg.is_error || false,
|
||||
}];
|
||||
}
|
||||
|
||||
if (isControlRequest(msg)) {
|
||||
return [{
|
||||
kind: 'permission_request',
|
||||
requestId: msg.request_id || msg.id,
|
||||
toolName: msg.tool_name || msg.tool?.name || 'Unknown',
|
||||
displayName: toolDisplayName(msg.tool_name || msg.tool?.name),
|
||||
description: msg.description || msg.tool?.description || '',
|
||||
input: msg.tool_input || msg.input || null,
|
||||
inputPreview: parseToolInput(msg.tool_input || msg.input),
|
||||
}];
|
||||
}
|
||||
|
||||
if (isToolProgressMessage(msg)) {
|
||||
return [{
|
||||
kind: 'tool_progress',
|
||||
toolUseId: msg.tool_use_id,
|
||||
content: msg.content || msg.progress || '',
|
||||
}];
|
||||
}
|
||||
|
||||
if (isStatusMessage(msg)) {
|
||||
return [{
|
||||
kind: 'status',
|
||||
content: msg.content || msg.message || '',
|
||||
}];
|
||||
}
|
||||
|
||||
if (isSessionStateChanged(msg)) {
|
||||
return [{
|
||||
kind: 'session_state',
|
||||
sessionId: msg.session_id || null,
|
||||
state: msg.state || null,
|
||||
}];
|
||||
}
|
||||
|
||||
if (isRateLimitEvent(msg)) {
|
||||
return [{
|
||||
kind: 'rate_limit',
|
||||
retryAfter: msg.retry_after || null,
|
||||
message: msg.message || 'Rate limited',
|
||||
}];
|
||||
}
|
||||
|
||||
return [{
|
||||
kind: 'unknown',
|
||||
type: msg.type,
|
||||
raw: msg,
|
||||
}];
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
toViewModel,
|
||||
toolDisplayName,
|
||||
toolIcon,
|
||||
parseToolInput,
|
||||
};
|
||||
194
vscode-extension/openclaude-vscode/src/chat/processManager.js
Normal file
194
vscode-extension/openclaude-vscode/src/chat/processManager.js
Normal file
@@ -0,0 +1,194 @@
|
||||
/**
|
||||
* ProcessManager — spawns OpenClaude in print/SDK mode and manages the
|
||||
* NDJSON stdin/stdout lifecycle.
|
||||
*
|
||||
* Usage:
|
||||
* const pm = new ProcessManager({ command, cwd, env });
|
||||
* pm.onMessage(msg => { ... });
|
||||
* pm.onError(err => { ... });
|
||||
* pm.onExit(code => { ... });
|
||||
* await pm.start();
|
||||
* pm.sendUserMessage('Hello');
|
||||
* pm.abort(); // SIGINT (graceful)
|
||||
* pm.kill(); // SIGTERM (hard)
|
||||
* pm.dispose();
|
||||
*/
|
||||
|
||||
const { spawn } = require('child_process');
|
||||
const vscode = require('vscode');
|
||||
const { parseStdoutLine, serializeStdinMessage, buildUserMessage, buildControlResponse } = require('./protocol');
|
||||
|
||||
class ProcessManager {
|
||||
/**
|
||||
* @param {object} opts
|
||||
* @param {string} opts.command - The openclaude binary (e.g. 'openclaude')
|
||||
* @param {string} [opts.cwd] - Working directory
|
||||
* @param {Record<string,string>} [opts.env] - Extra env vars
|
||||
* @param {string} [opts.sessionId] - Session to resume
|
||||
* @param {boolean} [opts.continueSession] - Use --continue instead of --resume
|
||||
* @param {string} [opts.model] - Model override
|
||||
* @param {string[]} [opts.extraArgs] - Additional CLI flags
|
||||
*/
|
||||
constructor(opts) {
|
||||
this._command = opts.command || 'openclaude';
|
||||
this._cwd = opts.cwd || undefined;
|
||||
this._env = opts.env || {};
|
||||
this._sessionId = opts.sessionId || null;
|
||||
this._continueSession = opts.continueSession || false;
|
||||
this._model = opts.model || null;
|
||||
this._permissionMode = opts.permissionMode || 'acceptEdits';
|
||||
this._extraArgs = opts.extraArgs || [];
|
||||
this._process = null;
|
||||
this._buffer = '';
|
||||
this._disposed = false;
|
||||
|
||||
this._onMessageEmitter = new vscode.EventEmitter();
|
||||
this._onErrorEmitter = new vscode.EventEmitter();
|
||||
this._onExitEmitter = new vscode.EventEmitter();
|
||||
this.onMessage = this._onMessageEmitter.event;
|
||||
this.onError = this._onErrorEmitter.event;
|
||||
this.onExit = this._onExitEmitter.event;
|
||||
}
|
||||
|
||||
get running() {
|
||||
return this._process !== null && !this._process.killed;
|
||||
}
|
||||
|
||||
get sessionId() {
|
||||
return this._sessionId;
|
||||
}
|
||||
|
||||
start() {
|
||||
if (this._disposed) throw new Error('ProcessManager is disposed');
|
||||
if (this._process) throw new Error('Process already started');
|
||||
|
||||
const args = [
|
||||
'--print',
|
||||
'--verbose',
|
||||
'--input-format=stream-json',
|
||||
'--output-format=stream-json',
|
||||
'--include-partial-messages',
|
||||
'--permission-mode', this._permissionMode || 'acceptEdits',
|
||||
];
|
||||
|
||||
if (this._sessionId) {
|
||||
args.push('--resume', this._sessionId);
|
||||
} else if (this._continueSession) {
|
||||
args.push('--continue');
|
||||
}
|
||||
|
||||
if (this._model) {
|
||||
args.push('--model', this._model);
|
||||
}
|
||||
|
||||
args.push(...this._extraArgs);
|
||||
|
||||
const spawnEnv = { ...process.env, ...this._env };
|
||||
const isWin = process.platform === 'win32';
|
||||
|
||||
if (isWin) {
|
||||
// On Windows, npm global installs create .cmd shims that spawn()
|
||||
// cannot find without a shell. Build one command string so the
|
||||
// deprecation warning about unsanitised args does not fire.
|
||||
const cmdLine = [this._command, ...args].join(' ');
|
||||
this._process = spawn(cmdLine, [], {
|
||||
cwd: this._cwd,
|
||||
env: spawnEnv,
|
||||
stdio: ['pipe', 'pipe', 'pipe'],
|
||||
shell: true,
|
||||
windowsHide: true,
|
||||
});
|
||||
} else {
|
||||
this._process = spawn(this._command, args, {
|
||||
cwd: this._cwd,
|
||||
env: spawnEnv,
|
||||
stdio: ['pipe', 'pipe', 'pipe'],
|
||||
windowsHide: true,
|
||||
});
|
||||
}
|
||||
|
||||
this._process.stdout.setEncoding('utf8');
|
||||
this._process.stderr.setEncoding('utf8');
|
||||
|
||||
this._process.stdout.on('data', (chunk) => this._onData(chunk));
|
||||
this._process.stderr.on('data', (chunk) => this._onStderr(chunk));
|
||||
this._process.on('error', (err) => this._onErrorEmitter.fire(err));
|
||||
this._process.on('close', (code, signal) => {
|
||||
this._process = null;
|
||||
this._onExitEmitter.fire({ code, signal });
|
||||
});
|
||||
}
|
||||
|
||||
_onData(chunk) {
|
||||
this._buffer += chunk;
|
||||
const lines = this._buffer.split('\n');
|
||||
this._buffer = lines.pop() || '';
|
||||
|
||||
for (const line of lines) {
|
||||
const msg = parseStdoutLine(line);
|
||||
if (msg) {
|
||||
this._extractSessionId(msg);
|
||||
this._onMessageEmitter.fire(msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_extractSessionId(msg) {
|
||||
if (msg.session_id && !this._sessionId) {
|
||||
this._sessionId = msg.session_id;
|
||||
}
|
||||
}
|
||||
|
||||
_onStderr(chunk) {
|
||||
const trimmed = chunk.trim();
|
||||
if (!trimmed) return;
|
||||
// Suppress common non-error noise from the CLI (deprecation warnings, etc.)
|
||||
if (/^\(node:\d+\)|^DeprecationWarning|^ExperimentalWarning/i.test(trimmed)) return;
|
||||
this._onErrorEmitter.fire(new Error(trimmed));
|
||||
}
|
||||
|
||||
sendUserMessage(text) {
|
||||
this._write(buildUserMessage(text));
|
||||
}
|
||||
|
||||
sendControlResponse(requestId, result) {
|
||||
this._write(buildControlResponse(requestId, result));
|
||||
}
|
||||
|
||||
write(msg) {
|
||||
if (!this._process || !this._process.stdin.writable) {
|
||||
throw new Error('Process is not running');
|
||||
}
|
||||
this._process.stdin.write(serializeStdinMessage(msg));
|
||||
}
|
||||
|
||||
_write(msg) {
|
||||
this.write(msg);
|
||||
}
|
||||
|
||||
abort() {
|
||||
if (this._process && !this._process.killed) {
|
||||
if (process.platform === 'win32') {
|
||||
this._process.kill('SIGINT');
|
||||
} else {
|
||||
this._process.kill('SIGINT');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kill() {
|
||||
if (this._process && !this._process.killed) {
|
||||
this._process.kill('SIGTERM');
|
||||
}
|
||||
}
|
||||
|
||||
dispose() {
|
||||
this._disposed = true;
|
||||
this.kill();
|
||||
this._onMessageEmitter.dispose();
|
||||
this._onErrorEmitter.dispose();
|
||||
this._onExitEmitter.dispose();
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = { ProcessManager };
|
||||
186
vscode-extension/openclaude-vscode/src/chat/protocol.js
Normal file
186
vscode-extension/openclaude-vscode/src/chat/protocol.js
Normal file
@@ -0,0 +1,186 @@
|
||||
/**
|
||||
* NDJSON protocol helpers and message type constants for the OpenClaude
|
||||
* stream-json SDK wire format.
|
||||
*
|
||||
* The extension spawns `openclaude --print --input-format=stream-json
|
||||
* --output-format=stream-json` and speaks NDJSON over stdin/stdout.
|
||||
* This module provides lightweight parsing, serialization, and type guards
|
||||
* so the rest of the extension never touches raw JSON strings.
|
||||
*/
|
||||
|
||||
const MESSAGE_TYPES = {
|
||||
ASSISTANT: 'assistant',
|
||||
USER: 'user',
|
||||
USER_REPLAY: 'user_replay',
|
||||
RESULT: 'result',
|
||||
SYSTEM: 'system',
|
||||
STREAM_EVENT: 'stream_event',
|
||||
PARTIAL: 'partial',
|
||||
COMPACT_BOUNDARY: 'compact_boundary',
|
||||
STATUS: 'status',
|
||||
API_RETRY: 'api_retry',
|
||||
LOCAL_COMMAND_OUTPUT: 'local_command_output',
|
||||
HOOK_STARTED: 'hook_started',
|
||||
HOOK_PROGRESS: 'hook_progress',
|
||||
HOOK_RESPONSE: 'hook_response',
|
||||
TOOL_PROGRESS: 'tool_progress',
|
||||
AUTH_STATUS: 'auth_status',
|
||||
TASK_NOTIFICATION: 'task_notification',
|
||||
TASK_STARTED: 'task_started',
|
||||
TASK_PROGRESS: 'task_progress',
|
||||
SESSION_STATE_CHANGED: 'session_state_changed',
|
||||
FILES_PERSISTED: 'files_persisted',
|
||||
TOOL_USE_SUMMARY: 'tool_use_summary',
|
||||
RATE_LIMIT: 'rate_limit',
|
||||
ELICITATION_COMPLETE: 'elicitation_complete',
|
||||
PROMPT_SUGGESTION: 'prompt_suggestion',
|
||||
STREAMLINED_TEXT: 'streamlined_text',
|
||||
STREAMLINED_TOOL_USE_SUMMARY: 'streamlined_tool_use_summary',
|
||||
POST_TURN_SUMMARY: 'post_turn_summary',
|
||||
CONTROL_RESPONSE: 'control_response',
|
||||
CONTROL_REQUEST: 'control_request',
|
||||
};
|
||||
|
||||
function parseStdoutLine(line) {
|
||||
const trimmed = (line || '').trim();
|
||||
if (!trimmed) return null;
|
||||
try {
|
||||
return JSON.parse(trimmed);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function serializeStdinMessage(msg) {
|
||||
return JSON.stringify(msg) + '\n';
|
||||
}
|
||||
|
||||
function buildUserMessage(text) {
|
||||
return {
|
||||
type: 'user',
|
||||
message: {
|
||||
role: 'user',
|
||||
content: text,
|
||||
},
|
||||
parent_tool_use_id: null,
|
||||
};
|
||||
}
|
||||
|
||||
function buildControlResponse(requestId, result) {
|
||||
return {
|
||||
type: 'control_response',
|
||||
response: {
|
||||
subtype: 'success',
|
||||
request_id: requestId,
|
||||
response: result || {},
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function isAssistantMessage(msg) {
|
||||
return msg && msg.type === MESSAGE_TYPES.ASSISTANT;
|
||||
}
|
||||
|
||||
function isPartialMessage(msg) {
|
||||
return msg && msg.type === MESSAGE_TYPES.PARTIAL;
|
||||
}
|
||||
|
||||
function isStreamEvent(msg) {
|
||||
return msg && msg.type === MESSAGE_TYPES.STREAM_EVENT && msg.event;
|
||||
}
|
||||
|
||||
function isContentBlockDelta(msg) {
|
||||
return isStreamEvent(msg) && msg.event.type === 'content_block_delta';
|
||||
}
|
||||
|
||||
function isContentBlockStart(msg) {
|
||||
return isStreamEvent(msg) && msg.event.type === 'content_block_start';
|
||||
}
|
||||
|
||||
function isMessageStart(msg) {
|
||||
return isStreamEvent(msg) && msg.event.type === 'message_start';
|
||||
}
|
||||
|
||||
function isMessageStop(msg) {
|
||||
return isStreamEvent(msg) && msg.event.type === 'message_stop';
|
||||
}
|
||||
|
||||
function isMessageDelta(msg) {
|
||||
return isStreamEvent(msg) && msg.event.type === 'message_delta';
|
||||
}
|
||||
|
||||
function isResultMessage(msg) {
|
||||
return msg && msg.type === MESSAGE_TYPES.RESULT;
|
||||
}
|
||||
|
||||
function isToolUse(block) {
|
||||
return block && block.type === 'tool_use';
|
||||
}
|
||||
|
||||
function isTextBlock(block) {
|
||||
return block && block.type === 'text';
|
||||
}
|
||||
|
||||
function isThinkingBlock(block) {
|
||||
return block && block.type === 'thinking';
|
||||
}
|
||||
|
||||
function isControlRequest(msg) {
|
||||
return msg && msg.type === MESSAGE_TYPES.CONTROL_REQUEST;
|
||||
}
|
||||
|
||||
function isStatusMessage(msg) {
|
||||
return msg && msg.type === MESSAGE_TYPES.STATUS;
|
||||
}
|
||||
|
||||
function isToolProgressMessage(msg) {
|
||||
return msg && msg.type === MESSAGE_TYPES.TOOL_PROGRESS;
|
||||
}
|
||||
|
||||
function isSessionStateChanged(msg) {
|
||||
return msg && msg.type === MESSAGE_TYPES.SESSION_STATE_CHANGED;
|
||||
}
|
||||
|
||||
function isRateLimitEvent(msg) {
|
||||
return msg && msg.type === MESSAGE_TYPES.RATE_LIMIT;
|
||||
}
|
||||
|
||||
function getTextContent(message) {
|
||||
if (!message || !Array.isArray(message.content)) return '';
|
||||
return message.content
|
||||
.filter(b => b.type === 'text')
|
||||
.map(b => b.text || '')
|
||||
.join('');
|
||||
}
|
||||
|
||||
function getToolUseBlocks(message) {
|
||||
if (!message || !Array.isArray(message.content)) return [];
|
||||
return message.content.filter(b => b.type === 'tool_use');
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
MESSAGE_TYPES,
|
||||
parseStdoutLine,
|
||||
serializeStdinMessage,
|
||||
buildUserMessage,
|
||||
buildControlResponse,
|
||||
isAssistantMessage,
|
||||
isPartialMessage,
|
||||
isStreamEvent,
|
||||
isContentBlockDelta,
|
||||
isContentBlockStart,
|
||||
isMessageStart,
|
||||
isMessageStop,
|
||||
isMessageDelta,
|
||||
isResultMessage,
|
||||
isToolUse,
|
||||
isTextBlock,
|
||||
isThinkingBlock,
|
||||
isControlRequest,
|
||||
isStatusMessage,
|
||||
isToolProgressMessage,
|
||||
isSessionStateChanged,
|
||||
isRateLimitEvent,
|
||||
getTextContent,
|
||||
getToolUseBlocks,
|
||||
};
|
||||
282
vscode-extension/openclaude-vscode/src/chat/sessionManager.js
Normal file
282
vscode-extension/openclaude-vscode/src/chat/sessionManager.js
Normal file
@@ -0,0 +1,282 @@
|
||||
/**
|
||||
* sessionManager — reads JSONL session history from disk, lists sessions,
|
||||
* and provides metadata for the session list UI.
|
||||
*
|
||||
* Session files live under:
|
||||
* ~/.openclaude/projects/<sanitized-cwd>/<sessionId>.jsonl
|
||||
*
|
||||
* Falls back to ~/.claude/projects/ for legacy installs.
|
||||
*/
|
||||
|
||||
const fs = require('fs');
|
||||
const fsp = require('fs/promises');
|
||||
const path = require('path');
|
||||
const os = require('os');
|
||||
const crypto = require('crypto');
|
||||
|
||||
const MAX_SANITIZED_LENGTH = 80;
|
||||
|
||||
function sanitizePath(name) {
|
||||
const sanitized = name.replace(/[^a-zA-Z0-9]/g, '-');
|
||||
if (sanitized.length <= MAX_SANITIZED_LENGTH) return sanitized;
|
||||
const hash = simpleHash(name);
|
||||
return sanitized.slice(0, MAX_SANITIZED_LENGTH) + '-' + hash;
|
||||
}
|
||||
|
||||
function simpleHash(str) {
|
||||
let h = 0;
|
||||
for (let i = 0; i < str.length; i++) {
|
||||
h = ((h << 5) - h + str.charCodeAt(i)) | 0;
|
||||
}
|
||||
return Math.abs(h).toString(36);
|
||||
}
|
||||
|
||||
function resolveConfigDir() {
|
||||
const envDir = process.env.CLAUDE_CONFIG_DIR;
|
||||
if (envDir) return envDir;
|
||||
const home = os.homedir();
|
||||
const openClaudeDir = path.join(home, '.openclaude');
|
||||
const legacyDir = path.join(home, '.claude');
|
||||
if (!fs.existsSync(openClaudeDir) && fs.existsSync(legacyDir)) {
|
||||
return legacyDir;
|
||||
}
|
||||
return openClaudeDir;
|
||||
}
|
||||
|
||||
function getProjectsDir() {
|
||||
return path.join(resolveConfigDir(), 'projects');
|
||||
}
|
||||
|
||||
function getProjectDir(cwd) {
|
||||
return path.join(getProjectsDir(), sanitizePath(cwd));
|
||||
}
|
||||
|
||||
class SessionManager {
|
||||
constructor() {
|
||||
this._cwd = null;
|
||||
}
|
||||
|
||||
setCwd(cwd) {
|
||||
this._cwd = cwd;
|
||||
}
|
||||
|
||||
async listSessions() {
|
||||
const projectDir = this._cwd
|
||||
? getProjectDir(this._cwd)
|
||||
: null;
|
||||
|
||||
const dirs = projectDir
|
||||
? [projectDir]
|
||||
: await this._allProjectDirs();
|
||||
|
||||
const sessions = [];
|
||||
for (const dir of dirs) {
|
||||
const items = await this._readSessionDir(dir);
|
||||
sessions.push(...items);
|
||||
}
|
||||
|
||||
sessions.sort((a, b) => b.timestamp - a.timestamp);
|
||||
return sessions;
|
||||
}
|
||||
|
||||
async _allProjectDirs() {
|
||||
const base = getProjectsDir();
|
||||
if (!fs.existsSync(base)) return [];
|
||||
try {
|
||||
const entries = await fsp.readdir(base, { withFileTypes: true });
|
||||
return entries
|
||||
.filter(e => e.isDirectory())
|
||||
.map(e => path.join(base, e.name));
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
async _readSessionDir(dir) {
|
||||
if (!fs.existsSync(dir)) return [];
|
||||
try {
|
||||
const files = await fsp.readdir(dir);
|
||||
const jsonlFiles = files.filter(f => f.endsWith('.jsonl'));
|
||||
const results = [];
|
||||
|
||||
for (const file of jsonlFiles) {
|
||||
const filePath = path.join(dir, file);
|
||||
try {
|
||||
const meta = await this._extractSessionMeta(filePath);
|
||||
if (meta) results.push(meta);
|
||||
} catch { /* skip unreadable */ }
|
||||
}
|
||||
|
||||
return results;
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
async _extractSessionMeta(filePath) {
|
||||
const sessionId = path.basename(filePath, '.jsonl');
|
||||
const stat = await fsp.stat(filePath);
|
||||
// Read a larger head because JSONL files often start with system/snapshot
|
||||
// entries before the first user message.
|
||||
const head = await this._readHead(filePath, 65536);
|
||||
const lines = head.split('\n').filter(Boolean);
|
||||
|
||||
let title = null;
|
||||
let preview = '';
|
||||
let timestamp = stat.mtimeMs;
|
||||
let firstTimestamp = null;
|
||||
|
||||
for (const line of lines) {
|
||||
try {
|
||||
const entry = JSON.parse(line);
|
||||
|
||||
if (!preview && entry.type === 'user' && entry.message) {
|
||||
const content = entry.message.content;
|
||||
if (typeof content === 'string') {
|
||||
preview = content.slice(0, 120);
|
||||
} else if (Array.isArray(content)) {
|
||||
const textBlock = content.find(b => b.type === 'text');
|
||||
preview = textBlock ? (textBlock.text || '').slice(0, 120) : '';
|
||||
}
|
||||
}
|
||||
|
||||
if (entry.type === 'custom-title' || entry.type === 'session-title') {
|
||||
title = entry.title || entry.name || null;
|
||||
}
|
||||
|
||||
if (entry.type === 'summary' && entry.summary && !title) {
|
||||
title = entry.summary;
|
||||
}
|
||||
|
||||
if (entry.timestamp && !firstTimestamp) {
|
||||
const t = typeof entry.timestamp === 'number'
|
||||
? entry.timestamp
|
||||
: new Date(entry.timestamp).getTime();
|
||||
if (t && !isNaN(t)) firstTimestamp = t;
|
||||
}
|
||||
} catch { /* skip bad line */ }
|
||||
}
|
||||
|
||||
if (firstTimestamp) timestamp = firstTimestamp;
|
||||
const timeLabel = formatRelativeTime(timestamp);
|
||||
|
||||
return {
|
||||
id: sessionId,
|
||||
title: title || preview.slice(0, 60) || 'Untitled session',
|
||||
preview: preview || '',
|
||||
timestamp,
|
||||
timeLabel,
|
||||
filePath,
|
||||
};
|
||||
}
|
||||
|
||||
async loadSession(sessionId) {
|
||||
const projectDir = this._cwd ? getProjectDir(this._cwd) : null;
|
||||
const dirs = projectDir ? [projectDir] : await this._allProjectDirs();
|
||||
|
||||
for (const dir of dirs) {
|
||||
const filePath = path.join(dir, `${sessionId}.jsonl`);
|
||||
if (fs.existsSync(filePath)) {
|
||||
return this._parseSessionFile(filePath);
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
async _parseSessionFile(filePath) {
|
||||
const content = await fsp.readFile(filePath, 'utf8');
|
||||
const lines = content.split('\n').filter(Boolean);
|
||||
const messages = [];
|
||||
const toolResults = new Map();
|
||||
|
||||
// First pass: collect tool results from user messages
|
||||
for (const line of lines) {
|
||||
try {
|
||||
const entry = JSON.parse(line);
|
||||
if (entry.type === 'user' && entry.message && Array.isArray(entry.message.content)) {
|
||||
for (const block of entry.message.content) {
|
||||
if (block.type === 'tool_result' && block.tool_use_id) {
|
||||
const resultText = typeof block.content === 'string'
|
||||
? block.content
|
||||
: Array.isArray(block.content)
|
||||
? block.content.map(b => b.text || '').join('')
|
||||
: '';
|
||||
toolResults.set(String(block.tool_use_id), {
|
||||
content: resultText.slice(0, 2000),
|
||||
isError: block.is_error || false,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch { /* skip */ }
|
||||
}
|
||||
|
||||
// Second pass: build messages with tool use details
|
||||
for (const line of lines) {
|
||||
try {
|
||||
const entry = JSON.parse(line);
|
||||
if (entry.type === 'user' && entry.message) {
|
||||
const c = entry.message.content;
|
||||
// Skip tool result messages (they're user messages with tool_result blocks)
|
||||
if (Array.isArray(c) && c.length > 0 && c[0].type === 'tool_result') continue;
|
||||
const text = typeof c === 'string'
|
||||
? c
|
||||
: Array.isArray(c)
|
||||
? c.filter(b => b.type === 'text').map(b => b.text).join('')
|
||||
: '';
|
||||
if (text) messages.push({ role: 'user', text });
|
||||
} else if (entry.type === 'assistant' && entry.message) {
|
||||
const c = entry.message.content;
|
||||
const text = typeof c === 'string'
|
||||
? c
|
||||
: Array.isArray(c)
|
||||
? c.filter(b => b.type === 'text').map(b => b.text).join('')
|
||||
: '';
|
||||
const toolUses = Array.isArray(c)
|
||||
? c.filter(b => b.type === 'tool_use').map(tu => {
|
||||
const result = toolResults.get(String(tu.id));
|
||||
return {
|
||||
id: tu.id,
|
||||
name: tu.name,
|
||||
input: tu.input || null,
|
||||
status: result ? (result.isError ? 'error' : 'complete') : 'complete',
|
||||
result: result ? result.content : null,
|
||||
isError: result ? result.isError : false,
|
||||
};
|
||||
})
|
||||
: [];
|
||||
messages.push({ role: 'assistant', text, toolUses });
|
||||
}
|
||||
} catch { /* skip */ }
|
||||
}
|
||||
|
||||
return messages;
|
||||
}
|
||||
|
||||
async _readHead(filePath, bytes) {
|
||||
const fd = await fsp.open(filePath, 'r');
|
||||
try {
|
||||
const buf = Buffer.alloc(bytes);
|
||||
const { bytesRead } = await fd.read(buf, 0, bytes, 0);
|
||||
return buf.slice(0, bytesRead).toString('utf8');
|
||||
} finally {
|
||||
await fd.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function formatRelativeTime(ts) {
|
||||
const now = Date.now();
|
||||
const diff = now - ts;
|
||||
const mins = Math.floor(diff / 60000);
|
||||
if (mins < 1) return 'Just now';
|
||||
if (mins < 60) return `${mins}m ago`;
|
||||
const hours = Math.floor(mins / 60);
|
||||
if (hours < 24) return `${hours}h ago`;
|
||||
const days = Math.floor(hours / 24);
|
||||
if (days < 7) return `${days}d ago`;
|
||||
const date = new Date(ts);
|
||||
return date.toLocaleDateString();
|
||||
}
|
||||
|
||||
module.exports = { SessionManager };
|
||||
@@ -12,6 +12,9 @@ const {
|
||||
resolveCommandCheckPath,
|
||||
} = require('./state');
|
||||
const { buildControlCenterViewModel } = require('./presentation');
|
||||
const { ChatController, OpenClaudeChatViewProvider, OpenClaudeChatPanelManager } = require('./chat/chatProvider');
|
||||
const { SessionManager } = require('./chat/sessionManager');
|
||||
const { DiffContentProvider, SCHEME: DIFF_SCHEME } = require('./chat/diffController');
|
||||
|
||||
const OPENCLAUDE_REPO_URL = 'https://github.com/Gitlawb/openclaude';
|
||||
const OPENCLAUDE_SETUP_URL = 'https://github.com/Gitlawb/openclaude/blob/main/README.md#quick-start';
|
||||
@@ -1041,11 +1044,58 @@ class OpenClaudeControlCenterProvider {
|
||||
* @param {vscode.ExtensionContext} context
|
||||
*/
|
||||
function activate(context) {
|
||||
// ── Control Center (existing) ──
|
||||
const provider = new OpenClaudeControlCenterProvider();
|
||||
const refreshProvider = () => {
|
||||
void provider.refresh();
|
||||
};
|
||||
|
||||
// ── Chat system ──
|
||||
const sessionManager = new SessionManager();
|
||||
const folders = vscode.workspace.workspaceFolders;
|
||||
if (folders && folders.length > 0) {
|
||||
sessionManager.setCwd(folders[0].uri.fsPath);
|
||||
}
|
||||
|
||||
const chatController = new ChatController(sessionManager);
|
||||
const chatViewProvider = new OpenClaudeChatViewProvider(chatController);
|
||||
const chatPanelManager = new OpenClaudeChatPanelManager(chatController);
|
||||
|
||||
// ── Diff content provider ──
|
||||
const diffProvider = new DiffContentProvider();
|
||||
const diffProviderReg = vscode.workspace.registerTextDocumentContentProvider(
|
||||
DIFF_SCHEME,
|
||||
diffProvider,
|
||||
);
|
||||
|
||||
// ── Status bar ──
|
||||
const statusBarItem = vscode.window.createStatusBarItem(
|
||||
vscode.StatusBarAlignment.Right,
|
||||
100,
|
||||
);
|
||||
statusBarItem.text = '$(comment-discussion) OpenClaude';
|
||||
statusBarItem.tooltip = 'Open OpenClaude Chat';
|
||||
statusBarItem.command = 'openclaude.openChat';
|
||||
statusBarItem.show();
|
||||
|
||||
chatController.onDidChangeState((state) => {
|
||||
switch (state) {
|
||||
case 'streaming':
|
||||
statusBarItem.text = '$(sync~spin) OpenClaude';
|
||||
statusBarItem.tooltip = 'OpenClaude is generating...';
|
||||
break;
|
||||
case 'connected':
|
||||
statusBarItem.text = '$(comment-discussion) OpenClaude';
|
||||
statusBarItem.tooltip = 'OpenClaude connected';
|
||||
break;
|
||||
default:
|
||||
statusBarItem.text = '$(comment-discussion) OpenClaude';
|
||||
statusBarItem.tooltip = 'Open OpenClaude Chat';
|
||||
break;
|
||||
}
|
||||
});
|
||||
|
||||
// ── Existing commands ──
|
||||
const startCommand = vscode.commands.registerCommand('openclaude.start', async () => {
|
||||
await launchOpenClaude();
|
||||
});
|
||||
@@ -1079,32 +1129,95 @@ function activate(context) {
|
||||
await vscode.commands.executeCommand('workbench.view.extension.openclaude');
|
||||
});
|
||||
|
||||
const providerDisposable = vscode.window.registerWebviewViewProvider(
|
||||
// ── New chat commands ──
|
||||
const newChatCommand = vscode.commands.registerCommand('openclaude.newChat', () => {
|
||||
chatController.stopSession();
|
||||
chatController.broadcast({ type: 'session_cleared' });
|
||||
});
|
||||
|
||||
const openChatCommand = vscode.commands.registerCommand('openclaude.openChat', () => {
|
||||
chatPanelManager.openPanel();
|
||||
});
|
||||
|
||||
const resumeSessionCommand = vscode.commands.registerCommand('openclaude.resumeSession', async () => {
|
||||
const sessions = await sessionManager.listSessions();
|
||||
if (sessions.length === 0) {
|
||||
await vscode.window.showInformationMessage('No sessions found to resume.');
|
||||
return;
|
||||
}
|
||||
const items = sessions.slice(0, 30).map(s => ({
|
||||
label: s.title || s.id,
|
||||
description: s.timeLabel,
|
||||
detail: s.preview,
|
||||
sessionId: s.id,
|
||||
}));
|
||||
const picked = await vscode.window.showQuickPick(items, {
|
||||
placeHolder: 'Select a session to resume',
|
||||
});
|
||||
if (picked) {
|
||||
chatController.stopSession();
|
||||
chatController.broadcast({ type: 'session_cleared' });
|
||||
await chatController.startSession({ sessionId: picked.sessionId });
|
||||
}
|
||||
});
|
||||
|
||||
const abortChatCommand = vscode.commands.registerCommand('openclaude.abortChat', () => {
|
||||
chatController.abort();
|
||||
});
|
||||
|
||||
// ── Register providers ──
|
||||
const controlCenterProviderReg = vscode.window.registerWebviewViewProvider(
|
||||
'openclaude.controlCenter',
|
||||
provider,
|
||||
);
|
||||
|
||||
const chatViewProviderReg = vscode.window.registerWebviewViewProvider(
|
||||
'openclaude.chat',
|
||||
chatViewProvider,
|
||||
{ webviewOptions: { retainContextWhenHidden: true } },
|
||||
);
|
||||
|
||||
const profileWatcher = vscode.workspace.createFileSystemWatcher(`**/${PROFILE_FILE_NAME}`);
|
||||
|
||||
context.subscriptions.push(
|
||||
// existing
|
||||
startCommand,
|
||||
startInWorkspaceRootCommand,
|
||||
openDocsCommand,
|
||||
openSetupDocsCommand,
|
||||
openWorkspaceProfileCommand,
|
||||
openUiCommand,
|
||||
providerDisposable,
|
||||
controlCenterProviderReg,
|
||||
// new chat
|
||||
newChatCommand,
|
||||
openChatCommand,
|
||||
resumeSessionCommand,
|
||||
abortChatCommand,
|
||||
chatViewProviderReg,
|
||||
diffProviderReg,
|
||||
statusBarItem,
|
||||
// watchers
|
||||
profileWatcher,
|
||||
vscode.workspace.onDidChangeConfiguration(event => {
|
||||
if (event.affectsConfiguration('openclaude')) {
|
||||
refreshProvider();
|
||||
}
|
||||
}),
|
||||
vscode.workspace.onDidChangeWorkspaceFolders(refreshProvider),
|
||||
vscode.workspace.onDidChangeWorkspaceFolders((e) => {
|
||||
refreshProvider();
|
||||
const folders = vscode.workspace.workspaceFolders;
|
||||
if (folders && folders.length > 0) {
|
||||
sessionManager.setCwd(folders[0].uri.fsPath);
|
||||
}
|
||||
}),
|
||||
vscode.window.onDidChangeActiveTextEditor(refreshProvider),
|
||||
profileWatcher.onDidCreate(refreshProvider),
|
||||
profileWatcher.onDidChange(refreshProvider),
|
||||
profileWatcher.onDidDelete(refreshProvider),
|
||||
// disposables
|
||||
{ dispose: () => chatController.dispose() },
|
||||
{ dispose: () => chatPanelManager.dispose() },
|
||||
{ dispose: () => diffProvider.dispose() },
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1116,4 +1229,7 @@ module.exports = {
|
||||
OpenClaudeControlCenterProvider,
|
||||
renderControlCenterHtml,
|
||||
resolveLaunchTargets,
|
||||
ChatController,
|
||||
OpenClaudeChatViewProvider,
|
||||
OpenClaudeChatPanelManager,
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user