feat: add intelligent smart auto-router with latency/cost scoring

This commit is contained in:
gnanam1990
2026-04-01 10:19:32 +05:30
parent c957d495ac
commit 6b163e2e7e
2 changed files with 569 additions and 0 deletions

361
smart_router.py Normal file
View File

@@ -0,0 +1,361 @@
"""
smart_router.py
---------------
Intelligent auto-router for openclaude.
Instead of always using one fixed provider, the smart router:
- Pings all configured providers on startup
- Scores them by latency, cost, and health
- Routes each request to the optimal provider
- Falls back automatically if a provider fails
- Learns from real request timings over time
Usage in server.py:
from smart_router import SmartRouter
router = SmartRouter()
await router.initialize()
result = await router.route(messages, model, stream)
.env config:
ROUTER_MODE=smart # or: fixed (default behaviour)
ROUTER_STRATEGY=latency # or: cost, balanced
ROUTER_FALLBACK=true # auto-retry on failure
Contribution to: https://github.com/Gitlawb/openclaude
"""
import asyncio
import logging
import os
import time
from dataclasses import dataclass, field
from typing import Optional
import httpx
logger = logging.getLogger(__name__)
# ── Provider definitions ──────────────────────────────────────────────────────
@dataclass
class Provider:
name: str # e.g. "openai", "gemini", "ollama"
ping_url: str # URL used to check health
api_key_env: str # env var name for API key
cost_per_1k_tokens: float # estimated cost USD per 1k tokens
big_model: str # model for sonnet/large requests
small_model: str # model for haiku/small requests
latency_ms: float = 9999.0 # updated by benchmark
healthy: bool = True # updated by health checks
request_count: int = 0 # total requests routed here
error_count: int = 0 # total errors from this provider
avg_latency_ms: float = 9999.0 # rolling average from real requests
@property
def api_key(self) -> Optional[str]:
return os.getenv(self.api_key_env)
@property
def is_configured(self) -> bool:
"""True if the provider has an API key set."""
if self.name == "ollama":
return True # Ollama needs no API key
return bool(self.api_key)
@property
def error_rate(self) -> float:
if self.request_count == 0:
return 0.0
return self.error_count / self.request_count
def score(self, strategy: str = "balanced") -> float:
"""
Lower score = better provider.
strategy: 'latency' | 'cost' | 'balanced'
"""
if not self.healthy or not self.is_configured:
return float("inf")
latency_score = self.avg_latency_ms / 1000.0 # normalize to seconds
cost_score = self.cost_per_1k_tokens * 100 # normalize to similar scale
error_penalty = self.error_rate * 500 # heavy penalty for errors
if strategy == "latency":
return latency_score + error_penalty
elif strategy == "cost":
return cost_score + error_penalty
else: # balanced
return (latency_score * 0.5) + (cost_score * 0.5) + error_penalty
# ── Default provider catalogue ────────────────────────────────────────────────
def build_default_providers() -> list[Provider]:
big = os.getenv("BIG_MODEL", "gpt-4.1")
small = os.getenv("SMALL_MODEL", "gpt-4.1-mini")
ollama_url = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
return [
Provider(
name="openai",
ping_url="https://api.openai.com/v1/models",
api_key_env="OPENAI_API_KEY",
cost_per_1k_tokens=0.002,
big_model=big if "gpt" in big else "gpt-4.1",
small_model=small if "gpt" in small else "gpt-4.1-mini",
),
Provider(
name="gemini",
ping_url="https://generativelanguage.googleapis.com/v1/models",
api_key_env="GEMINI_API_KEY",
cost_per_1k_tokens=0.0005,
big_model=big if "gemini" in big else "gemini-2.5-pro",
small_model=small if "gemini" in small else "gemini-2.0-flash",
),
Provider(
name="ollama",
ping_url=f"{ollama_url}/api/tags",
api_key_env="",
cost_per_1k_tokens=0.0, # free — local
big_model=big if "gemini" not in big and "gpt" not in big else "llama3:8b",
small_model=small if "gemini" not in small and "gpt" not in small else "llama3:8b",
),
]
# ── Smart Router ──────────────────────────────────────────────────────────────
class SmartRouter:
"""
Intelligently routes Claude Code API requests to the best
available LLM provider based on latency, cost, and health.
"""
def __init__(
self,
providers: Optional[list[Provider]] = None,
strategy: Optional[str] = None,
fallback_enabled: Optional[bool] = None,
):
self.providers = providers or build_default_providers()
self.strategy = strategy or os.getenv("ROUTER_STRATEGY", "balanced")
self.fallback_enabled = (
fallback_enabled
if fallback_enabled is not None
else os.getenv("ROUTER_FALLBACK", "true").lower() == "true"
)
self._initialized = False
# ── Initialization ────────────────────────────────────────────────────────
async def initialize(self) -> None:
"""Ping all providers and build initial latency scores."""
logger.info("SmartRouter: benchmarking providers...")
await asyncio.gather(
*[self._ping_provider(p) for p in self.providers],
return_exceptions=True,
)
available = [p for p in self.providers if p.healthy and p.is_configured]
logger.info(
f"SmartRouter ready. Available providers: "
f"{[p.name for p in available]}"
)
if not available:
logger.warning(
"SmartRouter: no providers available! "
"Check your API keys in .env"
)
self._initialized = True
async def _ping_provider(self, provider: Provider) -> None:
"""Measure latency to a provider's health endpoint."""
if not provider.is_configured:
provider.healthy = False
logger.debug(f"SmartRouter: {provider.name} skipped — no API key")
return
headers = {}
if provider.api_key:
headers["Authorization"] = f"Bearer {provider.api_key}"
start = time.monotonic()
try:
async with httpx.AsyncClient(timeout=5.0) as client:
resp = await client.get(provider.ping_url, headers=headers)
elapsed_ms = (time.monotonic() - start) * 1000
if resp.status_code in (200, 400, 401, 403):
# 400/401/403 means reachable, just possibly bad key
# We still mark healthy for routing purposes
provider.healthy = True
provider.latency_ms = elapsed_ms
provider.avg_latency_ms = elapsed_ms
logger.info(
f"SmartRouter: {provider.name} OK "
f"({elapsed_ms:.0f}ms, status={resp.status_code})"
)
else:
provider.healthy = False
logger.warning(
f"SmartRouter: {provider.name} unhealthy "
f"(status={resp.status_code})"
)
except Exception as e:
provider.healthy = False
logger.warning(f"SmartRouter: {provider.name} unreachable — {e}")
# ── Routing logic ─────────────────────────────────────────────────────────
def select_provider(self, is_large_request: bool = False) -> Optional[Provider]:
"""
Pick the best available provider for this request.
Returns None if no providers are available.
"""
available = [
p for p in self.providers
if p.healthy and p.is_configured
]
if not available:
return None
return min(available, key=lambda p: p.score(self.strategy))
def get_model_for_provider(
self, provider: Provider, claude_model: str
) -> str:
"""Map a Claude model name to the provider's actual model."""
is_large = any(
keyword in claude_model.lower()
for keyword in ["opus", "sonnet", "large", "big"]
)
return provider.big_model if is_large else provider.small_model
def is_large_request(self, messages: list[dict]) -> bool:
"""Estimate if this is a large request based on message length."""
total_chars = sum(
len(str(m.get("content", ""))) for m in messages
)
return total_chars > 2000 # >2000 chars = treat as large
def _update_latency(self, provider: Provider, duration_ms: float) -> None:
"""Exponential moving average update for latency tracking."""
alpha = 0.3 # weight for new observation
provider.avg_latency_ms = (
alpha * duration_ms + (1 - alpha) * provider.avg_latency_ms
)
# ── Main routing entry point ──────────────────────────────────────────────
async def route(
self,
messages: list[dict],
claude_model: str = "claude-sonnet",
attempt: int = 0,
exclude_providers: Optional[list[str]] = None,
) -> dict:
"""
Route a request to the best provider.
Returns a dict with routing decision info:
{
"provider": provider name,
"model": actual model to use,
"api_key": API key for the provider,
"base_url": base URL for the provider,
}
Raises RuntimeError if no providers available.
"""
if not self._initialized:
await self.initialize()
exclude = set(exclude_providers or [])
large = self.is_large_request(messages)
available = [
p for p in self.providers
if p.healthy and p.is_configured and p.name not in exclude
]
if not available:
raise RuntimeError(
"SmartRouter: no providers available. "
"Check your API keys and provider health."
)
provider = min(available, key=lambda p: p.score(self.strategy))
model = self.get_model_for_provider(provider, claude_model)
logger.debug(
f"SmartRouter: routing to {provider.name}/{model} "
f"(strategy={self.strategy}, large={large}, attempt={attempt})"
)
return {
"provider": provider.name,
"model": model,
"api_key": provider.api_key or "none",
"provider_object": provider,
}
async def record_result(
self,
provider_name: str,
success: bool,
duration_ms: float,
) -> None:
"""
Record the outcome of a request.
Called after each proxied request to update provider scores.
"""
provider = next(
(p for p in self.providers if p.name == provider_name), None
)
if not provider:
return
provider.request_count += 1
if success:
self._update_latency(provider, duration_ms)
else:
provider.error_count += 1
# After 3 consecutive failures, mark unhealthy temporarily
recent_errors = provider.error_count
recent_total = provider.request_count
if recent_total >= 3 and (recent_errors / recent_total) > 0.7:
logger.warning(
f"SmartRouter: {provider_name} error rate high "
f"({provider.error_rate:.0%}), marking unhealthy"
)
provider.healthy = False
# Schedule re-check after 60s
asyncio.create_task(self._recheck_provider(provider, delay=60))
async def _recheck_provider(
self, provider: Provider, delay: float = 60
) -> None:
"""Re-ping a provider after a delay and restore if healthy."""
await asyncio.sleep(delay)
await self._ping_provider(provider)
if provider.healthy:
logger.info(
f"SmartRouter: {provider.name} recovered, "
f"re-adding to pool"
)
# ── Status report ─────────────────────────────────────────────────────────
def status(self) -> list[dict]:
"""Return current provider status for monitoring."""
return [
{
"provider": p.name,
"healthy": p.healthy,
"configured": p.is_configured,
"latency_ms": round(p.avg_latency_ms, 1),
"cost_per_1k": p.cost_per_1k_tokens,
"requests": p.request_count,
"errors": p.error_count,
"error_rate": f"{p.error_rate:.1%}",
"score": round(p.score(self.strategy), 3)
if p.healthy and p.is_configured
else "N/A",
}
for p in self.providers
]

208
test_smart_router.py Normal file
View File

@@ -0,0 +1,208 @@
"""
test_smart_router.py
--------------------
Tests for the SmartRouter.
Run: pytest test_smart_router.py -v
"""
import pytest
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from smart_router import SmartRouter, Provider
# ── Fixtures ──────────────────────────────────────────────────────────────────
def make_provider(name, healthy=True, configured=True,
latency=100.0, cost=0.002, errors=0, requests=0):
p = Provider(
name=name,
ping_url=f"https://{name}.example.com/health",
api_key_env="FAKE_KEY",
cost_per_1k_tokens=cost,
big_model=f"{name}-big",
small_model=f"{name}-small",
)
p.healthy = healthy
p.avg_latency_ms = latency
p.error_count = errors
p.request_count = requests
if not configured:
p.api_key_env = "" # makes is_configured False for non-ollama
return p
def make_router(providers=None, strategy="balanced"):
r = SmartRouter(providers=providers, strategy=strategy)
r._initialized = True
return r
# ── Provider.score() ──────────────────────────────────────────────────────────
def test_score_unhealthy_is_inf():
p = make_provider("openai", healthy=False)
assert p.score() == float("inf")
def test_score_unconfigured_is_inf():
p = make_provider("openai", configured=False)
assert p.score() == float("inf")
def test_score_latency_strategy_prefers_faster():
fast = make_provider("fast", latency=50.0, cost=0.01)
slow = make_provider("slow", latency=500.0, cost=0.001)
assert fast.score("latency") < slow.score("latency")
def test_score_cost_strategy_prefers_cheaper():
cheap = make_provider("cheap", latency=500.0, cost=0.0001)
expensive = make_provider("expensive", latency=50.0, cost=0.05)
assert cheap.score("cost") < expensive.score("cost")
def test_score_balanced_strategy_uses_both():
p = make_provider("test", latency=200.0, cost=0.002)
s = p.score("balanced")
assert s > 0
def test_score_error_rate_penalty():
clean = make_provider("clean", errors=0, requests=10)
dirty = make_provider("dirty", errors=8, requests=10)
assert clean.score() < dirty.score()
# ── SmartRouter.is_large_request() ───────────────────────────────────────────
def test_is_large_request_short():
r = make_router()
msgs = [{"role": "user", "content": "Hello!"}]
assert r.is_large_request(msgs) is False
def test_is_large_request_long():
r = make_router()
msgs = [{"role": "user", "content": "x" * 3000}]
assert r.is_large_request(msgs) is True
# ── SmartRouter.select_provider() ────────────────────────────────────────────
def test_select_provider_picks_best_score():
p1 = make_provider("slow", latency=800.0)
p2 = make_provider("fast", latency=50.0)
r = make_router(providers=[p1, p2], strategy="latency")
selected = r.select_provider()
assert selected.name == "fast"
def test_select_provider_skips_unhealthy():
p1 = make_provider("bad", healthy=False)
p2 = make_provider("good", healthy=True)
r = make_router(providers=[p1, p2])
selected = r.select_provider()
assert selected.name == "good"
def test_select_provider_returns_none_when_all_down():
p1 = make_provider("a", healthy=False)
p2 = make_provider("b", healthy=False)
r = make_router(providers=[p1, p2])
assert r.select_provider() is None
# ── SmartRouter.get_model_for_provider() ─────────────────────────────────────
def test_get_model_large_request():
p = make_provider("openai")
r = make_router()
model = r.get_model_for_provider(p, "claude-sonnet")
assert model == "openai-big"
def test_get_model_small_request():
p = make_provider("openai")
r = make_router()
model = r.get_model_for_provider(p, "claude-haiku")
assert model == "openai-small"
# ── SmartRouter.route() ───────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_route_returns_best_provider():
p1 = make_provider("expensive", cost=0.05, latency=50.0)
p2 = make_provider("cheap", cost=0.0005, latency=200.0)
r = make_router(providers=[p1, p2], strategy="cost")
result = await r.route([{"role": "user", "content": "Hi"}], "claude-haiku")
assert result["provider"] == "cheap"
@pytest.mark.asyncio
async def test_route_raises_when_no_providers():
p = make_provider("a", healthy=False)
r = make_router(providers=[p])
with pytest.raises(RuntimeError, match="no providers available"):
await r.route([{"role": "user", "content": "Hi"}])
@pytest.mark.asyncio
async def test_route_excludes_providers():
p1 = make_provider("openai", latency=50.0)
p2 = make_provider("gemini", latency=200.0)
r = make_router(providers=[p1, p2], strategy="latency")
result = await r.route(
[{"role": "user", "content": "Hi"}],
exclude_providers=["openai"]
)
assert result["provider"] == "gemini"
# ── SmartRouter.record_result() ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_record_result_updates_latency():
p = make_provider("openai", latency=200.0)
r = make_router(providers=[p])
await r.record_result("openai", success=True, duration_ms=100.0)
assert p.avg_latency_ms < 200.0 # should decrease toward 100
@pytest.mark.asyncio
async def test_record_result_increments_requests():
p = make_provider("openai")
r = make_router(providers=[p])
await r.record_result("openai", success=True, duration_ms=100.0)
assert p.request_count == 1
@pytest.mark.asyncio
async def test_record_result_increments_errors():
p = make_provider("openai")
r = make_router(providers=[p])
await r.record_result("openai", success=False, duration_ms=0)
assert p.error_count == 1
# ── SmartRouter.status() ─────────────────────────────────────────────────────
def test_status_returns_all_providers():
p1 = make_provider("openai")
p2 = make_provider("gemini")
r = make_router(providers=[p1, p2])
status = r.status()
assert len(status) == 2
names = [s["provider"] for s in status]
assert "openai" in names
assert "gemini" in names
def test_status_contains_required_fields():
p = make_provider("openai")
r = make_router(providers=[p])
status = r.status()[0]
for field in ["provider", "healthy", "latency_ms",
"cost_per_1k", "requests", "errors", "score"]:
assert field in status