feat: add streaming token counter (#797)
* feat: add streaming token counter
- Add StreamingTokenCounter for real-time token counting during generation
- Tracks output tokens as they arrive from stream
- Calculates tokens per second rate
- Add tests (4 passing)
PR 4A: Streaming Token Counter (Features 1.2, 1.7)
* refactor: move StreamingTokenCounter to separate file
- Extract StreamingTokenCounter from tokens.ts to streamingTokenCounter.ts
- Add getEstimatedRemainingTokens() method
- Update test import
* fix: word-boundary token counting for stable stream totals
- Accumulate raw content, count only at word boundaries
- Eliminates instability from arbitrary chunk boundaries
- Add finalize() to flush remaining content on stream end
- Add characterCount getter for raw content tracking
- Rename getEstimatedRemainingTokens -> getEstimatedGenerationTimeMs
- Add comprehensive tests
* fix: update streamingTokens test for word-boundary API
- Add finalize() call before checking output tokens
- Use characterCount for interim checks
- Add spaces to trigger word boundary counting
* fix: add estimateRemainingTokens/Time methods
- Add estimateRemainingTokens(target) method
- Add estimateRemainingTimeMs(target) method
- Covers non-blocking: now properly estimates remaining tokens
* fix: PR 797 - fix word boundary counting, consolidate tests
Blockers (Vasanthdev2004):
- recountAtWordBoundary now searches forward from lastCountedIndex+1
- Finds NEXT space after already-counted region, not before it
- Provides accurate live token counts during streaming, not just finalize()
Non-blocking (gnanam1990):
- Delete streamingTokens.test.ts, merge tests into streamingTokenCounter.test.ts
- Added interim-counting test to verify counting updates during streaming
* fix: PR 797 - fix word boundary advancement after space
Blocking:
- Fix recountAtWordBoundary to skip past space when searching for next boundary
- After counting at a space, indexOf(' ') returns 0 (the space itself)
- Now starts search from index 1 to find the NEXT word boundary
- Short chunks now properly trigger count advancement
Non-blocking:
- Add test verifying count increases after each word boundary
- Add test for space-skipping behavior
This commit is contained in:
committed by
GitHub
parent
92d297e50e
commit
0ca4333537
165
src/utils/streamingTokenCounter.test.ts
Normal file
165
src/utils/streamingTokenCounter.test.ts
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
import { describe, expect, it } from 'bun:test'
|
||||||
|
import { StreamingTokenCounter } from './streamingTokenCounter.js'
|
||||||
|
|
||||||
|
describe('StreamingTokenCounter', () => {
|
||||||
|
describe('start', () => {
|
||||||
|
it('resets state and sets input tokens', () => {
|
||||||
|
const counter = new StreamingTokenCounter()
|
||||||
|
counter.start(1000)
|
||||||
|
expect(counter.total).toBe(1000)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('addChunk', () => {
|
||||||
|
it('accumulates content', () => {
|
||||||
|
const counter = new StreamingTokenCounter()
|
||||||
|
counter.start(500)
|
||||||
|
counter.addChunk('Hello world ')
|
||||||
|
expect(counter.characterCount).toBe(12)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('accumulates multiple chunks', () => {
|
||||||
|
const counter = new StreamingTokenCounter()
|
||||||
|
counter.start(500)
|
||||||
|
counter.addChunk('Hello ')
|
||||||
|
counter.addChunk('world ')
|
||||||
|
expect(counter.characterCount).toBeGreaterThanOrEqual(10)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('handles empty chunks', () => {
|
||||||
|
const counter = new StreamingTokenCounter()
|
||||||
|
counter.start(50)
|
||||||
|
counter.addChunk(undefined)
|
||||||
|
counter.addChunk('')
|
||||||
|
expect(counter.output).toBe(0)
|
||||||
|
expect(counter.total).toBe(50)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('updates cached token count at word boundaries during streaming', () => {
|
||||||
|
const counter = new StreamingTokenCounter()
|
||||||
|
counter.start(100)
|
||||||
|
counter.addChunk('Hello ')
|
||||||
|
const afterFirst = counter.output
|
||||||
|
expect(afterFirst).toBeGreaterThan(0)
|
||||||
|
counter.addChunk('world ')
|
||||||
|
const afterSecond = counter.output
|
||||||
|
expect(afterSecond).toBeGreaterThan(afterFirst)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('advances count past space after word boundary', () => {
|
||||||
|
const counter = new StreamingTokenCounter()
|
||||||
|
counter.start()
|
||||||
|
counter.addChunk('Hello ') // counts Hello
|
||||||
|
const count1 = counter.output
|
||||||
|
|
||||||
|
counter.addChunk('world') // short chunk, no space - shouldn't advance
|
||||||
|
const count2 = counter.output
|
||||||
|
expect(count2).toBe(count1)
|
||||||
|
|
||||||
|
counter.addChunk(' ') // space triggers count
|
||||||
|
const count3 = counter.output
|
||||||
|
expect(count3).toBeGreaterThan(count2)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('finalize', () => {
|
||||||
|
it('counts all content after finalize', () => {
|
||||||
|
const counter = new StreamingTokenCounter()
|
||||||
|
counter.start(500)
|
||||||
|
counter.addChunk('Hello world')
|
||||||
|
counter.finalize()
|
||||||
|
expect(counter.output).toBeGreaterThan(0)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('counts tokens after finalize', () => {
|
||||||
|
const counter = new StreamingTokenCounter()
|
||||||
|
counter.start(100)
|
||||||
|
counter.addChunk('Hello ')
|
||||||
|
counter.addChunk('world ')
|
||||||
|
counter.finalize()
|
||||||
|
expect(counter.output).toBeGreaterThan(0)
|
||||||
|
expect(counter.total).toBe(100 + counter.output)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('total', () => {
|
||||||
|
it('sums input and output after finalize', () => {
|
||||||
|
const counter = new StreamingTokenCounter()
|
||||||
|
counter.start(500)
|
||||||
|
counter.addChunk('Test content ')
|
||||||
|
counter.finalize()
|
||||||
|
expect(counter.total).toBeGreaterThanOrEqual(500)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('tokensPerSecond', () => {
|
||||||
|
it('calculates tokens per second', () => {
|
||||||
|
const counter = new StreamingTokenCounter()
|
||||||
|
counter.start()
|
||||||
|
counter.addChunk('123456789 ')
|
||||||
|
expect(typeof counter.tokensPerSecond).toBe('number')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('estimateRemainingTokens', () => {
|
||||||
|
it('returns positive when under target', () => {
|
||||||
|
const counter = new StreamingTokenCounter()
|
||||||
|
counter.start(500)
|
||||||
|
counter.addChunk('Hello ')
|
||||||
|
counter.finalize()
|
||||||
|
expect(counter.estimateRemainingTokens(1000)).toBeGreaterThan(0)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('returns 0 when at or over target', () => {
|
||||||
|
const counter = new StreamingTokenCounter()
|
||||||
|
counter.start(500)
|
||||||
|
counter.addChunk('Hello ')
|
||||||
|
counter.finalize()
|
||||||
|
expect(counter.estimateRemainingTokens(1)).toBe(0)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('estimateRemainingTimeMs', () => {
|
||||||
|
it('returns estimate based on rate', () => {
|
||||||
|
const counter = new StreamingTokenCounter()
|
||||||
|
counter.start()
|
||||||
|
counter.addChunk('Hello world ')
|
||||||
|
expect(counter.estimateRemainingTimeMs(100)).toBeGreaterThanOrEqual(0)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('characterCount', () => {
|
||||||
|
it('returns accumulated character count', () => {
|
||||||
|
const counter = new StreamingTokenCounter()
|
||||||
|
counter.addChunk('Hello')
|
||||||
|
expect(counter.characterCount).toBe(5)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('accumulates content from chunks', () => {
|
||||||
|
const counter = new StreamingTokenCounter()
|
||||||
|
counter.start(100)
|
||||||
|
counter.addChunk('Hello ')
|
||||||
|
counter.addChunk('world ')
|
||||||
|
expect(counter.characterCount).toBeGreaterThan(0)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('reset', () => {
|
||||||
|
it('clears all state', () => {
|
||||||
|
const counter = new StreamingTokenCounter()
|
||||||
|
counter.start(500)
|
||||||
|
counter.addChunk('Hello world ')
|
||||||
|
counter.reset()
|
||||||
|
expect(counter.characterCount).toBe(0)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('resets correctly', () => {
|
||||||
|
const counter = new StreamingTokenCounter()
|
||||||
|
counter.start(100)
|
||||||
|
counter.addChunk('test ')
|
||||||
|
counter.reset()
|
||||||
|
expect(counter.characterCount).toBe(0)
|
||||||
|
expect(counter.total).toBe(0)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
133
src/utils/streamingTokenCounter.ts
Normal file
133
src/utils/streamingTokenCounter.ts
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
/**
|
||||||
|
* Streaming Token Counter - Accurate token counting during generation
|
||||||
|
*
|
||||||
|
* Accumulates raw content and counts tokens at consistent boundaries
|
||||||
|
* to avoid dependency on arbitrary chunk boundaries.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { roughTokenCountEstimation } from '../services/tokenEstimation.js'
|
||||||
|
|
||||||
|
export class StreamingTokenCounter {
|
||||||
|
private inputTokens = 0
|
||||||
|
private accumulatedContent = ''
|
||||||
|
private lastCountedIndex = 0
|
||||||
|
private cachedOutputTokens = 0
|
||||||
|
private startTime = 0
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Start tracking a new stream
|
||||||
|
* @param initialInputTokens - Token count for system prompt + history
|
||||||
|
*/
|
||||||
|
start(initialInputTokens?: number): void {
|
||||||
|
this.reset()
|
||||||
|
this.startTime = Date.now()
|
||||||
|
this.inputTokens = initialInputTokens ?? 0
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add content from a streaming chunk
|
||||||
|
* Accumulates raw content, counting only at word boundaries
|
||||||
|
* to avoid instability from arbitrary chunk boundaries.
|
||||||
|
*/
|
||||||
|
addChunk(deltaContent?: string): void {
|
||||||
|
if (deltaContent) {
|
||||||
|
this.accumulatedContent += deltaContent
|
||||||
|
this.recountAtWordBoundary()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Recount tokens at word boundaries for stability.
|
||||||
|
* Only counts after whitespace to avoid mid-word splits.
|
||||||
|
*/
|
||||||
|
private recountAtWordBoundary(): void {
|
||||||
|
const content = this.accumulatedContent
|
||||||
|
const unprocessedContent = content.slice(this.lastCountedIndex)
|
||||||
|
|
||||||
|
const searchStart = unprocessedContent[0] === ' ' ? 1 : 0
|
||||||
|
const nextSpaceIndex = unprocessedContent.indexOf(' ', searchStart)
|
||||||
|
|
||||||
|
const shouldCount =
|
||||||
|
nextSpaceIndex > 0 ||
|
||||||
|
unprocessedContent.length > 50 ||
|
||||||
|
unprocessedContent.length === 0
|
||||||
|
|
||||||
|
let boundaryIndex: number
|
||||||
|
if (nextSpaceIndex > 0) {
|
||||||
|
boundaryIndex = this.lastCountedIndex + nextSpaceIndex
|
||||||
|
} else if (unprocessedContent.length > 50) {
|
||||||
|
boundaryIndex = content.length
|
||||||
|
} else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const toCount = content.slice(0, boundaryIndex)
|
||||||
|
this.cachedOutputTokens = roughTokenCountEstimation(toCount)
|
||||||
|
this.lastCountedIndex = boundaryIndex
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Flush remaining content and finalize count.
|
||||||
|
* Call this when stream completes.
|
||||||
|
*/
|
||||||
|
finalize(): number {
|
||||||
|
if (this.accumulatedContent.length > this.lastCountedIndex) {
|
||||||
|
this.cachedOutputTokens = roughTokenCountEstimation(this.accumulatedContent)
|
||||||
|
this.lastCountedIndex = this.accumulatedContent.length
|
||||||
|
}
|
||||||
|
return this.cachedOutputTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Get total tokens (input + output) */
|
||||||
|
get total(): number {
|
||||||
|
return this.inputTokens + this.cachedOutputTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Get output tokens only */
|
||||||
|
get output(): number {
|
||||||
|
return this.cachedOutputTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Get elapsed time in milliseconds */
|
||||||
|
get elapsedMs(): number {
|
||||||
|
return this.startTime > 0 ? Date.now() - this.startTime : 0
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Get tokens per second generation rate */
|
||||||
|
get tokensPerSecond(): number {
|
||||||
|
if (this.elapsedMs === 0) return 0
|
||||||
|
return (this.cachedOutputTokens / this.elapsedMs) * 1000
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Get estimated total generation time based on current rate */
|
||||||
|
getEstimatedGenerationTimeMs(): number {
|
||||||
|
if (this.tokensPerSecond === 0) return 0
|
||||||
|
return Math.round((this.cachedOutputTokens / this.tokensPerSecond) * 1000)
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Estimate remaining tokens until target output size */
|
||||||
|
estimateRemainingTokens(targetOutputTokens: number): number {
|
||||||
|
return Math.max(0, targetOutputTokens - this.cachedOutputTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Estimate remaining time based on target output tokens */
|
||||||
|
estimateRemainingTimeMs(targetOutputTokens: number): number {
|
||||||
|
if (this.tokensPerSecond === 0) return 0
|
||||||
|
const remaining = this.estimateRemainingTokens(targetOutputTokens)
|
||||||
|
return Math.round((remaining / this.tokensPerSecond) * 1000)
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Get character count for raw content */
|
||||||
|
get characterCount(): number {
|
||||||
|
return this.accumulatedContent.length
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Reset counter */
|
||||||
|
reset(): void {
|
||||||
|
this.inputTokens = 0
|
||||||
|
this.accumulatedContent = ''
|
||||||
|
this.lastCountedIndex = 0
|
||||||
|
this.cachedOutputTokens = 0
|
||||||
|
this.startTime = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user