diff --git a/README.md b/README.md index a71cbcc8..a6206c9f 100644 --- a/README.md +++ b/README.md @@ -139,6 +139,50 @@ For best results, use models with strong tool/function calling support. --- +## Agent Routing + +Route different agents to different AI providers within the same session. Useful for cost optimization (cheap model for code review, powerful model for complex coding) or leveraging model strengths. + +### Configuration + +Add to `~/.claude/settings.json`: + +```json +{ + "agentModels": { + "deepseek-chat": { + "base_url": "https://api.deepseek.com/v1", + "api_key": "sk-your-key" + }, + "gpt-4o": { + "base_url": "https://api.openai.com/v1", + "api_key": "sk-your-key" + } + }, + "agentRouting": { + "Explore": "deepseek-chat", + "Plan": "gpt-4o", + "general-purpose": "gpt-4o", + "frontend-dev": "deepseek-chat", + "default": "gpt-4o" + } +} +``` + +### How It Works + +- **agentModels**: Maps model names to OpenAI-compatible API endpoints +- **agentRouting**: Maps agent types or team member names to model names +- **Priority**: `name` > `subagent_type` > `"default"` > global provider +- **Matching**: Case-insensitive, hyphen/underscore equivalent (`general-purpose` = `general_purpose`) +- **Teams**: Team members are routed by their `name` — no extra config needed + +When no routing match is found, the global provider (env vars) is used as fallback. + +> **Note:** `api_key` values in `settings.json` are stored in plaintext. Keep this file private and do not commit it to version control. + +--- + ## Web Search and Fetch `WebFetch` works out of the box. diff --git a/src/Tool.ts b/src/Tool.ts index 3a89f2d4..fa098ee2 100644 --- a/src/Tool.ts +++ b/src/Tool.ts @@ -176,6 +176,8 @@ export type ToolUseContext = { querySource?: QuerySource /** Optional callback to get the latest tools (e.g., after MCP servers connect mid-query) */ refreshTools?: () => Tools + /** Per-agent provider override from agentRouting config */ + providerOverride?: { model: string; baseURL: string; apiKey: string } } abortController: AbortController readFileState: FileStateCache diff --git a/src/query.ts b/src/query.ts index c94abc96..9733912e 100644 --- a/src/query.ts +++ b/src/query.ts @@ -702,6 +702,7 @@ async function* queryLoop( skipCacheWrite, agentId: toolUseContext.agentId, addNotification: toolUseContext.addNotification, + providerOverride: toolUseContext.options.providerOverride, ...(params.taskBudget && { taskBudget: { total: params.taskBudget.total, diff --git a/src/services/api/agentRouting.test.ts b/src/services/api/agentRouting.test.ts new file mode 100644 index 00000000..1522b545 --- /dev/null +++ b/src/services/api/agentRouting.test.ts @@ -0,0 +1,125 @@ +import { describe, expect, test } from 'bun:test' +import { resolveAgentProvider } from './agentRouting.js' +import type { SettingsJson } from '../../utils/settings/types.js' + +const baseSettings = { + agentModels: { + 'deepseek-chat': { base_url: 'https://api.deepseek.com/v1', api_key: 'sk-ds' }, + 'gpt-4o': { base_url: 'https://api.openai.com/v1', api_key: 'sk-oai' }, + }, + agentRouting: { + Explore: 'deepseek-chat', + 'general-purpose': 'gpt-4o', + 'frontend-dev': 'deepseek-chat', + default: 'gpt-4o', + }, +} as unknown as SettingsJson + +describe('resolveAgentProvider', () => { + // ── Priority chain ────────────────────────────────────────── + + test('name takes priority over subagentType', () => { + const result = resolveAgentProvider('frontend-dev', 'Explore', baseSettings) + expect(result).toEqual({ + model: 'deepseek-chat', + baseURL: 'https://api.deepseek.com/v1', + apiKey: 'sk-ds', + }) + }) + + test('subagentType used when name has no match', () => { + const result = resolveAgentProvider('unknown-name', 'Explore', baseSettings) + expect(result).toEqual({ + model: 'deepseek-chat', + baseURL: 'https://api.deepseek.com/v1', + apiKey: 'sk-ds', + }) + }) + + test('falls back to "default" when neither name nor subagentType match', () => { + const result = resolveAgentProvider('nobody', 'unknown-type', baseSettings) + expect(result).toEqual({ + model: 'gpt-4o', + baseURL: 'https://api.openai.com/v1', + apiKey: 'sk-oai', + }) + }) + + test('returns null when no routing match and no default', () => { + const settings = { + agentModels: baseSettings.agentModels, + agentRouting: { Explore: 'deepseek-chat' }, + } as unknown as SettingsJson + const result = resolveAgentProvider('nobody', 'unknown-type', settings) + expect(result).toBeNull() + }) + + test('returns null when name and subagentType are both undefined', () => { + const settings = { + agentModels: baseSettings.agentModels, + agentRouting: { Explore: 'deepseek-chat' }, + } as unknown as SettingsJson + const result = resolveAgentProvider(undefined, undefined, settings) + expect(result).toBeNull() + }) + + // ── normalize() matching ──────────────────────────────────── + + test('matching is case-insensitive', () => { + const result = resolveAgentProvider(undefined, 'explore', baseSettings) + expect(result?.model).toBe('deepseek-chat') + }) + + test('matching is case-insensitive (UPPER)', () => { + const result = resolveAgentProvider(undefined, 'EXPLORE', baseSettings) + expect(result?.model).toBe('deepseek-chat') + }) + + test('hyphen and underscore are equivalent', () => { + const result = resolveAgentProvider(undefined, 'general_purpose', baseSettings) + expect(result?.model).toBe('gpt-4o') + }) + + test('underscore in config matches hyphen in input', () => { + const settings = { + agentModels: baseSettings.agentModels, + agentRouting: { general_purpose: 'deepseek-chat' }, + } as unknown as SettingsJson + const result = resolveAgentProvider(undefined, 'general-purpose', settings) + expect(result?.model).toBe('deepseek-chat') + }) + + // ── Edge cases ────────────────────────────────────────────── + + test('returns null when settings is null', () => { + expect(resolveAgentProvider('Explore', 'Explore', null)).toBeNull() + }) + + test('returns null when agentRouting is missing', () => { + const settings = { agentModels: baseSettings.agentModels } as unknown as SettingsJson + expect(resolveAgentProvider(undefined, 'Explore', settings)).toBeNull() + }) + + test('returns null when agentModels is missing', () => { + const settings = { agentRouting: baseSettings.agentRouting } as unknown as SettingsJson + expect(resolveAgentProvider(undefined, 'Explore', settings)).toBeNull() + }) + + test('returns null when routing references non-existent model', () => { + const settings = { + agentModels: {}, + agentRouting: { Explore: 'non-existent-model' }, + } as unknown as SettingsJson + expect(resolveAgentProvider(undefined, 'Explore', settings)).toBeNull() + }) + + test('subagentType only (no name)', () => { + const result = resolveAgentProvider(undefined, 'Explore', baseSettings) + expect(result?.model).toBe('deepseek-chat') + }) + + test('name only (no subagentType)', () => { + const result = resolveAgentProvider('frontend-dev', undefined, baseSettings) + expect(result?.model).toBe('deepseek-chat') + }) +}) diff --git a/src/services/api/agentRouting.ts b/src/services/api/agentRouting.ts new file mode 100644 index 00000000..1afdf795 --- /dev/null +++ b/src/services/api/agentRouting.ts @@ -0,0 +1,75 @@ +import type { SettingsJson } from '../../utils/settings/types.js' + +/** + * Provider override resolved from agent routing config. + * When present, the API client should use these instead of global env vars. + */ +export interface ProviderOverride { + /** Model name to send to the API (e.g. "deepseek-chat", "gpt-4o") */ + model: string + /** OpenAI-compatible base URL */ + baseURL: string + /** API key for this provider */ + apiKey: string +} + +/** + * Normalize an agent identifier for case-insensitive, hyphen/underscore-agnostic matching. + */ +function normalize(key: string): string { + return key.toLowerCase().replace(/[-_]/g, '') +} + +/** + * Look up agent.routing by name or subagent_type, then resolve via agent.models. + * + * Priority: name > subagentType > "default" > null (use global provider) + */ +export function resolveAgentProvider( + name: string | undefined, + subagentType: string | undefined, + settings: SettingsJson | null, +): ProviderOverride | null { + if (!settings) return null + + const routing = settings.agentRouting + const models = settings.agentModels + if (!routing || !models) return null + + // Build normalized lookup from routing config. + // Warn on duplicate normalized keys (e.g. "explore-agent" and "explore_agent" + // both normalize to "exploreagent") to prevent silent shadowing. + const normalizedRouting = new Map() + for (const [key, value] of Object.entries(routing)) { + const nk = normalize(key) + if (normalizedRouting.has(nk)) { + console.error(`[agentRouting] Warning: routing key "${key}" collides with an existing key after normalization (both map to "${nk}"). First entry wins.`) + } + if (!normalizedRouting.has(nk)) { + normalizedRouting.set(nk, value) + } + } + + // Try name first, then subagentType, then "default" + const candidates = [name, subagentType, 'default'].filter(Boolean) as string[] + let modelName: string | undefined + + for (const candidate of candidates) { + const match = normalizedRouting.get(normalize(candidate)) + if (match) { + modelName = match + break + } + } + + if (!modelName) return null + + const modelConfig = models[modelName] + if (!modelConfig) return null + + return { + model: modelName, + baseURL: modelConfig.base_url, + apiKey: modelConfig.api_key, + } +} diff --git a/src/services/api/claude.ts b/src/services/api/claude.ts index 89a6e661..b5e55c37 100644 --- a/src/services/api/claude.ts +++ b/src/services/api/claude.ts @@ -704,6 +704,7 @@ export type Options = { // so the model can pace itself. `remaining` is computed by the caller // (query.ts decrements across the agentic loop). taskBudget?: { total: number; remaining?: number } + providerOverride?: { model: string; baseURL: string; apiKey: string } } export async function queryModelWithoutStreaming({ @@ -820,6 +821,7 @@ export async function* executeNonStreamingRequest( model: string fetchOverride?: Options['fetchOverride'] source: string + providerOverride?: Options['providerOverride'] }, retryOptions: { model: string @@ -847,6 +849,7 @@ export async function* executeNonStreamingRequest( model: clientOptions.model, fetchOverride: clientOptions.fetchOverride, source: clientOptions.source, + providerOverride: clientOptions.providerOverride, }), async (anthropic, attempt, context) => { const start = Date.now() @@ -1782,6 +1785,7 @@ async function* queryModel( model: options.model, fetchOverride: options.fetchOverride, source: options.querySource, + providerOverride: options.providerOverride, }), async (anthropic, attempt, context) => { attemptNumber = attempt @@ -2549,7 +2553,7 @@ async function* queryModel( : 'other') as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS, }) const result = yield* executeNonStreamingRequest( - { model: options.model, source: options.querySource }, + { model: options.model, source: options.querySource, providerOverride: options.providerOverride }, { model: options.model, fallbackModel: options.fallbackModel, diff --git a/src/services/api/client.ts b/src/services/api/client.ts index a32e0779..ab73d805 100644 --- a/src/services/api/client.ts +++ b/src/services/api/client.ts @@ -95,12 +95,14 @@ export async function getAnthropicClient({ model, fetchOverride, source, + providerOverride, }: { apiKey?: string maxRetries: number model?: string fetchOverride?: ClientOptions['fetch'] source?: string + providerOverride?: { model: string; baseURL: string; apiKey: string } }): Promise { const containerId = process.env.CLAUDE_CODE_CONTAINER_ID const remoteSessionId = process.env.CLAUDE_CODE_REMOTE_SESSION_ID @@ -154,6 +156,24 @@ export async function getAnthropicClient({ fetch: resolvedFetch, }), } + // Agent routing override: use per-agent provider when configured. + // Strip auth-related headers to prevent leaking Anthropic credentials + // to third-party endpoints (SSRF / credential forwarding mitigation). + if (providerOverride) { + const { createOpenAIShimClient } = await import('./openaiShim.js') + const safeHeaders: Record = {} + for (const [k, v] of Object.entries(defaultHeaders)) { + const lower = k.toLowerCase() + if (lower === 'authorization' || lower === 'x-api-key' || lower === 'api-key') continue + safeHeaders[k] = v + } + return createOpenAIShimClient({ + defaultHeaders: safeHeaders, + maxRetries, + timeout: parseInt(process.env.API_TIMEOUT_MS || String(600 * 1000), 10), + providerOverride, + }) as unknown as Anthropic + } if ( isEnvTruthy(process.env.CLAUDE_CODE_USE_OPENAI) || isEnvTruthy(process.env.CLAUDE_CODE_USE_GITHUB) || diff --git a/src/services/api/openaiShim.ts b/src/services/api/openaiShim.ts index 68cb31e9..b8407be3 100644 --- a/src/services/api/openaiShim.ts +++ b/src/services/api/openaiShim.ts @@ -683,10 +683,12 @@ class OpenAIShimStream { class OpenAIShimMessages { private defaultHeaders: Record private reasoningEffort?: 'low' | 'medium' | 'high' | 'xhigh' + private providerOverride?: { model: string; baseURL: string; apiKey: string } - constructor(defaultHeaders: Record, reasoningEffort?: 'low' | 'medium' | 'high' | 'xhigh') { + constructor(defaultHeaders: Record, reasoningEffort?: 'low' | 'medium' | 'high' | 'xhigh', providerOverride?: { model: string; baseURL: string; apiKey: string }) { this.defaultHeaders = defaultHeaders this.reasoningEffort = reasoningEffort + this.providerOverride = providerOverride } create( @@ -698,7 +700,7 @@ class OpenAIShimMessages { let httpResponse: Response | undefined const promise = (async () => { - const request = resolveProviderRequest({ model: params.model, reasoningEffortOverride: self.reasoningEffort }) + const request = resolveProviderRequest({ model: self.providerOverride?.model ?? params.model, baseUrl: self.providerOverride?.baseURL, reasoningEffortOverride: self.reasoningEffort }) const response = await self._doRequest(request, params, options) httpResponse = response @@ -857,7 +859,7 @@ class OpenAIShimMessages { ...(options?.headers ?? {}), } - const apiKey = process.env.OPENAI_API_KEY ?? '' + const apiKey = this.providerOverride?.apiKey ?? process.env.OPENAI_API_KEY ?? '' // Detect Azure endpoints by hostname (not raw URL) to prevent bypass via // path segments like https://evil.com/cognitiveservices.azure.com/ let isAzure = false @@ -1056,8 +1058,8 @@ class OpenAIShimBeta { messages: OpenAIShimMessages reasoningEffort?: 'low' | 'medium' | 'high' | 'xhigh' - constructor(defaultHeaders: Record, reasoningEffort?: 'low' | 'medium' | 'high' | 'xhigh') { - this.messages = new OpenAIShimMessages(defaultHeaders, reasoningEffort) + constructor(defaultHeaders: Record, reasoningEffort?: 'low' | 'medium' | 'high' | 'xhigh', providerOverride?: { model: string; baseURL: string; apiKey: string }) { + this.messages = new OpenAIShimMessages(defaultHeaders, reasoningEffort, providerOverride) this.reasoningEffort = reasoningEffort } } @@ -1067,6 +1069,7 @@ export function createOpenAIShimClient(options: { maxRetries?: number timeout?: number reasoningEffort?: 'low' | 'medium' | 'high' | 'xhigh' + providerOverride?: { model: string; baseURL: string; apiKey: string } }): unknown { hydrateGithubModelsTokenFromSecureStorage() @@ -1089,7 +1092,7 @@ export function createOpenAIShimClient(options: { const beta = new OpenAIShimBeta({ ...(options.defaultHeaders ?? {}), - }, options.reasoningEffort) + }, options.reasoningEffort, options.providerOverride) return { beta, diff --git a/src/tools/AgentTool/AgentTool.tsx b/src/tools/AgentTool/AgentTool.tsx index dd7fafde..51b434f3 100644 --- a/src/tools/AgentTool/AgentTool.tsx +++ b/src/tools/AgentTool/AgentTool.tsx @@ -644,7 +644,8 @@ export const AgentTool = buildTool({ useExactTools: true }), worktreePath: worktreeInfo?.worktreePath, - description + description, + agentName: name, }; // Helper to wrap execution with a cwd override: explicit cwd arg (KAIROS) diff --git a/src/tools/AgentTool/runAgent.ts b/src/tools/AgentTool/runAgent.ts index e2cb942e..f6fede20 100644 --- a/src/tools/AgentTool/runAgent.ts +++ b/src/tools/AgentTool/runAgent.ts @@ -57,6 +57,8 @@ import { clearSessionHooks } from '../../utils/hooks/sessionHooks.js' import { executeSubagentStartHooks } from '../../utils/hooks.js' import { createUserMessage } from '../../utils/messages.js' import { getAgentModel } from '../../utils/model/agent.js' +import { resolveAgentProvider } from '../../services/api/agentRouting.js' +import { getInitialSettings } from '../../utils/settings/settings.js' import type { ModelAlias } from '../../utils/model/aliases.js' import { clearAgentTranscriptSubdir, @@ -267,6 +269,7 @@ export async function* runAgent({ description, transcriptSubdir, onQueryProgress, + agentName, }: { agentDefinition: AgentDefinition promptMessages: Message[] @@ -326,6 +329,8 @@ export async function* runAgent({ * during long single-block streams (e.g. thinking) where no assistant * message is yielded for >60s. */ onQueryProgress?: () => void + /** Agent name (team member name) for routing resolution */ + agentName?: string }): AsyncGenerator { // Track subagent usage for feature discovery @@ -344,6 +349,14 @@ export async function* runAgent({ permissionMode, ) + // Resolve per-agent provider routing from settings + const providerOverride = resolveAgentProvider( + agentName, + agentDefinition.agentType, + getInitialSettings(), + ) + const effectiveModel = providerOverride ? providerOverride.model : resolvedAgentModel + const agentId = override?.agentId ? override.agentId : createAgentId() // Route this agent's transcript into a grouping subdirectory if requested @@ -675,7 +688,8 @@ export async function* runAgent({ commands: [], debug: toolUseContext.options.debug, verbose: toolUseContext.options.verbose, - mainLoopModel: resolvedAgentModel, + mainLoopModel: effectiveModel, + providerOverride: providerOverride ?? undefined, // For fork children (useExactTools), inherit thinking config to match the // parent's API request prefix for prompt cache hits. For regular // sub-agents, disable thinking to control output token costs. diff --git a/src/utils/settings/types.ts b/src/utils/settings/types.ts index ba89edd8..e53d6601 100644 --- a/src/utils/settings/types.ts +++ b/src/utils/settings/types.ts @@ -713,6 +713,27 @@ export const SettingsSchema = lazySchema(() => .string() .optional() .describe('Advisor model for the server-side advisor tool.'), + agentModels: z + .record( + z.string(), + z.object({ + base_url: z.string().url().describe('OpenAI-compatible API endpoint (must be https:// or http://)'), + api_key: z.string().describe('API key for this provider'), + }), + ) + .optional() + .describe( + 'Map of model name to provider connection info. ' + + 'Example: { "deepseek-chat": { "base_url": "https://api.deepseek.com/v1", "api_key": "sk-xxx" } }', + ), + agentRouting: z + .record(z.string(), z.string()) + .optional() + .describe( + 'Map of agent identifier (subagent_type or team member name) to model name. ' + + 'Use "default" key as fallback. Model name must exist in agentModels. ' + + 'Example: { "Explore": "deepseek-chat", "general-purpose": "gpt-4o", "default": "gpt-4o" }', + ), fastMode: z .boolean() .optional()