feat: context preloading and hybrid context strategy (#860)

* feat: context preloading and hybrid context strategy

PR 2D - Section 2.7, 2.8:
- Add contextPreload.ts with pattern-based prediction
- Add hybridContextStrategy.ts with cache/fresh balancing
- Optimize for cost vs accuracy
- Add comprehensive tests (13 passing)

* feat: wire hybrid context strategy into API path

- Apply hybrid strategy after normalizeMessagesForAPI
- Feature-flag controlled (HYBRID_CONTEXT_STRATEGY)
- Optimizes cache/fresh balance for API requests

* fix: resolve PR 2D blocking issues

- Fix predictContextNeeds self-assign bug (matchedCategory = category)
- Add test for non-empty predictedNeed
- Preserve conversation tail in hybridStrategy (never drop last 3 messages)
- Add comment for hardcoded 200k cap in claude.ts

Fixes reviewer feedback from gnanam1990 and Vasanthdev2004

* fix: preserve tool_use/tool_result chains in hybridStrategy

- Increase MIN_TAIL to 5 (tool_use -> tool_result -> assistant -> user -> next)
- Add getMessageChain() to preserve paired messages
- Chains kept together in final selection

* fix: PR 860 - tool_use/tool_result pairing and safe token counting

Blocking:
- getMessageChain() now pairs by tool_use.id (block ID) not msg.message.id
- Find tool_use blocks by id, pair with tool_result having matching tool_use_id
- Fixes tool_result surviving while paired tool_use dropped

- Token counting now includes array content (tool_use, tool_result, thinking)
- Not just string content, prevents undercounting prompt size

- Deduplicate messages by UUID when combining chains + split + tail
- Prevents duplicate messages in final request

Non-blocking:
- Add regression test for tool_use/tool_result pairing

* fix: PR 860 - account for actual structured payload size in token counting

Blocking:
- getMessageTokenCount now calculates actual token count for structured blocks
- tool_use: uses JSON.stringify(input).length / 4 + base
- tool_result: counts actual content (string or array of text blocks)
- thinking: counts actual thinking text length / 4
- is_error flag adds small overhead

Non-blocking:
- Add tests for large tool_use input and large thinking blocks
This commit is contained in:
ArkhAngelLifeJiggy
2026-04-29 08:49:46 +01:00
committed by GitHub
parent 91f93ce615
commit 92d297e50e
5 changed files with 800 additions and 0 deletions

View File

@@ -1283,6 +1283,21 @@ async function* queryModel(
let messagesForAPI = normalizeMessagesForAPI(messages, filteredTools) let messagesForAPI = normalizeMessagesForAPI(messages, filteredTools)
queryCheckpoint('query_message_normalization_end') queryCheckpoint('query_message_normalization_end')
// Apply hybrid context strategy for optimal cache/fresh balance
if (feature('HYBRID_CONTEXT_STRATEGY')) {
const { applyHybridStrategy } = await import('../../utils/hybridContextStrategy.js')
// Cap at 200k to avoid edge case with very large context windows
const strategyResult = applyHybridStrategy(messagesForAPI, {
cacheWeight: 0.4,
freshWeight: 0.6,
maxTotalTokens: Math.min(
getContextWindowForModel(model, getSdkBetas()) - COMPACT_MAX_OUTPUT_TOKENS,
200000
),
})
messagesForAPI = strategyResult.selectedMessages
}
// Model-specific post-processing: strip tool-search-specific fields if the // Model-specific post-processing: strip tool-search-specific fields if the
// selected model doesn't support tool search. // selected model doesn't support tool search.
// //

View File

@@ -0,0 +1,104 @@
import { describe, expect, it } from 'bun:test'
import {
analyzeConversationPatterns,
predictContextNeeds,
preloadContext,
createPreloadStrategy,
} from './contextPreload.js'
function createMessage(role: string, content: string, createdAt: number = Date.now()): any {
return {
message: { role, content, id: 'test', type: 'message', created_at: createdAt },
sender: role,
}
}
describe('contextPreload', () => {
describe('analyzeConversationPatterns', () => {
it('extracts patterns from messages', () => {
const messages = [
createMessage('user', 'Fix the error in my code', 1000),
createMessage('assistant', 'I found the bug', 2000),
]
const patterns = analyzeConversationPatterns(messages)
expect(patterns.length).toBeGreaterThanOrEqual(0)
})
it('detects debug patterns', () => {
const messages = [
createMessage('user', 'Debug this error please', 1000),
createMessage('assistant', 'Found it', 2000),
]
const patterns = analyzeConversationPatterns(messages)
expect(patterns.some(p => p.userQuery === 'debug')).toBe(true)
})
it('detects code patterns', () => {
const messages = [
createMessage('user', 'Write a function for me', 1000),
createMessage('assistant', 'Here is the code', 2000),
]
const patterns = analyzeConversationPatterns(messages)
expect(patterns.some(p => p.userQuery === 'code')).toBe(true)
})
})
describe('predictContextNeeds', () => {
it('predicts context needs based on query', () => {
const patterns = [{ userQuery: 'debug', neededContext: ['error_history'], frequency: 1 }]
const prediction = predictContextNeeds('Fix the bug', patterns, {
maxPreloadTokens: 10000,
confidenceThreshold: 0.3,
})
expect(prediction.confidence).toBeGreaterThan(0)
expect(prediction.predictedNeed.length).toBeGreaterThan(0)
})
it('returns non-empty predictedNeed when pattern matches', () => {
const patterns = [
{ userQuery: 'debug', neededContext: ['error_history', 'stack_trace'], frequency: 2 },
]
const prediction = predictContextNeeds('debug this error', patterns, {
maxPreloadTokens: 10000,
confidenceThreshold: 0.1,
})
expect(prediction.predictedNeed).toContain('error_history')
})
})
describe('preloadContext', () => {
it('preloads relevant context', () => {
const messages = [
createMessage('system', 'System prompt'),
createMessage('user', 'Debug error'),
createMessage('assistant', 'Fixed'),
]
const prediction = { predictedNeed: ['error'], confidence: 0.8, suggestedMessages: [] }
const result = preloadContext(messages, prediction, { maxPreloadTokens: 5000 })
expect(result.length).toBeGreaterThan(0)
})
})
describe('createPreloadStrategy', () => {
it('creates strategy with all methods', () => {
const strategy = createPreloadStrategy({ maxPreloadTokens: 10000 })
expect(strategy.analyze).toBeDefined()
expect(strategy.predict).toBeDefined()
expect(strategy.preload).toBeDefined()
})
})
})

145
src/utils/contextPreload.ts Normal file
View File

@@ -0,0 +1,145 @@
/**
* Context Pre-loading - Production Grade
*
* Proactively loads relevant context before it's needed.
* Prediction based on conversation patterns.
*/
import { roughTokenCountEstimation } from '../services/tokenEstimation.js'
import type { Message } from '../types/message.js'
export interface PreloadConfig {
maxPreloadTokens: number
predictionWindow?: number
confidenceThreshold?: number
}
export interface PreloadPrediction {
predictedNeed: string[]
confidence: number
suggestedMessages: Message[]
}
export interface ConversationPattern {
userQuery: string
neededContext: string[]
frequency: number
}
const PATTERN_KEYWORDS: Record<string, string[]> = {
'code': ['code', 'function', 'implement', 'write'],
'debug': ['error', 'bug', 'fix', 'issue', 'debug'],
'refactor': ['refactor', 'improve', 'clean', 'optimize'],
'test': ['test', 'spec', 'coverage', 'verify'],
'explain': ['explain', 'what', 'how', 'why', 'describe'],
'search': ['find', 'search', 'look', 'grep', 'glob'],
}
export function analyzeConversationPatterns(messages: Message[]): ConversationPattern[] {
const patterns: ConversationPattern[] = []
const recentMessages = messages.slice(-10)
for (let i = 0; i < recentMessages.length - 1; i++) {
const userMsg = recentMessages[i]
const assistantMsg = recentMessages[i + 1]
const userContent = typeof userMsg.message?.content === 'string' ? userMsg.message.content : ''
const assistantContent = typeof assistantMsg.message?.content === 'string' ? assistantMsg.message.content : ''
for (const [category, keywords] of Object.entries(PATTERN_KEYWORDS)) {
if (keywords.some(k => userContent.toLowerCase().includes(k))) {
patterns.push({
userQuery: category,
neededContext: extractContextNeeds(assistantContent),
frequency: 1,
})
}
}
}
return patterns
}
function extractContextNeeds(content: string): string[] {
const needs: string[] = []
if (content.includes('file')) needs.push('file_context')
if (content.includes('function')) needs.push('function_defs')
if (content.includes('error')) needs.push('error_history')
if (content.includes('test')) needs.push('test_files')
return needs
}
export function predictContextNeeds(
currentQuery: string,
patterns: ConversationPattern[],
config: PreloadConfig,
): PreloadPrediction {
const threshold = config.confidenceThreshold ?? 0.5
let matchedCategory = ''
let highestConfidence = 0
for (const [category, keywords] of Object.entries(PATTERN_KEYWORDS)) {
const matches = keywords.filter(k => currentQuery.toLowerCase().includes(k)).length
const confidence = matches / keywords.length
if (confidence > highestConfidence && confidence >= threshold) {
highestConfidence = confidence
matchedCategory = category
}
}
const relevantPatterns = patterns.filter(p => p.userQuery === matchedCategory)
const allNeeds = relevantPatterns.flatMap(p => p.neededContext)
return {
predictedNeed: [...new Set(allNeeds)],
confidence: highestConfidence,
suggestedMessages: [],
}
}
export function preloadContext(
availableContext: Message[],
prediction: PreloadPrediction,
config: PreloadConfig,
): Message[] {
const targetTokens = config.maxPreloadTokens ?? 30000
const selected: Message[] = []
let usedTokens = 0
const priorityTypes = prediction.predictedNeed
const sorted = [...availableContext].sort((a, b) => {
const aContent = typeof a.message?.content === 'string' ? a.message.content : ''
const bContent = typeof b.message?.content === 'string' ? b.message.content : ''
const aPriority = priorityTypes.some(t => aContent.includes(t)) ? 1 : 0
const bPriority = priorityTypes.some(t => bContent.includes(t)) ? 1 : 0
if (bPriority !== aPriority) return bPriority - aPriority
return (b.message?.created_at ?? 0) - (a.message?.created_at ?? 0)
})
for (const msg of sorted) {
const tokens = roughTokenCountEstimation(
typeof msg.message?.content === 'string' ? msg.message.content : ''
)
if (usedTokens + tokens > targetTokens) break
selected.push(msg)
usedTokens += tokens
}
return selected
}
export function createPreloadStrategy(config: PreloadConfig) {
return {
analyze: analyzeConversationPatterns,
predict: (query: string, patterns: ConversationPattern[]) =>
predictContextNeeds(query, patterns, config),
preload: (context: Message[], prediction: PreloadPrediction) =>
preloadContext(context, prediction, config),
}
}

View File

@@ -0,0 +1,230 @@
import { describe, expect, it } from 'bun:test'
import {
splitContext,
applyHybridStrategy,
optimizeForCost,
optimizeForAccuracy,
getHybridStats,
} from './hybridContextStrategy.js'
function createMessage(role: string, content: string, createdAt: number = Date.now()): any {
return {
message: { role, content, id: 'test', type: 'message', created_at: createdAt },
sender: role,
}
}
describe('hybridContextStrategy', () => {
describe('splitContext', () => {
it('splits context into cached and fresh', () => {
const messages = [
createMessage('system', 'System prompt', Date.now() - 86400000),
createMessage('user', 'Hello'),
createMessage('assistant', 'Hi there'),
]
const split = splitContext(messages, {
cacheWeight: 0.4,
freshWeight: 0.6,
maxTotalTokens: 10000,
})
expect(split.cachedTokens).toBeGreaterThanOrEqual(0)
expect(split.freshTokens).toBeGreaterThanOrEqual(0)
expect(split.totalTokens).toBeGreaterThan(0)
})
it('respects weight configuration', () => {
const messages = [
createMessage('system', 'Old system', Date.now() - 86400000),
createMessage('user', 'Recent message', Date.now()),
]
const split = splitContext(messages, {
cacheWeight: 0.5,
freshWeight: 0.5,
maxTotalTokens: 10000,
})
expect(split.cached).toBeDefined()
expect(split.fresh).toBeDefined()
})
})
describe('applyHybridStrategy', () => {
it('applies strategy and returns messages', () => {
const messages = [
createMessage('user', 'Message 1'),
createMessage('assistant', 'Response 1'),
]
const result = applyHybridStrategy(messages, {
cacheWeight: 0.5,
freshWeight: 0.5,
maxTotalTokens: 10000,
})
expect(result.selectedMessages.length).toBeGreaterThan(0)
expect(['cache_heavy', 'fresh_heavy', 'balanced']).toContain(result.strategy)
})
it('calculates estimated cost', () => {
const messages = [
createMessage('user', 'Test message'),
]
const result = applyHybridStrategy(messages, {
cacheWeight: 0.5,
freshWeight: 0.5,
maxTotalTokens: 10000,
})
expect(result.estimatedCost).toBeGreaterThanOrEqual(0)
})
})
describe('optimizeForCost', () => {
it('returns messages within budget', () => {
const messages = [
createMessage('user', 'Message 1'),
createMessage('assistant', 'Response 1'),
]
const result = optimizeForCost(messages, 0.001)
expect(result.length).toBeGreaterThanOrEqual(0)
})
})
describe('optimizeForAccuracy', () => {
it('optimizes for accuracy with token limit', () => {
const messages = [
createMessage('user', 'Message 1'),
createMessage('assistant', 'Response 1'),
]
const result = optimizeForAccuracy(messages, 5000)
expect(result.length).toBeGreaterThan(0)
})
})
describe('getHybridStats', () => {
it('returns statistics', () => {
const messages = [
createMessage('system', 'System', Date.now() - 86400000),
createMessage('user', 'Hello'),
]
const split = splitContext(messages, { cacheWeight: 0.5, freshWeight: 0.5, maxTotalTokens: 10000 })
const stats = getHybridStats(split)
expect(stats.cacheRatio).toBeGreaterThanOrEqual(0)
expect(stats.freshRatio).toBeGreaterThanOrEqual(0)
expect(stats.totalTokens).toBeGreaterThan(0)
})
})
describe('tool_use/tool_result pairing', () => {
it('preserves tool_use and tool_result together', () => {
const toolUseId = 'tool-use-123'
const messages = [
{
type: 'assistant',
uuid: 'uuid-1',
message: {
role: 'assistant',
content: [{ type: 'tool_use', id: toolUseId, name: 'Read' }],
id: 'msg-1',
created_at: 1000,
},
},
{
type: 'user',
uuid: 'uuid-2',
message: {
role: 'user',
content: [{ type: 'tool_result', tool_use_id: toolUseId, content: 'file content' }],
id: 'msg-2',
created_at: 2000,
},
},
{
type: 'assistant',
uuid: 'uuid-3',
message: {
role: 'assistant',
content: 'Response after tool',
id: 'msg-3',
created_at: 3000,
},
},
] as any[]
const result = applyHybridStrategy(messages, {
cacheWeight: 0.5,
freshWeight: 0.5,
maxTotalTokens: 10000,
})
const hasToolUse = result.selectedMessages.some(
m => Array.isArray(m.message?.content) && m.message.content.some((b: any) => b.type === 'tool_use')
)
const hasToolResult = result.selectedMessages.some(
m => Array.isArray(m.message?.content) && m.message.content.some((b: any) => b.type === 'tool_result')
)
expect(hasToolUse).toBe(true)
expect(hasToolResult).toBe(true)
})
it('accounts for large tool_use input in token counting', () => {
const largeInput = 'x'.repeat(5000)
const messages = [
{
type: 'assistant',
message: {
role: 'assistant',
content: [
{ type: 'tool_use', id: 'tu1', name: 'Edit', input: { path: 'test.js', content: largeInput } },
],
created_at: 1000,
},
},
] as any[]
const result = applyHybridStrategy(messages, {
cacheWeight: 0.5,
freshWeight: 0.5,
maxTotalTokens: 20000,
})
expect(result.totalTokens).toBeGreaterThan(1000)
})
it('accounts for large thinking blocks in token counting', () => {
const longThinking = 'Thinking '.repeat(1000)
const messages = [
{
type: 'assistant',
message: {
role: 'assistant',
content: [
{ type: 'thinking', thinking: longThinking },
{ type: 'text', text: 'Final response' },
],
created_at: 1000,
},
},
] as any[]
const result = applyHybridStrategy(messages, {
cacheWeight: 0.5,
freshWeight: 0.5,
maxTotalTokens: 20000,
})
expect(result.totalTokens).toBeGreaterThan(500)
})
})
})

View File

@@ -0,0 +1,306 @@
/**
* Hybrid Context Strategy - Production Grade
*
* Combines cached + new tokens intelligently.
* Optimizes for cost vs accuracy.
*/
import { roughTokenCountEstimation } from '../services/tokenEstimation.js'
import type { Message } from '../types/message.js'
export interface HybridConfig {
cacheWeight: number
freshWeight: number
maxTotalTokens: number
costThreshold?: number
}
export interface ContextSplit {
cached: Message[]
fresh: Message[]
cachedTokens: number
freshTokens: number
totalTokens: number
}
export interface HybridStrategyResult {
selectedMessages: Message[]
totalTokens: number
strategy: 'cache_heavy' | 'fresh_heavy' | 'balanced'
estimatedCost: number
}
const DEFAULT_CONFIG: Required<HybridConfig> = {
cacheWeight: 0.4,
freshWeight: 0.6,
maxTotalTokens: 100000,
costThreshold: 0.01,
}
// Keep enough for: tool_use -> tool_result -> assistant -> user -> next
const MIN_TAILMessages = 5
function getMessageChain(
messages: Message[],
): { chains: Message[][]; orphans: Message[] } {
const toolUseIds = new Set<string>()
const toolUseMessages = new Map<string, Message[]>()
const allMessagesByUuid = new Map<string, Message[]>()
for (const msg of messages) {
const uuid = msg.uuid ?? ''
if (uuid) {
const existing = allMessagesByUuid.get(uuid) ?? []
existing.push(msg)
allMessagesByUuid.set(uuid, existing)
}
const content = msg.message?.content
if (Array.isArray(content)) {
for (const block of content) {
if (block?.type === 'tool_use' && block?.id) {
toolUseIds.add(block.id)
const existing = toolUseMessages.get(block.id) ?? []
existing.push(msg)
toolUseMessages.set(block.id, existing)
}
}
}
}
const chains: Message[][] = []
const orphans: Message[] = []
for (const [toolUseId, msgs] of toolUseMessages) {
const chainMessages: Message[] = [...msgs]
for (const msg of messages) {
const content = msg.message?.content
if (Array.isArray(content)) {
for (const block of content) {
if (block?.type === 'tool_result' && block?.tool_use_id === toolUseId) {
chainMessages.push(msg)
}
}
}
}
chains.push(chainMessages)
}
const chainMessageUuids = new Set<string>()
for (const chain of chains) {
for (const msg of chain) {
if (msg.uuid) chainMessageUuids.add(msg.uuid)
}
}
for (const [uuid, msgs] of allMessagesByUuid) {
if (!chainMessageUuids.has(uuid)) {
orphans.push(...msgs)
}
}
return { chains, orphans }
}
function getCacheAge(message: Message): number {
const created = message.message?.created_at ?? 0
if (created === 0) return 1000
return (Date.now() - created) / (1000 * 60 * 60)
}
function getMessageTokenCount(message: Message): number {
const content = message.message?.content
if (typeof content === 'string') {
return roughTokenCountEstimation(content)
}
if (Array.isArray(content)) {
let tokens = 0
for (const block of content) {
if (typeof block !== 'object' || block === null) continue
const b = block as Record<string, unknown>
if (b.type === 'text' && typeof b.text === 'string') {
tokens += roughTokenCountEstimation(b.text)
} else if (b.type === 'tool_use') {
const inputSize = JSON.stringify(b.input ?? {}).length
tokens += Math.ceil(inputSize / 4) + 20
} else if (b.type === 'tool_result') {
if (typeof b.content === 'string') {
tokens += roughTokenCountEstimation(b.content)
} else if (Array.isArray(b.content)) {
for (const rc of b.content) {
if (typeof rc === 'object' && rc !== null && 'text' in rc) {
tokens += roughTokenCountEstimation((rc as { text: string }).text)
}
}
} else {
tokens += 50
}
if (b.is_error === true) tokens += 10
} else if (b.type === 'thinking' && typeof b.thinking === 'string') {
tokens += roughTokenCountEstimation(b.thinking)
}
}
return tokens
}
return 0
}
function calculateCacheValue(message: Message): number {
const content = typeof message.message?.content === 'string' ? message.message.content : ''
const age = getCacheAge(message)
let value = 0.5
if (content.includes('error') || content.includes('fail')) value += 0.3
if (content.includes('function') || content.includes('class')) value += 0.2
if (content.includes('important') || content.includes('key')) value += 0.15
if (age < 1) value += 0.2
else if (age < 6) value += 0.1
else value -= 0.2
if (message.message?.role === 'system') value += 0.1
return Math.max(0, Math.min(1, value))
}
export function splitContext(
messages: Message[],
config: HybridConfig,
): ContextSplit {
const cfg = { ...DEFAULT_CONFIG, ...config }
const sorted = [...messages].sort((a, b) => {
const aValue = calculateCacheValue(a)
const bValue = calculateCacheValue(b)
return bValue - aValue
})
const cached: Message[] = []
const fresh: Message[] = []
let cachedTokens = 0
let freshTokens = 0
const cacheTarget = Math.floor(cfg.maxTotalTokens * cfg.cacheWeight)
const freshTarget = Math.floor(cfg.maxTotalTokens * cfg.freshWeight)
for (const msg of sorted) {
const tokens = getMessageTokenCount(msg)
const age = getCacheAge(msg)
if (age > 24 && cachedTokens < cacheTarget) {
if (cachedTokens + tokens <= cacheTarget) {
cached.push(msg)
cachedTokens += tokens
continue
}
}
if (freshTokens + tokens <= freshTarget) {
fresh.push(msg)
freshTokens += tokens
}
}
return {
cached,
fresh,
cachedTokens,
freshTokens,
totalTokens: cachedTokens + freshTokens,
}
}
export function applyHybridStrategy(
messages: Message[],
config: HybridConfig,
): HybridStrategyResult {
const cfg = { ...DEFAULT_CONFIG, ...config }
// Preserve message chains (tool_use/tool_result pairs)
const { chains, orphans } = getMessageChain(messages)
// Always preserve the conversation tail (last N messages)
const tailMessages = messages.slice(-MIN_TAILMessages)
const coreMessages = messages.slice(0, -MIN_TAILMessages)
const split = splitContext(coreMessages, cfg)
let strategy: HybridStrategyResult['strategy'] = 'balanced'
if (split.cachedTokens > split.freshTokens * 1.5) {
strategy = 'cache_heavy'
} else if (split.freshTokens > split.cachedTokens * 1.5) {
strategy = 'fresh_heavy'
}
const allSelected = [
...chains.flat(),
...split.cached,
...split.fresh,
...tailMessages
]
const seenUuids = new Set<string>()
const selectedMessages: Message[] = []
for (const msg of allSelected) {
const uuid = msg.uuid ?? msg.message?.id ?? ''
if (!seenUuids.has(uuid)) {
seenUuids.add(uuid)
selectedMessages.push(msg)
}
}
selectedMessages.sort(
(a, b) => (a.message?.created_at ?? 0) - (b.message?.created_at ?? 0)
)
let totalTokens = 0
for (const msg of selectedMessages) {
totalTokens += getMessageTokenCount(msg)
}
const estimatedCost = totalTokens * 0.000001 * 0.5
return {
selectedMessages,
totalTokens,
strategy,
estimatedCost,
}
}
export function optimizeForCost(messages: Message[], budget: number): Message[] {
const result = applyHybridStrategy(messages, {
cacheWeight: 0.7,
freshWeight: 0.3,
maxTotalTokens: Math.floor(budget * 1000),
costThreshold: budget,
})
return result.selectedMessages
}
export function optimizeForAccuracy(messages: Message[], maxTokens: number): Message[] {
const result = applyHybridStrategy(messages, {
cacheWeight: 0.3,
freshWeight: 0.7,
maxTotalTokens: maxTokens,
})
return result.selectedMessages
}
export function getHybridStats(split: ContextSplit) {
const cacheRatio = split.totalTokens > 0 ? split.cachedTokens / split.totalTokens : 0
const freshRatio = split.totalTokens > 0 ? split.freshTokens / split.totalTokens : 0
return {
cacheRatio: Math.round(cacheRatio * 100),
freshRatio: Math.round(freshRatio * 100),
totalTokens: split.totalTokens,
messageCount: split.cached.length + split.fresh.length,
efficiency: split.totalTokens / (split.cachedTokens + split.freshTokens + 1),
}
}