Merge pull request #207 from alamnahin/feat/router-large-request-modeling

fix(router): use large message size when selecting models
This commit is contained in:
Kevin Codex
2026-04-03 08:58:29 +08:00
committed by GitHub
2 changed files with 33 additions and 2 deletions

View File

@@ -228,9 +228,14 @@ class SmartRouter:
return min(available, key=lambda p: p.score(self.strategy)) return min(available, key=lambda p: p.score(self.strategy))
def get_model_for_provider( def get_model_for_provider(
self, provider: Provider, claude_model: str self,
provider: Provider,
claude_model: str,
is_large_request: bool = False,
) -> str: ) -> str:
"""Map a Claude model name to the provider's actual model.""" """Map a Claude model name to the provider's actual model."""
if is_large_request:
return provider.big_model
is_large = any( is_large = any(
keyword in claude_model.lower() keyword in claude_model.lower()
for keyword in ["opus", "sonnet", "large", "big"] for keyword in ["opus", "sonnet", "large", "big"]
@@ -289,7 +294,11 @@ class SmartRouter:
) )
provider = min(available, key=lambda p: p.score(self.strategy)) provider = min(available, key=lambda p: p.score(self.strategy))
model = self.get_model_for_provider(provider, claude_model) model = self.get_model_for_provider(
provider,
claude_model,
is_large_request=large,
)
logger.debug( logger.debug(
f"SmartRouter: routing to {provider.name}/{model} " f"SmartRouter: routing to {provider.name}/{model} "

View File

@@ -13,6 +13,11 @@ from smart_router import SmartRouter, Provider
# ── Fixtures ────────────────────────────────────────────────────────────────── # ── Fixtures ──────────────────────────────────────────────────────────────────
@pytest.fixture(autouse=True)
def fake_api_key(monkeypatch):
monkeypatch.setenv("FAKE_KEY", "test-key")
def make_provider(name, healthy=True, configured=True, def make_provider(name, healthy=True, configured=True,
latency=100.0, cost=0.002, errors=0, requests=0): latency=100.0, cost=0.002, errors=0, requests=0):
p = Provider( p = Provider(
@@ -122,6 +127,13 @@ def test_get_model_large_request():
assert model == "openai-big" assert model == "openai-big"
def test_get_model_large_message_overrides_claude_label():
p = make_provider("openai")
r = make_router()
model = r.get_model_for_provider(p, "claude-haiku", is_large_request=True)
assert model == "openai-big"
def test_get_model_small_request(): def test_get_model_small_request():
p = make_provider("openai") p = make_provider("openai")
r = make_router() r = make_router()
@@ -140,6 +152,16 @@ async def test_route_returns_best_provider():
assert result["provider"] == "cheap" assert result["provider"] == "cheap"
@pytest.mark.asyncio
async def test_route_uses_big_model_for_large_message_bodies():
p = make_provider("openai")
r = make_router(providers=[p])
result = await r.route([
{"role": "user", "content": "x" * 3001},
], "claude-haiku")
assert result["model"] == "openai-big"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_route_raises_when_no_providers(): async def test_route_raises_when_no_providers():
p = make_provider("a", healthy=False) p = make_provider("a", healthy=False)