From 43deb49c2c877180ba8425aad3302ad8ea7605e4 Mon Sep 17 00:00:00 2001 From: "Md.Nahin Alam" <82433419+alamnahin@users.noreply.github.com> Date: Fri, 3 Apr 2026 03:45:33 +0600 Subject: [PATCH] fix(router): use large request size for model selection --- smart_router.py | 13 +++++++++++-- test_smart_router.py | 22 +++++++++++++++++++++- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/smart_router.py b/smart_router.py index 14b90c03..feccc4eb 100644 --- a/smart_router.py +++ b/smart_router.py @@ -228,9 +228,14 @@ class SmartRouter: return min(available, key=lambda p: p.score(self.strategy)) def get_model_for_provider( - self, provider: Provider, claude_model: str + self, + provider: Provider, + claude_model: str, + is_large_request: bool = False, ) -> str: """Map a Claude model name to the provider's actual model.""" + if is_large_request: + return provider.big_model is_large = any( keyword in claude_model.lower() for keyword in ["opus", "sonnet", "large", "big"] @@ -289,7 +294,11 @@ class SmartRouter: ) provider = min(available, key=lambda p: p.score(self.strategy)) - model = self.get_model_for_provider(provider, claude_model) + model = self.get_model_for_provider( + provider, + claude_model, + is_large_request=large, + ) logger.debug( f"SmartRouter: routing to {provider.name}/{model} " diff --git a/test_smart_router.py b/test_smart_router.py index 97b73863..fac31c85 100644 --- a/test_smart_router.py +++ b/test_smart_router.py @@ -7,6 +7,7 @@ Run: pytest test_smart_router.py -v import pytest import asyncio +import os from unittest.mock import AsyncMock, MagicMock, patch from smart_router import SmartRouter, Provider @@ -27,7 +28,9 @@ def make_provider(name, healthy=True, configured=True, p.avg_latency_ms = latency p.error_count = errors p.request_count = requests - if not configured: + if configured: + os.environ.setdefault("FAKE_KEY", "test-key") + else: p.api_key_env = "" # makes is_configured False for non-ollama return p @@ -122,6 +125,13 @@ def test_get_model_large_request(): assert model == "openai-big" +def test_get_model_large_message_overrides_claude_label(): + p = make_provider("openai") + r = make_router() + model = r.get_model_for_provider(p, "claude-haiku", is_large_request=True) + assert model == "openai-big" + + def test_get_model_small_request(): p = make_provider("openai") r = make_router() @@ -140,6 +150,16 @@ async def test_route_returns_best_provider(): assert result["provider"] == "cheap" +@pytest.mark.asyncio +async def test_route_uses_big_model_for_large_message_bodies(): + p = make_provider("openai") + r = make_router(providers=[p]) + result = await r.route([ + {"role": "user", "content": "x" * 3001}, + ], "claude-haiku") + assert result["model"] == "openai-big" + + @pytest.mark.asyncio async def test_route_raises_when_no_providers(): p = make_provider("a", healthy=False)