diff --git a/atomic_chat_provider.py b/atomic_chat_provider.py new file mode 100644 index 00000000..bf55155f --- /dev/null +++ b/atomic_chat_provider.py @@ -0,0 +1,146 @@ +""" +atomic_chat_provider.py +----------------------- +Adds native Atomic Chat support to openclaude. +Lets Claude Code route requests to any locally-running model via +Atomic Chat (Apple Silicon only) at 127.0.0.1:1337. + +Atomic Chat exposes an OpenAI-compatible API, so messages are forwarded +directly without translation. + +Usage (.env): + PREFERRED_PROVIDER=atomic-chat + ATOMIC_CHAT_BASE_URL=http://127.0.0.1:1337 +""" + +import httpx +import json +import logging +import os +from typing import AsyncIterator + +logger = logging.getLogger(__name__) +ATOMIC_CHAT_BASE_URL = os.getenv("ATOMIC_CHAT_BASE_URL", "http://127.0.0.1:1337") + + +def _api_url(path: str) -> str: + return f"{ATOMIC_CHAT_BASE_URL}/v1{path}" + + +async def check_atomic_chat_running() -> bool: + try: + async with httpx.AsyncClient(timeout=3.0) as client: + resp = await client.get(_api_url("/models")) + return resp.status_code == 200 + except Exception: + return False + + +async def list_atomic_chat_models() -> list[str]: + try: + async with httpx.AsyncClient(timeout=5.0) as client: + resp = await client.get(_api_url("/models")) + resp.raise_for_status() + data = resp.json() + return [m["id"] for m in data.get("data", [])] + except Exception as e: + logger.warning(f"Could not list Atomic Chat models: {e}") + return [] + + +async def atomic_chat( + model: str, + messages: list[dict], + system: str | None = None, + max_tokens: int = 4096, + temperature: float = 1.0, +) -> dict: + chat_messages = list(messages) + if system: + chat_messages.insert(0, {"role": "system", "content": system}) + + payload = { + "model": model, + "messages": chat_messages, + "max_tokens": max_tokens, + "temperature": temperature, + "stream": False, + } + + async with httpx.AsyncClient(timeout=120.0) as client: + resp = await client.post(_api_url("/chat/completions"), json=payload) + resp.raise_for_status() + data = resp.json() + + choice = data.get("choices", [{}])[0] + assistant_text = choice.get("message", {}).get("content", "") + usage = data.get("usage", {}) + + return { + "id": data.get("id", "msg_atomic_chat"), + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": assistant_text}], + "model": model, + "stop_reason": "end_turn", + "stop_sequence": None, + "usage": { + "input_tokens": usage.get("prompt_tokens", 0), + "output_tokens": usage.get("completion_tokens", 0), + }, + } + + +async def atomic_chat_stream( + model: str, + messages: list[dict], + system: str | None = None, + max_tokens: int = 4096, + temperature: float = 1.0, +) -> AsyncIterator[str]: + chat_messages = list(messages) + if system: + chat_messages.insert(0, {"role": "system", "content": system}) + + payload = { + "model": model, + "messages": chat_messages, + "max_tokens": max_tokens, + "temperature": temperature, + "stream": True, + } + + yield "event: message_start\n" + yield f'data: {json.dumps({"type": "message_start", "message": {"id": "msg_atomic_chat_stream", "type": "message", "role": "assistant", "content": [], "model": model, "stop_reason": None, "usage": {"input_tokens": 0, "output_tokens": 0}}})}\n\n' + yield "event: content_block_start\n" + yield f'data: {json.dumps({"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}})}\n\n' + + async with httpx.AsyncClient(timeout=120.0) as client: + async with client.stream("POST", _api_url("/chat/completions"), json=payload) as resp: + resp.raise_for_status() + async for line in resp.aiter_lines(): + if not line or not line.startswith("data: "): + continue + raw = line[len("data: "):] + if raw.strip() == "[DONE]": + break + try: + chunk = json.loads(raw) + delta = chunk.get("choices", [{}])[0].get("delta", {}) + delta_text = delta.get("content", "") + if delta_text: + yield "event: content_block_delta\n" + yield f'data: {json.dumps({"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": delta_text}})}\n\n' + + finish_reason = chunk.get("choices", [{}])[0].get("finish_reason") + if finish_reason: + usage = chunk.get("usage", {}) + yield "event: content_block_stop\n" + yield f'data: {json.dumps({"type": "content_block_stop", "index": 0})}\n\n' + yield "event: message_delta\n" + yield f'data: {json.dumps({"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence": None}, "usage": {"output_tokens": usage.get("completion_tokens", 0)}})}\n\n' + yield "event: message_stop\n" + yield f'data: {json.dumps({"type": "message_stop"})}\n\n' + break + except json.JSONDecodeError: + continue diff --git a/package.json b/package.json index 47052352..03abde72 100644 --- a/package.json +++ b/package.json @@ -21,6 +21,7 @@ "dev:gemini": "bun run scripts/provider-launch.ts gemini", "dev:ollama": "bun run scripts/provider-launch.ts ollama", "dev:ollama:fast": "bun run scripts/provider-launch.ts ollama --fast --bare", + "dev:atomic-chat": "bun run scripts/provider-launch.ts atomic-chat", "profile:init": "bun run scripts/provider-bootstrap.ts", "profile:recommend": "bun run scripts/provider-recommend.ts", "profile:auto": "bun run scripts/provider-recommend.ts --apply", diff --git a/scripts/provider-bootstrap.ts b/scripts/provider-bootstrap.ts index 82ebbbb6..f39e3e50 100644 --- a/scripts/provider-bootstrap.ts +++ b/scripts/provider-bootstrap.ts @@ -10,6 +10,7 @@ import { recommendOllamaModel, } from '../src/utils/providerRecommendation.ts' import { + buildAtomicChatProfileEnv, buildCodexProfileEnv, buildGeminiProfileEnv, buildOllamaProfileEnv, @@ -20,8 +21,11 @@ import { type ProviderProfile, } from '../src/utils/providerProfile.ts' import { + getAtomicChatChatBaseUrl, getOllamaChatBaseUrl, + hasLocalAtomicChat, hasLocalOllama, + listAtomicChatModels, listOllamaModels, } from './provider-discovery.ts' @@ -34,7 +38,7 @@ function parseArg(name: string): string | null { function parseProviderArg(): ProviderProfile | 'auto' { const p = parseArg('--provider')?.toLowerCase() - if (p === 'openai' || p === 'ollama' || p === 'codex' || p === 'gemini') return p + if (p === 'openai' || p === 'ollama' || p === 'codex' || p === 'gemini' || p === 'atomic-chat') return p return 'auto' } @@ -102,6 +106,21 @@ async function main(): Promise { getOllamaChatBaseUrl, }, ) + } else if (selected === 'atomic-chat') { + const model = argModel || (await listAtomicChatModels(argBaseUrl || undefined))[0] + if (!model) { + if (!(await hasLocalAtomicChat(argBaseUrl || undefined))) { + console.error('Atomic Chat is not running (could not connect to 127.0.0.1:1337).\n Download from https://atomic.chat/ and launch the application.') + } else { + console.error('Atomic Chat is running but no model is loaded. Open Atomic Chat and download or start a model first.') + } + process.exit(1) + } + + env = buildAtomicChatProfileEnv(model, { + baseUrl: argBaseUrl, + getAtomicChatChatBaseUrl, + }) } else if (selected === 'codex') { const builtEnv = buildCodexProfileEnv({ model: argModel, diff --git a/scripts/provider-discovery.ts b/scripts/provider-discovery.ts index 9e3aacda..9c463f2f 100644 --- a/scripts/provider-discovery.ts +++ b/scripts/provider-discovery.ts @@ -1,6 +1,7 @@ import type { OllamaModelDescriptor } from '../src/utils/providerRecommendation.ts' export const DEFAULT_OLLAMA_BASE_URL = 'http://localhost:11434' +export const DEFAULT_ATOMIC_CHAT_BASE_URL = 'http://127.0.0.1:1337' function withTimeoutSignal(timeoutMs: number): { signal: AbortSignal @@ -93,6 +94,61 @@ export async function listOllamaModels( } } +// ── Atomic Chat discovery (Apple Silicon local LLMs at 127.0.0.1:1337) ────── + +export function getAtomicChatApiBaseUrl(baseUrl?: string): string { + const raw = baseUrl || process.env.ATOMIC_CHAT_BASE_URL || DEFAULT_ATOMIC_CHAT_BASE_URL + return trimTrailingSlash(raw) +} + +export function getAtomicChatChatBaseUrl(baseUrl?: string): string { + return `${getAtomicChatApiBaseUrl(baseUrl)}/v1` +} + +export async function hasLocalAtomicChat(baseUrl?: string): Promise { + const { signal, clear } = withTimeoutSignal(1200) + try { + const response = await fetch(`${getAtomicChatChatBaseUrl(baseUrl)}/models`, { + method: 'GET', + signal, + }) + return response.ok + } catch { + return false + } finally { + clear() + } +} + +export async function listAtomicChatModels( + baseUrl?: string, +): Promise { + const { signal, clear } = withTimeoutSignal(5000) + try { + const response = await fetch(`${getAtomicChatChatBaseUrl(baseUrl)}/models`, { + method: 'GET', + signal, + }) + if (!response.ok) { + return [] + } + + const data = await response.json() as { + data?: Array<{ id?: string }> + } + + return (data.data ?? []) + .filter(model => Boolean(model.id)) + .map(model => model.id!) + } catch { + return [] + } finally { + clear() + } +} + +// ── Ollama benchmarking ───────────────────────────────────────────────────── + export async function benchmarkOllamaModel( modelName: string, baseUrl?: string, diff --git a/scripts/provider-launch.ts b/scripts/provider-launch.ts index 2859e9e8..17f11fb8 100644 --- a/scripts/provider-launch.ts +++ b/scripts/provider-launch.ts @@ -16,8 +16,11 @@ import { type ProviderProfile, } from '../src/utils/providerProfile.ts' import { + getAtomicChatChatBaseUrl, getOllamaChatBaseUrl, + hasLocalAtomicChat, hasLocalOllama, + listAtomicChatModels, listOllamaModels, } from './provider-discovery.ts' @@ -48,7 +51,7 @@ function parseLaunchOptions(argv: string[]): LaunchOptions { continue } - if ((lower === 'auto' || lower === 'openai' || lower === 'ollama' || lower === 'codex' || lower === 'gemini') && requestedProfile === 'auto') { + if ((lower === 'auto' || lower === 'openai' || lower === 'ollama' || lower === 'codex' || lower === 'gemini' || lower === 'atomic-chat') && requestedProfile === 'auto') { requestedProfile = lower as ProviderProfile | 'auto' continue } @@ -79,7 +82,7 @@ function loadPersistedProfile(): ProfileFile | null { if (!existsSync(path)) return null try { const parsed = JSON.parse(readFileSync(path, 'utf8')) as ProfileFile - if (parsed.profile === 'openai' || parsed.profile === 'ollama' || parsed.profile === 'codex' || parsed.profile === 'gemini') { + if (parsed.profile === 'openai' || parsed.profile === 'ollama' || parsed.profile === 'codex' || parsed.profile === 'gemini' || parsed.profile === 'atomic-chat') { return parsed } return null @@ -96,6 +99,11 @@ async function resolveOllamaDefaultModel( return recommended?.name ?? null } +async function resolveAtomicChatDefaultModel(): Promise { + const models = await listAtomicChatModels() + return models[0] ?? null +} + function runCommand(command: string, env: NodeJS.ProcessEnv): Promise { return runProcess(command, [], env) } @@ -132,6 +140,10 @@ function printSummary(profile: ProviderProfile, env: NodeJS.ProcessEnv): void { console.log(`OPENAI_BASE_URL=${env.OPENAI_BASE_URL}`) console.log(`OPENAI_MODEL=${env.OPENAI_MODEL}`) console.log(`CODEX_API_KEY_SET=${Boolean(resolveCodexApiCredentials(env).apiKey)}`) + } else if (profile === 'atomic-chat') { + console.log(`OPENAI_BASE_URL=${env.OPENAI_BASE_URL}`) + console.log(`OPENAI_MODEL=${env.OPENAI_MODEL}`) + console.log('OPENAI_API_KEY_SET=false (local provider, no key required)') } else { console.log(`OPENAI_BASE_URL=${env.OPENAI_BASE_URL}`) console.log(`OPENAI_MODEL=${env.OPENAI_MODEL}`) @@ -143,7 +155,7 @@ async function main(): Promise { const options = parseLaunchOptions(process.argv.slice(2)) const requestedProfile = options.requestedProfile if (!requestedProfile) { - console.error('Usage: bun run scripts/provider-launch.ts [openai|ollama|codex|gemini|auto] [--fast] [--goal ] [-- ]') + console.error('Usage: bun run scripts/provider-launch.ts [openai|ollama|codex|gemini|atomic-chat|auto] [--fast] [--goal ] [-- ]') process.exit(1) } @@ -175,12 +187,30 @@ async function main(): Promise { } } + let resolvedAtomicChatModel: string | null = null + if ( + profile === 'atomic-chat' && + (persisted?.profile !== 'atomic-chat' || !persisted?.env?.OPENAI_MODEL) + ) { + if (!(await hasLocalAtomicChat())) { + console.error('Atomic Chat is not running (could not connect to 127.0.0.1:1337).\n Download from https://atomic.chat/ and launch the application.') + process.exit(1) + } + resolvedAtomicChatModel = await resolveAtomicChatDefaultModel() + if (!resolvedAtomicChatModel) { + console.error('Atomic Chat is running but no model is loaded. Open Atomic Chat and download or start a model first.') + process.exit(1) + } + } + const env = await buildLaunchEnv({ profile, persisted, goal: options.goal, getOllamaChatBaseUrl, resolveOllamaDefaultModel: async () => resolvedOllamaModel || 'llama3.1:8b', + getAtomicChatChatBaseUrl, + resolveAtomicChatDefaultModel: async () => resolvedAtomicChatModel, }) if (options.fast) { applyFastFlags(env) diff --git a/smart_router.py b/smart_router.py index 0a54a791..14b90c03 100644 --- a/smart_router.py +++ b/smart_router.py @@ -57,8 +57,8 @@ class Provider: @property def is_configured(self) -> bool: """True if the provider has an API key set.""" - if self.name == "ollama": - return True # Ollama needs no API key + if self.name in ("ollama", "atomic-chat"): + return True # Local providers need no API key return bool(self.api_key) @property @@ -93,6 +93,7 @@ def build_default_providers() -> list[Provider]: big = os.getenv("BIG_MODEL", "gpt-4.1") small = os.getenv("SMALL_MODEL", "gpt-4.1-mini") ollama_url = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434") + atomic_chat_url = os.getenv("ATOMIC_CHAT_BASE_URL", "http://127.0.0.1:1337") return [ Provider( @@ -119,6 +120,14 @@ def build_default_providers() -> list[Provider]: big_model=big if "gemini" not in big and "gpt" not in big else "llama3:8b", small_model=small if "gemini" not in small and "gpt" not in small else "llama3:8b", ), + Provider( + name="atomic-chat", + ping_url=f"{atomic_chat_url}/v1/models", + api_key_env="", + cost_per_1k_tokens=0.0, # free — local (Apple Silicon) + big_model=big if "gemini" not in big and "gpt" not in big else "llama3:8b", + small_model=small if "gemini" not in small and "gpt" not in small else "llama3:8b", + ), ] diff --git a/src/utils/providerProfile.test.ts b/src/utils/providerProfile.test.ts index e90746c6..b953e1b6 100644 --- a/src/utils/providerProfile.test.ts +++ b/src/utils/providerProfile.test.ts @@ -5,6 +5,7 @@ import { join } from 'node:path' import test from 'node:test' import { + buildAtomicChatProfileEnv, buildCodexProfileEnv, buildGeminiProfileEnv, buildLaunchEnv, @@ -381,3 +382,72 @@ test('auto profile falls back to openai when no viable ollama model exists', () assert.equal(selectAutoProfile(null), 'openai') assert.equal(selectAutoProfile('qwen2.5-coder:7b'), 'ollama') }) + +// ── Atomic Chat profile tests ──────────────────────────────────────────────── + +test('atomic-chat profiles never persist openai api keys', () => { + const env = buildAtomicChatProfileEnv('some-local-model', { + getAtomicChatChatBaseUrl: () => 'http://127.0.0.1:1337/v1', + }) + + assert.deepEqual(env, { + OPENAI_BASE_URL: 'http://127.0.0.1:1337/v1', + OPENAI_MODEL: 'some-local-model', + }) + assert.equal('OPENAI_API_KEY' in env, false) +}) + +test('atomic-chat profiles respect custom base url', () => { + const env = buildAtomicChatProfileEnv('my-model', { + baseUrl: 'http://192.168.1.100:1337', + getAtomicChatChatBaseUrl: (baseUrl?: string) => + baseUrl ? `${baseUrl}/v1` : 'http://127.0.0.1:1337/v1', + }) + + assert.equal(env.OPENAI_BASE_URL, 'http://192.168.1.100:1337/v1') + assert.equal(env.OPENAI_MODEL, 'my-model') +}) + +test('matching persisted atomic-chat env is reused for atomic-chat launch', async () => { + const env = await buildLaunchEnv({ + profile: 'atomic-chat', + persisted: profile('atomic-chat', { + OPENAI_BASE_URL: 'http://127.0.0.1:1337/v1', + OPENAI_MODEL: 'llama-3.1-8b', + }), + goal: 'balanced', + processEnv: {}, + getAtomicChatChatBaseUrl: () => 'http://127.0.0.1:1337/v1', + resolveAtomicChatDefaultModel: async () => 'other-model', + }) + + assert.equal(env.OPENAI_BASE_URL, 'http://127.0.0.1:1337/v1') + assert.equal(env.OPENAI_MODEL, 'llama-3.1-8b') + assert.equal(env.OPENAI_API_KEY, undefined) + assert.equal(env.CODEX_API_KEY, undefined) +}) + +test('atomic-chat launch ignores mismatched persisted openai env', async () => { + const env = await buildLaunchEnv({ + profile: 'atomic-chat', + persisted: profile('openai', { + OPENAI_BASE_URL: 'https://api.openai.com/v1', + OPENAI_MODEL: 'gpt-4o', + OPENAI_API_KEY: 'sk-persisted', + }), + goal: 'balanced', + processEnv: { + OPENAI_API_KEY: 'sk-live', + CODEX_API_KEY: 'codex-live', + CHATGPT_ACCOUNT_ID: 'acct_live', + }, + getAtomicChatChatBaseUrl: () => 'http://127.0.0.1:1337/v1', + resolveAtomicChatDefaultModel: async () => 'local-model', + }) + + assert.equal(env.OPENAI_BASE_URL, 'http://127.0.0.1:1337/v1') + assert.equal(env.OPENAI_MODEL, 'local-model') + assert.equal(env.OPENAI_API_KEY, undefined) + assert.equal(env.CODEX_API_KEY, undefined) + assert.equal(env.CHATGPT_ACCOUNT_ID, undefined) +}) diff --git a/src/utils/providerProfile.ts b/src/utils/providerProfile.ts index 866c19c5..d85af0c6 100644 --- a/src/utils/providerProfile.ts +++ b/src/utils/providerProfile.ts @@ -13,7 +13,7 @@ import { const DEFAULT_GEMINI_BASE_URL = 'https://generativelanguage.googleapis.com/v1beta/openai' const DEFAULT_GEMINI_MODEL = 'gemini-2.0-flash' -export type ProviderProfile = 'openai' | 'ollama' | 'codex' | 'gemini' +export type ProviderProfile = 'openai' | 'ollama' | 'codex' | 'gemini' | 'atomic-chat' export type ProfileEnv = { OPENAI_BASE_URL?: string @@ -53,6 +53,19 @@ export function buildOllamaProfileEnv( } } +export function buildAtomicChatProfileEnv( + model: string, + options: { + baseUrl?: string | null + getAtomicChatChatBaseUrl: (baseUrl?: string) => string + }, +): ProfileEnv { + return { + OPENAI_BASE_URL: options.getAtomicChatChatBaseUrl(options.baseUrl ?? undefined), + OPENAI_MODEL: model, + } +} + export function buildGeminiProfileEnv(options: { model?: string | null baseUrl?: string | null @@ -171,6 +184,8 @@ export async function buildLaunchEnv(options: { processEnv?: NodeJS.ProcessEnv getOllamaChatBaseUrl?: (baseUrl?: string) => string resolveOllamaDefaultModel?: (goal: RecommendationGoal) => Promise + getAtomicChatChatBaseUrl?: (baseUrl?: string) => string + resolveAtomicChatDefaultModel?: () => Promise }): Promise { const processEnv = options.processEnv ?? process.env const persistedEnv = @@ -248,6 +263,26 @@ export async function buildLaunchEnv(options: { return env } + if (options.profile === 'atomic-chat') { + const getAtomicChatBaseUrl = + options.getAtomicChatChatBaseUrl ?? (() => 'http://127.0.0.1:1337/v1') + const resolveModel = + options.resolveAtomicChatDefaultModel ?? (async () => null as string | null) + + env.OPENAI_BASE_URL = persistedEnv.OPENAI_BASE_URL || getAtomicChatBaseUrl() + env.OPENAI_MODEL = + persistedEnv.OPENAI_MODEL || + (await resolveModel()) || + '' + + delete env.OPENAI_API_KEY + delete env.CODEX_API_KEY + delete env.CHATGPT_ACCOUNT_ID + delete env.CODEX_ACCOUNT_ID + + return env + } + if (options.profile === 'codex') { env.OPENAI_BASE_URL = persistedEnv.OPENAI_BASE_URL && isCodexBaseUrl(persistedEnv.OPENAI_BASE_URL) diff --git a/test_atomic_chat_provider.py b/test_atomic_chat_provider.py new file mode 100644 index 00000000..819c610c --- /dev/null +++ b/test_atomic_chat_provider.py @@ -0,0 +1,130 @@ +""" +test_atomic_chat_provider.py +Run: pytest test_atomic_chat_provider.py -v +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from atomic_chat_provider import ( + atomic_chat, + list_atomic_chat_models, + check_atomic_chat_running, +) + + +@pytest.mark.asyncio +async def test_atomic_chat_running_true(): + mock_response = MagicMock() + mock_response.status_code = 200 + with patch("atomic_chat_provider.httpx.AsyncClient") as MockClient: + MockClient.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response) + result = await check_atomic_chat_running() + assert result is True + + +@pytest.mark.asyncio +async def test_atomic_chat_running_false_on_exception(): + with patch("atomic_chat_provider.httpx.AsyncClient") as MockClient: + MockClient.return_value.__aenter__.return_value.get = AsyncMock(side_effect=Exception("refused")) + result = await check_atomic_chat_running() + assert result is False + + +@pytest.mark.asyncio +async def test_list_models_returns_ids(): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [{"id": "llama-3.1-8b"}, {"id": "mistral-7b"}], + } + mock_response.raise_for_status = MagicMock() + with patch("atomic_chat_provider.httpx.AsyncClient") as MockClient: + MockClient.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response) + models = await list_atomic_chat_models() + assert "llama-3.1-8b" in models + assert "mistral-7b" in models + + +@pytest.mark.asyncio +async def test_list_models_empty_on_failure(): + with patch("atomic_chat_provider.httpx.AsyncClient") as MockClient: + MockClient.return_value.__aenter__.return_value.get = AsyncMock(side_effect=Exception("down")) + models = await list_atomic_chat_models() + assert models == [] + + +@pytest.mark.asyncio +async def test_atomic_chat_returns_anthropic_format(): + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + mock_response.json.return_value = { + "id": "chatcmpl-abc123", + "choices": [{"message": {"content": "42 is the answer."}}], + "usage": {"prompt_tokens": 10, "completion_tokens": 8}, + } + with patch("atomic_chat_provider.httpx.AsyncClient") as MockClient: + MockClient.return_value.__aenter__.return_value.post = AsyncMock(return_value=mock_response) + result = await atomic_chat( + model="llama-3.1-8b", + messages=[{"role": "user", "content": "What is 6*7?"}], + ) + assert result["type"] == "message" + assert result["role"] == "assistant" + assert "42" in result["content"][0]["text"] + assert result["usage"]["input_tokens"] == 10 + assert result["usage"]["output_tokens"] == 8 + + +@pytest.mark.asyncio +async def test_atomic_chat_prepends_system(): + captured = {} + + async def mock_post(url, json=None, **kwargs): + captured.update(json or {}) + m = MagicMock() + m.raise_for_status = MagicMock() + m.json.return_value = { + "id": "chatcmpl-xyz", + "choices": [{"message": {"content": "ok"}}], + "usage": {"prompt_tokens": 1, "completion_tokens": 1}, + } + return m + + with patch("atomic_chat_provider.httpx.AsyncClient") as MockClient: + MockClient.return_value.__aenter__.return_value.post = mock_post + await atomic_chat( + model="llama-3.1-8b", + messages=[{"role": "user", "content": "Hi"}], + system="Be helpful.", + ) + assert captured["messages"][0]["role"] == "system" + assert "helpful" in captured["messages"][0]["content"] + + +@pytest.mark.asyncio +async def test_atomic_chat_sends_correct_payload(): + captured = {} + + async def mock_post(url, json=None, **kwargs): + captured.update(json or {}) + m = MagicMock() + m.raise_for_status = MagicMock() + m.json.return_value = { + "id": "chatcmpl-xyz", + "choices": [{"message": {"content": "ok"}}], + "usage": {"prompt_tokens": 1, "completion_tokens": 1}, + } + return m + + with patch("atomic_chat_provider.httpx.AsyncClient") as MockClient: + MockClient.return_value.__aenter__.return_value.post = mock_post + await atomic_chat( + model="test-model", + messages=[{"role": "user", "content": "Test"}], + max_tokens=2048, + temperature=0.5, + ) + assert captured["model"] == "test-model" + assert captured["max_tokens"] == 2048 + assert captured["temperature"] == 0.5 + assert captured["stream"] is False