feat: context preloading and hybrid context strategy (#860)
* feat: context preloading and hybrid context strategy PR 2D - Section 2.7, 2.8: - Add contextPreload.ts with pattern-based prediction - Add hybridContextStrategy.ts with cache/fresh balancing - Optimize for cost vs accuracy - Add comprehensive tests (13 passing) * feat: wire hybrid context strategy into API path - Apply hybrid strategy after normalizeMessagesForAPI - Feature-flag controlled (HYBRID_CONTEXT_STRATEGY) - Optimizes cache/fresh balance for API requests * fix: resolve PR 2D blocking issues - Fix predictContextNeeds self-assign bug (matchedCategory = category) - Add test for non-empty predictedNeed - Preserve conversation tail in hybridStrategy (never drop last 3 messages) - Add comment for hardcoded 200k cap in claude.ts Fixes reviewer feedback from gnanam1990 and Vasanthdev2004 * fix: preserve tool_use/tool_result chains in hybridStrategy - Increase MIN_TAIL to 5 (tool_use -> tool_result -> assistant -> user -> next) - Add getMessageChain() to preserve paired messages - Chains kept together in final selection * fix: PR 860 - tool_use/tool_result pairing and safe token counting Blocking: - getMessageChain() now pairs by tool_use.id (block ID) not msg.message.id - Find tool_use blocks by id, pair with tool_result having matching tool_use_id - Fixes tool_result surviving while paired tool_use dropped - Token counting now includes array content (tool_use, tool_result, thinking) - Not just string content, prevents undercounting prompt size - Deduplicate messages by UUID when combining chains + split + tail - Prevents duplicate messages in final request Non-blocking: - Add regression test for tool_use/tool_result pairing * fix: PR 860 - account for actual structured payload size in token counting Blocking: - getMessageTokenCount now calculates actual token count for structured blocks - tool_use: uses JSON.stringify(input).length / 4 + base - tool_result: counts actual content (string or array of text blocks) - thinking: counts actual thinking text length / 4 - is_error flag adds small overhead Non-blocking: - Add tests for large tool_use input and large thinking blocks
This commit is contained in:
committed by
GitHub
parent
91f93ce615
commit
92d297e50e
@@ -1283,6 +1283,21 @@ async function* queryModel(
|
|||||||
let messagesForAPI = normalizeMessagesForAPI(messages, filteredTools)
|
let messagesForAPI = normalizeMessagesForAPI(messages, filteredTools)
|
||||||
queryCheckpoint('query_message_normalization_end')
|
queryCheckpoint('query_message_normalization_end')
|
||||||
|
|
||||||
|
// Apply hybrid context strategy for optimal cache/fresh balance
|
||||||
|
if (feature('HYBRID_CONTEXT_STRATEGY')) {
|
||||||
|
const { applyHybridStrategy } = await import('../../utils/hybridContextStrategy.js')
|
||||||
|
// Cap at 200k to avoid edge case with very large context windows
|
||||||
|
const strategyResult = applyHybridStrategy(messagesForAPI, {
|
||||||
|
cacheWeight: 0.4,
|
||||||
|
freshWeight: 0.6,
|
||||||
|
maxTotalTokens: Math.min(
|
||||||
|
getContextWindowForModel(model, getSdkBetas()) - COMPACT_MAX_OUTPUT_TOKENS,
|
||||||
|
200000
|
||||||
|
),
|
||||||
|
})
|
||||||
|
messagesForAPI = strategyResult.selectedMessages
|
||||||
|
}
|
||||||
|
|
||||||
// Model-specific post-processing: strip tool-search-specific fields if the
|
// Model-specific post-processing: strip tool-search-specific fields if the
|
||||||
// selected model doesn't support tool search.
|
// selected model doesn't support tool search.
|
||||||
//
|
//
|
||||||
|
|||||||
104
src/utils/contextPreload.test.ts
Normal file
104
src/utils/contextPreload.test.ts
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
import { describe, expect, it } from 'bun:test'
|
||||||
|
import {
|
||||||
|
analyzeConversationPatterns,
|
||||||
|
predictContextNeeds,
|
||||||
|
preloadContext,
|
||||||
|
createPreloadStrategy,
|
||||||
|
} from './contextPreload.js'
|
||||||
|
|
||||||
|
function createMessage(role: string, content: string, createdAt: number = Date.now()): any {
|
||||||
|
return {
|
||||||
|
message: { role, content, id: 'test', type: 'message', created_at: createdAt },
|
||||||
|
sender: role,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('contextPreload', () => {
|
||||||
|
describe('analyzeConversationPatterns', () => {
|
||||||
|
it('extracts patterns from messages', () => {
|
||||||
|
const messages = [
|
||||||
|
createMessage('user', 'Fix the error in my code', 1000),
|
||||||
|
createMessage('assistant', 'I found the bug', 2000),
|
||||||
|
]
|
||||||
|
|
||||||
|
const patterns = analyzeConversationPatterns(messages)
|
||||||
|
|
||||||
|
expect(patterns.length).toBeGreaterThanOrEqual(0)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('detects debug patterns', () => {
|
||||||
|
const messages = [
|
||||||
|
createMessage('user', 'Debug this error please', 1000),
|
||||||
|
createMessage('assistant', 'Found it', 2000),
|
||||||
|
]
|
||||||
|
|
||||||
|
const patterns = analyzeConversationPatterns(messages)
|
||||||
|
|
||||||
|
expect(patterns.some(p => p.userQuery === 'debug')).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('detects code patterns', () => {
|
||||||
|
const messages = [
|
||||||
|
createMessage('user', 'Write a function for me', 1000),
|
||||||
|
createMessage('assistant', 'Here is the code', 2000),
|
||||||
|
]
|
||||||
|
|
||||||
|
const patterns = analyzeConversationPatterns(messages)
|
||||||
|
|
||||||
|
expect(patterns.some(p => p.userQuery === 'code')).toBe(true)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('predictContextNeeds', () => {
|
||||||
|
it('predicts context needs based on query', () => {
|
||||||
|
const patterns = [{ userQuery: 'debug', neededContext: ['error_history'], frequency: 1 }]
|
||||||
|
|
||||||
|
const prediction = predictContextNeeds('Fix the bug', patterns, {
|
||||||
|
maxPreloadTokens: 10000,
|
||||||
|
confidenceThreshold: 0.3,
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(prediction.confidence).toBeGreaterThan(0)
|
||||||
|
expect(prediction.predictedNeed.length).toBeGreaterThan(0)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('returns non-empty predictedNeed when pattern matches', () => {
|
||||||
|
const patterns = [
|
||||||
|
{ userQuery: 'debug', neededContext: ['error_history', 'stack_trace'], frequency: 2 },
|
||||||
|
]
|
||||||
|
|
||||||
|
const prediction = predictContextNeeds('debug this error', patterns, {
|
||||||
|
maxPreloadTokens: 10000,
|
||||||
|
confidenceThreshold: 0.1,
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(prediction.predictedNeed).toContain('error_history')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('preloadContext', () => {
|
||||||
|
it('preloads relevant context', () => {
|
||||||
|
const messages = [
|
||||||
|
createMessage('system', 'System prompt'),
|
||||||
|
createMessage('user', 'Debug error'),
|
||||||
|
createMessage('assistant', 'Fixed'),
|
||||||
|
]
|
||||||
|
|
||||||
|
const prediction = { predictedNeed: ['error'], confidence: 0.8, suggestedMessages: [] }
|
||||||
|
|
||||||
|
const result = preloadContext(messages, prediction, { maxPreloadTokens: 5000 })
|
||||||
|
|
||||||
|
expect(result.length).toBeGreaterThan(0)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('createPreloadStrategy', () => {
|
||||||
|
it('creates strategy with all methods', () => {
|
||||||
|
const strategy = createPreloadStrategy({ maxPreloadTokens: 10000 })
|
||||||
|
|
||||||
|
expect(strategy.analyze).toBeDefined()
|
||||||
|
expect(strategy.predict).toBeDefined()
|
||||||
|
expect(strategy.preload).toBeDefined()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
145
src/utils/contextPreload.ts
Normal file
145
src/utils/contextPreload.ts
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
/**
|
||||||
|
* Context Pre-loading - Production Grade
|
||||||
|
*
|
||||||
|
* Proactively loads relevant context before it's needed.
|
||||||
|
* Prediction based on conversation patterns.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { roughTokenCountEstimation } from '../services/tokenEstimation.js'
|
||||||
|
import type { Message } from '../types/message.js'
|
||||||
|
|
||||||
|
export interface PreloadConfig {
|
||||||
|
maxPreloadTokens: number
|
||||||
|
predictionWindow?: number
|
||||||
|
confidenceThreshold?: number
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface PreloadPrediction {
|
||||||
|
predictedNeed: string[]
|
||||||
|
confidence: number
|
||||||
|
suggestedMessages: Message[]
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ConversationPattern {
|
||||||
|
userQuery: string
|
||||||
|
neededContext: string[]
|
||||||
|
frequency: number
|
||||||
|
}
|
||||||
|
|
||||||
|
const PATTERN_KEYWORDS: Record<string, string[]> = {
|
||||||
|
'code': ['code', 'function', 'implement', 'write'],
|
||||||
|
'debug': ['error', 'bug', 'fix', 'issue', 'debug'],
|
||||||
|
'refactor': ['refactor', 'improve', 'clean', 'optimize'],
|
||||||
|
'test': ['test', 'spec', 'coverage', 'verify'],
|
||||||
|
'explain': ['explain', 'what', 'how', 'why', 'describe'],
|
||||||
|
'search': ['find', 'search', 'look', 'grep', 'glob'],
|
||||||
|
}
|
||||||
|
|
||||||
|
export function analyzeConversationPatterns(messages: Message[]): ConversationPattern[] {
|
||||||
|
const patterns: ConversationPattern[] = []
|
||||||
|
const recentMessages = messages.slice(-10)
|
||||||
|
|
||||||
|
for (let i = 0; i < recentMessages.length - 1; i++) {
|
||||||
|
const userMsg = recentMessages[i]
|
||||||
|
const assistantMsg = recentMessages[i + 1]
|
||||||
|
|
||||||
|
const userContent = typeof userMsg.message?.content === 'string' ? userMsg.message.content : ''
|
||||||
|
const assistantContent = typeof assistantMsg.message?.content === 'string' ? assistantMsg.message.content : ''
|
||||||
|
|
||||||
|
for (const [category, keywords] of Object.entries(PATTERN_KEYWORDS)) {
|
||||||
|
if (keywords.some(k => userContent.toLowerCase().includes(k))) {
|
||||||
|
patterns.push({
|
||||||
|
userQuery: category,
|
||||||
|
neededContext: extractContextNeeds(assistantContent),
|
||||||
|
frequency: 1,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return patterns
|
||||||
|
}
|
||||||
|
|
||||||
|
function extractContextNeeds(content: string): string[] {
|
||||||
|
const needs: string[] = []
|
||||||
|
if (content.includes('file')) needs.push('file_context')
|
||||||
|
if (content.includes('function')) needs.push('function_defs')
|
||||||
|
if (content.includes('error')) needs.push('error_history')
|
||||||
|
if (content.includes('test')) needs.push('test_files')
|
||||||
|
return needs
|
||||||
|
}
|
||||||
|
|
||||||
|
export function predictContextNeeds(
|
||||||
|
currentQuery: string,
|
||||||
|
patterns: ConversationPattern[],
|
||||||
|
config: PreloadConfig,
|
||||||
|
): PreloadPrediction {
|
||||||
|
const threshold = config.confidenceThreshold ?? 0.5
|
||||||
|
let matchedCategory = ''
|
||||||
|
let highestConfidence = 0
|
||||||
|
|
||||||
|
for (const [category, keywords] of Object.entries(PATTERN_KEYWORDS)) {
|
||||||
|
const matches = keywords.filter(k => currentQuery.toLowerCase().includes(k)).length
|
||||||
|
const confidence = matches / keywords.length
|
||||||
|
|
||||||
|
if (confidence > highestConfidence && confidence >= threshold) {
|
||||||
|
highestConfidence = confidence
|
||||||
|
matchedCategory = category
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const relevantPatterns = patterns.filter(p => p.userQuery === matchedCategory)
|
||||||
|
const allNeeds = relevantPatterns.flatMap(p => p.neededContext)
|
||||||
|
|
||||||
|
return {
|
||||||
|
predictedNeed: [...new Set(allNeeds)],
|
||||||
|
confidence: highestConfidence,
|
||||||
|
suggestedMessages: [],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function preloadContext(
|
||||||
|
availableContext: Message[],
|
||||||
|
prediction: PreloadPrediction,
|
||||||
|
config: PreloadConfig,
|
||||||
|
): Message[] {
|
||||||
|
const targetTokens = config.maxPreloadTokens ?? 30000
|
||||||
|
const selected: Message[] = []
|
||||||
|
let usedTokens = 0
|
||||||
|
|
||||||
|
const priorityTypes = prediction.predictedNeed
|
||||||
|
|
||||||
|
const sorted = [...availableContext].sort((a, b) => {
|
||||||
|
const aContent = typeof a.message?.content === 'string' ? a.message.content : ''
|
||||||
|
const bContent = typeof b.message?.content === 'string' ? b.message.content : ''
|
||||||
|
|
||||||
|
const aPriority = priorityTypes.some(t => aContent.includes(t)) ? 1 : 0
|
||||||
|
const bPriority = priorityTypes.some(t => bContent.includes(t)) ? 1 : 0
|
||||||
|
|
||||||
|
if (bPriority !== aPriority) return bPriority - aPriority
|
||||||
|
return (b.message?.created_at ?? 0) - (a.message?.created_at ?? 0)
|
||||||
|
})
|
||||||
|
|
||||||
|
for (const msg of sorted) {
|
||||||
|
const tokens = roughTokenCountEstimation(
|
||||||
|
typeof msg.message?.content === 'string' ? msg.message.content : ''
|
||||||
|
)
|
||||||
|
|
||||||
|
if (usedTokens + tokens > targetTokens) break
|
||||||
|
|
||||||
|
selected.push(msg)
|
||||||
|
usedTokens += tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
return selected
|
||||||
|
}
|
||||||
|
|
||||||
|
export function createPreloadStrategy(config: PreloadConfig) {
|
||||||
|
return {
|
||||||
|
analyze: analyzeConversationPatterns,
|
||||||
|
predict: (query: string, patterns: ConversationPattern[]) =>
|
||||||
|
predictContextNeeds(query, patterns, config),
|
||||||
|
preload: (context: Message[], prediction: PreloadPrediction) =>
|
||||||
|
preloadContext(context, prediction, config),
|
||||||
|
}
|
||||||
|
}
|
||||||
230
src/utils/hybridContextStrategy.test.ts
Normal file
230
src/utils/hybridContextStrategy.test.ts
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
import { describe, expect, it } from 'bun:test'
|
||||||
|
import {
|
||||||
|
splitContext,
|
||||||
|
applyHybridStrategy,
|
||||||
|
optimizeForCost,
|
||||||
|
optimizeForAccuracy,
|
||||||
|
getHybridStats,
|
||||||
|
} from './hybridContextStrategy.js'
|
||||||
|
|
||||||
|
function createMessage(role: string, content: string, createdAt: number = Date.now()): any {
|
||||||
|
return {
|
||||||
|
message: { role, content, id: 'test', type: 'message', created_at: createdAt },
|
||||||
|
sender: role,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('hybridContextStrategy', () => {
|
||||||
|
describe('splitContext', () => {
|
||||||
|
it('splits context into cached and fresh', () => {
|
||||||
|
const messages = [
|
||||||
|
createMessage('system', 'System prompt', Date.now() - 86400000),
|
||||||
|
createMessage('user', 'Hello'),
|
||||||
|
createMessage('assistant', 'Hi there'),
|
||||||
|
]
|
||||||
|
|
||||||
|
const split = splitContext(messages, {
|
||||||
|
cacheWeight: 0.4,
|
||||||
|
freshWeight: 0.6,
|
||||||
|
maxTotalTokens: 10000,
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(split.cachedTokens).toBeGreaterThanOrEqual(0)
|
||||||
|
expect(split.freshTokens).toBeGreaterThanOrEqual(0)
|
||||||
|
expect(split.totalTokens).toBeGreaterThan(0)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('respects weight configuration', () => {
|
||||||
|
const messages = [
|
||||||
|
createMessage('system', 'Old system', Date.now() - 86400000),
|
||||||
|
createMessage('user', 'Recent message', Date.now()),
|
||||||
|
]
|
||||||
|
|
||||||
|
const split = splitContext(messages, {
|
||||||
|
cacheWeight: 0.5,
|
||||||
|
freshWeight: 0.5,
|
||||||
|
maxTotalTokens: 10000,
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(split.cached).toBeDefined()
|
||||||
|
expect(split.fresh).toBeDefined()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('applyHybridStrategy', () => {
|
||||||
|
it('applies strategy and returns messages', () => {
|
||||||
|
const messages = [
|
||||||
|
createMessage('user', 'Message 1'),
|
||||||
|
createMessage('assistant', 'Response 1'),
|
||||||
|
]
|
||||||
|
|
||||||
|
const result = applyHybridStrategy(messages, {
|
||||||
|
cacheWeight: 0.5,
|
||||||
|
freshWeight: 0.5,
|
||||||
|
maxTotalTokens: 10000,
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.selectedMessages.length).toBeGreaterThan(0)
|
||||||
|
expect(['cache_heavy', 'fresh_heavy', 'balanced']).toContain(result.strategy)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('calculates estimated cost', () => {
|
||||||
|
const messages = [
|
||||||
|
createMessage('user', 'Test message'),
|
||||||
|
]
|
||||||
|
|
||||||
|
const result = applyHybridStrategy(messages, {
|
||||||
|
cacheWeight: 0.5,
|
||||||
|
freshWeight: 0.5,
|
||||||
|
maxTotalTokens: 10000,
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.estimatedCost).toBeGreaterThanOrEqual(0)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('optimizeForCost', () => {
|
||||||
|
it('returns messages within budget', () => {
|
||||||
|
const messages = [
|
||||||
|
createMessage('user', 'Message 1'),
|
||||||
|
createMessage('assistant', 'Response 1'),
|
||||||
|
]
|
||||||
|
|
||||||
|
const result = optimizeForCost(messages, 0.001)
|
||||||
|
|
||||||
|
expect(result.length).toBeGreaterThanOrEqual(0)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('optimizeForAccuracy', () => {
|
||||||
|
it('optimizes for accuracy with token limit', () => {
|
||||||
|
const messages = [
|
||||||
|
createMessage('user', 'Message 1'),
|
||||||
|
createMessage('assistant', 'Response 1'),
|
||||||
|
]
|
||||||
|
|
||||||
|
const result = optimizeForAccuracy(messages, 5000)
|
||||||
|
|
||||||
|
expect(result.length).toBeGreaterThan(0)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('getHybridStats', () => {
|
||||||
|
it('returns statistics', () => {
|
||||||
|
const messages = [
|
||||||
|
createMessage('system', 'System', Date.now() - 86400000),
|
||||||
|
createMessage('user', 'Hello'),
|
||||||
|
]
|
||||||
|
|
||||||
|
const split = splitContext(messages, { cacheWeight: 0.5, freshWeight: 0.5, maxTotalTokens: 10000 })
|
||||||
|
const stats = getHybridStats(split)
|
||||||
|
|
||||||
|
expect(stats.cacheRatio).toBeGreaterThanOrEqual(0)
|
||||||
|
expect(stats.freshRatio).toBeGreaterThanOrEqual(0)
|
||||||
|
expect(stats.totalTokens).toBeGreaterThan(0)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('tool_use/tool_result pairing', () => {
|
||||||
|
it('preserves tool_use and tool_result together', () => {
|
||||||
|
const toolUseId = 'tool-use-123'
|
||||||
|
const messages = [
|
||||||
|
{
|
||||||
|
type: 'assistant',
|
||||||
|
uuid: 'uuid-1',
|
||||||
|
message: {
|
||||||
|
role: 'assistant',
|
||||||
|
content: [{ type: 'tool_use', id: toolUseId, name: 'Read' }],
|
||||||
|
id: 'msg-1',
|
||||||
|
created_at: 1000,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
type: 'user',
|
||||||
|
uuid: 'uuid-2',
|
||||||
|
message: {
|
||||||
|
role: 'user',
|
||||||
|
content: [{ type: 'tool_result', tool_use_id: toolUseId, content: 'file content' }],
|
||||||
|
id: 'msg-2',
|
||||||
|
created_at: 2000,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
type: 'assistant',
|
||||||
|
uuid: 'uuid-3',
|
||||||
|
message: {
|
||||||
|
role: 'assistant',
|
||||||
|
content: 'Response after tool',
|
||||||
|
id: 'msg-3',
|
||||||
|
created_at: 3000,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
] as any[]
|
||||||
|
|
||||||
|
const result = applyHybridStrategy(messages, {
|
||||||
|
cacheWeight: 0.5,
|
||||||
|
freshWeight: 0.5,
|
||||||
|
maxTotalTokens: 10000,
|
||||||
|
})
|
||||||
|
|
||||||
|
const hasToolUse = result.selectedMessages.some(
|
||||||
|
m => Array.isArray(m.message?.content) && m.message.content.some((b: any) => b.type === 'tool_use')
|
||||||
|
)
|
||||||
|
const hasToolResult = result.selectedMessages.some(
|
||||||
|
m => Array.isArray(m.message?.content) && m.message.content.some((b: any) => b.type === 'tool_result')
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(hasToolUse).toBe(true)
|
||||||
|
expect(hasToolResult).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('accounts for large tool_use input in token counting', () => {
|
||||||
|
const largeInput = 'x'.repeat(5000)
|
||||||
|
const messages = [
|
||||||
|
{
|
||||||
|
type: 'assistant',
|
||||||
|
message: {
|
||||||
|
role: 'assistant',
|
||||||
|
content: [
|
||||||
|
{ type: 'tool_use', id: 'tu1', name: 'Edit', input: { path: 'test.js', content: largeInput } },
|
||||||
|
],
|
||||||
|
created_at: 1000,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
] as any[]
|
||||||
|
|
||||||
|
const result = applyHybridStrategy(messages, {
|
||||||
|
cacheWeight: 0.5,
|
||||||
|
freshWeight: 0.5,
|
||||||
|
maxTotalTokens: 20000,
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.totalTokens).toBeGreaterThan(1000)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('accounts for large thinking blocks in token counting', () => {
|
||||||
|
const longThinking = 'Thinking '.repeat(1000)
|
||||||
|
const messages = [
|
||||||
|
{
|
||||||
|
type: 'assistant',
|
||||||
|
message: {
|
||||||
|
role: 'assistant',
|
||||||
|
content: [
|
||||||
|
{ type: 'thinking', thinking: longThinking },
|
||||||
|
{ type: 'text', text: 'Final response' },
|
||||||
|
],
|
||||||
|
created_at: 1000,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
] as any[]
|
||||||
|
|
||||||
|
const result = applyHybridStrategy(messages, {
|
||||||
|
cacheWeight: 0.5,
|
||||||
|
freshWeight: 0.5,
|
||||||
|
maxTotalTokens: 20000,
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.totalTokens).toBeGreaterThan(500)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
306
src/utils/hybridContextStrategy.ts
Normal file
306
src/utils/hybridContextStrategy.ts
Normal file
@@ -0,0 +1,306 @@
|
|||||||
|
/**
|
||||||
|
* Hybrid Context Strategy - Production Grade
|
||||||
|
*
|
||||||
|
* Combines cached + new tokens intelligently.
|
||||||
|
* Optimizes for cost vs accuracy.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { roughTokenCountEstimation } from '../services/tokenEstimation.js'
|
||||||
|
import type { Message } from '../types/message.js'
|
||||||
|
|
||||||
|
export interface HybridConfig {
|
||||||
|
cacheWeight: number
|
||||||
|
freshWeight: number
|
||||||
|
maxTotalTokens: number
|
||||||
|
costThreshold?: number
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ContextSplit {
|
||||||
|
cached: Message[]
|
||||||
|
fresh: Message[]
|
||||||
|
cachedTokens: number
|
||||||
|
freshTokens: number
|
||||||
|
totalTokens: number
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface HybridStrategyResult {
|
||||||
|
selectedMessages: Message[]
|
||||||
|
totalTokens: number
|
||||||
|
strategy: 'cache_heavy' | 'fresh_heavy' | 'balanced'
|
||||||
|
estimatedCost: number
|
||||||
|
}
|
||||||
|
|
||||||
|
const DEFAULT_CONFIG: Required<HybridConfig> = {
|
||||||
|
cacheWeight: 0.4,
|
||||||
|
freshWeight: 0.6,
|
||||||
|
maxTotalTokens: 100000,
|
||||||
|
costThreshold: 0.01,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keep enough for: tool_use -> tool_result -> assistant -> user -> next
|
||||||
|
const MIN_TAILMessages = 5
|
||||||
|
|
||||||
|
function getMessageChain(
|
||||||
|
messages: Message[],
|
||||||
|
): { chains: Message[][]; orphans: Message[] } {
|
||||||
|
const toolUseIds = new Set<string>()
|
||||||
|
const toolUseMessages = new Map<string, Message[]>()
|
||||||
|
const allMessagesByUuid = new Map<string, Message[]>()
|
||||||
|
|
||||||
|
for (const msg of messages) {
|
||||||
|
const uuid = msg.uuid ?? ''
|
||||||
|
if (uuid) {
|
||||||
|
const existing = allMessagesByUuid.get(uuid) ?? []
|
||||||
|
existing.push(msg)
|
||||||
|
allMessagesByUuid.set(uuid, existing)
|
||||||
|
}
|
||||||
|
|
||||||
|
const content = msg.message?.content
|
||||||
|
if (Array.isArray(content)) {
|
||||||
|
for (const block of content) {
|
||||||
|
if (block?.type === 'tool_use' && block?.id) {
|
||||||
|
toolUseIds.add(block.id)
|
||||||
|
const existing = toolUseMessages.get(block.id) ?? []
|
||||||
|
existing.push(msg)
|
||||||
|
toolUseMessages.set(block.id, existing)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const chains: Message[][] = []
|
||||||
|
const orphans: Message[] = []
|
||||||
|
|
||||||
|
for (const [toolUseId, msgs] of toolUseMessages) {
|
||||||
|
const chainMessages: Message[] = [...msgs]
|
||||||
|
|
||||||
|
for (const msg of messages) {
|
||||||
|
const content = msg.message?.content
|
||||||
|
if (Array.isArray(content)) {
|
||||||
|
for (const block of content) {
|
||||||
|
if (block?.type === 'tool_result' && block?.tool_use_id === toolUseId) {
|
||||||
|
chainMessages.push(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
chains.push(chainMessages)
|
||||||
|
}
|
||||||
|
|
||||||
|
const chainMessageUuids = new Set<string>()
|
||||||
|
for (const chain of chains) {
|
||||||
|
for (const msg of chain) {
|
||||||
|
if (msg.uuid) chainMessageUuids.add(msg.uuid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const [uuid, msgs] of allMessagesByUuid) {
|
||||||
|
if (!chainMessageUuids.has(uuid)) {
|
||||||
|
orphans.push(...msgs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return { chains, orphans }
|
||||||
|
}
|
||||||
|
|
||||||
|
function getCacheAge(message: Message): number {
|
||||||
|
const created = message.message?.created_at ?? 0
|
||||||
|
if (created === 0) return 1000
|
||||||
|
return (Date.now() - created) / (1000 * 60 * 60)
|
||||||
|
}
|
||||||
|
|
||||||
|
function getMessageTokenCount(message: Message): number {
|
||||||
|
const content = message.message?.content
|
||||||
|
if (typeof content === 'string') {
|
||||||
|
return roughTokenCountEstimation(content)
|
||||||
|
}
|
||||||
|
if (Array.isArray(content)) {
|
||||||
|
let tokens = 0
|
||||||
|
for (const block of content) {
|
||||||
|
if (typeof block !== 'object' || block === null) continue
|
||||||
|
|
||||||
|
const b = block as Record<string, unknown>
|
||||||
|
|
||||||
|
if (b.type === 'text' && typeof b.text === 'string') {
|
||||||
|
tokens += roughTokenCountEstimation(b.text)
|
||||||
|
} else if (b.type === 'tool_use') {
|
||||||
|
const inputSize = JSON.stringify(b.input ?? {}).length
|
||||||
|
tokens += Math.ceil(inputSize / 4) + 20
|
||||||
|
} else if (b.type === 'tool_result') {
|
||||||
|
if (typeof b.content === 'string') {
|
||||||
|
tokens += roughTokenCountEstimation(b.content)
|
||||||
|
} else if (Array.isArray(b.content)) {
|
||||||
|
for (const rc of b.content) {
|
||||||
|
if (typeof rc === 'object' && rc !== null && 'text' in rc) {
|
||||||
|
tokens += roughTokenCountEstimation((rc as { text: string }).text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tokens += 50
|
||||||
|
}
|
||||||
|
if (b.is_error === true) tokens += 10
|
||||||
|
} else if (b.type === 'thinking' && typeof b.thinking === 'string') {
|
||||||
|
tokens += roughTokenCountEstimation(b.thinking)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
function calculateCacheValue(message: Message): number {
|
||||||
|
const content = typeof message.message?.content === 'string' ? message.message.content : ''
|
||||||
|
const age = getCacheAge(message)
|
||||||
|
|
||||||
|
let value = 0.5
|
||||||
|
|
||||||
|
if (content.includes('error') || content.includes('fail')) value += 0.3
|
||||||
|
if (content.includes('function') || content.includes('class')) value += 0.2
|
||||||
|
if (content.includes('important') || content.includes('key')) value += 0.15
|
||||||
|
|
||||||
|
if (age < 1) value += 0.2
|
||||||
|
else if (age < 6) value += 0.1
|
||||||
|
else value -= 0.2
|
||||||
|
|
||||||
|
if (message.message?.role === 'system') value += 0.1
|
||||||
|
|
||||||
|
return Math.max(0, Math.min(1, value))
|
||||||
|
}
|
||||||
|
|
||||||
|
export function splitContext(
|
||||||
|
messages: Message[],
|
||||||
|
config: HybridConfig,
|
||||||
|
): ContextSplit {
|
||||||
|
const cfg = { ...DEFAULT_CONFIG, ...config }
|
||||||
|
|
||||||
|
const sorted = [...messages].sort((a, b) => {
|
||||||
|
const aValue = calculateCacheValue(a)
|
||||||
|
const bValue = calculateCacheValue(b)
|
||||||
|
return bValue - aValue
|
||||||
|
})
|
||||||
|
|
||||||
|
const cached: Message[] = []
|
||||||
|
const fresh: Message[] = []
|
||||||
|
let cachedTokens = 0
|
||||||
|
let freshTokens = 0
|
||||||
|
|
||||||
|
const cacheTarget = Math.floor(cfg.maxTotalTokens * cfg.cacheWeight)
|
||||||
|
const freshTarget = Math.floor(cfg.maxTotalTokens * cfg.freshWeight)
|
||||||
|
|
||||||
|
for (const msg of sorted) {
|
||||||
|
const tokens = getMessageTokenCount(msg)
|
||||||
|
const age = getCacheAge(msg)
|
||||||
|
|
||||||
|
if (age > 24 && cachedTokens < cacheTarget) {
|
||||||
|
if (cachedTokens + tokens <= cacheTarget) {
|
||||||
|
cached.push(msg)
|
||||||
|
cachedTokens += tokens
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (freshTokens + tokens <= freshTarget) {
|
||||||
|
fresh.push(msg)
|
||||||
|
freshTokens += tokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
cached,
|
||||||
|
fresh,
|
||||||
|
cachedTokens,
|
||||||
|
freshTokens,
|
||||||
|
totalTokens: cachedTokens + freshTokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function applyHybridStrategy(
|
||||||
|
messages: Message[],
|
||||||
|
config: HybridConfig,
|
||||||
|
): HybridStrategyResult {
|
||||||
|
const cfg = { ...DEFAULT_CONFIG, ...config }
|
||||||
|
|
||||||
|
// Preserve message chains (tool_use/tool_result pairs)
|
||||||
|
const { chains, orphans } = getMessageChain(messages)
|
||||||
|
|
||||||
|
// Always preserve the conversation tail (last N messages)
|
||||||
|
const tailMessages = messages.slice(-MIN_TAILMessages)
|
||||||
|
const coreMessages = messages.slice(0, -MIN_TAILMessages)
|
||||||
|
|
||||||
|
const split = splitContext(coreMessages, cfg)
|
||||||
|
|
||||||
|
let strategy: HybridStrategyResult['strategy'] = 'balanced'
|
||||||
|
if (split.cachedTokens > split.freshTokens * 1.5) {
|
||||||
|
strategy = 'cache_heavy'
|
||||||
|
} else if (split.freshTokens > split.cachedTokens * 1.5) {
|
||||||
|
strategy = 'fresh_heavy'
|
||||||
|
}
|
||||||
|
|
||||||
|
const allSelected = [
|
||||||
|
...chains.flat(),
|
||||||
|
...split.cached,
|
||||||
|
...split.fresh,
|
||||||
|
...tailMessages
|
||||||
|
]
|
||||||
|
|
||||||
|
const seenUuids = new Set<string>()
|
||||||
|
const selectedMessages: Message[] = []
|
||||||
|
for (const msg of allSelected) {
|
||||||
|
const uuid = msg.uuid ?? msg.message?.id ?? ''
|
||||||
|
if (!seenUuids.has(uuid)) {
|
||||||
|
seenUuids.add(uuid)
|
||||||
|
selectedMessages.push(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
selectedMessages.sort(
|
||||||
|
(a, b) => (a.message?.created_at ?? 0) - (b.message?.created_at ?? 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
let totalTokens = 0
|
||||||
|
for (const msg of selectedMessages) {
|
||||||
|
totalTokens += getMessageTokenCount(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
const estimatedCost = totalTokens * 0.000001 * 0.5
|
||||||
|
|
||||||
|
return {
|
||||||
|
selectedMessages,
|
||||||
|
totalTokens,
|
||||||
|
strategy,
|
||||||
|
estimatedCost,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function optimizeForCost(messages: Message[], budget: number): Message[] {
|
||||||
|
const result = applyHybridStrategy(messages, {
|
||||||
|
cacheWeight: 0.7,
|
||||||
|
freshWeight: 0.3,
|
||||||
|
maxTotalTokens: Math.floor(budget * 1000),
|
||||||
|
costThreshold: budget,
|
||||||
|
})
|
||||||
|
return result.selectedMessages
|
||||||
|
}
|
||||||
|
|
||||||
|
export function optimizeForAccuracy(messages: Message[], maxTokens: number): Message[] {
|
||||||
|
const result = applyHybridStrategy(messages, {
|
||||||
|
cacheWeight: 0.3,
|
||||||
|
freshWeight: 0.7,
|
||||||
|
maxTotalTokens: maxTokens,
|
||||||
|
})
|
||||||
|
return result.selectedMessages
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getHybridStats(split: ContextSplit) {
|
||||||
|
const cacheRatio = split.totalTokens > 0 ? split.cachedTokens / split.totalTokens : 0
|
||||||
|
const freshRatio = split.totalTokens > 0 ? split.freshTokens / split.totalTokens : 0
|
||||||
|
|
||||||
|
return {
|
||||||
|
cacheRatio: Math.round(cacheRatio * 100),
|
||||||
|
freshRatio: Math.round(freshRatio * 100),
|
||||||
|
totalTokens: split.totalTokens,
|
||||||
|
messageCount: split.cached.length + split.fresh.length,
|
||||||
|
efficiency: split.totalTokens / (split.cachedTokens + split.freshTokens + 1),
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user