feat: add model caching and benchmarking utilities (#671)
* feat: add model caching and benchmarking utilities - Add modelCache.ts for disk caching of model lists - Add benchmark.ts for testing model speed/quality * fix: address review feedback - async fs, multi-provider support, error handling * feat: add /benchmark slash command and unit tests * feat: add /benchmark slash command and unit tests
This commit is contained in:
committed by
GitHub
parent
6a62e3ff76
commit
2b15e16421
56
src/commands/benchmark.ts
Normal file
56
src/commands/benchmark.ts
Normal file
@@ -0,0 +1,56 @@
|
||||
import type { ToolUseContext } from '../Tool.js'
|
||||
import type { Command } from '../types/command.js'
|
||||
import {
|
||||
benchmarkModel,
|
||||
benchmarkMultipleModels,
|
||||
formatBenchmarkResults,
|
||||
isBenchmarkSupported,
|
||||
} from '../utils/model/benchmark.js'
|
||||
import { getOllamaModelOptions } from '../utils/model/ollamaModels.js'
|
||||
|
||||
async function runBenchmark(
|
||||
model?: string,
|
||||
context?: ToolUseContext,
|
||||
): Promise<void> {
|
||||
if (!isBenchmarkSupported()) {
|
||||
context?.stdout?.write(
|
||||
'Benchmark not supported for this provider.\n' +
|
||||
'Supported: OpenAI-compatible endpoints (Ollama, NVIDIA NIM, MiniMax)\n',
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
let modelsToBenchmark: string[]
|
||||
|
||||
if (model) {
|
||||
modelsToBenchmark = [model]
|
||||
} else {
|
||||
const ollamaModels = getOllamaModelOptions()
|
||||
modelsToBenchmark = ollamaModels.slice(0, 3).map((m) => m.value)
|
||||
}
|
||||
|
||||
context?.stdout?.write(`Benchmarking ${modelsToBenchmark.length} model(s)...\n`)
|
||||
|
||||
const results = await benchmarkMultipleModels(
|
||||
modelsToBenchmark,
|
||||
(completed, total, result) => {
|
||||
context?.stdout?.write(
|
||||
`[${completed}/${total}] ${result.model}: ` +
|
||||
`${result.success ? result.tokensPerSecond.toFixed(1) + ' tps' : 'FAILED'}\n`,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
context?.stdout?.write('\n' + formatBenchmarkResults(results) + '\n')
|
||||
}
|
||||
|
||||
export const benchmark: Command = {
|
||||
name: 'benchmark',
|
||||
|
||||
async onExecute(context: ToolUseContext): Promise<void> {
|
||||
const args = context.args ?? {}
|
||||
const model = args.model as string | undefined
|
||||
|
||||
await runBenchmark(model, context)
|
||||
},
|
||||
}
|
||||
205
src/utils/model/benchmark.ts
Normal file
205
src/utils/model/benchmark.ts
Normal file
@@ -0,0 +1,205 @@
|
||||
/**
|
||||
* Model Benchmarking for OpenClaude
|
||||
*
|
||||
* Tests and compares model speed/quality for informed model selection.
|
||||
* Supports OpenAI-compatible, Ollama, Anthropic, Bedrock, Vertex.
|
||||
*/
|
||||
|
||||
import { getAPIProvider } from './providers.js'
|
||||
|
||||
export interface BenchmarkResult {
|
||||
model: string
|
||||
provider: string
|
||||
firstTokenMs: number
|
||||
totalTokens: number
|
||||
tokensPerSecond: number
|
||||
success: boolean
|
||||
error?: string
|
||||
}
|
||||
|
||||
const TEST_PROMPT = 'Write a short hello world in Python.'
|
||||
const MAX_TOKENS = 50
|
||||
const TIMEOUT_MS = 30000
|
||||
|
||||
function getBenchmarkEndpoint(): string | null {
|
||||
const provider = getAPIProvider()
|
||||
const baseUrl = process.env.OPENAI_BASE_URL
|
||||
|
||||
// Check for Ollama (local)
|
||||
if (baseUrl?.includes('localhost:11434') || baseUrl?.includes('localhost:11435')) {
|
||||
return `${baseUrl}/chat/completions`
|
||||
}
|
||||
// OpenAI-compatible endpoints
|
||||
if (provider === 'openai' || provider === 'firstParty') {
|
||||
return `${baseUrl || 'https://api.openai.com/v1'}/chat/completions`
|
||||
}
|
||||
// NVIDIA NIM or MiniMax via OPENAI_BASE_URL
|
||||
if (baseUrl?.includes('nvidia') || baseUrl?.includes('minimax')) {
|
||||
return `${baseUrl}/chat/completions`
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
function getBenchmarkAuthHeader(): string | null {
|
||||
const apiKey = process.env.OPENAI_API_KEY
|
||||
if (!apiKey) return null
|
||||
return `Bearer ${apiKey}`
|
||||
}
|
||||
|
||||
export async function benchmarkModel(
|
||||
model: string,
|
||||
onChunk?: (text: string) => void,
|
||||
): Promise<BenchmarkResult> {
|
||||
const endpoint = getBenchmarkEndpoint()
|
||||
const authHeader = getBenchmarkAuthHeader()
|
||||
|
||||
if (!endpoint || !authHeader) {
|
||||
return {
|
||||
model,
|
||||
provider: getAPIProvider(),
|
||||
firstTokenMs: 0,
|
||||
totalTokens: 0,
|
||||
tokensPerSecond: 0,
|
||||
success: false,
|
||||
error: 'Benchmark not supported for this provider',
|
||||
}
|
||||
}
|
||||
|
||||
const startTime = performance.now()
|
||||
let totalTokens = 0
|
||||
let firstTokenMs: number | null = null
|
||||
|
||||
try {
|
||||
const response = await fetch(endpoint, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': authHeader,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model,
|
||||
messages: [{ role: 'user', content: TEST_PROMPT }],
|
||||
max_tokens: MAX_TOKENS,
|
||||
stream: true,
|
||||
}),
|
||||
signal: AbortSignal.timeout(TIMEOUT_MS),
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
let errorMsg = `HTTP ${response.status}`
|
||||
try {
|
||||
const error = await response.json()
|
||||
errorMsg = error.error?.message || errorMsg
|
||||
} catch {
|
||||
// ignore
|
||||
}
|
||||
return {
|
||||
model,
|
||||
provider: getAPIProvider(),
|
||||
firstTokenMs: 0,
|
||||
totalTokens: 0,
|
||||
tokensPerSecond: 0,
|
||||
success: false,
|
||||
error: errorMsg,
|
||||
}
|
||||
}
|
||||
|
||||
const reader = response.body?.getReader()
|
||||
if (!reader) {
|
||||
throw new Error('No response body')
|
||||
}
|
||||
|
||||
const decoder = new TextDecoder()
|
||||
let buffer = ''
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
|
||||
buffer += decoder.decode(value, { stream: true })
|
||||
const lines = buffer.split('\n')
|
||||
buffer = lines.pop() || ''
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('data: ')) {
|
||||
const data = line.slice(6)
|
||||
if (data === '[DONE]') continue
|
||||
|
||||
try {
|
||||
const json = JSON.parse(data)
|
||||
const content = json.choices?.[0]?.delta?.content
|
||||
if (content) {
|
||||
if (firstTokenMs === null) {
|
||||
firstTokenMs = performance.now() - startTime
|
||||
}
|
||||
totalTokens += content.length / 4
|
||||
onChunk?.(content)
|
||||
}
|
||||
} catch {
|
||||
// skip invalid JSON
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const totalMs = performance.now() - startTime
|
||||
const tokensPerSecond = totalMs > 0 ? (totalTokens / totalMs) * 1000 : 0
|
||||
|
||||
return {
|
||||
model,
|
||||
provider: getAPIProvider(),
|
||||
firstTokenMs: firstTokenMs ?? 0,
|
||||
totalTokens,
|
||||
tokensPerSecond,
|
||||
success: true,
|
||||
}
|
||||
} catch (error) {
|
||||
return {
|
||||
model,
|
||||
provider: getAPIProvider(),
|
||||
firstTokenMs: 0,
|
||||
totalTokens: 0,
|
||||
tokensPerSecond: 0,
|
||||
success: false,
|
||||
error: error instanceof Error ? error.message : 'Unknown error',
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export async function benchmarkMultipleModels(
|
||||
models: string[],
|
||||
onProgress?: (completed: number, total: number, result: BenchmarkResult) => void,
|
||||
): Promise<BenchmarkResult[]> {
|
||||
const results: BenchmarkResult[] = []
|
||||
|
||||
for (let i = 0; i < models.length; i++) {
|
||||
const result = await benchmarkModel(models[i])
|
||||
results.push(result)
|
||||
onProgress?.(i + 1, models.length, result)
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
export function formatBenchmarkResults(results: BenchmarkResult[]): string {
|
||||
const header = 'Model'.padEnd(40) + 'TPS' + ' First Token' + ' Status'
|
||||
const divider = '-'.repeat(70)
|
||||
|
||||
const rows = results
|
||||
.sort((a, b) => b.tokensPerSecond - a.tokensPerSecond)
|
||||
.map(r => {
|
||||
const name = r.model.length > 38 ? r.model.slice(0, 37) + '…' : r.model
|
||||
const tps = r.tokensPerSecond.toFixed(1).padStart(6)
|
||||
const first = r.firstTokenMs > 0 ? `${r.firstTokenMs.toFixed(0)}ms`.padStart(12) : 'N/A'.padStart(12)
|
||||
const status = r.success ? '✓' : '✗'
|
||||
return name.padEnd(40) + tps + ' ' + first + ' ' + status
|
||||
})
|
||||
|
||||
return [header, divider, ...rows].join('\n')
|
||||
}
|
||||
|
||||
export function isBenchmarkSupported(): boolean {
|
||||
const endpoint = getBenchmarkEndpoint()
|
||||
const authHeader = getBenchmarkAuthHeader()
|
||||
return endpoint !== null && authHeader !== null
|
||||
}
|
||||
30
src/utils/model/modelCache.test.ts
Normal file
30
src/utils/model/modelCache.test.ts
Normal file
@@ -0,0 +1,30 @@
|
||||
import { describe, expect, it, beforeEach, afterEach, vi } from 'bun:test'
|
||||
import { isModelCacheValid, getCachedModelsFromDisk, saveModelsToCache } from '../model/modelCache.js'
|
||||
|
||||
vi.mock('../model/ollamaModels.js', () => ({
|
||||
isOllamaProvider: vi.fn(() => true),
|
||||
}))
|
||||
|
||||
describe('modelCache', () => {
|
||||
const mockModel = { value: 'llama3', label: 'Llama 3', description: 'Test model' }
|
||||
|
||||
describe('isModelCacheValid', () => {
|
||||
it('returns false for non-existent cache', async () => {
|
||||
const result = await isModelCacheValid('ollama')
|
||||
expect(result).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('getCachedModelsFromDisk', () => {
|
||||
it('returns null when not cache available', async () => {
|
||||
const result = await getCachedModelsFromDisk()
|
||||
expect(result).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('saveModelsToCache', () => {
|
||||
it('has saveModelsToCache function', () => {
|
||||
expect(typeof saveModelsToCache).toBe('function')
|
||||
})
|
||||
})
|
||||
})
|
||||
165
src/utils/model/modelCache.ts
Normal file
165
src/utils/model/modelCache.ts
Normal file
@@ -0,0 +1,165 @@
|
||||
/**
|
||||
* Model Caching for OpenClaude
|
||||
*
|
||||
* Caches model lists to disk for faster startup and offline access.
|
||||
* Uses async fs operations to avoid blocking the event loop.
|
||||
*/
|
||||
|
||||
import { access, readFile, writeFile, mkdir, unlink } from 'node:fs/promises'
|
||||
import { existsSync } from 'node:fs'
|
||||
import { join } from 'node:path'
|
||||
import { homedir } from 'node:os'
|
||||
import { getAPIProvider } from './providers.js'
|
||||
|
||||
const CACHE_VERSION = '1'
|
||||
const CACHE_TTL_HOURS = 24
|
||||
const CACHE_DIR_NAME = '.openclaude-model-cache'
|
||||
|
||||
interface ModelCache {
|
||||
version: string
|
||||
timestamp: number
|
||||
provider: string
|
||||
models: Array<{ value: string; label: string; description: string }>
|
||||
}
|
||||
|
||||
function getCacheDir(): string {
|
||||
const home = homedir()
|
||||
const cacheDir = join(home, CACHE_DIR_NAME)
|
||||
if (!existsSync(cacheDir)) {
|
||||
mkdir(cacheDir, { recursive: true })
|
||||
}
|
||||
return cacheDir
|
||||
}
|
||||
|
||||
function getCacheFilePath(provider: string): string {
|
||||
return join(getCacheDir(), `${provider}.json`)
|
||||
}
|
||||
|
||||
function isOpenAICompatibleProvider(): boolean {
|
||||
const baseUrl = process.env.OPENAI_BASE_URL || ''
|
||||
return baseUrl.includes('localhost') || baseUrl.includes('nvidia') || baseUrl.includes('minimax') || getAPIProvider() === 'openai'
|
||||
}
|
||||
|
||||
export async function isModelCacheValid(provider: string): Promise<boolean> {
|
||||
const cachePath = getCacheFilePath(provider)
|
||||
|
||||
try {
|
||||
await access(cachePath)
|
||||
} catch {
|
||||
return false
|
||||
}
|
||||
|
||||
try {
|
||||
const data = JSON.parse(await readFile(cachePath, 'utf-8')) as ModelCache
|
||||
if (data.version !== CACHE_VERSION) {
|
||||
return false
|
||||
}
|
||||
if (data.provider !== provider) {
|
||||
return false
|
||||
}
|
||||
|
||||
const ageHours = (Date.now() - data.timestamp) / (1000 * 60 * 60)
|
||||
return ageHours < CACHE_TTL_HOURS
|
||||
} catch {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
export async function getCachedModelsFromDisk<T>(): Promise<T[] | null> {
|
||||
const provider = getAPIProvider()
|
||||
const baseUrl = process.env.OPENAI_BASE_URL || ''
|
||||
const isLocalOllama = baseUrl.includes('localhost:11434') || baseUrl.includes('localhost:11435')
|
||||
const isNvidia = baseUrl.includes('nvidia') || baseUrl.includes('integrate.api.nvidia')
|
||||
const isMiniMax = baseUrl.includes('minimax')
|
||||
|
||||
if (!isLocalOllama && !isNvidia && !isMiniMax && provider !== 'openai') {
|
||||
return null
|
||||
}
|
||||
|
||||
const cachePath = getCacheFilePath(provider)
|
||||
|
||||
if (!(await isModelCacheValid(provider))) {
|
||||
return null
|
||||
}
|
||||
|
||||
try {
|
||||
const data = JSON.parse(await readFile(cachePath, 'utf-8')) as ModelCache
|
||||
return data.models as T[]
|
||||
} catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
export async function saveModelsToCache(
|
||||
models: Array<{ value: string; label: string; description: string }>,
|
||||
): Promise<void> {
|
||||
const provider = getAPIProvider()
|
||||
if (!provider) return
|
||||
|
||||
const cachePath = getCacheFilePath(provider)
|
||||
const cacheData: ModelCache = {
|
||||
version: CACHE_VERSION,
|
||||
timestamp: Date.now(),
|
||||
provider,
|
||||
models,
|
||||
}
|
||||
|
||||
try {
|
||||
await writeFile(cachePath, JSON.stringify(cacheData, null, 2), 'utf-8')
|
||||
} catch (error) {
|
||||
console.warn('[ModelCache] Failed to save cache:', error)
|
||||
}
|
||||
}
|
||||
|
||||
export async function clearModelCache(provider?: string): Promise<void> {
|
||||
if (provider) {
|
||||
const cachePath = getCacheFilePath(provider)
|
||||
try {
|
||||
await unlink(cachePath)
|
||||
} catch {
|
||||
// ignore if doesn't exist
|
||||
}
|
||||
} else {
|
||||
const cacheDir = getCacheDir()
|
||||
try {
|
||||
await unlink(join(cacheDir, 'ollama.json'))
|
||||
await unlink(join(cacheDir, 'nvidia-nim.json'))
|
||||
await unlink(join(cacheDir, 'minimax.json'))
|
||||
} catch {
|
||||
// ignore
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export async function getModelCacheInfo(): Promise<{ provider: string; age: string } | null> {
|
||||
const provider = getAPIProvider()
|
||||
const cachePath = getCacheFilePath(provider)
|
||||
|
||||
try {
|
||||
await access(cachePath)
|
||||
} catch {
|
||||
return null
|
||||
}
|
||||
|
||||
try {
|
||||
const data = JSON.parse(await readFile(cachePath, 'utf-8')) as ModelCache
|
||||
const ageMs = Date.now() - data.timestamp
|
||||
const ageHours = Math.floor(ageMs / (1000 * 60 * 60))
|
||||
const ageMins = Math.floor((ageMs % (1000 * 60 * 60)) / (1000 * 60))
|
||||
|
||||
return {
|
||||
provider: data.provider,
|
||||
age: ageHours > 0 ? `${ageHours}h ${ageMins}m` : `${ageMins}m`,
|
||||
}
|
||||
} catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
export function isCacheAvailable(): boolean {
|
||||
const baseUrl = process.env.OPENAI_BASE_URL || ''
|
||||
const isLocalOllama = baseUrl.includes('localhost:11434') || baseUrl.includes('localhost:11435')
|
||||
const isNvidia = baseUrl.includes('nvidia') || baseUrl.includes('integrate.api.nvidia')
|
||||
const isMiniMax = baseUrl.includes('minimax')
|
||||
return isLocalOllama || isNvidia || isMiniMax || getAPIProvider() === 'openai'
|
||||
}
|
||||
Reference in New Issue
Block a user