docs: organize Python helpers and refresh README (#334)
* docs: organize Python helpers and refresh README * docs: add README status badges * test: centralize Python helper test imports * docs: add short provenance disclaimer
This commit is contained in:
1
python/tests/__init__.py
Normal file
1
python/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Pytest package marker for the Python helper test suite.
|
||||
5
python/tests/conftest.py
Normal file
5
python/tests/conftest.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
# Make the sibling `python/` helper modules importable from this test package.
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
||||
130
python/tests/test_atomic_chat_provider.py
Normal file
130
python/tests/test_atomic_chat_provider.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""
|
||||
test_atomic_chat_provider.py
|
||||
Run: pytest python/tests/test_atomic_chat_provider.py -v
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from atomic_chat_provider import (
|
||||
atomic_chat,
|
||||
list_atomic_chat_models,
|
||||
check_atomic_chat_running,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_atomic_chat_running_true():
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
with patch("atomic_chat_provider.httpx.AsyncClient") as MockClient:
|
||||
MockClient.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response)
|
||||
result = await check_atomic_chat_running()
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_atomic_chat_running_false_on_exception():
|
||||
with patch("atomic_chat_provider.httpx.AsyncClient") as MockClient:
|
||||
MockClient.return_value.__aenter__.return_value.get = AsyncMock(side_effect=Exception("refused"))
|
||||
result = await check_atomic_chat_running()
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_models_returns_ids():
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"data": [{"id": "llama-3.1-8b"}, {"id": "mistral-7b"}],
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
with patch("atomic_chat_provider.httpx.AsyncClient") as MockClient:
|
||||
MockClient.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response)
|
||||
models = await list_atomic_chat_models()
|
||||
assert "llama-3.1-8b" in models
|
||||
assert "mistral-7b" in models
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_models_empty_on_failure():
|
||||
with patch("atomic_chat_provider.httpx.AsyncClient") as MockClient:
|
||||
MockClient.return_value.__aenter__.return_value.get = AsyncMock(side_effect=Exception("down"))
|
||||
models = await list_atomic_chat_models()
|
||||
assert models == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_atomic_chat_returns_anthropic_format():
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"id": "chatcmpl-abc123",
|
||||
"choices": [{"message": {"content": "42 is the answer."}}],
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 8},
|
||||
}
|
||||
with patch("atomic_chat_provider.httpx.AsyncClient") as MockClient:
|
||||
MockClient.return_value.__aenter__.return_value.post = AsyncMock(return_value=mock_response)
|
||||
result = await atomic_chat(
|
||||
model="llama-3.1-8b",
|
||||
messages=[{"role": "user", "content": "What is 6*7?"}],
|
||||
)
|
||||
assert result["type"] == "message"
|
||||
assert result["role"] == "assistant"
|
||||
assert "42" in result["content"][0]["text"]
|
||||
assert result["usage"]["input_tokens"] == 10
|
||||
assert result["usage"]["output_tokens"] == 8
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_atomic_chat_prepends_system():
|
||||
captured = {}
|
||||
|
||||
async def mock_post(url, json=None, **kwargs):
|
||||
captured.update(json or {})
|
||||
m = MagicMock()
|
||||
m.raise_for_status = MagicMock()
|
||||
m.json.return_value = {
|
||||
"id": "chatcmpl-xyz",
|
||||
"choices": [{"message": {"content": "ok"}}],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 1},
|
||||
}
|
||||
return m
|
||||
|
||||
with patch("atomic_chat_provider.httpx.AsyncClient") as MockClient:
|
||||
MockClient.return_value.__aenter__.return_value.post = mock_post
|
||||
await atomic_chat(
|
||||
model="llama-3.1-8b",
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
system="Be helpful.",
|
||||
)
|
||||
assert captured["messages"][0]["role"] == "system"
|
||||
assert "helpful" in captured["messages"][0]["content"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_atomic_chat_sends_correct_payload():
|
||||
captured = {}
|
||||
|
||||
async def mock_post(url, json=None, **kwargs):
|
||||
captured.update(json or {})
|
||||
m = MagicMock()
|
||||
m.raise_for_status = MagicMock()
|
||||
m.json.return_value = {
|
||||
"id": "chatcmpl-xyz",
|
||||
"choices": [{"message": {"content": "ok"}}],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 1},
|
||||
}
|
||||
return m
|
||||
|
||||
with patch("atomic_chat_provider.httpx.AsyncClient") as MockClient:
|
||||
MockClient.return_value.__aenter__.return_value.post = mock_post
|
||||
await atomic_chat(
|
||||
model="test-model",
|
||||
messages=[{"role": "user", "content": "Test"}],
|
||||
max_tokens=2048,
|
||||
temperature=0.5,
|
||||
)
|
||||
assert captured["model"] == "test-model"
|
||||
assert captured["max_tokens"] == 2048
|
||||
assert captured["temperature"] == 0.5
|
||||
assert captured["stream"] is False
|
||||
192
python/tests/test_ollama_provider.py
Normal file
192
python/tests/test_ollama_provider.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""
|
||||
test_ollama_provider.py
|
||||
Run: pytest python/tests/test_ollama_provider.py -v
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from ollama_provider import (
|
||||
normalize_ollama_model,
|
||||
anthropic_to_ollama_messages,
|
||||
ollama_chat,
|
||||
list_ollama_models,
|
||||
check_ollama_running,
|
||||
)
|
||||
|
||||
|
||||
def test_normalize_strips_prefix():
|
||||
assert normalize_ollama_model("ollama/llama3:8b") == "llama3:8b"
|
||||
|
||||
|
||||
def test_normalize_no_prefix():
|
||||
assert normalize_ollama_model("codellama:34b") == "codellama:34b"
|
||||
|
||||
|
||||
def test_normalize_empty():
|
||||
assert normalize_ollama_model("") == ""
|
||||
|
||||
|
||||
def test_converts_string_content():
|
||||
messages = [{"role": "user", "content": "Hello!"}]
|
||||
result = anthropic_to_ollama_messages(messages)
|
||||
assert result == [{"role": "user", "content": "Hello!"}]
|
||||
|
||||
|
||||
def test_converts_text_block_list():
|
||||
messages = [{"role": "user", "content": [{"type": "text", "text": "What is Python?"}]}]
|
||||
result = anthropic_to_ollama_messages(messages)
|
||||
assert result[0]["content"] == "What is Python?"
|
||||
|
||||
|
||||
def test_converts_image_block_to_placeholder():
|
||||
messages = [{"role": "user", "content": [{"type": "image", "source": {}}, {"type": "text", "text": "Describe this"}]}]
|
||||
result = anthropic_to_ollama_messages(messages)
|
||||
assert "[image]" in result[0]["content"]
|
||||
assert "Describe this" in result[0]["content"]
|
||||
|
||||
|
||||
def test_converts_base64_image_block_to_ollama_images():
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": "YWJjMTIz",
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "Describe this"},
|
||||
],
|
||||
}]
|
||||
result = anthropic_to_ollama_messages(messages)
|
||||
assert result[0]["images"] == ["YWJjMTIz"]
|
||||
assert "Describe this" in result[0]["content"]
|
||||
|
||||
def test_converts_multi_turn():
|
||||
messages = [
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": "Hello!"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
]
|
||||
result = anthropic_to_ollama_messages(messages)
|
||||
assert len(result) == 3
|
||||
assert result[1]["role"] == "assistant"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ollama_running_true():
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
with patch("ollama_provider.httpx.AsyncClient") as MockClient:
|
||||
MockClient.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response)
|
||||
result = await check_ollama_running()
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ollama_running_false_on_exception():
|
||||
with patch("ollama_provider.httpx.AsyncClient") as MockClient:
|
||||
MockClient.return_value.__aenter__.return_value.get = AsyncMock(side_effect=Exception("refused"))
|
||||
result = await check_ollama_running()
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_models_returns_names():
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"models": [{"name": "llama3:8b"}, {"name": "codellama:34b"}]}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
with patch("ollama_provider.httpx.AsyncClient") as MockClient:
|
||||
MockClient.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response)
|
||||
models = await list_ollama_models()
|
||||
assert "llama3:8b" in models
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ollama_chat_returns_anthropic_format():
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"message": {"content": "42 is the answer."},
|
||||
"created_at": "2026-01-01T00:00:00Z",
|
||||
"prompt_eval_count": 10,
|
||||
"eval_count": 8,
|
||||
}
|
||||
with patch("ollama_provider.httpx.AsyncClient") as MockClient:
|
||||
MockClient.return_value.__aenter__.return_value.post = AsyncMock(return_value=mock_response)
|
||||
result = await ollama_chat(
|
||||
model="llama3:8b",
|
||||
messages=[{"role": "user", "content": "What is 6*7?"}]
|
||||
)
|
||||
assert result["type"] == "message"
|
||||
assert result["role"] == "assistant"
|
||||
assert "42" in result["content"][0]["text"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ollama_chat_prepends_system():
|
||||
captured = {}
|
||||
|
||||
async def mock_post(url, json=None, **kwargs):
|
||||
captured.update(json or {})
|
||||
m = MagicMock()
|
||||
m.raise_for_status = MagicMock()
|
||||
m.json.return_value = {
|
||||
"message": {"content": "ok"},
|
||||
"created_at": "",
|
||||
"prompt_eval_count": 1,
|
||||
"eval_count": 1
|
||||
}
|
||||
return m
|
||||
with patch("ollama_provider.httpx.AsyncClient") as MockClient:
|
||||
MockClient.return_value.__aenter__.return_value.post = mock_post
|
||||
await ollama_chat(
|
||||
model="llama3:8b",
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
system="Be helpful.",
|
||||
)
|
||||
assert captured["messages"][0]["role"] == "system"
|
||||
assert "helpful" in captured["messages"][0]["content"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ollama_chat_includes_base64_images_in_payload():
|
||||
captured = {}
|
||||
|
||||
async def mock_post(url, json=None, **kwargs):
|
||||
captured.update(json or {})
|
||||
m = MagicMock()
|
||||
m.raise_for_status = MagicMock()
|
||||
m.json.return_value = {
|
||||
"message": {"content": "ok"},
|
||||
"created_at": "",
|
||||
"prompt_eval_count": 1,
|
||||
"eval_count": 1,
|
||||
}
|
||||
return m
|
||||
|
||||
with patch("ollama_provider.httpx.AsyncClient") as MockClient:
|
||||
MockClient.return_value.__aenter__.return_value.post = mock_post
|
||||
await ollama_chat(
|
||||
model="llama3:8b",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": "ZHVtbXk=",
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "What is in this image?"},
|
||||
],
|
||||
}],
|
||||
)
|
||||
|
||||
assert captured["messages"][0]["images"] == ["ZHVtbXk="]
|
||||
assert "What is in this image?" in captured["messages"][0]["content"]
|
||||
231
python/tests/test_smart_router.py
Normal file
231
python/tests/test_smart_router.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""
|
||||
test_smart_router.py
|
||||
--------------------
|
||||
Tests for the SmartRouter.
|
||||
Run: pytest python/tests/test_smart_router.py -v
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from smart_router import SmartRouter, Provider
|
||||
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def fake_api_key(monkeypatch):
|
||||
monkeypatch.setenv("FAKE_KEY", "test-key")
|
||||
|
||||
|
||||
def make_provider(name, healthy=True, configured=True,
|
||||
latency=100.0, cost=0.002, errors=0, requests=0):
|
||||
p = Provider(
|
||||
name=name,
|
||||
ping_url=f"https://{name}.example.com/health",
|
||||
api_key_env="FAKE_KEY",
|
||||
cost_per_1k_tokens=cost,
|
||||
big_model=f"{name}-big",
|
||||
small_model=f"{name}-small",
|
||||
)
|
||||
p.healthy = healthy
|
||||
p.avg_latency_ms = latency
|
||||
p.error_count = errors
|
||||
p.request_count = requests
|
||||
if not configured:
|
||||
p.api_key_env = "" # makes is_configured False for non-local providers
|
||||
return p
|
||||
|
||||
|
||||
def make_router(providers=None, strategy="balanced"):
|
||||
r = SmartRouter(providers=providers, strategy=strategy)
|
||||
r._initialized = True
|
||||
return r
|
||||
|
||||
|
||||
# ── Provider.score() ──────────────────────────────────────────────────────────
|
||||
|
||||
def test_score_unhealthy_is_inf():
|
||||
p = make_provider("openai", healthy=False)
|
||||
assert p.score() == float("inf")
|
||||
|
||||
|
||||
def test_score_unconfigured_is_inf():
|
||||
p = make_provider("openai", configured=False)
|
||||
assert p.score() == float("inf")
|
||||
|
||||
|
||||
def test_score_latency_strategy_prefers_faster():
|
||||
fast = make_provider("fast", latency=50.0, cost=0.01)
|
||||
slow = make_provider("slow", latency=500.0, cost=0.001)
|
||||
assert fast.score("latency") < slow.score("latency")
|
||||
|
||||
|
||||
def test_score_cost_strategy_prefers_cheaper():
|
||||
cheap = make_provider("cheap", latency=500.0, cost=0.0001)
|
||||
expensive = make_provider("expensive", latency=50.0, cost=0.05)
|
||||
assert cheap.score("cost") < expensive.score("cost")
|
||||
|
||||
|
||||
def test_score_balanced_strategy_uses_both():
|
||||
p = make_provider("test", latency=200.0, cost=0.002)
|
||||
s = p.score("balanced")
|
||||
assert s > 0
|
||||
|
||||
|
||||
def test_score_error_rate_penalty():
|
||||
clean = make_provider("clean", errors=0, requests=10)
|
||||
dirty = make_provider("dirty", errors=8, requests=10)
|
||||
assert clean.score() < dirty.score()
|
||||
|
||||
|
||||
# ── SmartRouter.is_large_request() ───────────────────────────────────────────
|
||||
|
||||
def test_is_large_request_short():
|
||||
r = make_router()
|
||||
msgs = [{"role": "user", "content": "Hello!"}]
|
||||
assert r.is_large_request(msgs) is False
|
||||
|
||||
|
||||
def test_is_large_request_long():
|
||||
r = make_router()
|
||||
msgs = [{"role": "user", "content": "x" * 3000}]
|
||||
assert r.is_large_request(msgs) is True
|
||||
|
||||
|
||||
# ── SmartRouter.select_provider() ────────────────────────────────────────────
|
||||
|
||||
def test_select_provider_picks_best_score():
|
||||
p1 = make_provider("slow", latency=800.0)
|
||||
p2 = make_provider("fast", latency=50.0)
|
||||
r = make_router(providers=[p1, p2], strategy="latency")
|
||||
selected = r.select_provider()
|
||||
assert selected.name == "fast"
|
||||
|
||||
|
||||
def test_select_provider_skips_unhealthy():
|
||||
p1 = make_provider("bad", healthy=False)
|
||||
p2 = make_provider("good", healthy=True)
|
||||
r = make_router(providers=[p1, p2])
|
||||
selected = r.select_provider()
|
||||
assert selected.name == "good"
|
||||
|
||||
|
||||
def test_select_provider_returns_none_when_all_down():
|
||||
p1 = make_provider("a", healthy=False)
|
||||
p2 = make_provider("b", healthy=False)
|
||||
r = make_router(providers=[p1, p2])
|
||||
assert r.select_provider() is None
|
||||
|
||||
|
||||
# ── SmartRouter.get_model_for_provider() ─────────────────────────────────────
|
||||
|
||||
def test_get_model_large_request():
|
||||
p = make_provider("openai")
|
||||
r = make_router()
|
||||
model = r.get_model_for_provider(p, "claude-sonnet")
|
||||
assert model == "openai-big"
|
||||
|
||||
|
||||
def test_get_model_large_message_overrides_claude_label():
|
||||
p = make_provider("openai")
|
||||
r = make_router()
|
||||
model = r.get_model_for_provider(p, "claude-haiku", is_large_request=True)
|
||||
assert model == "openai-big"
|
||||
|
||||
|
||||
def test_get_model_small_request():
|
||||
p = make_provider("openai")
|
||||
r = make_router()
|
||||
model = r.get_model_for_provider(p, "claude-haiku")
|
||||
assert model == "openai-small"
|
||||
|
||||
|
||||
# ── SmartRouter.route() ───────────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_returns_best_provider():
|
||||
p1 = make_provider("expensive", cost=0.05, latency=50.0)
|
||||
p2 = make_provider("cheap", cost=0.0005, latency=200.0)
|
||||
r = make_router(providers=[p1, p2], strategy="cost")
|
||||
result = await r.route([{"role": "user", "content": "Hi"}], "claude-haiku")
|
||||
assert result["provider"] == "cheap"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_uses_big_model_for_large_message_bodies():
|
||||
p = make_provider("openai")
|
||||
r = make_router(providers=[p])
|
||||
result = await r.route([
|
||||
{"role": "user", "content": "x" * 3001},
|
||||
], "claude-haiku")
|
||||
assert result["model"] == "openai-big"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_raises_when_no_providers():
|
||||
p = make_provider("a", healthy=False)
|
||||
r = make_router(providers=[p])
|
||||
with pytest.raises(RuntimeError, match="no providers available"):
|
||||
await r.route([{"role": "user", "content": "Hi"}])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_excludes_providers():
|
||||
p1 = make_provider("openai", latency=50.0)
|
||||
p2 = make_provider("gemini", latency=200.0)
|
||||
r = make_router(providers=[p1, p2], strategy="latency")
|
||||
result = await r.route(
|
||||
[{"role": "user", "content": "Hi"}],
|
||||
exclude_providers=["openai"]
|
||||
)
|
||||
assert result["provider"] == "gemini"
|
||||
|
||||
|
||||
# ── SmartRouter.record_result() ──────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_result_updates_latency():
|
||||
p = make_provider("openai", latency=200.0)
|
||||
r = make_router(providers=[p])
|
||||
await r.record_result("openai", success=True, duration_ms=100.0)
|
||||
assert p.avg_latency_ms < 200.0 # should decrease toward 100
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_result_increments_requests():
|
||||
p = make_provider("openai")
|
||||
r = make_router(providers=[p])
|
||||
await r.record_result("openai", success=True, duration_ms=100.0)
|
||||
assert p.request_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_result_increments_errors():
|
||||
p = make_provider("openai")
|
||||
r = make_router(providers=[p])
|
||||
await r.record_result("openai", success=False, duration_ms=0)
|
||||
assert p.error_count == 1
|
||||
|
||||
|
||||
# ── SmartRouter.status() ─────────────────────────────────────────────────────
|
||||
|
||||
def test_status_returns_all_providers():
|
||||
p1 = make_provider("openai")
|
||||
p2 = make_provider("gemini")
|
||||
r = make_router(providers=[p1, p2])
|
||||
status = r.status()
|
||||
assert len(status) == 2
|
||||
names = [s["provider"] for s in status]
|
||||
assert "openai" in names
|
||||
assert "gemini" in names
|
||||
|
||||
|
||||
def test_status_contains_required_fields():
|
||||
p = make_provider("openai")
|
||||
r = make_router(providers=[p])
|
||||
status = r.status()[0]
|
||||
for field in ["provider", "healthy", "latency_ms",
|
||||
"cost_per_1k", "requests", "errors", "score"]:
|
||||
assert field in status
|
||||
Reference in New Issue
Block a user