fix(router): use large request size for model selection

This commit is contained in:
Md.Nahin Alam
2026-04-03 03:45:33 +06:00
parent 63ad0196d6
commit 43deb49c2c
2 changed files with 32 additions and 3 deletions

View File

@@ -228,9 +228,14 @@ class SmartRouter:
return min(available, key=lambda p: p.score(self.strategy))
def get_model_for_provider(
self, provider: Provider, claude_model: str
self,
provider: Provider,
claude_model: str,
is_large_request: bool = False,
) -> str:
"""Map a Claude model name to the provider's actual model."""
if is_large_request:
return provider.big_model
is_large = any(
keyword in claude_model.lower()
for keyword in ["opus", "sonnet", "large", "big"]
@@ -289,7 +294,11 @@ class SmartRouter:
)
provider = min(available, key=lambda p: p.score(self.strategy))
model = self.get_model_for_provider(provider, claude_model)
model = self.get_model_for_provider(
provider,
claude_model,
is_large_request=large,
)
logger.debug(
f"SmartRouter: routing to {provider.name}/{model} "

View File

@@ -7,6 +7,7 @@ Run: pytest test_smart_router.py -v
import pytest
import asyncio
import os
from unittest.mock import AsyncMock, MagicMock, patch
from smart_router import SmartRouter, Provider
@@ -27,7 +28,9 @@ def make_provider(name, healthy=True, configured=True,
p.avg_latency_ms = latency
p.error_count = errors
p.request_count = requests
if not configured:
if configured:
os.environ.setdefault("FAKE_KEY", "test-key")
else:
p.api_key_env = "" # makes is_configured False for non-ollama
return p
@@ -122,6 +125,13 @@ def test_get_model_large_request():
assert model == "openai-big"
def test_get_model_large_message_overrides_claude_label():
p = make_provider("openai")
r = make_router()
model = r.get_model_for_provider(p, "claude-haiku", is_large_request=True)
assert model == "openai-big"
def test_get_model_small_request():
p = make_provider("openai")
r = make_router()
@@ -140,6 +150,16 @@ async def test_route_returns_best_provider():
assert result["provider"] == "cheap"
@pytest.mark.asyncio
async def test_route_uses_big_model_for_large_message_bodies():
p = make_provider("openai")
r = make_router(providers=[p])
result = await r.route([
{"role": "user", "content": "x" * 3001},
], "claude-haiku")
assert result["model"] == "openai-big"
@pytest.mark.asyncio
async def test_route_raises_when_no_providers():
p = make_provider("a", healthy=False)