diff --git a/src/utils/streamingTokenCounter.test.ts b/src/utils/streamingTokenCounter.test.ts new file mode 100644 index 00000000..f4cd5fd0 --- /dev/null +++ b/src/utils/streamingTokenCounter.test.ts @@ -0,0 +1,165 @@ +import { describe, expect, it } from 'bun:test' +import { StreamingTokenCounter } from './streamingTokenCounter.js' + +describe('StreamingTokenCounter', () => { + describe('start', () => { + it('resets state and sets input tokens', () => { + const counter = new StreamingTokenCounter() + counter.start(1000) + expect(counter.total).toBe(1000) + }) + }) + + describe('addChunk', () => { + it('accumulates content', () => { + const counter = new StreamingTokenCounter() + counter.start(500) + counter.addChunk('Hello world ') + expect(counter.characterCount).toBe(12) + }) + + it('accumulates multiple chunks', () => { + const counter = new StreamingTokenCounter() + counter.start(500) + counter.addChunk('Hello ') + counter.addChunk('world ') + expect(counter.characterCount).toBeGreaterThanOrEqual(10) + }) + + it('handles empty chunks', () => { + const counter = new StreamingTokenCounter() + counter.start(50) + counter.addChunk(undefined) + counter.addChunk('') + expect(counter.output).toBe(0) + expect(counter.total).toBe(50) + }) + + it('updates cached token count at word boundaries during streaming', () => { + const counter = new StreamingTokenCounter() + counter.start(100) + counter.addChunk('Hello ') + const afterFirst = counter.output + expect(afterFirst).toBeGreaterThan(0) + counter.addChunk('world ') + const afterSecond = counter.output + expect(afterSecond).toBeGreaterThan(afterFirst) + }) + + it('advances count past space after word boundary', () => { + const counter = new StreamingTokenCounter() + counter.start() + counter.addChunk('Hello ') // counts Hello + const count1 = counter.output + + counter.addChunk('world') // short chunk, no space - shouldn't advance + const count2 = counter.output + expect(count2).toBe(count1) + + counter.addChunk(' ') // space triggers count + const count3 = counter.output + expect(count3).toBeGreaterThan(count2) + }) + }) + + describe('finalize', () => { + it('counts all content after finalize', () => { + const counter = new StreamingTokenCounter() + counter.start(500) + counter.addChunk('Hello world') + counter.finalize() + expect(counter.output).toBeGreaterThan(0) + }) + + it('counts tokens after finalize', () => { + const counter = new StreamingTokenCounter() + counter.start(100) + counter.addChunk('Hello ') + counter.addChunk('world ') + counter.finalize() + expect(counter.output).toBeGreaterThan(0) + expect(counter.total).toBe(100 + counter.output) + }) + }) + + describe('total', () => { + it('sums input and output after finalize', () => { + const counter = new StreamingTokenCounter() + counter.start(500) + counter.addChunk('Test content ') + counter.finalize() + expect(counter.total).toBeGreaterThanOrEqual(500) + }) + }) + + describe('tokensPerSecond', () => { + it('calculates tokens per second', () => { + const counter = new StreamingTokenCounter() + counter.start() + counter.addChunk('123456789 ') + expect(typeof counter.tokensPerSecond).toBe('number') + }) + }) + + describe('estimateRemainingTokens', () => { + it('returns positive when under target', () => { + const counter = new StreamingTokenCounter() + counter.start(500) + counter.addChunk('Hello ') + counter.finalize() + expect(counter.estimateRemainingTokens(1000)).toBeGreaterThan(0) + }) + + it('returns 0 when at or over target', () => { + const counter = new StreamingTokenCounter() + counter.start(500) + counter.addChunk('Hello ') + counter.finalize() + expect(counter.estimateRemainingTokens(1)).toBe(0) + }) + }) + + describe('estimateRemainingTimeMs', () => { + it('returns estimate based on rate', () => { + const counter = new StreamingTokenCounter() + counter.start() + counter.addChunk('Hello world ') + expect(counter.estimateRemainingTimeMs(100)).toBeGreaterThanOrEqual(0) + }) + }) + + describe('characterCount', () => { + it('returns accumulated character count', () => { + const counter = new StreamingTokenCounter() + counter.addChunk('Hello') + expect(counter.characterCount).toBe(5) + }) + + it('accumulates content from chunks', () => { + const counter = new StreamingTokenCounter() + counter.start(100) + counter.addChunk('Hello ') + counter.addChunk('world ') + expect(counter.characterCount).toBeGreaterThan(0) + }) + }) + + describe('reset', () => { + it('clears all state', () => { + const counter = new StreamingTokenCounter() + counter.start(500) + counter.addChunk('Hello world ') + counter.reset() + expect(counter.characterCount).toBe(0) + }) + + it('resets correctly', () => { + const counter = new StreamingTokenCounter() + counter.start(100) + counter.addChunk('test ') + counter.reset() + expect(counter.characterCount).toBe(0) + expect(counter.total).toBe(0) + }) + }) +}) \ No newline at end of file diff --git a/src/utils/streamingTokenCounter.ts b/src/utils/streamingTokenCounter.ts new file mode 100644 index 00000000..09ba6a4a --- /dev/null +++ b/src/utils/streamingTokenCounter.ts @@ -0,0 +1,133 @@ +/** + * Streaming Token Counter - Accurate token counting during generation + * + * Accumulates raw content and counts tokens at consistent boundaries + * to avoid dependency on arbitrary chunk boundaries. + */ + +import { roughTokenCountEstimation } from '../services/tokenEstimation.js' + +export class StreamingTokenCounter { + private inputTokens = 0 + private accumulatedContent = '' + private lastCountedIndex = 0 + private cachedOutputTokens = 0 + private startTime = 0 + + /** + * Start tracking a new stream + * @param initialInputTokens - Token count for system prompt + history + */ + start(initialInputTokens?: number): void { + this.reset() + this.startTime = Date.now() + this.inputTokens = initialInputTokens ?? 0 + } + + /** + * Add content from a streaming chunk + * Accumulates raw content, counting only at word boundaries + * to avoid instability from arbitrary chunk boundaries. + */ + addChunk(deltaContent?: string): void { + if (deltaContent) { + this.accumulatedContent += deltaContent + this.recountAtWordBoundary() + } + } + + /** + * Recount tokens at word boundaries for stability. + * Only counts after whitespace to avoid mid-word splits. + */ + private recountAtWordBoundary(): void { + const content = this.accumulatedContent + const unprocessedContent = content.slice(this.lastCountedIndex) + + const searchStart = unprocessedContent[0] === ' ' ? 1 : 0 + const nextSpaceIndex = unprocessedContent.indexOf(' ', searchStart) + + const shouldCount = + nextSpaceIndex > 0 || + unprocessedContent.length > 50 || + unprocessedContent.length === 0 + + let boundaryIndex: number + if (nextSpaceIndex > 0) { + boundaryIndex = this.lastCountedIndex + nextSpaceIndex + } else if (unprocessedContent.length > 50) { + boundaryIndex = content.length + } else { + return + } + + const toCount = content.slice(0, boundaryIndex) + this.cachedOutputTokens = roughTokenCountEstimation(toCount) + this.lastCountedIndex = boundaryIndex + } + + /** + * Flush remaining content and finalize count. + * Call this when stream completes. + */ + finalize(): number { + if (this.accumulatedContent.length > this.lastCountedIndex) { + this.cachedOutputTokens = roughTokenCountEstimation(this.accumulatedContent) + this.lastCountedIndex = this.accumulatedContent.length + } + return this.cachedOutputTokens + } + + /** Get total tokens (input + output) */ + get total(): number { + return this.inputTokens + this.cachedOutputTokens + } + + /** Get output tokens only */ + get output(): number { + return this.cachedOutputTokens + } + + /** Get elapsed time in milliseconds */ + get elapsedMs(): number { + return this.startTime > 0 ? Date.now() - this.startTime : 0 + } + + /** Get tokens per second generation rate */ + get tokensPerSecond(): number { + if (this.elapsedMs === 0) return 0 + return (this.cachedOutputTokens / this.elapsedMs) * 1000 + } + + /** Get estimated total generation time based on current rate */ + getEstimatedGenerationTimeMs(): number { + if (this.tokensPerSecond === 0) return 0 + return Math.round((this.cachedOutputTokens / this.tokensPerSecond) * 1000) + } + + /** Estimate remaining tokens until target output size */ + estimateRemainingTokens(targetOutputTokens: number): number { + return Math.max(0, targetOutputTokens - this.cachedOutputTokens) + } + + /** Estimate remaining time based on target output tokens */ + estimateRemainingTimeMs(targetOutputTokens: number): number { + if (this.tokensPerSecond === 0) return 0 + const remaining = this.estimateRemainingTokens(targetOutputTokens) + return Math.round((remaining / this.tokensPerSecond) * 1000) + } + + /** Get character count for raw content */ + get characterCount(): number { + return this.accumulatedContent.length + } + + /** Reset counter */ + reset(): void { + this.inputTokens = 0 + this.accumulatedContent = '' + this.lastCountedIndex = 0 + this.cachedOutputTokens = 0 + this.startTime = 0 + } +} \ No newline at end of file