From b5f70473583718bfef2742eac5b82072820bb204 Mon Sep 17 00:00:00 2001 From: 3kin0x Date: Fri, 24 Apr 2026 20:26:02 +0200 Subject: [PATCH] Feature/memory pr (#889) * feat: multi-turn context and conversation arc memory PR 2E - Section 2.9, 2.10: - Add multiTurnContext.ts with turn tracking and state preservation - Add conversationArc.ts with goal/decision/milestone tracking - Wire into query.ts after tool execution - Feature-flags: MULTI_TURN_CONTEXT, CONVERSATION_ARC - Add comprehensive tests (22 passing) * feat(memory): resolve review blockers and integrate native Knowledge Graph into Conversation Arcs - Fix: Extract text from production block arrays in phase detector\n- Fix: Ensure proper turn segmentation in query loop\n- Fix: Respect options in multi-turn context tracker\n- Feat: Add native Knowledge Graph (Entities/Relations) to ConversationArc architecture\n- Test: Comprehensive test suite for all fixes and new graph features * test(perf): add automated performance benchmarks for Knowledge Graph extraction and summary --------- Co-authored-by: LifeJiggy --- src/query.ts | 28 ++ src/utils/conversationArc.perf.test.ts | 68 +++++ src/utils/conversationArc.test.ts | 191 +++++++++++++ src/utils/conversationArc.ts | 363 +++++++++++++++++++++++++ src/utils/multiTurnContext.test.ts | 134 +++++++++ src/utils/multiTurnContext.ts | 149 ++++++++++ 6 files changed, 933 insertions(+) create mode 100644 src/utils/conversationArc.perf.test.ts create mode 100644 src/utils/conversationArc.test.ts create mode 100644 src/utils/conversationArc.ts create mode 100644 src/utils/multiTurnContext.test.ts create mode 100644 src/utils/multiTurnContext.ts diff --git a/src/query.ts b/src/query.ts index f36b2610..569229c1 100644 --- a/src/query.ts +++ b/src/query.ts @@ -252,6 +252,12 @@ async function* queryLoop( | ToolUseSummaryMessage, Terminal > { + // Start a new turn for multi-turn context tracking + if (feature('MULTI_TURN_CONTEXT')) { + const { startNewTurn } = await import('./utils/multiTurnContext.js') + startNewTurn() + } + // Immutable params — never reassigned during the query loop. const { systemPrompt, @@ -1516,6 +1522,28 @@ async function* queryLoop( } queryCheckpoint('query_tool_execution_end') + // Track multi-turn context after tool execution + if (feature('MULTI_TURN_CONTEXT')) { + const { addMessageToTurn, addToolCallToTurn } = await import( + './utils/multiTurnContext.js' + ) + addMessageToTurn(assistantMessage) + for (const toolUse of toolUseBlocks) { + addToolCallToTurn({ + id: toolUse.id, + name: toolUse.name, + input: toolUse.input as Record, + timestamp: Date.now(), + }) + } + } + + // Update conversation arc phase + if (feature('CONVERSATION_ARC')) { + const { updateArcPhase } = await import('./utils/conversationArc.js') + updateArcPhase([assistantMessage]) + } + // Generate tool use summary after tool batch completes — passed to next recursive call let nextPendingToolUseSummary: | Promise diff --git a/src/utils/conversationArc.perf.test.ts b/src/utils/conversationArc.perf.test.ts new file mode 100644 index 00000000..eba0e971 --- /dev/null +++ b/src/utils/conversationArc.perf.test.ts @@ -0,0 +1,68 @@ +import { describe, expect, it, beforeEach } from 'bun:test' +import { + initializeArc, + updateArcPhase, + getArcSummary, + resetArc +} from './conversationArc.js' + +function createMessage(content: string): any { + return { + message: { role: 'user', content, id: 'test', type: 'message', created_at: Date.now() }, + sender: 'user', + } +} + +describe('Conversation Arc Performance Benchmarks', () => { + beforeEach(() => { + resetArc() + initializeArc() + }) + + it('performs automatic fact extraction in sub-millisecond time', () => { + const iterations = 100 + const complexContent = 'Deploying version v1.2.3 to /opt/prod/server on https://api.prod.local with JIRA_URL=https://jira.corp' + + const startTime = performance.now() + for (let i = 0; i < iterations; i++) { + updateArcPhase([createMessage(complexContent)]) + } + const duration = performance.now() - startTime + const averageTime = duration / iterations + + console.log(`[Benchmark] Avg extraction time: ${averageTime.toFixed(4)}ms`) + + // Performance guard: should definitely be under 0.5ms per message on any modern CI + expect(averageTime).toBeLessThan(0.5) + }) + + it('generates summaries quickly even with a populated graph', () => { + // Populate graph with 50 facts + for (let i = 0; i < 50; i++) { + updateArcPhase([createMessage(`Var_${i}=Value_${i} in /path/to/file_${i}`)]) + } + + const startTime = performance.now() + const summary = getArcSummary() + const duration = performance.now() - startTime + + console.log(`[Benchmark] Summary generation time (50 entities): ${duration.toFixed(4)}ms`) + expect(summary).toContain('Knowledge Graph:') + // Summary generation should be extremely fast + expect(duration).toBeLessThan(10) + }) + + it('maintains a compact memory footprint', () => { + const arc = initializeArc() + for (let i = 0; i < 100; i++) { + updateArcPhase([createMessage(`Fact_${i}=Value_${i}`)]) + } + + const serialized = JSON.stringify(arc) + const sizeKB = serialized.length / 1024 + console.log(`[Benchmark] Memory footprint (100 facts): ${sizeKB.toFixed(2)}KB`) + + // Should be well under 100KB for 100 simple facts + expect(sizeKB).toBeLessThan(100) + }) +}) diff --git a/src/utils/conversationArc.test.ts b/src/utils/conversationArc.test.ts new file mode 100644 index 00000000..06773815 --- /dev/null +++ b/src/utils/conversationArc.test.ts @@ -0,0 +1,191 @@ +import { describe, expect, it, beforeEach } from 'bun:test' +import { + initializeArc, + getArc, + updateArcPhase, + addGoal, + updateGoalStatus, + addDecision, + addMilestone, + addEntity, + addRelation, + getGraphSummary, + getArcSummary, + resetArc, + getArcStats, +} from './conversationArc.js' + +function createMessage(role: string, content: string): any { + return { + message: { role, content, id: 'test', type: 'message', created_at: Date.now() }, + sender: role, + } +} + +describe('conversationArc', () => { + beforeEach(() => { + resetArc() + }) + + describe('initializeArc', () => { + it('creates new arc', () => { + const arc = initializeArc() + expect(arc.id).toBeDefined() + expect(arc.currentPhase).toBe('init') + expect(arc.goals).toEqual([]) + expect(arc.decisions).toEqual([]) + }) + }) + + describe('Knowledge Graph', () => { + it('adds entities and relations', () => { + initializeArc() + const e1 = addEntity('system', 'RHEL9', { version: '9.4' }) + const e2 = addEntity('credential', 'Jira PAT') + + expect(e1.name).toBe('RHEL9') + expect(e1.attributes.version).toBe('9.4') + + addRelation(e1.id, e2.id, 'requires') + + const arc = getArc() + expect(Object.keys(arc!.knowledgeGraph.entities).length).toBe(2) + expect(arc!.knowledgeGraph.relations.length).toBe(1) + expect(arc!.knowledgeGraph.relations[0].type).toBe('requires') + }) + + it('generates a knowledge graph summary', () => { + initializeArc() + const e1 = addEntity('system', 'RHEL9', { os: 'linux' }) + const e2 = addEntity('feature', 'OpenClaude') + addRelation(e2.id, e1.id, 'runs_on') + + const summary = getArcSummary() + expect(summary).toContain('Knowledge Graph:') + expect(summary).toContain('[system] RHEL9 (os: linux)') + expect(summary).toContain('OpenClaude --(runs_on)--> RHEL9') + }) + + it('automatically learns facts from message content', () => { + initializeArc() + const complexMessage = createMessage('user', 'Set JIRA_URL=https://jira.local and look in /opt/app/bin version v1.2.3') + + updateArcPhase([complexMessage]) + + const summary = getGraphSummary() + expect(summary).toContain('[environment_variable] JIRA_URL') + expect(summary).toContain('[endpoint] jira.local') + expect(summary).toContain('[path] /opt/app/bin') + expect(summary).toContain('[version] v1.2.3') + }) + + it('throws error when adding relation to non-existent entity', () => { + initializeArc() + expect(() => addRelation('invalid1', 'invalid2', 'test')).toThrow('Source or target entity not found in graph') + }) + }) + + describe('resetArc', () => { + + it('returns existing arc or creates new', () => { + const arc1 = getArc() + const arc2 = getArc() + expect(arc1?.id).toBe(arc2?.id) + }) + }) + + describe('updateArcPhase', () => { + it('detects exploring phase', () => { + initializeArc() + updateArcPhase([createMessage('user', 'Find the file')]) + + expect(getArc()?.currentPhase).toBe('exploring') + }) + + it('detects phase from block array content', () => { + initializeArc() + const blockMessage = { + message: { + role: 'assistant', + content: [{ type: 'text', text: 'I will now implement the requested changes.' }], + id: 'test', + type: 'message', + created_at: Date.now(), + }, + sender: 'assistant', + } + updateArcPhase([blockMessage as any]) + + expect(getArc()?.currentPhase).toBe('implementing') + }) + + it('progresses phases forward only', () => { + initializeArc() + updateArcPhase([createMessage('user', 'Write code')]) + updateArcPhase([createMessage('user', 'Find file')]) + + // Phase should remain at implementing since it was detected first + expect(getArc()?.currentPhase).toBe('implementing') + }) + }) + + describe('goal management', () => { + it('adds goal', () => { + initializeArc() + const goal = addGoal('Fix the bug') + expect(goal.description).toBe('Fix the bug') + expect(goal.status).toBe('pending') + }) + + it('updates goal status', () => { + initializeArc() + const goal = addGoal('Test feature') + updateGoalStatus(goal.id, 'completed') + + const updated = getArc()?.goals.find(g => g.id === goal.id) + expect(updated?.status).toBe('completed') + expect(updated?.completedAt).toBeDefined() + }) + }) + + describe('addDecision', () => { + it('adds decision', () => { + initializeArc() + const decision = addDecision('Use TypeScript', 'Type safety') + expect(decision.description).toBe('Use TypeScript') + expect(decision.rationale).toBe('Type safety') + }) + }) + + describe('addMilestone', () => { + it('adds milestone', () => { + initializeArc() + const milestone = addMilestone('Phase 1 complete') + expect(milestone.description).toBe('Phase 1 complete') + expect(milestone.achievedAt).toBeDefined() + }) + }) + + describe('getArcSummary', () => { + it('returns summary string', () => { + initializeArc() + addGoal('Test goal') + const summary = getArcSummary() + + expect(summary).toContain('Phase:') + expect(summary).toContain('Goals:') + }) + }) + + describe('getArcStats', () => { + it('returns statistics', () => { + initializeArc() + addGoal('Goal 1') + addDecision('Decision 1') + + const stats = getArcStats() + expect(stats?.goalCount).toBe(1) + expect(stats?.decisionCount).toBe(1) + }) + }) +}) \ No newline at end of file diff --git a/src/utils/conversationArc.ts b/src/utils/conversationArc.ts new file mode 100644 index 00000000..41ca2ae3 --- /dev/null +++ b/src/utils/conversationArc.ts @@ -0,0 +1,363 @@ +/** + * Conversation Arc Memory - Production Grade + * + * Remembers conversation goals and key decisions. + * High-level abstraction of conversation progress. + */ + +import type { Message } from '../types/message.js' + +export interface Entity { + id: string + type: string // e.g., 'system', 'preference', 'credential' + name: string // e.g., 'RHEL9', 'Jira URL' + attributes: Record +} + +export interface Relation { + sourceId: string + targetId: string + type: string // e.g., 'runs_on', 'configured_as' +} + +export interface KnowledgeGraph { + entities: Record + relations: Relation[] +} + +export interface ConversationArc { + id: string + goals: Goal[] + decisions: Decision[] + milestones: Milestone[] + knowledgeGraph: KnowledgeGraph + currentPhase: 'init' | 'exploring' | 'implementing' | 'reviewing' | 'completed' + startTime: number + lastUpdateTime: number +} + +export interface Goal { + id: string + description: string + status: 'pending' | 'active' | 'completed' | 'abandoned' + createdAt: number + completedAt?: number +} + +export interface Decision { + id: string + description: string + rationale?: string + timestamp: number +} + +export interface Milestone { + id: string + description: string + achievedAt: number +} + +const ARC_KEYWORDS = { + init: ['start', 'begin', 'help', 'please'], + exploring: ['check', 'find', 'look', 'what', 'how', 'where', 'show'], + implementing: ['write', 'create', 'add', 'fix', 'update', 'modify', 'implement'], + reviewing: ['test', 'review', 'verify', 'check', 'ensure'], + completed: ['done', 'complete', 'finished', 'ready', 'good'], +} + +let conversationArc: ConversationArc | null = null + +export function initializeArc(): ConversationArc { + conversationArc = { + id: `arc_${Date.now()}`, + goals: [], + decisions: [], + milestones: [], + knowledgeGraph: { + entities: {}, + relations: [], + }, + currentPhase: 'init', + startTime: Date.now(), + lastUpdateTime: Date.now(), + } + return conversationArc +} + +export function getArc(): ConversationArc | null { + if (!conversationArc) { + return initializeArc() + } + return conversationArc +} + +function extractTextFromContent(content: unknown): string { + if (!content) return '' + if (typeof content === 'string') return content + if (Array.isArray(content)) { + return content + .filter((block: any) => block.type === 'text' && typeof block.text === 'string') + .map((block: any) => block.text) + .join('\\n') + } + return '' +} + +function detectPhase(content: string): ConversationArc['currentPhase'] | null { + const lower = content.toLowerCase() + + for (const [phase, keywords] of Object.entries(ARC_KEYWORDS)) { + if (keywords.some(k => lower.includes(k))) { + return phase as ConversationArc['currentPhase'] + } + } + + return null +} + +function extractFactsAutomatically(content: string): void { + const arc = getArc() + if (!arc) return + + // 1. Detect Environment Variables (KEY=VALUE) + const envMatches = content.matchAll(/(?:export\s+)?([A-Z_]+)=([^\s\n"']+)/g) + for (const match of envMatches) { + addEntity('environment_variable', match[1], { value: match[2] }) + } + + // 2. Detect Absolute Paths + const pathMatches = content.matchAll(/(\/(?:[\w.-]+\/)+[\w.-]+)/g) + for (const match of pathMatches) { + const path = match[1] + if (path.length > 5 && !path.includes('node_modules')) { + addEntity('path', path, { type: 'absolute' }) + } + } + + // 3. Detect Versions (v1.2.3 or version 1.2.3) + const versionMatches = content.matchAll(/(?:v|version\s+)(\d+\.\d+(?:\.\d+)?)/gi) + for (const match of versionMatches) { + addEntity('version', match[0], { semver: match[1] }) + } + + // 4. Detect Hostnames/URLs + const urlMatches = content.matchAll(/(https?:\/\/[^\s\n"']+)/g) + for (const match of urlMatches) { + try { + const url = new URL(match[1]) + addEntity('endpoint', url.hostname, { url: url.toString() }) + } catch { + // Ignore invalid URLs + } + } +} + +export function updateArcPhase(messages: Message[]): void { + const arc = getArc() + if (!arc) return + + for (const msg of messages.slice(-5).reverse()) { + const content = extractTextFromContent(msg.message?.content) + if (!content) continue + + // Phase detection + const detected = detectPhase(content) + if (detected && detected !== arc.currentPhase) { + const phaseOrder = [ + 'init', + 'exploring', + 'implementing', + 'reviewing', + 'completed', + ] + const oldIdx = phaseOrder.indexOf(arc.currentPhase) + const newIdx = phaseOrder.indexOf(detected) + + if (newIdx > oldIdx) { + arc.currentPhase = detected + arc.lastUpdateTime = Date.now() + } + } + + // NEW: Passive fact extraction (Automatic Learning) + extractFactsAutomatically(content) + } +} + +export function addGoal(description: string): Goal { + const arc = getArc() + if (!arc) throw new Error('Arc not initialized') + + const goal: Goal = { + id: `goal_${Date.now()}`, + description, + status: 'pending', + createdAt: Date.now(), + } + + arc.goals.push(goal) + arc.lastUpdateTime = Date.now() + + if (arc.currentPhase === 'init') { + arc.currentPhase = 'exploring' + } + + return goal +} + +export function updateGoalStatus(goalId: string, status: Goal['status']): void { + const arc = getArc() + if (!arc) return + + const goal = arc.goals.find(g => g.id === goalId) + if (!goal) return + + goal.status = status + if (status === 'completed') { + goal.completedAt = Date.now() + addMilestone(`Completed: ${goal.description}`) + } + + arc.lastUpdateTime = Date.now() +} + +export function addDecision(description: string, rationale?: string): Decision { + const arc = getArc() + if (!arc) throw new Error('Arc not initialized') + + const decision: Decision = { + id: `decision_${Date.now()}`, + description, + rationale, + timestamp: Date.now(), + } + + arc.decisions.push(decision) + arc.lastUpdateTime = Date.now() + + return decision +} + +export function addMilestone(description: string): Milestone { + const arc = getArc() + if (!arc) throw new Error('Arc not initialized') + + const milestone: Milestone = { + id: `milestone_${Date.now()}`, + description, + achievedAt: Date.now(), + } + + arc.milestones.push(milestone) + arc.lastUpdateTime = Date.now() + + return milestone +} + +export function addEntity( + type: string, + name: string, + attributes: Record = {}, +): Entity { + const arc = getArc() + if (!arc) throw new Error('Arc not initialized') + + const id = `entity_${Date.now()}_${Math.random().toString(36).slice(2, 7)}` + const entity: Entity = { id, type, name, attributes } + + arc.knowledgeGraph.entities[id] = entity + arc.lastUpdateTime = Date.now() + return entity +} + +export function addRelation( + sourceId: string, + targetId: string, + type: string, +): void { + const arc = getArc() + if (!arc) throw new Error('Arc not initialized') + + if (!arc.knowledgeGraph.entities[sourceId] || !arc.knowledgeGraph.entities[targetId]) { + throw new Error('Source or target entity not found in graph') + } + + arc.knowledgeGraph.relations.push({ sourceId, targetId, type }) + arc.lastUpdateTime = Date.now() +} + +export function getGraphSummary(): string { + const arc = getArc() + if (!arc || Object.keys(arc.knowledgeGraph.entities).length === 0) { + return '' + } + + let summary = '\\nKnowledge Graph:\\n' + for (const entity of Object.values(arc.knowledgeGraph.entities)) { + summary += `- [${entity.type}] ${entity.name}` + const attrs = Object.entries(entity.attributes) + if (attrs.length > 0) { + summary += ` (${attrs.map(([k, v]) => `${k}: ${v}`).join(', ')})` + } + summary += '\\n' + } + + for (const rel of arc.knowledgeGraph.relations) { + const src = arc.knowledgeGraph.entities[rel.sourceId]?.name + const tgt = arc.knowledgeGraph.entities[rel.targetId]?.name + if (src && tgt) { + summary += `- ${src} --(${rel.type})--> ${tgt}\\n` + } + } + + return summary +} + +export function getArcSummary(): string { + const arc = getArc() + if (!arc) return 'No conversation arc' + + const activeGoals = arc.goals.filter( + g => g.status === 'active' || g.status === 'pending', + ) + const completedGoals = arc.goals.filter(g => g.status === 'completed') + + let summary = `Phase: ${arc.currentPhase}\\n` + summary += `Goals: ${completedGoals.length}/${arc.goals.length} completed\\n` + + if (activeGoals.length > 0) { + summary += `Active: ${activeGoals[0].description.slice(0, 50)}...\\n` + } + + if (arc.decisions.length > 0) { + summary += `Decisions: ${arc.decisions.length}\\n` + } + + if (arc.milestones.length > 0) { + summary += `Latest milestone: ${arc.milestones[ + arc.milestones.length - 1 + ].description.slice(0, 40)}` + } + + summary += getGraphSummary() + + return summary +} + +export function resetArc(): void { + conversationArc = null +} + +export function getArcStats() { + const arc = getArc() + if (!arc) return null + + return { + phase: arc.currentPhase, + goalCount: arc.goals.length, + completedGoals: arc.goals.filter(g => g.status === 'completed').length, + decisionCount: arc.decisions.length, + milestoneCount: arc.milestones.length, + durationMs: arc.lastUpdateTime - arc.startTime, + } +} \ No newline at end of file diff --git a/src/utils/multiTurnContext.test.ts b/src/utils/multiTurnContext.test.ts new file mode 100644 index 00000000..3b990a3d --- /dev/null +++ b/src/utils/multiTurnContext.test.ts @@ -0,0 +1,134 @@ +import { describe, expect, it, beforeEach } from 'bun:test' +import { + startNewTurn, + getCurrentTurn, + addMessageToTurn, + addToolCallToTurn, + setTurnState, + getTurnState, + getTurnHistory, + getRecentTurns, + getMultiTurnStats, + resetMultiTurnState, + createMultiTurnTracker, +} from './multiTurnContext.js' + +function createMessage(role: string, content: string): any { + return { + message: { role, content, id: 'test', type: 'message', created_at: Date.now() }, + sender: role, + } +} + +describe('multiTurnContext', () => { + beforeEach(() => { + resetMultiTurnState() + }) + + describe('startNewTurn', () => { + it('creates a new turn', () => { + const turn = startNewTurn() + expect(turn.turnId).toBeDefined() + expect(turn.messages).toEqual([]) + expect(turn.toolCalls).toEqual([]) + }) + + it('tracks turn count', () => { + startNewTurn() + const turn2 = startNewTurn() + expect(turn2.turnId).toContain('turn_2') + }) + }) + + describe('addMessageToTurn', () => { + it('adds message to current turn', () => { + startNewTurn() + addMessageToTurn(createMessage('user', 'Hello')) + + const turn = getCurrentTurn() + expect(turn?.messages.length).toBe(1) + }) + + it('creates turn if none exists', () => { + addMessageToTurn(createMessage('user', 'Hello')) + expect(getCurrentTurn()).not.toBeNull() + }) + }) + + describe('addToolCallToTurn', () => { + it('adds tool call to turn', () => { + startNewTurn() + addToolCallToTurn({ id: 'tool1', name: 'read', input: { file: 'test' }, timestamp: Date.now() }) + + const turn = getCurrentTurn() + expect(turn?.toolCalls.length).toBe(1) + expect(turn?.toolCalls[0].name).toBe('read') + }) + }) + + describe('state management', () => { + it('sets and gets turn state', () => { + startNewTurn() + setTurnState('key1', 'value1') + expect(getTurnState('key1')).toBe('value1') + }) + + it('returns undefined for unknown keys', () => { + startNewTurn() + expect(getTurnState('unknown')).toBeUndefined() + }) + }) + + describe('getTurnHistory', () => { + it('returns turn history', () => { + startNewTurn() + startNewTurn() + + const history = getTurnHistory() + expect(history.length).toBe(2) + }) + }) + + describe('getRecentTurns', () => { + it('returns recent turns', () => { + for (let i = 0; i < 5; i++) startNewTurn() + + const recent = getRecentTurns(3) + expect(recent.length).toBe(3) + }) + }) + + describe('getMultiTurnStats', () => { + it('returns statistics', () => { + startNewTurn() + addMessageToTurn(createMessage('user', 'Test')) + + const stats = getMultiTurnStats() + expect(stats.totalTurns).toBe(1) + expect(stats.currentTurnActive).toBe(true) + }) + }) + + describe('createMultiTurnTracker', () => { + it('creates tracker with all methods', () => { + const tracker = createMultiTurnTracker() + expect(tracker.startTurn).toBeDefined() + expect(tracker.addMessage).toBeDefined() + expect(tracker.getStats).toBeDefined() + }) + + it('respects the maxTurns option', () => { + // Create a tracker with a very small maxTurns + createMultiTurnTracker({ maxTurns: 2 }) + + startNewTurn() // turn 1 + startNewTurn() // turn 2 + startNewTurn() // turn 3 - should drop turn 1 + + const history = getTurnHistory() + expect(history.length).toBe(2) + // The first remaining turn should be the 2nd one created + expect(history[0].turnId).toContain('turn_2') + }) + }) + }) \ No newline at end of file diff --git a/src/utils/multiTurnContext.ts b/src/utils/multiTurnContext.ts new file mode 100644 index 00000000..227e03fa --- /dev/null +++ b/src/utils/multiTurnContext.ts @@ -0,0 +1,149 @@ +/** + * Multi-Turn Context Tracking - Production Grade + * + * Tracks context across multiple tool use cycles. + * Preserves state between tool invocations. + */ + +import { roughTokenCountEstimation } from '../services/tokenEstimation.js' +import type { Message } from '../types/message.js' + +export interface TurnContext { + turnId: string + startTime: number + messages: Message[] + toolCalls: ToolCallInfo[] + state: Map + tokens: number +} + +export interface ToolCallInfo { + id: string + name: string + input: Record + result?: string + timestamp: number +} + +export interface MultiTurnOptions { + maxTurns?: number + maxTokensPerTurn?: number + preserveState?: boolean +} + +const DEFAULT_OPTIONS: Required = { + maxTurns: 10, + maxTokensPerTurn: 5000, + preserveState: true, +} + +let turnHistory: TurnContext[] = [] +let currentTurn: TurnContext | null = null +let turnCounter = 0 +let activeOptions: Required = { ...DEFAULT_OPTIONS } + +export function startNewTurn(): TurnContext { + const turn: TurnContext = { + turnId: `turn_${++turnCounter}_${Date.now()}`, + startTime: Date.now(), + messages: [], + toolCalls: [], + state: new Map(), + tokens: 0, + } + + if (turnHistory.length >= activeOptions.maxTurns) { + turnHistory = turnHistory.slice(-activeOptions.maxTurns + 1) + } + + currentTurn = turn + turnHistory.push(turn) + + return turn +} + +export function getCurrentTurn(): TurnContext | null { + return currentTurn +} + +export function addMessageToTurn(message: Message): void { + if (!currentTurn) { + currentTurn = startNewTurn() + } + + const content = typeof message.message?.content === 'string' + ? message.message.content + : JSON.stringify(message.message?.content) + + currentTurn.messages.push(message) + currentTurn.tokens += roughTokenCountEstimation(content) +} + +export function addToolCallToTurn(toolCall: ToolCallInfo): void { + if (!currentTurn) { + currentTurn = startNewTurn() + } + + currentTurn.toolCalls.push(toolCall) +} + +export function setTurnState(key: string, value: unknown): void { + if (!currentTurn) return + currentTurn.state.set(key, value) +} + +export function getTurnState(key: string): T | undefined { + if (!currentTurn) return undefined + return currentTurn.state.get(key) as T | undefined +} + +export function getTurnHistory(): TurnContext[] { + return turnHistory +} + +export function getRecentTurns(count: number): TurnContext[] { + return turnHistory.slice(-count) +} + +export function getTurnById(turnId: string): TurnContext | undefined { + return turnHistory.find(t => t.turnId === turnId) +} + +export function getCrossTurnContext(key: string): unknown[] { + return turnHistory.map(t => t.state.get(key)).filter(v => v !== undefined) +} + +export function getMultiTurnStats() { + return { + totalTurns: turnHistory.length, + currentTurnActive: currentTurn !== null, + totalTokens: turnHistory.reduce((sum, t) => sum + t.tokens, 0), + totalToolCalls: turnHistory.reduce((sum, t) => sum + t.toolCalls.length, 0), + } +} + +export function clearTurnHistory(): void { + turnHistory = [] + currentTurn = null +} + +export function resetMultiTurnState(): void { + clearTurnHistory() + turnCounter = 0 +} + +export function createMultiTurnTracker(options: MultiTurnOptions = {}) { + activeOptions = { ...DEFAULT_OPTIONS, ...options } + return { + startTurn: startNewTurn, + getCurrentTurn, + addMessage: addMessageToTurn, + addToolCall: addToolCallToTurn, + setState: (k: string, v: unknown) => setTurnState(k, v), + getState: (k: string) => getTurnState(k), + getHistory: getTurnHistory, + getRecent: (n: number) => getRecentTurns(n), + getStats: getMultiTurnStats, + reset: resetMultiTurnState, + } +} \ No newline at end of file