docs: organize Python helpers and refresh README (#334)
* docs: organize Python helpers and refresh README * docs: add README status badges * test: centralize Python helper test imports * docs: add short provenance disclaimer
This commit is contained in:
1
python/__init__.py
Normal file
1
python/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Python helper package for standalone provider-side utilities.
|
||||
146
python/atomic_chat_provider.py
Normal file
146
python/atomic_chat_provider.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""
|
||||
atomic_chat_provider.py
|
||||
-----------------------
|
||||
Adds native Atomic Chat support to openclaude.
|
||||
Lets Claude Code route requests to any locally-running model via
|
||||
Atomic Chat (Apple Silicon only) at 127.0.0.1:1337.
|
||||
|
||||
Atomic Chat exposes an OpenAI-compatible API, so messages are forwarded
|
||||
directly without translation.
|
||||
|
||||
Usage (.env):
|
||||
PREFERRED_PROVIDER=atomic-chat
|
||||
ATOMIC_CHAT_BASE_URL=http://127.0.0.1:1337
|
||||
"""
|
||||
|
||||
import httpx
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import AsyncIterator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
ATOMIC_CHAT_BASE_URL = os.getenv("ATOMIC_CHAT_BASE_URL", "http://127.0.0.1:1337")
|
||||
|
||||
|
||||
def _api_url(path: str) -> str:
|
||||
return f"{ATOMIC_CHAT_BASE_URL}/v1{path}"
|
||||
|
||||
|
||||
async def check_atomic_chat_running() -> bool:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=3.0) as client:
|
||||
resp = await client.get(_api_url("/models"))
|
||||
return resp.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
async def list_atomic_chat_models() -> list[str]:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
resp = await client.get(_api_url("/models"))
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return [m["id"] for m in data.get("data", [])]
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not list Atomic Chat models: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def atomic_chat(
|
||||
model: str,
|
||||
messages: list[dict],
|
||||
system: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 1.0,
|
||||
) -> dict:
|
||||
chat_messages = list(messages)
|
||||
if system:
|
||||
chat_messages.insert(0, {"role": "system", "content": system})
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": chat_messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
resp = await client.post(_api_url("/chat/completions"), json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
choice = data.get("choices", [{}])[0]
|
||||
assistant_text = choice.get("message", {}).get("content", "")
|
||||
usage = data.get("usage", {})
|
||||
|
||||
return {
|
||||
"id": data.get("id", "msg_atomic_chat"),
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": assistant_text}],
|
||||
"model": model,
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
"usage": {
|
||||
"input_tokens": usage.get("prompt_tokens", 0),
|
||||
"output_tokens": usage.get("completion_tokens", 0),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def atomic_chat_stream(
|
||||
model: str,
|
||||
messages: list[dict],
|
||||
system: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 1.0,
|
||||
) -> AsyncIterator[str]:
|
||||
chat_messages = list(messages)
|
||||
if system:
|
||||
chat_messages.insert(0, {"role": "system", "content": system})
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": chat_messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
yield "event: message_start\n"
|
||||
yield f'data: {json.dumps({"type": "message_start", "message": {"id": "msg_atomic_chat_stream", "type": "message", "role": "assistant", "content": [], "model": model, "stop_reason": None, "usage": {"input_tokens": 0, "output_tokens": 0}}})}\n\n'
|
||||
yield "event: content_block_start\n"
|
||||
yield f'data: {json.dumps({"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}})}\n\n'
|
||||
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
async with client.stream("POST", _api_url("/chat/completions"), json=payload) as resp:
|
||||
resp.raise_for_status()
|
||||
async for line in resp.aiter_lines():
|
||||
if not line or not line.startswith("data: "):
|
||||
continue
|
||||
raw = line[len("data: "):]
|
||||
if raw.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
chunk = json.loads(raw)
|
||||
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
||||
delta_text = delta.get("content", "")
|
||||
if delta_text:
|
||||
yield "event: content_block_delta\n"
|
||||
yield f'data: {json.dumps({"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": delta_text}})}\n\n'
|
||||
|
||||
finish_reason = chunk.get("choices", [{}])[0].get("finish_reason")
|
||||
if finish_reason:
|
||||
usage = chunk.get("usage", {})
|
||||
yield "event: content_block_stop\n"
|
||||
yield f'data: {json.dumps({"type": "content_block_stop", "index": 0})}\n\n'
|
||||
yield "event: message_delta\n"
|
||||
yield f'data: {json.dumps({"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence": None}, "usage": {"output_tokens": usage.get("completion_tokens", 0)}})}\n\n'
|
||||
yield "event: message_stop\n"
|
||||
yield f'data: {json.dumps({"type": "message_stop"})}\n\n'
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
173
python/ollama_provider.py
Normal file
173
python/ollama_provider.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""
|
||||
ollama_provider.py
|
||||
------------------
|
||||
Adds native Ollama support to openclaude.
|
||||
Lets Claude Code route requests to any locally-running Ollama model
|
||||
(llama3, mistral, codellama, phi3, qwen2, deepseek-coder, etc.)
|
||||
without needing an API key.
|
||||
|
||||
Usage (.env):
|
||||
PREFERRED_PROVIDER=ollama
|
||||
OLLAMA_BASE_URL=http://localhost:11434
|
||||
BIG_MODEL=codellama:34b
|
||||
SMALL_MODEL=llama3:8b
|
||||
"""
|
||||
|
||||
import httpx
|
||||
import logging
|
||||
import os
|
||||
from typing import AsyncIterator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
|
||||
|
||||
|
||||
async def check_ollama_running() -> bool:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=3.0) as client:
|
||||
resp = await client.get(f"{OLLAMA_BASE_URL}/api/tags")
|
||||
return resp.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
async def list_ollama_models() -> list[str]:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
resp = await client.get(f"{OLLAMA_BASE_URL}/api/tags")
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return [m["name"] for m in data.get("models", [])]
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not list Ollama models: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def normalize_ollama_model(model_name: str) -> str:
|
||||
if model_name.startswith("ollama/"):
|
||||
return model_name[len("ollama/"):]
|
||||
return model_name
|
||||
|
||||
|
||||
def _extract_ollama_image_data(block: dict) -> str | None:
|
||||
source = block.get("source")
|
||||
if not isinstance(source, dict):
|
||||
return None
|
||||
if source.get("type") != "base64":
|
||||
return None
|
||||
data = source.get("data")
|
||||
if isinstance(data, str) and data:
|
||||
return data
|
||||
return None
|
||||
|
||||
|
||||
def anthropic_to_ollama_messages(messages: list[dict]) -> list[dict]:
|
||||
ollama_messages = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, str):
|
||||
ollama_messages.append({"role": role, "content": content})
|
||||
elif isinstance(content, list):
|
||||
text_parts = []
|
||||
image_parts = []
|
||||
for block in content:
|
||||
if isinstance(block, dict):
|
||||
if block.get("type") == "text":
|
||||
text_parts.append(block.get("text", ""))
|
||||
elif block.get("type") == "image":
|
||||
image_data = _extract_ollama_image_data(block)
|
||||
if image_data:
|
||||
image_parts.append(image_data)
|
||||
else:
|
||||
text_parts.append("[image]")
|
||||
elif isinstance(block, str):
|
||||
text_parts.append(block)
|
||||
ollama_message = {"role": role, "content": "\n".join(text_parts)}
|
||||
if image_parts:
|
||||
ollama_message["images"] = image_parts
|
||||
ollama_messages.append(ollama_message)
|
||||
return ollama_messages
|
||||
|
||||
|
||||
async def ollama_chat(
|
||||
model: str,
|
||||
messages: list[dict],
|
||||
system: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 1.0,
|
||||
) -> dict:
|
||||
model = normalize_ollama_model(model)
|
||||
ollama_messages = anthropic_to_ollama_messages(messages)
|
||||
if system:
|
||||
ollama_messages.insert(0, {"role": "system", "content": system})
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": ollama_messages,
|
||||
"stream": False,
|
||||
"options": {"num_predict": max_tokens, "temperature": temperature},
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
resp = await client.post(f"{OLLAMA_BASE_URL}/api/chat", json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
assistant_text = data.get("message", {}).get("content", "")
|
||||
return {
|
||||
"id": f"msg_ollama_{data.get('created_at', 'unknown')}",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": assistant_text}],
|
||||
"model": model,
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
"usage": {
|
||||
"input_tokens": data.get("prompt_eval_count", 0),
|
||||
"output_tokens": data.get("eval_count", 0),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def ollama_chat_stream(
|
||||
model: str,
|
||||
messages: list[dict],
|
||||
system: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 1.0,
|
||||
) -> AsyncIterator[str]:
|
||||
import json
|
||||
model = normalize_ollama_model(model)
|
||||
ollama_messages = anthropic_to_ollama_messages(messages)
|
||||
if system:
|
||||
ollama_messages.insert(0, {"role": "system", "content": system})
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": ollama_messages,
|
||||
"stream": True,
|
||||
"options": {"num_predict": max_tokens, "temperature": temperature},
|
||||
}
|
||||
yield "event: message_start\n"
|
||||
yield f'data: {json.dumps({"type": "message_start", "message": {"id": "msg_ollama_stream", "type": "message", "role": "assistant", "content": [], "model": model, "stop_reason": None, "usage": {"input_tokens": 0, "output_tokens": 0}}})}\n\n'
|
||||
yield "event: content_block_start\n"
|
||||
yield f'data: {json.dumps({"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}})}\n\n'
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
async with client.stream("POST", f"{OLLAMA_BASE_URL}/api/chat", json=payload) as resp:
|
||||
resp.raise_for_status()
|
||||
async for line in resp.aiter_lines():
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
chunk = json.loads(line)
|
||||
delta_text = chunk.get("message", {}).get("content", "")
|
||||
if delta_text:
|
||||
yield "event: content_block_delta\n"
|
||||
yield f'data: {json.dumps({"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": delta_text}})}\n\n'
|
||||
if chunk.get("done"):
|
||||
yield "event: content_block_stop\n"
|
||||
yield f'data: {json.dumps({"type": "content_block_stop", "index": 0})}\n\n'
|
||||
yield "event: message_delta\n"
|
||||
yield f'data: {json.dumps({"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence": None}, "usage": {"output_tokens": chunk.get("eval_count", 0)}})}\n\n'
|
||||
yield "event: message_stop\n"
|
||||
yield f'data: {json.dumps({"type": "message_stop"})}\n\n'
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
379
python/smart_router.py
Normal file
379
python/smart_router.py
Normal file
@@ -0,0 +1,379 @@
|
||||
"""
|
||||
smart_router.py
|
||||
---------------
|
||||
Intelligent auto-router for openclaude.
|
||||
|
||||
Instead of always using one fixed provider, the smart router:
|
||||
- Pings all configured providers on startup
|
||||
- Scores them by latency, cost, and health
|
||||
- Routes each request to the optimal provider
|
||||
- Falls back automatically if a provider fails
|
||||
- Learns from real request timings over time
|
||||
|
||||
Usage in server.py:
|
||||
from smart_router import SmartRouter
|
||||
router = SmartRouter()
|
||||
await router.initialize()
|
||||
result = await router.route(messages, model, stream)
|
||||
|
||||
.env config:
|
||||
ROUTER_MODE=smart # or: fixed (default behaviour)
|
||||
ROUTER_STRATEGY=latency # or: cost, balanced
|
||||
ROUTER_FALLBACK=true # auto-retry on failure
|
||||
|
||||
Contribution to: https://github.com/Gitlawb/openclaude
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Provider definitions ──────────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class Provider:
|
||||
name: str # e.g. "openai", "gemini", "ollama"
|
||||
ping_url: str # URL used to check health
|
||||
api_key_env: str # env var name for API key
|
||||
cost_per_1k_tokens: float # estimated cost USD per 1k tokens
|
||||
big_model: str # model for sonnet/large requests
|
||||
small_model: str # model for haiku/small requests
|
||||
latency_ms: float = 9999.0 # updated by benchmark
|
||||
healthy: bool = True # updated by health checks
|
||||
request_count: int = 0 # total requests routed here
|
||||
error_count: int = 0 # total errors from this provider
|
||||
avg_latency_ms: float = 9999.0 # rolling average from real requests
|
||||
|
||||
@property
|
||||
def api_key(self) -> Optional[str]:
|
||||
return os.getenv(self.api_key_env)
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
"""True if the provider has an API key set."""
|
||||
if self.name in ("ollama", "atomic-chat"):
|
||||
return True # Local providers need no API key
|
||||
return bool(self.api_key)
|
||||
|
||||
@property
|
||||
def error_rate(self) -> float:
|
||||
if self.request_count == 0:
|
||||
return 0.0
|
||||
return self.error_count / self.request_count
|
||||
|
||||
def score(self, strategy: str = "balanced") -> float:
|
||||
"""
|
||||
Lower score = better provider.
|
||||
strategy: 'latency' | 'cost' | 'balanced'
|
||||
"""
|
||||
if not self.healthy or not self.is_configured:
|
||||
return float("inf")
|
||||
|
||||
latency_score = self.avg_latency_ms / 1000.0 # normalize to seconds
|
||||
cost_score = self.cost_per_1k_tokens * 100 # normalize to similar scale
|
||||
error_penalty = self.error_rate * 500 # heavy penalty for errors
|
||||
|
||||
if strategy == "latency":
|
||||
return latency_score + error_penalty
|
||||
elif strategy == "cost":
|
||||
return cost_score + error_penalty
|
||||
else: # balanced
|
||||
return (latency_score * 0.5) + (cost_score * 0.5) + error_penalty
|
||||
|
||||
|
||||
# ── Default provider catalogue ────────────────────────────────────────────────
|
||||
|
||||
def build_default_providers() -> list[Provider]:
|
||||
big = os.getenv("BIG_MODEL", "gpt-4.1")
|
||||
small = os.getenv("SMALL_MODEL", "gpt-4.1-mini")
|
||||
ollama_url = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
|
||||
atomic_chat_url = os.getenv("ATOMIC_CHAT_BASE_URL", "http://127.0.0.1:1337")
|
||||
|
||||
return [
|
||||
Provider(
|
||||
name="openai",
|
||||
ping_url="https://api.openai.com/v1/models",
|
||||
api_key_env="OPENAI_API_KEY",
|
||||
cost_per_1k_tokens=0.002,
|
||||
big_model=big if "gpt" in big else "gpt-4.1",
|
||||
small_model=small if "gpt" in small else "gpt-4.1-mini",
|
||||
),
|
||||
Provider(
|
||||
name="gemini",
|
||||
ping_url="https://generativelanguage.googleapis.com/v1/models",
|
||||
api_key_env="GEMINI_API_KEY",
|
||||
cost_per_1k_tokens=0.0005,
|
||||
big_model=big if "gemini" in big else "gemini-2.5-pro",
|
||||
small_model=small if "gemini" in small else "gemini-2.0-flash",
|
||||
),
|
||||
Provider(
|
||||
name="ollama",
|
||||
ping_url=f"{ollama_url}/api/tags",
|
||||
api_key_env="",
|
||||
cost_per_1k_tokens=0.0, # free — local
|
||||
big_model=big if "gemini" not in big and "gpt" not in big else "llama3:8b",
|
||||
small_model=small if "gemini" not in small and "gpt" not in small else "llama3:8b",
|
||||
),
|
||||
Provider(
|
||||
name="atomic-chat",
|
||||
ping_url=f"{atomic_chat_url}/v1/models",
|
||||
api_key_env="",
|
||||
cost_per_1k_tokens=0.0, # free — local (Apple Silicon)
|
||||
big_model=big if "gemini" not in big and "gpt" not in big else "llama3:8b",
|
||||
small_model=small if "gemini" not in small and "gpt" not in small else "llama3:8b",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ── Smart Router ──────────────────────────────────────────────────────────────
|
||||
|
||||
class SmartRouter:
|
||||
"""
|
||||
Intelligently routes Claude Code API requests to the best
|
||||
available LLM provider based on latency, cost, and health.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
providers: Optional[list[Provider]] = None,
|
||||
strategy: Optional[str] = None,
|
||||
fallback_enabled: Optional[bool] = None,
|
||||
):
|
||||
self.providers = providers or build_default_providers()
|
||||
self.strategy = strategy or os.getenv("ROUTER_STRATEGY", "balanced")
|
||||
self.fallback_enabled = (
|
||||
fallback_enabled
|
||||
if fallback_enabled is not None
|
||||
else os.getenv("ROUTER_FALLBACK", "true").lower() == "true"
|
||||
)
|
||||
self._initialized = False
|
||||
|
||||
# ── Initialization ────────────────────────────────────────────────────────
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Ping all providers and build initial latency scores."""
|
||||
logger.info("SmartRouter: benchmarking providers...")
|
||||
await asyncio.gather(
|
||||
*[self._ping_provider(p) for p in self.providers],
|
||||
return_exceptions=True,
|
||||
)
|
||||
available = [p for p in self.providers if p.healthy and p.is_configured]
|
||||
logger.info(
|
||||
f"SmartRouter ready. Available providers: "
|
||||
f"{[p.name for p in available]}"
|
||||
)
|
||||
if not available:
|
||||
logger.warning(
|
||||
"SmartRouter: no providers available! "
|
||||
"Check your API keys in .env"
|
||||
)
|
||||
self._initialized = True
|
||||
|
||||
async def _ping_provider(self, provider: Provider) -> None:
|
||||
"""Measure latency to a provider's health endpoint."""
|
||||
if not provider.is_configured:
|
||||
provider.healthy = False
|
||||
logger.debug(f"SmartRouter: {provider.name} skipped — no API key")
|
||||
return
|
||||
|
||||
headers = {}
|
||||
if provider.api_key:
|
||||
headers["Authorization"] = f"Bearer {provider.api_key}"
|
||||
|
||||
start = time.monotonic()
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
resp = await client.get(provider.ping_url, headers=headers)
|
||||
elapsed_ms = (time.monotonic() - start) * 1000
|
||||
if resp.status_code in (200, 400, 401, 403):
|
||||
# 400/401/403 means reachable, just possibly bad key
|
||||
# We still mark healthy for routing purposes
|
||||
provider.healthy = True
|
||||
provider.latency_ms = elapsed_ms
|
||||
provider.avg_latency_ms = elapsed_ms
|
||||
logger.info(
|
||||
f"SmartRouter: {provider.name} OK "
|
||||
f"({elapsed_ms:.0f}ms, status={resp.status_code})"
|
||||
)
|
||||
else:
|
||||
provider.healthy = False
|
||||
logger.warning(
|
||||
f"SmartRouter: {provider.name} unhealthy "
|
||||
f"(status={resp.status_code})"
|
||||
)
|
||||
except Exception as e:
|
||||
provider.healthy = False
|
||||
logger.warning(f"SmartRouter: {provider.name} unreachable — {e}")
|
||||
|
||||
# ── Routing logic ─────────────────────────────────────────────────────────
|
||||
|
||||
def select_provider(self, is_large_request: bool = False) -> Optional[Provider]:
|
||||
"""
|
||||
Pick the best available provider for this request.
|
||||
Returns None if no providers are available.
|
||||
"""
|
||||
available = [
|
||||
p for p in self.providers
|
||||
if p.healthy and p.is_configured
|
||||
]
|
||||
if not available:
|
||||
return None
|
||||
|
||||
return min(available, key=lambda p: p.score(self.strategy))
|
||||
|
||||
def get_model_for_provider(
|
||||
self,
|
||||
provider: Provider,
|
||||
claude_model: str,
|
||||
is_large_request: bool = False,
|
||||
) -> str:
|
||||
"""Map a Claude model name to the provider's actual model."""
|
||||
if is_large_request:
|
||||
return provider.big_model
|
||||
is_large = any(
|
||||
keyword in claude_model.lower()
|
||||
for keyword in ["opus", "sonnet", "large", "big"]
|
||||
)
|
||||
return provider.big_model if is_large else provider.small_model
|
||||
|
||||
def is_large_request(self, messages: list[dict]) -> bool:
|
||||
"""Estimate if this is a large request based on message length."""
|
||||
total_chars = sum(
|
||||
len(str(m.get("content", ""))) for m in messages
|
||||
)
|
||||
return total_chars > 2000 # >2000 chars = treat as large
|
||||
|
||||
def _update_latency(self, provider: Provider, duration_ms: float) -> None:
|
||||
"""Exponential moving average update for latency tracking."""
|
||||
alpha = 0.3 # weight for new observation
|
||||
provider.avg_latency_ms = (
|
||||
alpha * duration_ms + (1 - alpha) * provider.avg_latency_ms
|
||||
)
|
||||
|
||||
# ── Main routing entry point ──────────────────────────────────────────────
|
||||
|
||||
async def route(
|
||||
self,
|
||||
messages: list[dict],
|
||||
claude_model: str = "claude-sonnet",
|
||||
attempt: int = 0,
|
||||
exclude_providers: Optional[list[str]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Route a request to the best provider.
|
||||
Returns a dict with routing decision info:
|
||||
{
|
||||
"provider": provider name,
|
||||
"model": actual model to use,
|
||||
"api_key": API key for the provider,
|
||||
"base_url": base URL for the provider,
|
||||
}
|
||||
Raises RuntimeError if no providers available.
|
||||
"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
exclude = set(exclude_providers or [])
|
||||
large = self.is_large_request(messages)
|
||||
|
||||
available = [
|
||||
p for p in self.providers
|
||||
if p.healthy and p.is_configured and p.name not in exclude
|
||||
]
|
||||
|
||||
if not available:
|
||||
raise RuntimeError(
|
||||
"SmartRouter: no providers available. "
|
||||
"Check your API keys and provider health."
|
||||
)
|
||||
|
||||
provider = min(available, key=lambda p: p.score(self.strategy))
|
||||
model = self.get_model_for_provider(
|
||||
provider,
|
||||
claude_model,
|
||||
is_large_request=large,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"SmartRouter: routing to {provider.name}/{model} "
|
||||
f"(strategy={self.strategy}, large={large}, attempt={attempt})"
|
||||
)
|
||||
|
||||
return {
|
||||
"provider": provider.name,
|
||||
"model": model,
|
||||
"api_key": provider.api_key or "none",
|
||||
"provider_object": provider,
|
||||
}
|
||||
|
||||
async def record_result(
|
||||
self,
|
||||
provider_name: str,
|
||||
success: bool,
|
||||
duration_ms: float,
|
||||
) -> None:
|
||||
"""
|
||||
Record the outcome of a request.
|
||||
Called after each proxied request to update provider scores.
|
||||
"""
|
||||
provider = next(
|
||||
(p for p in self.providers if p.name == provider_name), None
|
||||
)
|
||||
if not provider:
|
||||
return
|
||||
|
||||
provider.request_count += 1
|
||||
if success:
|
||||
self._update_latency(provider, duration_ms)
|
||||
else:
|
||||
provider.error_count += 1
|
||||
# After 3 consecutive failures, mark unhealthy temporarily
|
||||
recent_errors = provider.error_count
|
||||
recent_total = provider.request_count
|
||||
if recent_total >= 3 and (recent_errors / recent_total) > 0.7:
|
||||
logger.warning(
|
||||
f"SmartRouter: {provider_name} error rate high "
|
||||
f"({provider.error_rate:.0%}), marking unhealthy"
|
||||
)
|
||||
provider.healthy = False
|
||||
# Schedule re-check after 60s
|
||||
asyncio.create_task(self._recheck_provider(provider, delay=60))
|
||||
|
||||
async def _recheck_provider(
|
||||
self, provider: Provider, delay: float = 60
|
||||
) -> None:
|
||||
"""Re-ping a provider after a delay and restore if healthy."""
|
||||
await asyncio.sleep(delay)
|
||||
await self._ping_provider(provider)
|
||||
if provider.healthy:
|
||||
logger.info(
|
||||
f"SmartRouter: {provider.name} recovered, "
|
||||
f"re-adding to pool"
|
||||
)
|
||||
|
||||
# ── Status report ─────────────────────────────────────────────────────────
|
||||
|
||||
def status(self) -> list[dict]:
|
||||
"""Return current provider status for monitoring."""
|
||||
return [
|
||||
{
|
||||
"provider": p.name,
|
||||
"healthy": p.healthy,
|
||||
"configured": p.is_configured,
|
||||
"latency_ms": round(p.avg_latency_ms, 1),
|
||||
"cost_per_1k": p.cost_per_1k_tokens,
|
||||
"requests": p.request_count,
|
||||
"errors": p.error_count,
|
||||
"error_rate": f"{p.error_rate:.1%}",
|
||||
"score": round(p.score(self.strategy), 3)
|
||||
if p.healthy and p.is_configured
|
||||
else "N/A",
|
||||
}
|
||||
for p in self.providers
|
||||
]
|
||||
1
python/tests/__init__.py
Normal file
1
python/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Pytest package marker for the Python helper test suite.
|
||||
5
python/tests/conftest.py
Normal file
5
python/tests/conftest.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
# Make the sibling `python/` helper modules importable from this test package.
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
||||
130
python/tests/test_atomic_chat_provider.py
Normal file
130
python/tests/test_atomic_chat_provider.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""
|
||||
test_atomic_chat_provider.py
|
||||
Run: pytest python/tests/test_atomic_chat_provider.py -v
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from atomic_chat_provider import (
|
||||
atomic_chat,
|
||||
list_atomic_chat_models,
|
||||
check_atomic_chat_running,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_atomic_chat_running_true():
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
with patch("atomic_chat_provider.httpx.AsyncClient") as MockClient:
|
||||
MockClient.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response)
|
||||
result = await check_atomic_chat_running()
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_atomic_chat_running_false_on_exception():
|
||||
with patch("atomic_chat_provider.httpx.AsyncClient") as MockClient:
|
||||
MockClient.return_value.__aenter__.return_value.get = AsyncMock(side_effect=Exception("refused"))
|
||||
result = await check_atomic_chat_running()
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_models_returns_ids():
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"data": [{"id": "llama-3.1-8b"}, {"id": "mistral-7b"}],
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
with patch("atomic_chat_provider.httpx.AsyncClient") as MockClient:
|
||||
MockClient.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response)
|
||||
models = await list_atomic_chat_models()
|
||||
assert "llama-3.1-8b" in models
|
||||
assert "mistral-7b" in models
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_models_empty_on_failure():
|
||||
with patch("atomic_chat_provider.httpx.AsyncClient") as MockClient:
|
||||
MockClient.return_value.__aenter__.return_value.get = AsyncMock(side_effect=Exception("down"))
|
||||
models = await list_atomic_chat_models()
|
||||
assert models == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_atomic_chat_returns_anthropic_format():
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"id": "chatcmpl-abc123",
|
||||
"choices": [{"message": {"content": "42 is the answer."}}],
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 8},
|
||||
}
|
||||
with patch("atomic_chat_provider.httpx.AsyncClient") as MockClient:
|
||||
MockClient.return_value.__aenter__.return_value.post = AsyncMock(return_value=mock_response)
|
||||
result = await atomic_chat(
|
||||
model="llama-3.1-8b",
|
||||
messages=[{"role": "user", "content": "What is 6*7?"}],
|
||||
)
|
||||
assert result["type"] == "message"
|
||||
assert result["role"] == "assistant"
|
||||
assert "42" in result["content"][0]["text"]
|
||||
assert result["usage"]["input_tokens"] == 10
|
||||
assert result["usage"]["output_tokens"] == 8
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_atomic_chat_prepends_system():
|
||||
captured = {}
|
||||
|
||||
async def mock_post(url, json=None, **kwargs):
|
||||
captured.update(json or {})
|
||||
m = MagicMock()
|
||||
m.raise_for_status = MagicMock()
|
||||
m.json.return_value = {
|
||||
"id": "chatcmpl-xyz",
|
||||
"choices": [{"message": {"content": "ok"}}],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 1},
|
||||
}
|
||||
return m
|
||||
|
||||
with patch("atomic_chat_provider.httpx.AsyncClient") as MockClient:
|
||||
MockClient.return_value.__aenter__.return_value.post = mock_post
|
||||
await atomic_chat(
|
||||
model="llama-3.1-8b",
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
system="Be helpful.",
|
||||
)
|
||||
assert captured["messages"][0]["role"] == "system"
|
||||
assert "helpful" in captured["messages"][0]["content"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_atomic_chat_sends_correct_payload():
|
||||
captured = {}
|
||||
|
||||
async def mock_post(url, json=None, **kwargs):
|
||||
captured.update(json or {})
|
||||
m = MagicMock()
|
||||
m.raise_for_status = MagicMock()
|
||||
m.json.return_value = {
|
||||
"id": "chatcmpl-xyz",
|
||||
"choices": [{"message": {"content": "ok"}}],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 1},
|
||||
}
|
||||
return m
|
||||
|
||||
with patch("atomic_chat_provider.httpx.AsyncClient") as MockClient:
|
||||
MockClient.return_value.__aenter__.return_value.post = mock_post
|
||||
await atomic_chat(
|
||||
model="test-model",
|
||||
messages=[{"role": "user", "content": "Test"}],
|
||||
max_tokens=2048,
|
||||
temperature=0.5,
|
||||
)
|
||||
assert captured["model"] == "test-model"
|
||||
assert captured["max_tokens"] == 2048
|
||||
assert captured["temperature"] == 0.5
|
||||
assert captured["stream"] is False
|
||||
192
python/tests/test_ollama_provider.py
Normal file
192
python/tests/test_ollama_provider.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""
|
||||
test_ollama_provider.py
|
||||
Run: pytest python/tests/test_ollama_provider.py -v
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from ollama_provider import (
|
||||
normalize_ollama_model,
|
||||
anthropic_to_ollama_messages,
|
||||
ollama_chat,
|
||||
list_ollama_models,
|
||||
check_ollama_running,
|
||||
)
|
||||
|
||||
|
||||
def test_normalize_strips_prefix():
|
||||
assert normalize_ollama_model("ollama/llama3:8b") == "llama3:8b"
|
||||
|
||||
|
||||
def test_normalize_no_prefix():
|
||||
assert normalize_ollama_model("codellama:34b") == "codellama:34b"
|
||||
|
||||
|
||||
def test_normalize_empty():
|
||||
assert normalize_ollama_model("") == ""
|
||||
|
||||
|
||||
def test_converts_string_content():
|
||||
messages = [{"role": "user", "content": "Hello!"}]
|
||||
result = anthropic_to_ollama_messages(messages)
|
||||
assert result == [{"role": "user", "content": "Hello!"}]
|
||||
|
||||
|
||||
def test_converts_text_block_list():
|
||||
messages = [{"role": "user", "content": [{"type": "text", "text": "What is Python?"}]}]
|
||||
result = anthropic_to_ollama_messages(messages)
|
||||
assert result[0]["content"] == "What is Python?"
|
||||
|
||||
|
||||
def test_converts_image_block_to_placeholder():
|
||||
messages = [{"role": "user", "content": [{"type": "image", "source": {}}, {"type": "text", "text": "Describe this"}]}]
|
||||
result = anthropic_to_ollama_messages(messages)
|
||||
assert "[image]" in result[0]["content"]
|
||||
assert "Describe this" in result[0]["content"]
|
||||
|
||||
|
||||
def test_converts_base64_image_block_to_ollama_images():
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": "YWJjMTIz",
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "Describe this"},
|
||||
],
|
||||
}]
|
||||
result = anthropic_to_ollama_messages(messages)
|
||||
assert result[0]["images"] == ["YWJjMTIz"]
|
||||
assert "Describe this" in result[0]["content"]
|
||||
|
||||
def test_converts_multi_turn():
|
||||
messages = [
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": "Hello!"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
]
|
||||
result = anthropic_to_ollama_messages(messages)
|
||||
assert len(result) == 3
|
||||
assert result[1]["role"] == "assistant"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ollama_running_true():
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
with patch("ollama_provider.httpx.AsyncClient") as MockClient:
|
||||
MockClient.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response)
|
||||
result = await check_ollama_running()
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ollama_running_false_on_exception():
|
||||
with patch("ollama_provider.httpx.AsyncClient") as MockClient:
|
||||
MockClient.return_value.__aenter__.return_value.get = AsyncMock(side_effect=Exception("refused"))
|
||||
result = await check_ollama_running()
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_models_returns_names():
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"models": [{"name": "llama3:8b"}, {"name": "codellama:34b"}]}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
with patch("ollama_provider.httpx.AsyncClient") as MockClient:
|
||||
MockClient.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response)
|
||||
models = await list_ollama_models()
|
||||
assert "llama3:8b" in models
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ollama_chat_returns_anthropic_format():
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"message": {"content": "42 is the answer."},
|
||||
"created_at": "2026-01-01T00:00:00Z",
|
||||
"prompt_eval_count": 10,
|
||||
"eval_count": 8,
|
||||
}
|
||||
with patch("ollama_provider.httpx.AsyncClient") as MockClient:
|
||||
MockClient.return_value.__aenter__.return_value.post = AsyncMock(return_value=mock_response)
|
||||
result = await ollama_chat(
|
||||
model="llama3:8b",
|
||||
messages=[{"role": "user", "content": "What is 6*7?"}]
|
||||
)
|
||||
assert result["type"] == "message"
|
||||
assert result["role"] == "assistant"
|
||||
assert "42" in result["content"][0]["text"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ollama_chat_prepends_system():
|
||||
captured = {}
|
||||
|
||||
async def mock_post(url, json=None, **kwargs):
|
||||
captured.update(json or {})
|
||||
m = MagicMock()
|
||||
m.raise_for_status = MagicMock()
|
||||
m.json.return_value = {
|
||||
"message": {"content": "ok"},
|
||||
"created_at": "",
|
||||
"prompt_eval_count": 1,
|
||||
"eval_count": 1
|
||||
}
|
||||
return m
|
||||
with patch("ollama_provider.httpx.AsyncClient") as MockClient:
|
||||
MockClient.return_value.__aenter__.return_value.post = mock_post
|
||||
await ollama_chat(
|
||||
model="llama3:8b",
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
system="Be helpful.",
|
||||
)
|
||||
assert captured["messages"][0]["role"] == "system"
|
||||
assert "helpful" in captured["messages"][0]["content"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ollama_chat_includes_base64_images_in_payload():
|
||||
captured = {}
|
||||
|
||||
async def mock_post(url, json=None, **kwargs):
|
||||
captured.update(json or {})
|
||||
m = MagicMock()
|
||||
m.raise_for_status = MagicMock()
|
||||
m.json.return_value = {
|
||||
"message": {"content": "ok"},
|
||||
"created_at": "",
|
||||
"prompt_eval_count": 1,
|
||||
"eval_count": 1,
|
||||
}
|
||||
return m
|
||||
|
||||
with patch("ollama_provider.httpx.AsyncClient") as MockClient:
|
||||
MockClient.return_value.__aenter__.return_value.post = mock_post
|
||||
await ollama_chat(
|
||||
model="llama3:8b",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": "ZHVtbXk=",
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "What is in this image?"},
|
||||
],
|
||||
}],
|
||||
)
|
||||
|
||||
assert captured["messages"][0]["images"] == ["ZHVtbXk="]
|
||||
assert "What is in this image?" in captured["messages"][0]["content"]
|
||||
231
python/tests/test_smart_router.py
Normal file
231
python/tests/test_smart_router.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""
|
||||
test_smart_router.py
|
||||
--------------------
|
||||
Tests for the SmartRouter.
|
||||
Run: pytest python/tests/test_smart_router.py -v
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from smart_router import SmartRouter, Provider
|
||||
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def fake_api_key(monkeypatch):
|
||||
monkeypatch.setenv("FAKE_KEY", "test-key")
|
||||
|
||||
|
||||
def make_provider(name, healthy=True, configured=True,
|
||||
latency=100.0, cost=0.002, errors=0, requests=0):
|
||||
p = Provider(
|
||||
name=name,
|
||||
ping_url=f"https://{name}.example.com/health",
|
||||
api_key_env="FAKE_KEY",
|
||||
cost_per_1k_tokens=cost,
|
||||
big_model=f"{name}-big",
|
||||
small_model=f"{name}-small",
|
||||
)
|
||||
p.healthy = healthy
|
||||
p.avg_latency_ms = latency
|
||||
p.error_count = errors
|
||||
p.request_count = requests
|
||||
if not configured:
|
||||
p.api_key_env = "" # makes is_configured False for non-local providers
|
||||
return p
|
||||
|
||||
|
||||
def make_router(providers=None, strategy="balanced"):
|
||||
r = SmartRouter(providers=providers, strategy=strategy)
|
||||
r._initialized = True
|
||||
return r
|
||||
|
||||
|
||||
# ── Provider.score() ──────────────────────────────────────────────────────────
|
||||
|
||||
def test_score_unhealthy_is_inf():
|
||||
p = make_provider("openai", healthy=False)
|
||||
assert p.score() == float("inf")
|
||||
|
||||
|
||||
def test_score_unconfigured_is_inf():
|
||||
p = make_provider("openai", configured=False)
|
||||
assert p.score() == float("inf")
|
||||
|
||||
|
||||
def test_score_latency_strategy_prefers_faster():
|
||||
fast = make_provider("fast", latency=50.0, cost=0.01)
|
||||
slow = make_provider("slow", latency=500.0, cost=0.001)
|
||||
assert fast.score("latency") < slow.score("latency")
|
||||
|
||||
|
||||
def test_score_cost_strategy_prefers_cheaper():
|
||||
cheap = make_provider("cheap", latency=500.0, cost=0.0001)
|
||||
expensive = make_provider("expensive", latency=50.0, cost=0.05)
|
||||
assert cheap.score("cost") < expensive.score("cost")
|
||||
|
||||
|
||||
def test_score_balanced_strategy_uses_both():
|
||||
p = make_provider("test", latency=200.0, cost=0.002)
|
||||
s = p.score("balanced")
|
||||
assert s > 0
|
||||
|
||||
|
||||
def test_score_error_rate_penalty():
|
||||
clean = make_provider("clean", errors=0, requests=10)
|
||||
dirty = make_provider("dirty", errors=8, requests=10)
|
||||
assert clean.score() < dirty.score()
|
||||
|
||||
|
||||
# ── SmartRouter.is_large_request() ───────────────────────────────────────────
|
||||
|
||||
def test_is_large_request_short():
|
||||
r = make_router()
|
||||
msgs = [{"role": "user", "content": "Hello!"}]
|
||||
assert r.is_large_request(msgs) is False
|
||||
|
||||
|
||||
def test_is_large_request_long():
|
||||
r = make_router()
|
||||
msgs = [{"role": "user", "content": "x" * 3000}]
|
||||
assert r.is_large_request(msgs) is True
|
||||
|
||||
|
||||
# ── SmartRouter.select_provider() ────────────────────────────────────────────
|
||||
|
||||
def test_select_provider_picks_best_score():
|
||||
p1 = make_provider("slow", latency=800.0)
|
||||
p2 = make_provider("fast", latency=50.0)
|
||||
r = make_router(providers=[p1, p2], strategy="latency")
|
||||
selected = r.select_provider()
|
||||
assert selected.name == "fast"
|
||||
|
||||
|
||||
def test_select_provider_skips_unhealthy():
|
||||
p1 = make_provider("bad", healthy=False)
|
||||
p2 = make_provider("good", healthy=True)
|
||||
r = make_router(providers=[p1, p2])
|
||||
selected = r.select_provider()
|
||||
assert selected.name == "good"
|
||||
|
||||
|
||||
def test_select_provider_returns_none_when_all_down():
|
||||
p1 = make_provider("a", healthy=False)
|
||||
p2 = make_provider("b", healthy=False)
|
||||
r = make_router(providers=[p1, p2])
|
||||
assert r.select_provider() is None
|
||||
|
||||
|
||||
# ── SmartRouter.get_model_for_provider() ─────────────────────────────────────
|
||||
|
||||
def test_get_model_large_request():
|
||||
p = make_provider("openai")
|
||||
r = make_router()
|
||||
model = r.get_model_for_provider(p, "claude-sonnet")
|
||||
assert model == "openai-big"
|
||||
|
||||
|
||||
def test_get_model_large_message_overrides_claude_label():
|
||||
p = make_provider("openai")
|
||||
r = make_router()
|
||||
model = r.get_model_for_provider(p, "claude-haiku", is_large_request=True)
|
||||
assert model == "openai-big"
|
||||
|
||||
|
||||
def test_get_model_small_request():
|
||||
p = make_provider("openai")
|
||||
r = make_router()
|
||||
model = r.get_model_for_provider(p, "claude-haiku")
|
||||
assert model == "openai-small"
|
||||
|
||||
|
||||
# ── SmartRouter.route() ───────────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_returns_best_provider():
|
||||
p1 = make_provider("expensive", cost=0.05, latency=50.0)
|
||||
p2 = make_provider("cheap", cost=0.0005, latency=200.0)
|
||||
r = make_router(providers=[p1, p2], strategy="cost")
|
||||
result = await r.route([{"role": "user", "content": "Hi"}], "claude-haiku")
|
||||
assert result["provider"] == "cheap"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_uses_big_model_for_large_message_bodies():
|
||||
p = make_provider("openai")
|
||||
r = make_router(providers=[p])
|
||||
result = await r.route([
|
||||
{"role": "user", "content": "x" * 3001},
|
||||
], "claude-haiku")
|
||||
assert result["model"] == "openai-big"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_raises_when_no_providers():
|
||||
p = make_provider("a", healthy=False)
|
||||
r = make_router(providers=[p])
|
||||
with pytest.raises(RuntimeError, match="no providers available"):
|
||||
await r.route([{"role": "user", "content": "Hi"}])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_excludes_providers():
|
||||
p1 = make_provider("openai", latency=50.0)
|
||||
p2 = make_provider("gemini", latency=200.0)
|
||||
r = make_router(providers=[p1, p2], strategy="latency")
|
||||
result = await r.route(
|
||||
[{"role": "user", "content": "Hi"}],
|
||||
exclude_providers=["openai"]
|
||||
)
|
||||
assert result["provider"] == "gemini"
|
||||
|
||||
|
||||
# ── SmartRouter.record_result() ──────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_result_updates_latency():
|
||||
p = make_provider("openai", latency=200.0)
|
||||
r = make_router(providers=[p])
|
||||
await r.record_result("openai", success=True, duration_ms=100.0)
|
||||
assert p.avg_latency_ms < 200.0 # should decrease toward 100
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_result_increments_requests():
|
||||
p = make_provider("openai")
|
||||
r = make_router(providers=[p])
|
||||
await r.record_result("openai", success=True, duration_ms=100.0)
|
||||
assert p.request_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_result_increments_errors():
|
||||
p = make_provider("openai")
|
||||
r = make_router(providers=[p])
|
||||
await r.record_result("openai", success=False, duration_ms=0)
|
||||
assert p.error_count == 1
|
||||
|
||||
|
||||
# ── SmartRouter.status() ─────────────────────────────────────────────────────
|
||||
|
||||
def test_status_returns_all_providers():
|
||||
p1 = make_provider("openai")
|
||||
p2 = make_provider("gemini")
|
||||
r = make_router(providers=[p1, p2])
|
||||
status = r.status()
|
||||
assert len(status) == 2
|
||||
names = [s["provider"] for s in status]
|
||||
assert "openai" in names
|
||||
assert "gemini" in names
|
||||
|
||||
|
||||
def test_status_contains_required_fields():
|
||||
p = make_provider("openai")
|
||||
r = make_router(providers=[p])
|
||||
status = r.status()[0]
|
||||
for field in ["provider", "healthy", "latency_ms",
|
||||
"cost_per_1k", "requests", "errors", "score"]:
|
||||
assert field in status
|
||||
Reference in New Issue
Block a user