From 174eb8ad3b94166e86ab63e95c69d6da46982799 Mon Sep 17 00:00:00 2001 From: Vasanthdev2004 Date: Wed, 1 Apr 2026 11:10:51 +0530 Subject: [PATCH] feat: add intelligent provider profile recommendation --- PLAYBOOK.md | 26 ++ README.md | 12 + package.json | 3 + scripts/provider-bootstrap.ts | 60 +++-- scripts/provider-discovery.ts | 129 ++++++++++ scripts/provider-launch.ts | 63 +++-- scripts/provider-recommend.ts | 277 +++++++++++++++++++++ src/utils/providerRecommendation.test.ts | 118 +++++++++ src/utils/providerRecommendation.ts | 297 +++++++++++++++++++++++ 9 files changed, 945 insertions(+), 40 deletions(-) create mode 100644 scripts/provider-discovery.ts create mode 100644 scripts/provider-recommend.ts create mode 100644 src/utils/providerRecommendation.test.ts create mode 100644 src/utils/providerRecommendation.ts diff --git a/PLAYBOOK.md b/PLAYBOOK.md index 662ee4dc..dfdaec76 100644 --- a/PLAYBOOK.md +++ b/PLAYBOOK.md @@ -37,6 +37,18 @@ If everything is healthy, OpenClaude starts directly. bun run profile:init -- --provider ollama --model llama3.1:8b ``` +Or let OpenClaude recommend the best local model for your goal: + +```powershell +bun run profile:init -- --provider ollama --goal coding +``` + +Preview recommendations before saving: + +```powershell +bun run profile:recommend -- --goal coding --benchmark +``` + ### 3.2 Confirm profile file ```powershell @@ -171,6 +183,12 @@ Fix: bun run profile:init -- --provider ollama --model llama3.1:8b ``` +Or auto-pick a local profile: + +```powershell +bun run profile:auto -- --goal balanced +``` + ## 6.5 Placeholder key (`SUA_CHAVE`) error Cause: @@ -202,6 +220,14 @@ bun run profile:fast # llama3.2:3b bun run profile:code # qwen2.5-coder:7b ``` +Goal-based auto-selection: + +```powershell +bun run profile:auto -- --goal latency +bun run profile:auto -- --goal balanced +bun run profile:auto -- --goal coding +``` + ## 8. Practical Prompt Playbook (Copy/Paste) ## 8.1 Code understanding diff --git a/README.md b/README.md index 5d17d276..358bf95d 100644 --- a/README.md +++ b/README.md @@ -206,12 +206,21 @@ Use profile launchers to avoid repeated environment setup: # one-time profile bootstrap (auto-detect ollama, otherwise openai) bun run profile:init +# preview the best provider/model for your goal +bun run profile:recommend -- --goal coding --benchmark + +# auto-apply the best available profile for your goal +bun run profile:auto -- --goal latency + # openai bootstrap with explicit key bun run profile:init -- --provider openai --api-key sk-... # ollama bootstrap with custom model bun run profile:init -- --provider ollama --model llama3.1:8b +# ollama bootstrap with intelligent model auto-selection +bun run profile:init -- --provider ollama --goal coding + # launch using persisted profile (.openclaude-profile.json) bun run dev:profile @@ -222,6 +231,9 @@ bun run dev:openai bun run dev:ollama ``` +`profile:recommend` ranks installed Ollama models for `latency`, `balanced`, or `coding`, and `profile:auto` can persist the recommendation directly. +If no profile exists yet, `dev:profile` now uses the same goal-aware defaults when picking the initial model. + `dev:openai` and `dev:ollama` run `doctor:runtime` first and only launch the app if checks pass. For `dev:ollama`, make sure Ollama is running locally before launch. diff --git a/package.json b/package.json index 15f9f348..ab44903f 100644 --- a/package.json +++ b/package.json @@ -20,11 +20,14 @@ "dev:ollama": "bun run scripts/provider-launch.ts ollama", "dev:ollama:fast": "bun run scripts/provider-launch.ts ollama --fast --bare", "profile:init": "bun run scripts/provider-bootstrap.ts", + "profile:recommend": "bun run scripts/provider-recommend.ts", + "profile:auto": "bun run scripts/provider-recommend.ts --apply", "profile:fast": "bun run profile:init -- --provider ollama --model llama3.2:3b", "profile:code": "bun run profile:init -- --provider ollama --model qwen2.5-coder:7b", "dev:fast": "bun run profile:fast && bun run dev:ollama:fast", "dev:code": "bun run profile:code && bun run dev:profile", "start": "node dist/cli.mjs", + "test:provider-recommendation": "node --test --experimental-strip-types src/utils/providerRecommendation.test.ts", "typecheck": "tsc --noEmit", "smoke": "bun run build && node dist/cli.mjs --version", "doctor:runtime": "bun run scripts/system-check.ts", diff --git a/scripts/provider-bootstrap.ts b/scripts/provider-bootstrap.ts index 7c066a00..31915b39 100644 --- a/scripts/provider-bootstrap.ts +++ b/scripts/provider-bootstrap.ts @@ -1,6 +1,16 @@ // @ts-nocheck import { writeFileSync } from 'node:fs' import { resolve } from 'node:path' +import { + getGoalDefaultOpenAIModel, + normalizeRecommendationGoal, + recommendOllamaModel, +} from '../src/utils/providerRecommendation.ts' +import { + getOllamaChatBaseUrl, + hasLocalOllama, + listOllamaModels, +} from './provider-discovery.ts' type ProviderProfile = 'openai' | 'ollama' @@ -27,51 +37,55 @@ function parseProviderArg(): ProviderProfile | 'auto' { return 'auto' } -async function hasLocalOllama(): Promise { - const endpoint = 'http://localhost:11434/api/tags' - const controller = new AbortController() - const timeout = setTimeout(() => controller.abort(), 1200) - - try { - const response = await fetch(endpoint, { - method: 'GET', - signal: controller.signal, - }) - return response.ok - } catch { - return false - } finally { - clearTimeout(timeout) - } -} - function sanitizeApiKey(key: string | null): string | undefined { if (!key || key === 'SUA_CHAVE') return undefined return key } +async function resolveOllamaModel( + argModel: string | null, + argBaseUrl: string | null, + goal: ReturnType, +): Promise { + if (argModel) return argModel + + const discovered = await listOllamaModels(argBaseUrl || undefined) + const recommended = recommendOllamaModel(discovered, goal) + if (recommended) { + return recommended.name + } + + return process.env.OPENAI_MODEL || 'llama3.1:8b' +} + async function main(): Promise { const provider = parseProviderArg() const argModel = parseArg('--model') const argBaseUrl = parseArg('--base-url') const argApiKey = parseArg('--api-key') + const goal = normalizeRecommendationGoal( + parseArg('--goal') || process.env.OPENCLAUDE_PROFILE_GOAL, + ) let selected: ProviderProfile if (provider === 'auto') { - selected = (await hasLocalOllama()) ? 'ollama' : 'openai' + selected = (await hasLocalOllama(argBaseUrl || undefined)) ? 'ollama' : 'openai' } else { selected = provider } const env: ProfileFile['env'] = {} if (selected === 'ollama') { - env.OPENAI_BASE_URL = argBaseUrl || 'http://localhost:11434/v1' - env.OPENAI_MODEL = argModel || process.env.OPENAI_MODEL || 'llama3.1:8b' + env.OPENAI_BASE_URL = getOllamaChatBaseUrl(argBaseUrl || undefined) + env.OPENAI_MODEL = await resolveOllamaModel(argModel, argBaseUrl, goal) const key = sanitizeApiKey(argApiKey || process.env.OPENAI_API_KEY || null) if (key) env.OPENAI_API_KEY = key } else { env.OPENAI_BASE_URL = argBaseUrl || process.env.OPENAI_BASE_URL || 'https://api.openai.com/v1' - env.OPENAI_MODEL = argModel || process.env.OPENAI_MODEL || 'gpt-4o' + env.OPENAI_MODEL = + argModel || + process.env.OPENAI_MODEL || + getGoalDefaultOpenAIModel(goal) const key = sanitizeApiKey(argApiKey || process.env.OPENAI_API_KEY || null) if (!key) { console.error('OpenAI profile requires a real API key. Use --api-key or set OPENAI_API_KEY.') @@ -90,6 +104,8 @@ async function main(): Promise { writeFileSync(outputPath, JSON.stringify(profile, null, 2), 'utf8') console.log(`Saved profile: ${selected}`) + console.log(`Goal: ${goal}`) + console.log(`Model: ${profile.env.OPENAI_MODEL}`) console.log(`Path: ${outputPath}`) console.log('Next: bun run dev:profile') } diff --git a/scripts/provider-discovery.ts b/scripts/provider-discovery.ts new file mode 100644 index 00000000..9e3aacda --- /dev/null +++ b/scripts/provider-discovery.ts @@ -0,0 +1,129 @@ +import type { OllamaModelDescriptor } from '../src/utils/providerRecommendation.ts' + +export const DEFAULT_OLLAMA_BASE_URL = 'http://localhost:11434' + +function withTimeoutSignal(timeoutMs: number): { + signal: AbortSignal + clear: () => void +} { + const controller = new AbortController() + const timeout = setTimeout(() => controller.abort(), timeoutMs) + return { + signal: controller.signal, + clear: () => clearTimeout(timeout), + } +} + +function trimTrailingSlash(value: string): string { + return value.replace(/\/+$/, '') +} + +export function getOllamaApiBaseUrl(baseUrl?: string): string { + const parsed = new URL( + baseUrl || process.env.OLLAMA_BASE_URL || DEFAULT_OLLAMA_BASE_URL, + ) + const pathname = trimTrailingSlash(parsed.pathname) + parsed.pathname = pathname.endsWith('/v1') + ? pathname.slice(0, -3) || '/' + : pathname || '/' + parsed.search = '' + parsed.hash = '' + return trimTrailingSlash(parsed.toString()) +} + +export function getOllamaChatBaseUrl(baseUrl?: string): string { + return `${getOllamaApiBaseUrl(baseUrl)}/v1` +} + +export async function hasLocalOllama(baseUrl?: string): Promise { + const { signal, clear } = withTimeoutSignal(1200) + try { + const response = await fetch(`${getOllamaApiBaseUrl(baseUrl)}/api/tags`, { + method: 'GET', + signal, + }) + return response.ok + } catch { + return false + } finally { + clear() + } +} + +export async function listOllamaModels( + baseUrl?: string, +): Promise { + const { signal, clear } = withTimeoutSignal(5000) + try { + const response = await fetch(`${getOllamaApiBaseUrl(baseUrl)}/api/tags`, { + method: 'GET', + signal, + }) + if (!response.ok) { + return [] + } + + const data = await response.json() as { + models?: Array<{ + name?: string + size?: number + details?: { + family?: string + families?: string[] + parameter_size?: string + quantization_level?: string + } + }> + } + + return (data.models ?? []) + .filter(model => Boolean(model.name)) + .map(model => ({ + name: model.name!, + sizeBytes: typeof model.size === 'number' ? model.size : null, + family: model.details?.family ?? null, + families: model.details?.families ?? [], + parameterSize: model.details?.parameter_size ?? null, + quantizationLevel: model.details?.quantization_level ?? null, + })) + } catch { + return [] + } finally { + clear() + } +} + +export async function benchmarkOllamaModel( + modelName: string, + baseUrl?: string, +): Promise { + const start = Date.now() + const { signal, clear } = withTimeoutSignal(20000) + try { + const response = await fetch(`${getOllamaApiBaseUrl(baseUrl)}/api/chat`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + signal, + body: JSON.stringify({ + model: modelName, + stream: false, + messages: [{ role: 'user', content: 'Reply with OK.' }], + options: { + temperature: 0, + num_predict: 8, + }, + }), + }) + if (!response.ok) { + return null + } + await response.json() + return Date.now() - start + } catch { + return null + } finally { + clear() + } +} diff --git a/scripts/provider-launch.ts b/scripts/provider-launch.ts index fa103d53..26666072 100644 --- a/scripts/provider-launch.ts +++ b/scripts/provider-launch.ts @@ -2,6 +2,16 @@ import { spawn } from 'node:child_process' import { existsSync, readFileSync } from 'node:fs' import { resolve } from 'node:path' +import { + getGoalDefaultOpenAIModel, + normalizeRecommendationGoal, + recommendOllamaModel, +} from '../src/utils/providerRecommendation.ts' +import { + getOllamaChatBaseUrl, + hasLocalOllama, + listOllamaModels, +} from './provider-discovery.ts' type ProviderProfile = 'openai' | 'ollama' @@ -18,20 +28,29 @@ type LaunchOptions = { requestedProfile: ProviderProfile | 'auto' | null passthroughArgs: string[] fast: boolean + goal: ReturnType } function parseLaunchOptions(argv: string[]): LaunchOptions { let requestedProfile: ProviderProfile | 'auto' | null = 'auto' const passthroughArgs: string[] = [] let fast = false + let goal = normalizeRecommendationGoal(process.env.OPENCLAUDE_PROFILE_GOAL) - for (const arg of argv) { + for (let i = 0; i < argv.length; i++) { + const arg = argv[i]! const lower = arg.toLowerCase() if (lower === '--fast') { fast = true continue } + if (lower === '--goal') { + goal = normalizeRecommendationGoal(argv[i + 1] ?? null) + i++ + continue + } + if ((lower === 'auto' || lower === 'openai' || lower === 'ollama') && requestedProfile === 'auto') { requestedProfile = lower as ProviderProfile | 'auto' continue @@ -54,6 +73,7 @@ function parseLaunchOptions(argv: string[]): LaunchOptions { requestedProfile, passthroughArgs, fast, + goal, } } @@ -71,18 +91,12 @@ function loadPersistedProfile(): ProfileFile | null { } } -async function hasLocalOllama(): Promise { - const endpoint = 'http://localhost:11434/api/tags' - const controller = new AbortController() - const timeout = setTimeout(() => controller.abort(), 1200) - try { - const response = await fetch(endpoint, { signal: controller.signal }) - return response.ok - } catch { - return false - } finally { - clearTimeout(timeout) - } +async function resolveOllamaDefaultModel( + goal: ReturnType, +): Promise { + const models = await listOllamaModels() + const recommended = recommendOllamaModel(models, goal) + return recommended?.name || process.env.OPENAI_MODEL || 'llama3.1:8b' } function runCommand(command: string, env: NodeJS.ProcessEnv): Promise { @@ -99,7 +113,11 @@ function runCommand(command: string, env: NodeJS.ProcessEnv): Promise { }) } -function buildEnv(profile: ProviderProfile, persisted: ProfileFile | null): NodeJS.ProcessEnv { +async function buildEnv( + profile: ProviderProfile, + persisted: ProfileFile | null, + goal: ReturnType, +): Promise { const persistedEnv = persisted?.env ?? {} const env: NodeJS.ProcessEnv = { ...process.env, @@ -107,8 +125,14 @@ function buildEnv(profile: ProviderProfile, persisted: ProfileFile | null): Node } if (profile === 'ollama') { - env.OPENAI_BASE_URL = persistedEnv.OPENAI_BASE_URL || process.env.OPENAI_BASE_URL || 'http://localhost:11434/v1' - env.OPENAI_MODEL = persistedEnv.OPENAI_MODEL || process.env.OPENAI_MODEL || 'llama3.1:8b' + env.OPENAI_BASE_URL = + persistedEnv.OPENAI_BASE_URL || + process.env.OPENAI_BASE_URL || + getOllamaChatBaseUrl() + env.OPENAI_MODEL = + persistedEnv.OPENAI_MODEL || + process.env.OPENAI_MODEL || + await resolveOllamaDefaultModel(goal) if (!process.env.OPENAI_API_KEY || process.env.OPENAI_API_KEY === 'SUA_CHAVE') { delete env.OPENAI_API_KEY } @@ -116,7 +140,10 @@ function buildEnv(profile: ProviderProfile, persisted: ProfileFile | null): Node } env.OPENAI_BASE_URL = process.env.OPENAI_BASE_URL || persistedEnv.OPENAI_BASE_URL || 'https://api.openai.com/v1' - env.OPENAI_MODEL = process.env.OPENAI_MODEL || persistedEnv.OPENAI_MODEL || 'gpt-4o' + env.OPENAI_MODEL = + process.env.OPENAI_MODEL || + persistedEnv.OPENAI_MODEL || + getGoalDefaultOpenAIModel(goal) env.OPENAI_API_KEY = process.env.OPENAI_API_KEY || persistedEnv.OPENAI_API_KEY return env } @@ -165,7 +192,7 @@ async function main(): Promise { profile = requestedProfile } - const env = buildEnv(profile, persisted) + const env = await buildEnv(profile, persisted, options.goal) if (options.fast) { applyFastFlags(env) } diff --git a/scripts/provider-recommend.ts b/scripts/provider-recommend.ts new file mode 100644 index 00000000..8cfdc883 --- /dev/null +++ b/scripts/provider-recommend.ts @@ -0,0 +1,277 @@ +// @ts-nocheck +import { writeFileSync } from 'node:fs' +import { resolve } from 'node:path' + +import { + applyBenchmarkLatency, + getGoalDefaultOpenAIModel, + normalizeRecommendationGoal, + rankOllamaModels, + type BenchmarkedOllamaModel, + type RecommendationGoal, +} from '../src/utils/providerRecommendation.ts' +import { + benchmarkOllamaModel, + getOllamaChatBaseUrl, + hasLocalOllama, + listOllamaModels, +} from './provider-discovery.ts' + +type ProviderProfile = 'openai' | 'ollama' + +type ProfileFile = { + profile: ProviderProfile + env: { + OPENAI_BASE_URL?: string + OPENAI_MODEL?: string + OPENAI_API_KEY?: string + } + createdAt: string +} + +type CliOptions = { + apply: boolean + benchmark: boolean + goal: RecommendationGoal + json: boolean + provider: ProviderProfile | 'auto' + baseUrl: string | null +} + +function parseOptions(argv: string[]): CliOptions { + const options: CliOptions = { + apply: false, + benchmark: false, + goal: normalizeRecommendationGoal(process.env.OPENCLAUDE_PROFILE_GOAL), + json: false, + provider: 'auto', + baseUrl: null, + } + + for (let i = 0; i < argv.length; i++) { + const arg = argv[i]?.toLowerCase() + if (!arg) continue + + if (arg === '--apply') { + options.apply = true + continue + } + if (arg === '--benchmark') { + options.benchmark = true + continue + } + if (arg === '--json') { + options.json = true + continue + } + if (arg === '--goal') { + options.goal = normalizeRecommendationGoal(argv[i + 1] ?? null) + i++ + continue + } + if (arg === '--provider') { + const provider = argv[i + 1]?.toLowerCase() + if ( + provider === 'openai' || + provider === 'ollama' || + provider === 'auto' + ) { + options.provider = provider + } + i++ + continue + } + if (arg === '--base-url') { + options.baseUrl = argv[i + 1] ?? null + i++ + } + } + + return options +} + +function sanitizeApiKey(key: string | undefined): string | undefined { + if (!key || key === 'SUA_CHAVE') return undefined + return key +} + +function printHumanSummary(payload: { + goal: RecommendationGoal + recommendedProfile: ProviderProfile + recommendedModel: string + rankedModels: BenchmarkedOllamaModel[] + benchmarked: boolean + applied: boolean +}): void { + console.log(`Recommendation goal: ${payload.goal}`) + console.log(`Recommended profile: ${payload.recommendedProfile}`) + console.log(`Recommended model: ${payload.recommendedModel}`) + + if (payload.rankedModels.length > 0) { + console.log('\nRanked Ollama models:') + for (const [index, model] of payload.rankedModels.slice(0, 5).entries()) { + const benchmarkPart = + payload.benchmarked && model.benchmarkMs !== null + ? ` | ${Math.round(model.benchmarkMs)}ms` + : '' + console.log( + `${index + 1}. ${model.name} | score=${model.score}${benchmarkPart} | ${model.summary}`, + ) + } + } + + if (payload.applied) { + console.log('\nSaved .openclaude-profile.json with the recommended profile.') + console.log('Next: bun run dev:profile') + } else { + console.log( + '\nTip: run `bun run profile:auto -- --goal ' + + payload.goal + + '` to apply this automatically.', + ) + } +} + +async function maybeApplyProfile( + profile: ProviderProfile, + model: string, + goal: RecommendationGoal, + baseUrl: string | null, +): Promise { + const env: ProfileFile['env'] = {} + if (profile === 'ollama') { + env.OPENAI_BASE_URL = getOllamaChatBaseUrl(baseUrl ?? undefined) + env.OPENAI_MODEL = model + const key = sanitizeApiKey(process.env.OPENAI_API_KEY) + if (key) env.OPENAI_API_KEY = key + } else { + const key = sanitizeApiKey(process.env.OPENAI_API_KEY) + if (!key) { + console.error('Cannot apply an OpenAI profile without OPENAI_API_KEY.') + return false + } + env.OPENAI_BASE_URL = + process.env.OPENAI_BASE_URL || 'https://api.openai.com/v1' + env.OPENAI_MODEL = model || getGoalDefaultOpenAIModel(goal) + env.OPENAI_API_KEY = key + } + + const profileFile: ProfileFile = { + profile, + env, + createdAt: new Date().toISOString(), + } + + writeFileSync( + resolve(process.cwd(), '.openclaude-profile.json'), + JSON.stringify(profileFile, null, 2), + 'utf8', + ) + return true +} + +async function main(): Promise { + const options = parseOptions(process.argv.slice(2)) + const ollamaAvailable = + options.provider !== 'openai' && + (await hasLocalOllama(options.baseUrl ?? undefined)) + const ollamaModels = ollamaAvailable + ? await listOllamaModels(options.baseUrl ?? undefined) + : [] + + const heuristicRanked = rankOllamaModels(ollamaModels, options.goal) + const benchmarkInput = options.benchmark ? heuristicRanked.slice(0, 3) : [] + + const benchmarkResults: Record = {} + for (const model of benchmarkInput) { + benchmarkResults[model.name] = await benchmarkOllamaModel( + model.name, + options.baseUrl ?? undefined, + ) + } + + const rankedModels: BenchmarkedOllamaModel[] = options.benchmark + ? applyBenchmarkLatency(heuristicRanked, benchmarkResults, options.goal) + : heuristicRanked.map(model => ({ + ...model, + benchmarkMs: null, + })) + + const recommendedOllama = rankedModels[0] ?? null + const openAIConfigured = Boolean(sanitizeApiKey(process.env.OPENAI_API_KEY)) + + let recommendedProfile: ProviderProfile + let recommendedModel: string + + if (options.provider === 'openai') { + recommendedProfile = 'openai' + recommendedModel = getGoalDefaultOpenAIModel(options.goal) + } else if (options.provider === 'ollama') { + if (!recommendedOllama) { + console.error( + 'No Ollama models were discovered. Pull a model first or switch to --provider openai.', + ) + process.exit(1) + } + recommendedProfile = 'ollama' + recommendedModel = recommendedOllama.name + } else if (recommendedOllama) { + recommendedProfile = 'ollama' + recommendedModel = recommendedOllama.name + } else { + recommendedProfile = 'openai' + recommendedModel = getGoalDefaultOpenAIModel(options.goal) + } + + let applied = false + if (options.apply) { + applied = await maybeApplyProfile( + recommendedProfile, + recommendedModel, + options.goal, + options.baseUrl, + ) + if (!applied) { + process.exit(1) + } + } + + const payload = { + goal: options.goal, + provider: options.provider, + ollamaAvailable, + openAIConfigured, + recommendedProfile, + recommendedModel, + benchmarked: options.benchmark, + rankedModels, + applied, + } + + if (options.json) { + console.log(JSON.stringify(payload, null, 2)) + return + } + + printHumanSummary({ + goal: options.goal, + recommendedProfile, + recommendedModel, + rankedModels, + benchmarked: options.benchmark, + applied, + }) + + if (!recommendedOllama && !openAIConfigured) { + console.log( + '\nNo local Ollama model was detected and OPENAI_API_KEY is unset.', + ) + console.log( + 'Next steps: `ollama pull qwen2.5-coder:7b` or set OPENAI_API_KEY.', + ) + } +} + +await main() + +export {} diff --git a/src/utils/providerRecommendation.test.ts b/src/utils/providerRecommendation.test.ts new file mode 100644 index 00000000..986e403f --- /dev/null +++ b/src/utils/providerRecommendation.test.ts @@ -0,0 +1,118 @@ +import assert from 'node:assert/strict' +import test from 'node:test' + +import { + applyBenchmarkLatency, + getGoalDefaultOpenAIModel, + normalizeRecommendationGoal, + rankOllamaModels, + recommendOllamaModel, + type OllamaModelDescriptor, +} from './providerRecommendation.ts' + +function model( + name: string, + overrides: Partial = {}, +): OllamaModelDescriptor { + return { + name, + sizeBytes: null, + family: null, + families: [], + parameterSize: null, + quantizationLevel: null, + ...overrides, + } +} + +test('normalizes recommendation goals safely', () => { + assert.equal(normalizeRecommendationGoal('coding'), 'coding') + assert.equal(normalizeRecommendationGoal(' LATENCY '), 'latency') + assert.equal(normalizeRecommendationGoal('weird'), 'balanced') + assert.equal(normalizeRecommendationGoal(undefined), 'balanced') +}) + +test('coding goal prefers coding-oriented ollama models', () => { + const recommended = recommendOllamaModel( + [ + model('llama3.1:8b', { + parameterSize: '8B', + quantizationLevel: 'Q4_K_M', + }), + model('qwen2.5-coder:7b', { + parameterSize: '7B', + quantizationLevel: 'Q4_K_M', + }), + ], + 'coding', + ) + + assert.equal(recommended?.name, 'qwen2.5-coder:7b') +}) + +test('latency goal prefers smaller models', () => { + const recommended = recommendOllamaModel( + [ + model('llama3.1:70b', { + parameterSize: '70B', + quantizationLevel: 'Q4_K_M', + }), + model('llama3.2:3b', { + parameterSize: '3B', + quantizationLevel: 'Q4_K_M', + }), + ], + 'latency', + ) + + assert.equal(recommended?.name, 'llama3.2:3b') +}) + +test('non-chat embedding models are heavily demoted', () => { + const ranked = rankOllamaModels( + [ + model('nomic-embed-text', { parameterSize: '0.5B' }), + model('mistral:7b-instruct', { + parameterSize: '7B', + quantizationLevel: 'Q4_K_M', + }), + ], + 'balanced', + ) + + assert.equal(ranked[0]?.name, 'mistral:7b-instruct') +}) + +test('benchmark latency can reorder close recommendations', () => { + const ranked = rankOllamaModels( + [ + model('llama3.1:8b', { + parameterSize: '8B', + quantizationLevel: 'Q4_K_M', + }), + model('mistral:7b-instruct', { + parameterSize: '7B', + quantizationLevel: 'Q4_K_M', + }), + ], + 'latency', + ) + + const benchmarked = applyBenchmarkLatency( + ranked, + { + 'llama3.1:8b': 2000, + 'mistral:7b-instruct': 350, + }, + 'latency', + ) + + assert.equal(benchmarked[0]?.name, 'mistral:7b-instruct') + assert.equal(benchmarked[0]?.benchmarkMs, 350) +}) + +test('goal defaults choose sensible openai models', () => { + assert.equal(getGoalDefaultOpenAIModel('latency'), 'gpt-4o-mini') + assert.equal(getGoalDefaultOpenAIModel('balanced'), 'gpt-4o') + assert.equal(getGoalDefaultOpenAIModel('coding'), 'gpt-4o') +}) diff --git a/src/utils/providerRecommendation.ts b/src/utils/providerRecommendation.ts new file mode 100644 index 00000000..e49c37aa --- /dev/null +++ b/src/utils/providerRecommendation.ts @@ -0,0 +1,297 @@ +export type RecommendationGoal = 'latency' | 'balanced' | 'coding' + +export type OllamaModelDescriptor = { + name: string + sizeBytes?: number | null + family?: string | null + families?: string[] + parameterSize?: string | null + quantizationLevel?: string | null +} + +export type RankedOllamaModel = OllamaModelDescriptor & { + score: number + reasons: string[] + summary: string +} + +export type BenchmarkedOllamaModel = RankedOllamaModel & { + benchmarkMs: number | null +} + +const CODING_HINTS = [ + 'coder', + 'codellama', + 'codegemma', + 'starcoder', + 'deepseek-coder', + 'qwen2.5-coder', + 'qwen-coder', +] + +const GENERAL_HINTS = [ + 'llama', + 'qwen', + 'mistral', + 'gemma', + 'phi', + 'deepseek', +] + +const INSTRUCT_HINTS = ['instruct', 'chat', 'assistant'] +const NON_CHAT_HINTS = ['embed', 'embedding', 'rerank', 'bge', 'whisper'] + +function modelHaystack(model: OllamaModelDescriptor): string { + return [ + model.name, + model.family ?? '', + ...(model.families ?? []), + model.parameterSize ?? '', + model.quantizationLevel ?? '', + ] + .join(' ') + .toLowerCase() +} + +function includesAny(text: string, needles: string[]): boolean { + return needles.some(needle => text.includes(needle)) +} + +function inferParameterBillions(model: OllamaModelDescriptor): number | null { + const text = `${model.parameterSize ?? ''} ${model.name}`.toLowerCase() + const match = text.match(/(\d+(?:\.\d+)?)\s*b\b/) + if (match?.[1]) { + return Number(match[1]) + } + if (typeof model.sizeBytes === 'number' && model.sizeBytes > 0) { + return Number((model.sizeBytes / 1_000_000_000).toFixed(1)) + } + return null +} + +function quantizationBucket(model: OllamaModelDescriptor): string { + return (model.quantizationLevel ?? model.name).toLowerCase() +} + +function scoreSizeTier( + paramsB: number | null, + goal: RecommendationGoal, + reasons: string[], +): number { + if (paramsB === null) { + reasons.push('unknown size') + return 0 + } + + if (goal === 'latency') { + if (paramsB <= 4) { + reasons.push('tiny model for low latency') + return 32 + } + if (paramsB <= 8) { + reasons.push('small model for fast responses') + return 26 + } + if (paramsB <= 14) { + reasons.push('mid-sized model with acceptable latency') + return 16 + } + if (paramsB <= 24) { + reasons.push('larger model may be slower') + return 8 + } + reasons.push('large model likely slower locally') + return paramsB <= 40 ? 0 : -8 + } + + if (goal === 'coding') { + if (paramsB >= 7 && paramsB <= 14) { + reasons.push('strong coding size tier') + return 24 + } + if (paramsB > 14 && paramsB <= 34) { + reasons.push('large coding-capable size tier') + return 28 + } + if (paramsB > 34) { + reasons.push('very large model with higher quality potential') + return 18 + } + reasons.push('compact model may trade off coding depth') + return 12 + } + + if (paramsB >= 7 && paramsB <= 14) { + reasons.push('great balanced size tier') + return 26 + } + if (paramsB >= 3 && paramsB < 7) { + reasons.push('compact balanced size tier') + return 18 + } + if (paramsB > 14 && paramsB <= 24) { + reasons.push('high quality balanced size tier') + return 20 + } + if (paramsB > 24) { + reasons.push('large model for quality-first usage') + return 10 + } + reasons.push('very small model for general usage') + return 8 +} + +function scoreQuantization( + model: OllamaModelDescriptor, + goal: RecommendationGoal, + reasons: string[], +): number { + const quant = quantizationBucket(model) + if (quant.includes('q4')) { + reasons.push('efficient Q4 quantization') + return goal === 'latency' ? 8 : 4 + } + if (quant.includes('q5')) { + reasons.push('balanced Q5 quantization') + return goal === 'latency' ? 6 : 5 + } + if (quant.includes('q8')) { + reasons.push('higher quality Q8 quantization') + return goal === 'latency' ? 2 : 5 + } + return 0 +} + +function compareRankedModels( + a: RankedOllamaModel | BenchmarkedOllamaModel, + b: RankedOllamaModel | BenchmarkedOllamaModel, + goal: RecommendationGoal, +): number { + if (b.score !== a.score) { + return b.score - a.score + } + + const aSize = inferParameterBillions(a) ?? Number.POSITIVE_INFINITY + const bSize = inferParameterBillions(b) ?? Number.POSITIVE_INFINITY + + if (goal === 'latency') { + return aSize - bSize + } + + if (goal === 'coding') { + return bSize - aSize + } + + const target = 14 + return Math.abs(aSize - target) - Math.abs(bSize - target) +} + +export function normalizeRecommendationGoal( + goal: string | null | undefined, +): RecommendationGoal { + const normalized = goal?.trim().toLowerCase() + if ( + normalized === 'latency' || + normalized === 'balanced' || + normalized === 'coding' + ) { + return normalized + } + return 'balanced' +} + +export function getGoalDefaultOpenAIModel(goal: RecommendationGoal): string { + switch (goal) { + case 'latency': + return 'gpt-4o-mini' + case 'coding': + return 'gpt-4o' + case 'balanced': + default: + return 'gpt-4o' + } +} + +export function rankOllamaModels( + models: OllamaModelDescriptor[], + goal: RecommendationGoal, +): RankedOllamaModel[] { + return models + .map(model => { + const haystack = modelHaystack(model) + const reasons: string[] = [] + let score = 0 + + if (includesAny(haystack, NON_CHAT_HINTS)) { + score -= 40 + reasons.push('not a chat-first model') + } + + if (includesAny(haystack, CODING_HINTS)) { + score += goal === 'coding' ? 24 : goal === 'balanced' ? 10 : 4 + reasons.push('coding-oriented model family') + } + + if (includesAny(haystack, GENERAL_HINTS)) { + score += goal === 'latency' ? 4 : goal === 'coding' ? 6 : 8 + reasons.push('strong general-purpose model family') + } + + if (includesAny(haystack, INSTRUCT_HINTS)) { + score += goal === 'latency' ? 2 : 6 + reasons.push('chat/instruct tuned') + } + + if (haystack.includes('vision') || haystack.includes('vl')) { + score -= 2 + reasons.push('vision model adds extra overhead') + } + + score += scoreSizeTier(inferParameterBillions(model), goal, reasons) + score += scoreQuantization(model, goal, reasons) + + const summary = reasons.slice(0, 3).join(', ') + return { + ...model, + score, + reasons, + summary, + } + }) + .sort((a, b) => compareRankedModels(a, b, goal)) +} + +export function recommendOllamaModel( + models: OllamaModelDescriptor[], + goal: RecommendationGoal, +): RankedOllamaModel | null { + return rankOllamaModels(models, goal)[0] ?? null +} + +export function applyBenchmarkLatency( + models: RankedOllamaModel[], + benchmarkMs: Record, + goal: RecommendationGoal, +): BenchmarkedOllamaModel[] { + const divisor = + goal === 'latency' ? 120 : goal === 'coding' ? 500 : 240 + + return models + .map(model => { + const latency = benchmarkMs[model.name] ?? null + const benchmarkPenalty = latency === null ? 0 : latency / divisor + const reasons = + latency === null + ? model.reasons + : [`benchmarked at ${Math.round(latency)}ms`, ...model.reasons] + + return { + ...model, + benchmarkMs: latency, + reasons, + summary: reasons.slice(0, 3).join(', '), + score: Number((model.score - benchmarkPenalty).toFixed(2)), + } + }) + .sort((a, b) => compareRankedModels(a, b, goal)) +}