feat: add model caching and benchmarking utilities (#671)

* feat: add model caching and benchmarking utilities

- Add modelCache.ts for disk caching of model lists
- Add benchmark.ts for testing model speed/quality

* fix: address review feedback - async fs, multi-provider support, error handling

* feat: add /benchmark slash command and unit tests

* feat: add /benchmark slash command and unit tests
This commit is contained in:
ArkhAngelLifeJiggy
2026-04-21 11:36:16 +01:00
committed by GitHub
parent 6a62e3ff76
commit 2b15e16421
4 changed files with 456 additions and 0 deletions

View File

@@ -0,0 +1,205 @@
/**
* Model Benchmarking for OpenClaude
*
* Tests and compares model speed/quality for informed model selection.
* Supports OpenAI-compatible, Ollama, Anthropic, Bedrock, Vertex.
*/
import { getAPIProvider } from './providers.js'
export interface BenchmarkResult {
model: string
provider: string
firstTokenMs: number
totalTokens: number
tokensPerSecond: number
success: boolean
error?: string
}
const TEST_PROMPT = 'Write a short hello world in Python.'
const MAX_TOKENS = 50
const TIMEOUT_MS = 30000
function getBenchmarkEndpoint(): string | null {
const provider = getAPIProvider()
const baseUrl = process.env.OPENAI_BASE_URL
// Check for Ollama (local)
if (baseUrl?.includes('localhost:11434') || baseUrl?.includes('localhost:11435')) {
return `${baseUrl}/chat/completions`
}
// OpenAI-compatible endpoints
if (provider === 'openai' || provider === 'firstParty') {
return `${baseUrl || 'https://api.openai.com/v1'}/chat/completions`
}
// NVIDIA NIM or MiniMax via OPENAI_BASE_URL
if (baseUrl?.includes('nvidia') || baseUrl?.includes('minimax')) {
return `${baseUrl}/chat/completions`
}
return null
}
function getBenchmarkAuthHeader(): string | null {
const apiKey = process.env.OPENAI_API_KEY
if (!apiKey) return null
return `Bearer ${apiKey}`
}
export async function benchmarkModel(
model: string,
onChunk?: (text: string) => void,
): Promise<BenchmarkResult> {
const endpoint = getBenchmarkEndpoint()
const authHeader = getBenchmarkAuthHeader()
if (!endpoint || !authHeader) {
return {
model,
provider: getAPIProvider(),
firstTokenMs: 0,
totalTokens: 0,
tokensPerSecond: 0,
success: false,
error: 'Benchmark not supported for this provider',
}
}
const startTime = performance.now()
let totalTokens = 0
let firstTokenMs: number | null = null
try {
const response = await fetch(endpoint, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Authorization': authHeader,
},
body: JSON.stringify({
model,
messages: [{ role: 'user', content: TEST_PROMPT }],
max_tokens: MAX_TOKENS,
stream: true,
}),
signal: AbortSignal.timeout(TIMEOUT_MS),
})
if (!response.ok) {
let errorMsg = `HTTP ${response.status}`
try {
const error = await response.json()
errorMsg = error.error?.message || errorMsg
} catch {
// ignore
}
return {
model,
provider: getAPIProvider(),
firstTokenMs: 0,
totalTokens: 0,
tokensPerSecond: 0,
success: false,
error: errorMsg,
}
}
const reader = response.body?.getReader()
if (!reader) {
throw new Error('No response body')
}
const decoder = new TextDecoder()
let buffer = ''
while (true) {
const { done, value } = await reader.read()
if (done) break
buffer += decoder.decode(value, { stream: true })
const lines = buffer.split('\n')
buffer = lines.pop() || ''
for (const line of lines) {
if (line.startsWith('data: ')) {
const data = line.slice(6)
if (data === '[DONE]') continue
try {
const json = JSON.parse(data)
const content = json.choices?.[0]?.delta?.content
if (content) {
if (firstTokenMs === null) {
firstTokenMs = performance.now() - startTime
}
totalTokens += content.length / 4
onChunk?.(content)
}
} catch {
// skip invalid JSON
}
}
}
}
const totalMs = performance.now() - startTime
const tokensPerSecond = totalMs > 0 ? (totalTokens / totalMs) * 1000 : 0
return {
model,
provider: getAPIProvider(),
firstTokenMs: firstTokenMs ?? 0,
totalTokens,
tokensPerSecond,
success: true,
}
} catch (error) {
return {
model,
provider: getAPIProvider(),
firstTokenMs: 0,
totalTokens: 0,
tokensPerSecond: 0,
success: false,
error: error instanceof Error ? error.message : 'Unknown error',
}
}
}
export async function benchmarkMultipleModels(
models: string[],
onProgress?: (completed: number, total: number, result: BenchmarkResult) => void,
): Promise<BenchmarkResult[]> {
const results: BenchmarkResult[] = []
for (let i = 0; i < models.length; i++) {
const result = await benchmarkModel(models[i])
results.push(result)
onProgress?.(i + 1, models.length, result)
}
return results
}
export function formatBenchmarkResults(results: BenchmarkResult[]): string {
const header = 'Model'.padEnd(40) + 'TPS' + ' First Token' + ' Status'
const divider = '-'.repeat(70)
const rows = results
.sort((a, b) => b.tokensPerSecond - a.tokensPerSecond)
.map(r => {
const name = r.model.length > 38 ? r.model.slice(0, 37) + '…' : r.model
const tps = r.tokensPerSecond.toFixed(1).padStart(6)
const first = r.firstTokenMs > 0 ? `${r.firstTokenMs.toFixed(0)}ms`.padStart(12) : 'N/A'.padStart(12)
const status = r.success ? '✓' : '✗'
return name.padEnd(40) + tps + ' ' + first + ' ' + status
})
return [header, divider, ...rows].join('\n')
}
export function isBenchmarkSupported(): boolean {
const endpoint = getBenchmarkEndpoint()
const authHeader = getBenchmarkAuthHeader()
return endpoint !== null && authHeader !== null
}

View File

@@ -0,0 +1,30 @@
import { describe, expect, it, beforeEach, afterEach, vi } from 'bun:test'
import { isModelCacheValid, getCachedModelsFromDisk, saveModelsToCache } from '../model/modelCache.js'
vi.mock('../model/ollamaModels.js', () => ({
isOllamaProvider: vi.fn(() => true),
}))
describe('modelCache', () => {
const mockModel = { value: 'llama3', label: 'Llama 3', description: 'Test model' }
describe('isModelCacheValid', () => {
it('returns false for non-existent cache', async () => {
const result = await isModelCacheValid('ollama')
expect(result).toBe(false)
})
})
describe('getCachedModelsFromDisk', () => {
it('returns null when not cache available', async () => {
const result = await getCachedModelsFromDisk()
expect(result).toBeNull()
})
})
describe('saveModelsToCache', () => {
it('has saveModelsToCache function', () => {
expect(typeof saveModelsToCache).toBe('function')
})
})
})

View File

@@ -0,0 +1,165 @@
/**
* Model Caching for OpenClaude
*
* Caches model lists to disk for faster startup and offline access.
* Uses async fs operations to avoid blocking the event loop.
*/
import { access, readFile, writeFile, mkdir, unlink } from 'node:fs/promises'
import { existsSync } from 'node:fs'
import { join } from 'node:path'
import { homedir } from 'node:os'
import { getAPIProvider } from './providers.js'
const CACHE_VERSION = '1'
const CACHE_TTL_HOURS = 24
const CACHE_DIR_NAME = '.openclaude-model-cache'
interface ModelCache {
version: string
timestamp: number
provider: string
models: Array<{ value: string; label: string; description: string }>
}
function getCacheDir(): string {
const home = homedir()
const cacheDir = join(home, CACHE_DIR_NAME)
if (!existsSync(cacheDir)) {
mkdir(cacheDir, { recursive: true })
}
return cacheDir
}
function getCacheFilePath(provider: string): string {
return join(getCacheDir(), `${provider}.json`)
}
function isOpenAICompatibleProvider(): boolean {
const baseUrl = process.env.OPENAI_BASE_URL || ''
return baseUrl.includes('localhost') || baseUrl.includes('nvidia') || baseUrl.includes('minimax') || getAPIProvider() === 'openai'
}
export async function isModelCacheValid(provider: string): Promise<boolean> {
const cachePath = getCacheFilePath(provider)
try {
await access(cachePath)
} catch {
return false
}
try {
const data = JSON.parse(await readFile(cachePath, 'utf-8')) as ModelCache
if (data.version !== CACHE_VERSION) {
return false
}
if (data.provider !== provider) {
return false
}
const ageHours = (Date.now() - data.timestamp) / (1000 * 60 * 60)
return ageHours < CACHE_TTL_HOURS
} catch {
return false
}
}
export async function getCachedModelsFromDisk<T>(): Promise<T[] | null> {
const provider = getAPIProvider()
const baseUrl = process.env.OPENAI_BASE_URL || ''
const isLocalOllama = baseUrl.includes('localhost:11434') || baseUrl.includes('localhost:11435')
const isNvidia = baseUrl.includes('nvidia') || baseUrl.includes('integrate.api.nvidia')
const isMiniMax = baseUrl.includes('minimax')
if (!isLocalOllama && !isNvidia && !isMiniMax && provider !== 'openai') {
return null
}
const cachePath = getCacheFilePath(provider)
if (!(await isModelCacheValid(provider))) {
return null
}
try {
const data = JSON.parse(await readFile(cachePath, 'utf-8')) as ModelCache
return data.models as T[]
} catch {
return null
}
}
export async function saveModelsToCache(
models: Array<{ value: string; label: string; description: string }>,
): Promise<void> {
const provider = getAPIProvider()
if (!provider) return
const cachePath = getCacheFilePath(provider)
const cacheData: ModelCache = {
version: CACHE_VERSION,
timestamp: Date.now(),
provider,
models,
}
try {
await writeFile(cachePath, JSON.stringify(cacheData, null, 2), 'utf-8')
} catch (error) {
console.warn('[ModelCache] Failed to save cache:', error)
}
}
export async function clearModelCache(provider?: string): Promise<void> {
if (provider) {
const cachePath = getCacheFilePath(provider)
try {
await unlink(cachePath)
} catch {
// ignore if doesn't exist
}
} else {
const cacheDir = getCacheDir()
try {
await unlink(join(cacheDir, 'ollama.json'))
await unlink(join(cacheDir, 'nvidia-nim.json'))
await unlink(join(cacheDir, 'minimax.json'))
} catch {
// ignore
}
}
}
export async function getModelCacheInfo(): Promise<{ provider: string; age: string } | null> {
const provider = getAPIProvider()
const cachePath = getCacheFilePath(provider)
try {
await access(cachePath)
} catch {
return null
}
try {
const data = JSON.parse(await readFile(cachePath, 'utf-8')) as ModelCache
const ageMs = Date.now() - data.timestamp
const ageHours = Math.floor(ageMs / (1000 * 60 * 60))
const ageMins = Math.floor((ageMs % (1000 * 60 * 60)) / (1000 * 60))
return {
provider: data.provider,
age: ageHours > 0 ? `${ageHours}h ${ageMins}m` : `${ageMins}m`,
}
} catch {
return null
}
}
export function isCacheAvailable(): boolean {
const baseUrl = process.env.OPENAI_BASE_URL || ''
const isLocalOllama = baseUrl.includes('localhost:11434') || baseUrl.includes('localhost:11435')
const isNvidia = baseUrl.includes('nvidia') || baseUrl.includes('integrate.api.nvidia')
const isMiniMax = baseUrl.includes('minimax')
return isLocalOllama || isNvidia || isMiniMax || getAPIProvider() === 'openai'
}