diff --git a/smart_router.py b/smart_router.py new file mode 100644 index 00000000..0a54a791 --- /dev/null +++ b/smart_router.py @@ -0,0 +1,361 @@ +""" +smart_router.py +--------------- +Intelligent auto-router for openclaude. + +Instead of always using one fixed provider, the smart router: +- Pings all configured providers on startup +- Scores them by latency, cost, and health +- Routes each request to the optimal provider +- Falls back automatically if a provider fails +- Learns from real request timings over time + +Usage in server.py: + from smart_router import SmartRouter + router = SmartRouter() + await router.initialize() + result = await router.route(messages, model, stream) + +.env config: + ROUTER_MODE=smart # or: fixed (default behaviour) + ROUTER_STRATEGY=latency # or: cost, balanced + ROUTER_FALLBACK=true # auto-retry on failure + +Contribution to: https://github.com/Gitlawb/openclaude +""" + +import asyncio +import logging +import os +import time +from dataclasses import dataclass, field +from typing import Optional +import httpx + +logger = logging.getLogger(__name__) + +# ── Provider definitions ────────────────────────────────────────────────────── + +@dataclass +class Provider: + name: str # e.g. "openai", "gemini", "ollama" + ping_url: str # URL used to check health + api_key_env: str # env var name for API key + cost_per_1k_tokens: float # estimated cost USD per 1k tokens + big_model: str # model for sonnet/large requests + small_model: str # model for haiku/small requests + latency_ms: float = 9999.0 # updated by benchmark + healthy: bool = True # updated by health checks + request_count: int = 0 # total requests routed here + error_count: int = 0 # total errors from this provider + avg_latency_ms: float = 9999.0 # rolling average from real requests + + @property + def api_key(self) -> Optional[str]: + return os.getenv(self.api_key_env) + + @property + def is_configured(self) -> bool: + """True if the provider has an API key set.""" + if self.name == "ollama": + return True # Ollama needs no API key + return bool(self.api_key) + + @property + def error_rate(self) -> float: + if self.request_count == 0: + return 0.0 + return self.error_count / self.request_count + + def score(self, strategy: str = "balanced") -> float: + """ + Lower score = better provider. + strategy: 'latency' | 'cost' | 'balanced' + """ + if not self.healthy or not self.is_configured: + return float("inf") + + latency_score = self.avg_latency_ms / 1000.0 # normalize to seconds + cost_score = self.cost_per_1k_tokens * 100 # normalize to similar scale + error_penalty = self.error_rate * 500 # heavy penalty for errors + + if strategy == "latency": + return latency_score + error_penalty + elif strategy == "cost": + return cost_score + error_penalty + else: # balanced + return (latency_score * 0.5) + (cost_score * 0.5) + error_penalty + + +# ── Default provider catalogue ──────────────────────────────────────────────── + +def build_default_providers() -> list[Provider]: + big = os.getenv("BIG_MODEL", "gpt-4.1") + small = os.getenv("SMALL_MODEL", "gpt-4.1-mini") + ollama_url = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434") + + return [ + Provider( + name="openai", + ping_url="https://api.openai.com/v1/models", + api_key_env="OPENAI_API_KEY", + cost_per_1k_tokens=0.002, + big_model=big if "gpt" in big else "gpt-4.1", + small_model=small if "gpt" in small else "gpt-4.1-mini", + ), + Provider( + name="gemini", + ping_url="https://generativelanguage.googleapis.com/v1/models", + api_key_env="GEMINI_API_KEY", + cost_per_1k_tokens=0.0005, + big_model=big if "gemini" in big else "gemini-2.5-pro", + small_model=small if "gemini" in small else "gemini-2.0-flash", + ), + Provider( + name="ollama", + ping_url=f"{ollama_url}/api/tags", + api_key_env="", + cost_per_1k_tokens=0.0, # free — local + big_model=big if "gemini" not in big and "gpt" not in big else "llama3:8b", + small_model=small if "gemini" not in small and "gpt" not in small else "llama3:8b", + ), + ] + + +# ── Smart Router ────────────────────────────────────────────────────────────── + +class SmartRouter: + """ + Intelligently routes Claude Code API requests to the best + available LLM provider based on latency, cost, and health. + """ + + def __init__( + self, + providers: Optional[list[Provider]] = None, + strategy: Optional[str] = None, + fallback_enabled: Optional[bool] = None, + ): + self.providers = providers or build_default_providers() + self.strategy = strategy or os.getenv("ROUTER_STRATEGY", "balanced") + self.fallback_enabled = ( + fallback_enabled + if fallback_enabled is not None + else os.getenv("ROUTER_FALLBACK", "true").lower() == "true" + ) + self._initialized = False + + # ── Initialization ──────────────────────────────────────────────────────── + + async def initialize(self) -> None: + """Ping all providers and build initial latency scores.""" + logger.info("SmartRouter: benchmarking providers...") + await asyncio.gather( + *[self._ping_provider(p) for p in self.providers], + return_exceptions=True, + ) + available = [p for p in self.providers if p.healthy and p.is_configured] + logger.info( + f"SmartRouter ready. Available providers: " + f"{[p.name for p in available]}" + ) + if not available: + logger.warning( + "SmartRouter: no providers available! " + "Check your API keys in .env" + ) + self._initialized = True + + async def _ping_provider(self, provider: Provider) -> None: + """Measure latency to a provider's health endpoint.""" + if not provider.is_configured: + provider.healthy = False + logger.debug(f"SmartRouter: {provider.name} skipped — no API key") + return + + headers = {} + if provider.api_key: + headers["Authorization"] = f"Bearer {provider.api_key}" + + start = time.monotonic() + try: + async with httpx.AsyncClient(timeout=5.0) as client: + resp = await client.get(provider.ping_url, headers=headers) + elapsed_ms = (time.monotonic() - start) * 1000 + if resp.status_code in (200, 400, 401, 403): + # 400/401/403 means reachable, just possibly bad key + # We still mark healthy for routing purposes + provider.healthy = True + provider.latency_ms = elapsed_ms + provider.avg_latency_ms = elapsed_ms + logger.info( + f"SmartRouter: {provider.name} OK " + f"({elapsed_ms:.0f}ms, status={resp.status_code})" + ) + else: + provider.healthy = False + logger.warning( + f"SmartRouter: {provider.name} unhealthy " + f"(status={resp.status_code})" + ) + except Exception as e: + provider.healthy = False + logger.warning(f"SmartRouter: {provider.name} unreachable — {e}") + + # ── Routing logic ───────────────────────────────────────────────────────── + + def select_provider(self, is_large_request: bool = False) -> Optional[Provider]: + """ + Pick the best available provider for this request. + Returns None if no providers are available. + """ + available = [ + p for p in self.providers + if p.healthy and p.is_configured + ] + if not available: + return None + + return min(available, key=lambda p: p.score(self.strategy)) + + def get_model_for_provider( + self, provider: Provider, claude_model: str + ) -> str: + """Map a Claude model name to the provider's actual model.""" + is_large = any( + keyword in claude_model.lower() + for keyword in ["opus", "sonnet", "large", "big"] + ) + return provider.big_model if is_large else provider.small_model + + def is_large_request(self, messages: list[dict]) -> bool: + """Estimate if this is a large request based on message length.""" + total_chars = sum( + len(str(m.get("content", ""))) for m in messages + ) + return total_chars > 2000 # >2000 chars = treat as large + + def _update_latency(self, provider: Provider, duration_ms: float) -> None: + """Exponential moving average update for latency tracking.""" + alpha = 0.3 # weight for new observation + provider.avg_latency_ms = ( + alpha * duration_ms + (1 - alpha) * provider.avg_latency_ms + ) + + # ── Main routing entry point ────────────────────────────────────────────── + + async def route( + self, + messages: list[dict], + claude_model: str = "claude-sonnet", + attempt: int = 0, + exclude_providers: Optional[list[str]] = None, + ) -> dict: + """ + Route a request to the best provider. + Returns a dict with routing decision info: + { + "provider": provider name, + "model": actual model to use, + "api_key": API key for the provider, + "base_url": base URL for the provider, + } + Raises RuntimeError if no providers available. + """ + if not self._initialized: + await self.initialize() + + exclude = set(exclude_providers or []) + large = self.is_large_request(messages) + + available = [ + p for p in self.providers + if p.healthy and p.is_configured and p.name not in exclude + ] + + if not available: + raise RuntimeError( + "SmartRouter: no providers available. " + "Check your API keys and provider health." + ) + + provider = min(available, key=lambda p: p.score(self.strategy)) + model = self.get_model_for_provider(provider, claude_model) + + logger.debug( + f"SmartRouter: routing to {provider.name}/{model} " + f"(strategy={self.strategy}, large={large}, attempt={attempt})" + ) + + return { + "provider": provider.name, + "model": model, + "api_key": provider.api_key or "none", + "provider_object": provider, + } + + async def record_result( + self, + provider_name: str, + success: bool, + duration_ms: float, + ) -> None: + """ + Record the outcome of a request. + Called after each proxied request to update provider scores. + """ + provider = next( + (p for p in self.providers if p.name == provider_name), None + ) + if not provider: + return + + provider.request_count += 1 + if success: + self._update_latency(provider, duration_ms) + else: + provider.error_count += 1 + # After 3 consecutive failures, mark unhealthy temporarily + recent_errors = provider.error_count + recent_total = provider.request_count + if recent_total >= 3 and (recent_errors / recent_total) > 0.7: + logger.warning( + f"SmartRouter: {provider_name} error rate high " + f"({provider.error_rate:.0%}), marking unhealthy" + ) + provider.healthy = False + # Schedule re-check after 60s + asyncio.create_task(self._recheck_provider(provider, delay=60)) + + async def _recheck_provider( + self, provider: Provider, delay: float = 60 + ) -> None: + """Re-ping a provider after a delay and restore if healthy.""" + await asyncio.sleep(delay) + await self._ping_provider(provider) + if provider.healthy: + logger.info( + f"SmartRouter: {provider.name} recovered, " + f"re-adding to pool" + ) + + # ── Status report ───────────────────────────────────────────────────────── + + def status(self) -> list[dict]: + """Return current provider status for monitoring.""" + return [ + { + "provider": p.name, + "healthy": p.healthy, + "configured": p.is_configured, + "latency_ms": round(p.avg_latency_ms, 1), + "cost_per_1k": p.cost_per_1k_tokens, + "requests": p.request_count, + "errors": p.error_count, + "error_rate": f"{p.error_rate:.1%}", + "score": round(p.score(self.strategy), 3) + if p.healthy and p.is_configured + else "N/A", + } + for p in self.providers + ] diff --git a/test_smart_router.py b/test_smart_router.py new file mode 100644 index 00000000..97b73863 --- /dev/null +++ b/test_smart_router.py @@ -0,0 +1,208 @@ +""" +test_smart_router.py +-------------------- +Tests for the SmartRouter. +Run: pytest test_smart_router.py -v +""" + +import pytest +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch +from smart_router import SmartRouter, Provider + + +# ── Fixtures ────────────────────────────────────────────────────────────────── + +def make_provider(name, healthy=True, configured=True, + latency=100.0, cost=0.002, errors=0, requests=0): + p = Provider( + name=name, + ping_url=f"https://{name}.example.com/health", + api_key_env="FAKE_KEY", + cost_per_1k_tokens=cost, + big_model=f"{name}-big", + small_model=f"{name}-small", + ) + p.healthy = healthy + p.avg_latency_ms = latency + p.error_count = errors + p.request_count = requests + if not configured: + p.api_key_env = "" # makes is_configured False for non-ollama + return p + + +def make_router(providers=None, strategy="balanced"): + r = SmartRouter(providers=providers, strategy=strategy) + r._initialized = True + return r + + +# ── Provider.score() ────────────────────────────────────────────────────────── + +def test_score_unhealthy_is_inf(): + p = make_provider("openai", healthy=False) + assert p.score() == float("inf") + + +def test_score_unconfigured_is_inf(): + p = make_provider("openai", configured=False) + assert p.score() == float("inf") + + +def test_score_latency_strategy_prefers_faster(): + fast = make_provider("fast", latency=50.0, cost=0.01) + slow = make_provider("slow", latency=500.0, cost=0.001) + assert fast.score("latency") < slow.score("latency") + + +def test_score_cost_strategy_prefers_cheaper(): + cheap = make_provider("cheap", latency=500.0, cost=0.0001) + expensive = make_provider("expensive", latency=50.0, cost=0.05) + assert cheap.score("cost") < expensive.score("cost") + + +def test_score_balanced_strategy_uses_both(): + p = make_provider("test", latency=200.0, cost=0.002) + s = p.score("balanced") + assert s > 0 + + +def test_score_error_rate_penalty(): + clean = make_provider("clean", errors=0, requests=10) + dirty = make_provider("dirty", errors=8, requests=10) + assert clean.score() < dirty.score() + + +# ── SmartRouter.is_large_request() ─────────────────────────────────────────── + +def test_is_large_request_short(): + r = make_router() + msgs = [{"role": "user", "content": "Hello!"}] + assert r.is_large_request(msgs) is False + + +def test_is_large_request_long(): + r = make_router() + msgs = [{"role": "user", "content": "x" * 3000}] + assert r.is_large_request(msgs) is True + + +# ── SmartRouter.select_provider() ──────────────────────────────────────────── + +def test_select_provider_picks_best_score(): + p1 = make_provider("slow", latency=800.0) + p2 = make_provider("fast", latency=50.0) + r = make_router(providers=[p1, p2], strategy="latency") + selected = r.select_provider() + assert selected.name == "fast" + + +def test_select_provider_skips_unhealthy(): + p1 = make_provider("bad", healthy=False) + p2 = make_provider("good", healthy=True) + r = make_router(providers=[p1, p2]) + selected = r.select_provider() + assert selected.name == "good" + + +def test_select_provider_returns_none_when_all_down(): + p1 = make_provider("a", healthy=False) + p2 = make_provider("b", healthy=False) + r = make_router(providers=[p1, p2]) + assert r.select_provider() is None + + +# ── SmartRouter.get_model_for_provider() ───────────────────────────────────── + +def test_get_model_large_request(): + p = make_provider("openai") + r = make_router() + model = r.get_model_for_provider(p, "claude-sonnet") + assert model == "openai-big" + + +def test_get_model_small_request(): + p = make_provider("openai") + r = make_router() + model = r.get_model_for_provider(p, "claude-haiku") + assert model == "openai-small" + + +# ── SmartRouter.route() ─────────────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_route_returns_best_provider(): + p1 = make_provider("expensive", cost=0.05, latency=50.0) + p2 = make_provider("cheap", cost=0.0005, latency=200.0) + r = make_router(providers=[p1, p2], strategy="cost") + result = await r.route([{"role": "user", "content": "Hi"}], "claude-haiku") + assert result["provider"] == "cheap" + + +@pytest.mark.asyncio +async def test_route_raises_when_no_providers(): + p = make_provider("a", healthy=False) + r = make_router(providers=[p]) + with pytest.raises(RuntimeError, match="no providers available"): + await r.route([{"role": "user", "content": "Hi"}]) + + +@pytest.mark.asyncio +async def test_route_excludes_providers(): + p1 = make_provider("openai", latency=50.0) + p2 = make_provider("gemini", latency=200.0) + r = make_router(providers=[p1, p2], strategy="latency") + result = await r.route( + [{"role": "user", "content": "Hi"}], + exclude_providers=["openai"] + ) + assert result["provider"] == "gemini" + + +# ── SmartRouter.record_result() ────────────────────────────────────────────── + +@pytest.mark.asyncio +async def test_record_result_updates_latency(): + p = make_provider("openai", latency=200.0) + r = make_router(providers=[p]) + await r.record_result("openai", success=True, duration_ms=100.0) + assert p.avg_latency_ms < 200.0 # should decrease toward 100 + + +@pytest.mark.asyncio +async def test_record_result_increments_requests(): + p = make_provider("openai") + r = make_router(providers=[p]) + await r.record_result("openai", success=True, duration_ms=100.0) + assert p.request_count == 1 + + +@pytest.mark.asyncio +async def test_record_result_increments_errors(): + p = make_provider("openai") + r = make_router(providers=[p]) + await r.record_result("openai", success=False, duration_ms=0) + assert p.error_count == 1 + + +# ── SmartRouter.status() ───────────────────────────────────────────────────── + +def test_status_returns_all_providers(): + p1 = make_provider("openai") + p2 = make_provider("gemini") + r = make_router(providers=[p1, p2]) + status = r.status() + assert len(status) == 2 + names = [s["provider"] for s in status] + assert "openai" in names + assert "gemini" in names + + +def test_status_contains_required_fields(): + p = make_provider("openai") + r = make_router(providers=[p]) + status = r.status()[0] + for field in ["provider", "healthy", "latency_ms", + "cost_per_1k", "requests", "errors", "score"]: + assert field in status