Merge pull request #207 from alamnahin/feat/router-large-request-modeling
fix(router): use large message size when selecting models
This commit is contained in:
@@ -228,9 +228,14 @@ class SmartRouter:
|
||||
return min(available, key=lambda p: p.score(self.strategy))
|
||||
|
||||
def get_model_for_provider(
|
||||
self, provider: Provider, claude_model: str
|
||||
self,
|
||||
provider: Provider,
|
||||
claude_model: str,
|
||||
is_large_request: bool = False,
|
||||
) -> str:
|
||||
"""Map a Claude model name to the provider's actual model."""
|
||||
if is_large_request:
|
||||
return provider.big_model
|
||||
is_large = any(
|
||||
keyword in claude_model.lower()
|
||||
for keyword in ["opus", "sonnet", "large", "big"]
|
||||
@@ -289,7 +294,11 @@ class SmartRouter:
|
||||
)
|
||||
|
||||
provider = min(available, key=lambda p: p.score(self.strategy))
|
||||
model = self.get_model_for_provider(provider, claude_model)
|
||||
model = self.get_model_for_provider(
|
||||
provider,
|
||||
claude_model,
|
||||
is_large_request=large,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"SmartRouter: routing to {provider.name}/{model} "
|
||||
|
||||
@@ -13,6 +13,11 @@ from smart_router import SmartRouter, Provider
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def fake_api_key(monkeypatch):
|
||||
monkeypatch.setenv("FAKE_KEY", "test-key")
|
||||
|
||||
def make_provider(name, healthy=True, configured=True,
|
||||
latency=100.0, cost=0.002, errors=0, requests=0):
|
||||
p = Provider(
|
||||
@@ -122,6 +127,13 @@ def test_get_model_large_request():
|
||||
assert model == "openai-big"
|
||||
|
||||
|
||||
def test_get_model_large_message_overrides_claude_label():
|
||||
p = make_provider("openai")
|
||||
r = make_router()
|
||||
model = r.get_model_for_provider(p, "claude-haiku", is_large_request=True)
|
||||
assert model == "openai-big"
|
||||
|
||||
|
||||
def test_get_model_small_request():
|
||||
p = make_provider("openai")
|
||||
r = make_router()
|
||||
@@ -140,6 +152,16 @@ async def test_route_returns_best_provider():
|
||||
assert result["provider"] == "cheap"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_uses_big_model_for_large_message_bodies():
|
||||
p = make_provider("openai")
|
||||
r = make_router(providers=[p])
|
||||
result = await r.route([
|
||||
{"role": "user", "content": "x" * 3001},
|
||||
], "claude-haiku")
|
||||
assert result["model"] == "openai-big"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_raises_when_no_providers():
|
||||
p = make_provider("a", healthy=False)
|
||||
|
||||
Reference in New Issue
Block a user