diff --git a/smart_router.py b/smart_router.py index 14b90c03..feccc4eb 100644 --- a/smart_router.py +++ b/smart_router.py @@ -228,9 +228,14 @@ class SmartRouter: return min(available, key=lambda p: p.score(self.strategy)) def get_model_for_provider( - self, provider: Provider, claude_model: str + self, + provider: Provider, + claude_model: str, + is_large_request: bool = False, ) -> str: """Map a Claude model name to the provider's actual model.""" + if is_large_request: + return provider.big_model is_large = any( keyword in claude_model.lower() for keyword in ["opus", "sonnet", "large", "big"] @@ -289,7 +294,11 @@ class SmartRouter: ) provider = min(available, key=lambda p: p.score(self.strategy)) - model = self.get_model_for_provider(provider, claude_model) + model = self.get_model_for_provider( + provider, + claude_model, + is_large_request=large, + ) logger.debug( f"SmartRouter: routing to {provider.name}/{model} " diff --git a/test_smart_router.py b/test_smart_router.py index 97b73863..651b94f1 100644 --- a/test_smart_router.py +++ b/test_smart_router.py @@ -13,6 +13,11 @@ from smart_router import SmartRouter, Provider # ── Fixtures ────────────────────────────────────────────────────────────────── + +@pytest.fixture(autouse=True) +def fake_api_key(monkeypatch): + monkeypatch.setenv("FAKE_KEY", "test-key") + def make_provider(name, healthy=True, configured=True, latency=100.0, cost=0.002, errors=0, requests=0): p = Provider( @@ -122,6 +127,13 @@ def test_get_model_large_request(): assert model == "openai-big" +def test_get_model_large_message_overrides_claude_label(): + p = make_provider("openai") + r = make_router() + model = r.get_model_for_provider(p, "claude-haiku", is_large_request=True) + assert model == "openai-big" + + def test_get_model_small_request(): p = make_provider("openai") r = make_router() @@ -140,6 +152,16 @@ async def test_route_returns_best_provider(): assert result["provider"] == "cheap" +@pytest.mark.asyncio +async def test_route_uses_big_model_for_large_message_bodies(): + p = make_provider("openai") + r = make_router(providers=[p]) + result = await r.route([ + {"role": "user", "content": "x" * 3001}, + ], "claude-haiku") + assert result["model"] == "openai-big" + + @pytest.mark.asyncio async def test_route_raises_when_no_providers(): p = make_provider("a", healthy=False)