fix(router): use large request size for model selection
This commit is contained in:
@@ -228,9 +228,14 @@ class SmartRouter:
|
||||
return min(available, key=lambda p: p.score(self.strategy))
|
||||
|
||||
def get_model_for_provider(
|
||||
self, provider: Provider, claude_model: str
|
||||
self,
|
||||
provider: Provider,
|
||||
claude_model: str,
|
||||
is_large_request: bool = False,
|
||||
) -> str:
|
||||
"""Map a Claude model name to the provider's actual model."""
|
||||
if is_large_request:
|
||||
return provider.big_model
|
||||
is_large = any(
|
||||
keyword in claude_model.lower()
|
||||
for keyword in ["opus", "sonnet", "large", "big"]
|
||||
@@ -289,7 +294,11 @@ class SmartRouter:
|
||||
)
|
||||
|
||||
provider = min(available, key=lambda p: p.score(self.strategy))
|
||||
model = self.get_model_for_provider(provider, claude_model)
|
||||
model = self.get_model_for_provider(
|
||||
provider,
|
||||
claude_model,
|
||||
is_large_request=large,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"SmartRouter: routing to {provider.name}/{model} "
|
||||
|
||||
@@ -7,6 +7,7 @@ Run: pytest test_smart_router.py -v
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from smart_router import SmartRouter, Provider
|
||||
|
||||
@@ -27,7 +28,9 @@ def make_provider(name, healthy=True, configured=True,
|
||||
p.avg_latency_ms = latency
|
||||
p.error_count = errors
|
||||
p.request_count = requests
|
||||
if not configured:
|
||||
if configured:
|
||||
os.environ.setdefault("FAKE_KEY", "test-key")
|
||||
else:
|
||||
p.api_key_env = "" # makes is_configured False for non-ollama
|
||||
return p
|
||||
|
||||
@@ -122,6 +125,13 @@ def test_get_model_large_request():
|
||||
assert model == "openai-big"
|
||||
|
||||
|
||||
def test_get_model_large_message_overrides_claude_label():
|
||||
p = make_provider("openai")
|
||||
r = make_router()
|
||||
model = r.get_model_for_provider(p, "claude-haiku", is_large_request=True)
|
||||
assert model == "openai-big"
|
||||
|
||||
|
||||
def test_get_model_small_request():
|
||||
p = make_provider("openai")
|
||||
r = make_router()
|
||||
@@ -140,6 +150,16 @@ async def test_route_returns_best_provider():
|
||||
assert result["provider"] == "cheap"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_uses_big_model_for_large_message_bodies():
|
||||
p = make_provider("openai")
|
||||
r = make_router(providers=[p])
|
||||
result = await r.route([
|
||||
{"role": "user", "content": "x" * 3001},
|
||||
], "claude-haiku")
|
||||
assert result["model"] == "openai-big"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_raises_when_no_providers():
|
||||
p = make_provider("a", healthy=False)
|
||||
|
||||
Reference in New Issue
Block a user