Merge pull request #207 from alamnahin/feat/router-large-request-modeling
fix(router): use large message size when selecting models
This commit is contained in:
@@ -228,9 +228,14 @@ class SmartRouter:
|
|||||||
return min(available, key=lambda p: p.score(self.strategy))
|
return min(available, key=lambda p: p.score(self.strategy))
|
||||||
|
|
||||||
def get_model_for_provider(
|
def get_model_for_provider(
|
||||||
self, provider: Provider, claude_model: str
|
self,
|
||||||
|
provider: Provider,
|
||||||
|
claude_model: str,
|
||||||
|
is_large_request: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Map a Claude model name to the provider's actual model."""
|
"""Map a Claude model name to the provider's actual model."""
|
||||||
|
if is_large_request:
|
||||||
|
return provider.big_model
|
||||||
is_large = any(
|
is_large = any(
|
||||||
keyword in claude_model.lower()
|
keyword in claude_model.lower()
|
||||||
for keyword in ["opus", "sonnet", "large", "big"]
|
for keyword in ["opus", "sonnet", "large", "big"]
|
||||||
@@ -289,7 +294,11 @@ class SmartRouter:
|
|||||||
)
|
)
|
||||||
|
|
||||||
provider = min(available, key=lambda p: p.score(self.strategy))
|
provider = min(available, key=lambda p: p.score(self.strategy))
|
||||||
model = self.get_model_for_provider(provider, claude_model)
|
model = self.get_model_for_provider(
|
||||||
|
provider,
|
||||||
|
claude_model,
|
||||||
|
is_large_request=large,
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"SmartRouter: routing to {provider.name}/{model} "
|
f"SmartRouter: routing to {provider.name}/{model} "
|
||||||
|
|||||||
@@ -13,6 +13,11 @@ from smart_router import SmartRouter, Provider
|
|||||||
|
|
||||||
# ── Fixtures ──────────────────────────────────────────────────────────────────
|
# ── Fixtures ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def fake_api_key(monkeypatch):
|
||||||
|
monkeypatch.setenv("FAKE_KEY", "test-key")
|
||||||
|
|
||||||
def make_provider(name, healthy=True, configured=True,
|
def make_provider(name, healthy=True, configured=True,
|
||||||
latency=100.0, cost=0.002, errors=0, requests=0):
|
latency=100.0, cost=0.002, errors=0, requests=0):
|
||||||
p = Provider(
|
p = Provider(
|
||||||
@@ -122,6 +127,13 @@ def test_get_model_large_request():
|
|||||||
assert model == "openai-big"
|
assert model == "openai-big"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_model_large_message_overrides_claude_label():
|
||||||
|
p = make_provider("openai")
|
||||||
|
r = make_router()
|
||||||
|
model = r.get_model_for_provider(p, "claude-haiku", is_large_request=True)
|
||||||
|
assert model == "openai-big"
|
||||||
|
|
||||||
|
|
||||||
def test_get_model_small_request():
|
def test_get_model_small_request():
|
||||||
p = make_provider("openai")
|
p = make_provider("openai")
|
||||||
r = make_router()
|
r = make_router()
|
||||||
@@ -140,6 +152,16 @@ async def test_route_returns_best_provider():
|
|||||||
assert result["provider"] == "cheap"
|
assert result["provider"] == "cheap"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_route_uses_big_model_for_large_message_bodies():
|
||||||
|
p = make_provider("openai")
|
||||||
|
r = make_router(providers=[p])
|
||||||
|
result = await r.route([
|
||||||
|
{"role": "user", "content": "x" * 3001},
|
||||||
|
], "claude-haiku")
|
||||||
|
assert result["model"] == "openai-big"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_route_raises_when_no_providers():
|
async def test_route_raises_when_no_providers():
|
||||||
p = make_provider("a", healthy=False)
|
p = make_provider("a", healthy=False)
|
||||||
|
|||||||
Reference in New Issue
Block a user