Files
orcs-code/smart_router.py
Misha Skvortsov 577e654ae7 feat: add support for Atomic Chat provider
- Introduced a new provider profile for Atomic Chat, allowing it to be used alongside existing providers.
- Updated `package.json` to include a new development script for launching Atomic Chat.
- Modified `smart_router.py` to recognize Atomic Chat as a local provider that does not require an API key.
- Enhanced provider discovery and launch scripts to handle Atomic Chat, including model listing and connection checks.
- Added tests to ensure proper environment setup and behavior for Atomic Chat profiles.

This update expands the functionality of the application to support local LLMs via Atomic Chat, improving versatility for users.
2026-04-02 10:37:54 +03:00

371 lines
14 KiB
Python

"""
smart_router.py
---------------
Intelligent auto-router for openclaude.
Instead of always using one fixed provider, the smart router:
- Pings all configured providers on startup
- Scores them by latency, cost, and health
- Routes each request to the optimal provider
- Falls back automatically if a provider fails
- Learns from real request timings over time
Usage in server.py:
from smart_router import SmartRouter
router = SmartRouter()
await router.initialize()
result = await router.route(messages, model, stream)
.env config:
ROUTER_MODE=smart # or: fixed (default behaviour)
ROUTER_STRATEGY=latency # or: cost, balanced
ROUTER_FALLBACK=true # auto-retry on failure
Contribution to: https://github.com/Gitlawb/openclaude
"""
import asyncio
import logging
import os
import time
from dataclasses import dataclass, field
from typing import Optional
import httpx
logger = logging.getLogger(__name__)
# ── Provider definitions ──────────────────────────────────────────────────────
@dataclass
class Provider:
name: str # e.g. "openai", "gemini", "ollama"
ping_url: str # URL used to check health
api_key_env: str # env var name for API key
cost_per_1k_tokens: float # estimated cost USD per 1k tokens
big_model: str # model for sonnet/large requests
small_model: str # model for haiku/small requests
latency_ms: float = 9999.0 # updated by benchmark
healthy: bool = True # updated by health checks
request_count: int = 0 # total requests routed here
error_count: int = 0 # total errors from this provider
avg_latency_ms: float = 9999.0 # rolling average from real requests
@property
def api_key(self) -> Optional[str]:
return os.getenv(self.api_key_env)
@property
def is_configured(self) -> bool:
"""True if the provider has an API key set."""
if self.name in ("ollama", "atomic-chat"):
return True # Local providers need no API key
return bool(self.api_key)
@property
def error_rate(self) -> float:
if self.request_count == 0:
return 0.0
return self.error_count / self.request_count
def score(self, strategy: str = "balanced") -> float:
"""
Lower score = better provider.
strategy: 'latency' | 'cost' | 'balanced'
"""
if not self.healthy or not self.is_configured:
return float("inf")
latency_score = self.avg_latency_ms / 1000.0 # normalize to seconds
cost_score = self.cost_per_1k_tokens * 100 # normalize to similar scale
error_penalty = self.error_rate * 500 # heavy penalty for errors
if strategy == "latency":
return latency_score + error_penalty
elif strategy == "cost":
return cost_score + error_penalty
else: # balanced
return (latency_score * 0.5) + (cost_score * 0.5) + error_penalty
# ── Default provider catalogue ────────────────────────────────────────────────
def build_default_providers() -> list[Provider]:
big = os.getenv("BIG_MODEL", "gpt-4.1")
small = os.getenv("SMALL_MODEL", "gpt-4.1-mini")
ollama_url = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
atomic_chat_url = os.getenv("ATOMIC_CHAT_BASE_URL", "http://127.0.0.1:1337")
return [
Provider(
name="openai",
ping_url="https://api.openai.com/v1/models",
api_key_env="OPENAI_API_KEY",
cost_per_1k_tokens=0.002,
big_model=big if "gpt" in big else "gpt-4.1",
small_model=small if "gpt" in small else "gpt-4.1-mini",
),
Provider(
name="gemini",
ping_url="https://generativelanguage.googleapis.com/v1/models",
api_key_env="GEMINI_API_KEY",
cost_per_1k_tokens=0.0005,
big_model=big if "gemini" in big else "gemini-2.5-pro",
small_model=small if "gemini" in small else "gemini-2.0-flash",
),
Provider(
name="ollama",
ping_url=f"{ollama_url}/api/tags",
api_key_env="",
cost_per_1k_tokens=0.0, # free — local
big_model=big if "gemini" not in big and "gpt" not in big else "llama3:8b",
small_model=small if "gemini" not in small and "gpt" not in small else "llama3:8b",
),
Provider(
name="atomic-chat",
ping_url=f"{atomic_chat_url}/v1/models",
api_key_env="",
cost_per_1k_tokens=0.0, # free — local (Apple Silicon)
big_model=big if "gemini" not in big and "gpt" not in big else "llama3:8b",
small_model=small if "gemini" not in small and "gpt" not in small else "llama3:8b",
),
]
# ── Smart Router ──────────────────────────────────────────────────────────────
class SmartRouter:
"""
Intelligently routes Claude Code API requests to the best
available LLM provider based on latency, cost, and health.
"""
def __init__(
self,
providers: Optional[list[Provider]] = None,
strategy: Optional[str] = None,
fallback_enabled: Optional[bool] = None,
):
self.providers = providers or build_default_providers()
self.strategy = strategy or os.getenv("ROUTER_STRATEGY", "balanced")
self.fallback_enabled = (
fallback_enabled
if fallback_enabled is not None
else os.getenv("ROUTER_FALLBACK", "true").lower() == "true"
)
self._initialized = False
# ── Initialization ────────────────────────────────────────────────────────
async def initialize(self) -> None:
"""Ping all providers and build initial latency scores."""
logger.info("SmartRouter: benchmarking providers...")
await asyncio.gather(
*[self._ping_provider(p) for p in self.providers],
return_exceptions=True,
)
available = [p for p in self.providers if p.healthy and p.is_configured]
logger.info(
f"SmartRouter ready. Available providers: "
f"{[p.name for p in available]}"
)
if not available:
logger.warning(
"SmartRouter: no providers available! "
"Check your API keys in .env"
)
self._initialized = True
async def _ping_provider(self, provider: Provider) -> None:
"""Measure latency to a provider's health endpoint."""
if not provider.is_configured:
provider.healthy = False
logger.debug(f"SmartRouter: {provider.name} skipped — no API key")
return
headers = {}
if provider.api_key:
headers["Authorization"] = f"Bearer {provider.api_key}"
start = time.monotonic()
try:
async with httpx.AsyncClient(timeout=5.0) as client:
resp = await client.get(provider.ping_url, headers=headers)
elapsed_ms = (time.monotonic() - start) * 1000
if resp.status_code in (200, 400, 401, 403):
# 400/401/403 means reachable, just possibly bad key
# We still mark healthy for routing purposes
provider.healthy = True
provider.latency_ms = elapsed_ms
provider.avg_latency_ms = elapsed_ms
logger.info(
f"SmartRouter: {provider.name} OK "
f"({elapsed_ms:.0f}ms, status={resp.status_code})"
)
else:
provider.healthy = False
logger.warning(
f"SmartRouter: {provider.name} unhealthy "
f"(status={resp.status_code})"
)
except Exception as e:
provider.healthy = False
logger.warning(f"SmartRouter: {provider.name} unreachable — {e}")
# ── Routing logic ─────────────────────────────────────────────────────────
def select_provider(self, is_large_request: bool = False) -> Optional[Provider]:
"""
Pick the best available provider for this request.
Returns None if no providers are available.
"""
available = [
p for p in self.providers
if p.healthy and p.is_configured
]
if not available:
return None
return min(available, key=lambda p: p.score(self.strategy))
def get_model_for_provider(
self, provider: Provider, claude_model: str
) -> str:
"""Map a Claude model name to the provider's actual model."""
is_large = any(
keyword in claude_model.lower()
for keyword in ["opus", "sonnet", "large", "big"]
)
return provider.big_model if is_large else provider.small_model
def is_large_request(self, messages: list[dict]) -> bool:
"""Estimate if this is a large request based on message length."""
total_chars = sum(
len(str(m.get("content", ""))) for m in messages
)
return total_chars > 2000 # >2000 chars = treat as large
def _update_latency(self, provider: Provider, duration_ms: float) -> None:
"""Exponential moving average update for latency tracking."""
alpha = 0.3 # weight for new observation
provider.avg_latency_ms = (
alpha * duration_ms + (1 - alpha) * provider.avg_latency_ms
)
# ── Main routing entry point ──────────────────────────────────────────────
async def route(
self,
messages: list[dict],
claude_model: str = "claude-sonnet",
attempt: int = 0,
exclude_providers: Optional[list[str]] = None,
) -> dict:
"""
Route a request to the best provider.
Returns a dict with routing decision info:
{
"provider": provider name,
"model": actual model to use,
"api_key": API key for the provider,
"base_url": base URL for the provider,
}
Raises RuntimeError if no providers available.
"""
if not self._initialized:
await self.initialize()
exclude = set(exclude_providers or [])
large = self.is_large_request(messages)
available = [
p for p in self.providers
if p.healthy and p.is_configured and p.name not in exclude
]
if not available:
raise RuntimeError(
"SmartRouter: no providers available. "
"Check your API keys and provider health."
)
provider = min(available, key=lambda p: p.score(self.strategy))
model = self.get_model_for_provider(provider, claude_model)
logger.debug(
f"SmartRouter: routing to {provider.name}/{model} "
f"(strategy={self.strategy}, large={large}, attempt={attempt})"
)
return {
"provider": provider.name,
"model": model,
"api_key": provider.api_key or "none",
"provider_object": provider,
}
async def record_result(
self,
provider_name: str,
success: bool,
duration_ms: float,
) -> None:
"""
Record the outcome of a request.
Called after each proxied request to update provider scores.
"""
provider = next(
(p for p in self.providers if p.name == provider_name), None
)
if not provider:
return
provider.request_count += 1
if success:
self._update_latency(provider, duration_ms)
else:
provider.error_count += 1
# After 3 consecutive failures, mark unhealthy temporarily
recent_errors = provider.error_count
recent_total = provider.request_count
if recent_total >= 3 and (recent_errors / recent_total) > 0.7:
logger.warning(
f"SmartRouter: {provider_name} error rate high "
f"({provider.error_rate:.0%}), marking unhealthy"
)
provider.healthy = False
# Schedule re-check after 60s
asyncio.create_task(self._recheck_provider(provider, delay=60))
async def _recheck_provider(
self, provider: Provider, delay: float = 60
) -> None:
"""Re-ping a provider after a delay and restore if healthy."""
await asyncio.sleep(delay)
await self._ping_provider(provider)
if provider.healthy:
logger.info(
f"SmartRouter: {provider.name} recovered, "
f"re-adding to pool"
)
# ── Status report ─────────────────────────────────────────────────────────
def status(self) -> list[dict]:
"""Return current provider status for monitoring."""
return [
{
"provider": p.name,
"healthy": p.healthy,
"configured": p.is_configured,
"latency_ms": round(p.avg_latency_ms, 1),
"cost_per_1k": p.cost_per_1k_tokens,
"requests": p.request_count,
"errors": p.error_count,
"error_rate": f"{p.error_rate:.1%}",
"score": round(p.score(self.strategy), 3)
if p.healthy and p.is_configured
else "N/A",
}
for p in self.providers
]