fix(router): use large request size for model selection
This commit is contained in:
@@ -228,9 +228,14 @@ class SmartRouter:
|
|||||||
return min(available, key=lambda p: p.score(self.strategy))
|
return min(available, key=lambda p: p.score(self.strategy))
|
||||||
|
|
||||||
def get_model_for_provider(
|
def get_model_for_provider(
|
||||||
self, provider: Provider, claude_model: str
|
self,
|
||||||
|
provider: Provider,
|
||||||
|
claude_model: str,
|
||||||
|
is_large_request: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Map a Claude model name to the provider's actual model."""
|
"""Map a Claude model name to the provider's actual model."""
|
||||||
|
if is_large_request:
|
||||||
|
return provider.big_model
|
||||||
is_large = any(
|
is_large = any(
|
||||||
keyword in claude_model.lower()
|
keyword in claude_model.lower()
|
||||||
for keyword in ["opus", "sonnet", "large", "big"]
|
for keyword in ["opus", "sonnet", "large", "big"]
|
||||||
@@ -289,7 +294,11 @@ class SmartRouter:
|
|||||||
)
|
)
|
||||||
|
|
||||||
provider = min(available, key=lambda p: p.score(self.strategy))
|
provider = min(available, key=lambda p: p.score(self.strategy))
|
||||||
model = self.get_model_for_provider(provider, claude_model)
|
model = self.get_model_for_provider(
|
||||||
|
provider,
|
||||||
|
claude_model,
|
||||||
|
is_large_request=large,
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"SmartRouter: routing to {provider.name}/{model} "
|
f"SmartRouter: routing to {provider.name}/{model} "
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ Run: pytest test_smart_router.py -v
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
from smart_router import SmartRouter, Provider
|
from smart_router import SmartRouter, Provider
|
||||||
|
|
||||||
@@ -27,7 +28,9 @@ def make_provider(name, healthy=True, configured=True,
|
|||||||
p.avg_latency_ms = latency
|
p.avg_latency_ms = latency
|
||||||
p.error_count = errors
|
p.error_count = errors
|
||||||
p.request_count = requests
|
p.request_count = requests
|
||||||
if not configured:
|
if configured:
|
||||||
|
os.environ.setdefault("FAKE_KEY", "test-key")
|
||||||
|
else:
|
||||||
p.api_key_env = "" # makes is_configured False for non-ollama
|
p.api_key_env = "" # makes is_configured False for non-ollama
|
||||||
return p
|
return p
|
||||||
|
|
||||||
@@ -122,6 +125,13 @@ def test_get_model_large_request():
|
|||||||
assert model == "openai-big"
|
assert model == "openai-big"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_model_large_message_overrides_claude_label():
|
||||||
|
p = make_provider("openai")
|
||||||
|
r = make_router()
|
||||||
|
model = r.get_model_for_provider(p, "claude-haiku", is_large_request=True)
|
||||||
|
assert model == "openai-big"
|
||||||
|
|
||||||
|
|
||||||
def test_get_model_small_request():
|
def test_get_model_small_request():
|
||||||
p = make_provider("openai")
|
p = make_provider("openai")
|
||||||
r = make_router()
|
r = make_router()
|
||||||
@@ -140,6 +150,16 @@ async def test_route_returns_best_provider():
|
|||||||
assert result["provider"] == "cheap"
|
assert result["provider"] == "cheap"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_route_uses_big_model_for_large_message_bodies():
|
||||||
|
p = make_provider("openai")
|
||||||
|
r = make_router(providers=[p])
|
||||||
|
result = await r.route([
|
||||||
|
{"role": "user", "content": "x" * 3001},
|
||||||
|
], "claude-haiku")
|
||||||
|
assert result["model"] == "openai-big"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_route_raises_when_no_providers():
|
async def test_route_raises_when_no_providers():
|
||||||
p = make_provider("a", healthy=False)
|
p = make_provider("a", healthy=False)
|
||||||
|
|||||||
Reference in New Issue
Block a user