Merge pull request #74 from Vect0rM/feature/atomic-chat-integration
feat: add support for Atomic Chat provider
This commit is contained in:
30
README.md
30
README.md
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
Use Claude Code with **any LLM** — not just Claude.
|
Use Claude Code with **any LLM** — not just Claude.
|
||||||
|
|
||||||
OpenClaude is a fork of the [Claude Code source leak](https://gitlawb.com/node/repos/z6MkgKkb/instructkr-claude-code) (exposed via npm source maps on March 31, 2026). We added an OpenAI-compatible provider shim so you can plug in GPT-4o, DeepSeek, Gemini, Llama, Mistral, or any model that speaks the OpenAI chat completions API. It now also supports the ChatGPT Codex backend for `codexplan` and `codexspark`.
|
OpenClaude is a fork of the [Claude Code source leak](https://gitlawb.com/node/repos/z6MkgKkb/instructkr-claude-code) (exposed via npm source maps on March 31, 2026). We added an OpenAI-compatible provider shim so you can plug in GPT-4o, DeepSeek, Gemini, Llama, Mistral, or any model that speaks the OpenAI chat completions API. It now also supports the ChatGPT Codex backend for `codexplan` and `codexspark`, and local inference via [Atomic Chat](https://atomic.chat/) on Apple Silicon.
|
||||||
|
|
||||||
All of Claude Code's tools work — bash, file read/write/edit, grep, glob, agents, tasks, MCP — just powered by whatever model you choose.
|
All of Claude Code's tools work — bash, file read/write/edit, grep, glob, agents, tasks, MCP — just powered by whatever model you choose.
|
||||||
|
|
||||||
@@ -140,6 +140,23 @@ export OPENAI_MODEL=llama3.3:70b
|
|||||||
# no API key needed for local models
|
# no API key needed for local models
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Atomic Chat (local, Apple Silicon)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export CLAUDE_CODE_USE_OPENAI=1
|
||||||
|
export OPENAI_BASE_URL=http://127.0.0.1:1337/v1
|
||||||
|
export OPENAI_MODEL=your-model-name
|
||||||
|
# no API key needed for local models
|
||||||
|
```
|
||||||
|
|
||||||
|
Or use the profile launcher:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bun run dev:atomic-chat
|
||||||
|
```
|
||||||
|
|
||||||
|
Download Atomic Chat from [atomic.chat](https://atomic.chat/). The app must be running with a model loaded before launching.
|
||||||
|
|
||||||
### LM Studio (local)
|
### LM Studio (local)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -191,7 +208,7 @@ export OPENAI_MODEL=gpt-4o
|
|||||||
| Variable | Required | Description |
|
| Variable | Required | Description |
|
||||||
|----------|----------|-------------|
|
|----------|----------|-------------|
|
||||||
| `CLAUDE_CODE_USE_OPENAI` | Yes | Set to `1` to enable the OpenAI provider |
|
| `CLAUDE_CODE_USE_OPENAI` | Yes | Set to `1` to enable the OpenAI provider |
|
||||||
| `OPENAI_API_KEY` | Yes* | Your API key (*not needed for local models like Ollama) |
|
| `OPENAI_API_KEY` | Yes* | Your API key (*not needed for local models like Ollama/Atomic Chat) |
|
||||||
| `OPENAI_MODEL` | Yes | Model name (e.g. `gpt-4o`, `deepseek-chat`, `llama3.3:70b`) |
|
| `OPENAI_MODEL` | Yes | Model name (e.g. `gpt-4o`, `deepseek-chat`, `llama3.3:70b`) |
|
||||||
| `OPENAI_BASE_URL` | No | API endpoint (defaults to `https://api.openai.com/v1`) |
|
| `OPENAI_BASE_URL` | No | API endpoint (defaults to `https://api.openai.com/v1`) |
|
||||||
| `CODEX_API_KEY` | Codex only | Codex/ChatGPT access token override |
|
| `CODEX_API_KEY` | Codex only | Codex/ChatGPT access token override |
|
||||||
@@ -254,6 +271,9 @@ bun run profile:codex
|
|||||||
# openai bootstrap with explicit key
|
# openai bootstrap with explicit key
|
||||||
bun run profile:init -- --provider openai --api-key sk-...
|
bun run profile:init -- --provider openai --api-key sk-...
|
||||||
|
|
||||||
|
# atomic-chat bootstrap (auto-detects running model)
|
||||||
|
bun run profile:init -- --provider atomic-chat
|
||||||
|
|
||||||
# ollama bootstrap with custom model
|
# ollama bootstrap with custom model
|
||||||
bun run profile:init -- --provider ollama --model llama3.1:8b
|
bun run profile:init -- --provider ollama --model llama3.1:8b
|
||||||
|
|
||||||
@@ -274,6 +294,9 @@ bun run dev:openai
|
|||||||
|
|
||||||
# Ollama profile (defaults: localhost:11434, llama3.1:8b)
|
# Ollama profile (defaults: localhost:11434, llama3.1:8b)
|
||||||
bun run dev:ollama
|
bun run dev:ollama
|
||||||
|
|
||||||
|
# Atomic Chat profile (Apple Silicon local LLMs at 127.0.0.1:1337)
|
||||||
|
bun run dev:atomic-chat
|
||||||
```
|
```
|
||||||
|
|
||||||
`profile:recommend` ranks installed Ollama models for `latency`, `balanced`, or `coding`, and `profile:auto` can persist the recommendation directly.
|
`profile:recommend` ranks installed Ollama models for `latency`, `balanced`, or `coding`, and `profile:auto` can persist the recommendation directly.
|
||||||
@@ -284,8 +307,9 @@ Goal-based Ollama selection only recommends among models that are already instal
|
|||||||
|
|
||||||
Use `profile:codex` or `--provider codex` when you want the ChatGPT Codex backend.
|
Use `profile:codex` or `--provider codex` when you want the ChatGPT Codex backend.
|
||||||
|
|
||||||
`dev:openai`, `dev:ollama`, and `dev:codex` run `doctor:runtime` first and only launch the app if checks pass.
|
`dev:openai`, `dev:ollama`, `dev:atomic-chat`, and `dev:codex` run `doctor:runtime` first and only launch the app if checks pass.
|
||||||
For `dev:ollama`, make sure Ollama is running locally before launch.
|
For `dev:ollama`, make sure Ollama is running locally before launch.
|
||||||
|
For `dev:atomic-chat`, make sure Atomic Chat is running with a model loaded.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
146
atomic_chat_provider.py
Normal file
146
atomic_chat_provider.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
"""
|
||||||
|
atomic_chat_provider.py
|
||||||
|
-----------------------
|
||||||
|
Adds native Atomic Chat support to openclaude.
|
||||||
|
Lets Claude Code route requests to any locally-running model via
|
||||||
|
Atomic Chat (Apple Silicon only) at 127.0.0.1:1337.
|
||||||
|
|
||||||
|
Atomic Chat exposes an OpenAI-compatible API, so messages are forwarded
|
||||||
|
directly without translation.
|
||||||
|
|
||||||
|
Usage (.env):
|
||||||
|
PREFERRED_PROVIDER=atomic-chat
|
||||||
|
ATOMIC_CHAT_BASE_URL=http://127.0.0.1:1337
|
||||||
|
"""
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import AsyncIterator
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
ATOMIC_CHAT_BASE_URL = os.getenv("ATOMIC_CHAT_BASE_URL", "http://127.0.0.1:1337")
|
||||||
|
|
||||||
|
|
||||||
|
def _api_url(path: str) -> str:
|
||||||
|
return f"{ATOMIC_CHAT_BASE_URL}/v1{path}"
|
||||||
|
|
||||||
|
|
||||||
|
async def check_atomic_chat_running() -> bool:
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=3.0) as client:
|
||||||
|
resp = await client.get(_api_url("/models"))
|
||||||
|
return resp.status_code == 200
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def list_atomic_chat_models() -> list[str]:
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||||
|
resp = await client.get(_api_url("/models"))
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
return [m["id"] for m in data.get("data", [])]
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not list Atomic Chat models: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
async def atomic_chat(
|
||||||
|
model: str,
|
||||||
|
messages: list[dict],
|
||||||
|
system: str | None = None,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
) -> dict:
|
||||||
|
chat_messages = list(messages)
|
||||||
|
if system:
|
||||||
|
chat_messages.insert(0, {"role": "system", "content": system})
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"messages": chat_messages,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"temperature": temperature,
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||||
|
resp = await client.post(_api_url("/chat/completions"), json=payload)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
|
||||||
|
choice = data.get("choices", [{}])[0]
|
||||||
|
assistant_text = choice.get("message", {}).get("content", "")
|
||||||
|
usage = data.get("usage", {})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": data.get("id", "msg_atomic_chat"),
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{"type": "text", "text": assistant_text}],
|
||||||
|
"model": model,
|
||||||
|
"stop_reason": "end_turn",
|
||||||
|
"stop_sequence": None,
|
||||||
|
"usage": {
|
||||||
|
"input_tokens": usage.get("prompt_tokens", 0),
|
||||||
|
"output_tokens": usage.get("completion_tokens", 0),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def atomic_chat_stream(
|
||||||
|
model: str,
|
||||||
|
messages: list[dict],
|
||||||
|
system: str | None = None,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
) -> AsyncIterator[str]:
|
||||||
|
chat_messages = list(messages)
|
||||||
|
if system:
|
||||||
|
chat_messages.insert(0, {"role": "system", "content": system})
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"messages": chat_messages,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"temperature": temperature,
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
yield "event: message_start\n"
|
||||||
|
yield f'data: {json.dumps({"type": "message_start", "message": {"id": "msg_atomic_chat_stream", "type": "message", "role": "assistant", "content": [], "model": model, "stop_reason": None, "usage": {"input_tokens": 0, "output_tokens": 0}}})}\n\n'
|
||||||
|
yield "event: content_block_start\n"
|
||||||
|
yield f'data: {json.dumps({"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}})}\n\n'
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||||
|
async with client.stream("POST", _api_url("/chat/completions"), json=payload) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
async for line in resp.aiter_lines():
|
||||||
|
if not line or not line.startswith("data: "):
|
||||||
|
continue
|
||||||
|
raw = line[len("data: "):]
|
||||||
|
if raw.strip() == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
chunk = json.loads(raw)
|
||||||
|
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
||||||
|
delta_text = delta.get("content", "")
|
||||||
|
if delta_text:
|
||||||
|
yield "event: content_block_delta\n"
|
||||||
|
yield f'data: {json.dumps({"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": delta_text}})}\n\n'
|
||||||
|
|
||||||
|
finish_reason = chunk.get("choices", [{}])[0].get("finish_reason")
|
||||||
|
if finish_reason:
|
||||||
|
usage = chunk.get("usage", {})
|
||||||
|
yield "event: content_block_stop\n"
|
||||||
|
yield f'data: {json.dumps({"type": "content_block_stop", "index": 0})}\n\n'
|
||||||
|
yield "event: message_delta\n"
|
||||||
|
yield f'data: {json.dumps({"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence": None}, "usage": {"output_tokens": usage.get("completion_tokens", 0)}})}\n\n'
|
||||||
|
yield "event: message_stop\n"
|
||||||
|
yield f'data: {json.dumps({"type": "message_stop"})}\n\n'
|
||||||
|
break
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
@@ -21,6 +21,7 @@
|
|||||||
"dev:gemini": "bun run scripts/provider-launch.ts gemini",
|
"dev:gemini": "bun run scripts/provider-launch.ts gemini",
|
||||||
"dev:ollama": "bun run scripts/provider-launch.ts ollama",
|
"dev:ollama": "bun run scripts/provider-launch.ts ollama",
|
||||||
"dev:ollama:fast": "bun run scripts/provider-launch.ts ollama --fast --bare",
|
"dev:ollama:fast": "bun run scripts/provider-launch.ts ollama --fast --bare",
|
||||||
|
"dev:atomic-chat": "bun run scripts/provider-launch.ts atomic-chat",
|
||||||
"profile:init": "bun run scripts/provider-bootstrap.ts",
|
"profile:init": "bun run scripts/provider-bootstrap.ts",
|
||||||
"profile:recommend": "bun run scripts/provider-recommend.ts",
|
"profile:recommend": "bun run scripts/provider-recommend.ts",
|
||||||
"profile:auto": "bun run scripts/provider-recommend.ts --apply",
|
"profile:auto": "bun run scripts/provider-recommend.ts --apply",
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import {
|
|||||||
recommendOllamaModel,
|
recommendOllamaModel,
|
||||||
} from '../src/utils/providerRecommendation.ts'
|
} from '../src/utils/providerRecommendation.ts'
|
||||||
import {
|
import {
|
||||||
|
buildAtomicChatProfileEnv,
|
||||||
buildCodexProfileEnv,
|
buildCodexProfileEnv,
|
||||||
buildGeminiProfileEnv,
|
buildGeminiProfileEnv,
|
||||||
buildOllamaProfileEnv,
|
buildOllamaProfileEnv,
|
||||||
@@ -20,8 +21,11 @@ import {
|
|||||||
type ProviderProfile,
|
type ProviderProfile,
|
||||||
} from '../src/utils/providerProfile.ts'
|
} from '../src/utils/providerProfile.ts'
|
||||||
import {
|
import {
|
||||||
|
getAtomicChatChatBaseUrl,
|
||||||
getOllamaChatBaseUrl,
|
getOllamaChatBaseUrl,
|
||||||
|
hasLocalAtomicChat,
|
||||||
hasLocalOllama,
|
hasLocalOllama,
|
||||||
|
listAtomicChatModels,
|
||||||
listOllamaModels,
|
listOllamaModels,
|
||||||
} from './provider-discovery.ts'
|
} from './provider-discovery.ts'
|
||||||
|
|
||||||
@@ -34,7 +38,7 @@ function parseArg(name: string): string | null {
|
|||||||
|
|
||||||
function parseProviderArg(): ProviderProfile | 'auto' {
|
function parseProviderArg(): ProviderProfile | 'auto' {
|
||||||
const p = parseArg('--provider')?.toLowerCase()
|
const p = parseArg('--provider')?.toLowerCase()
|
||||||
if (p === 'openai' || p === 'ollama' || p === 'codex' || p === 'gemini') return p
|
if (p === 'openai' || p === 'ollama' || p === 'codex' || p === 'gemini' || p === 'atomic-chat') return p
|
||||||
return 'auto'
|
return 'auto'
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -102,6 +106,21 @@ async function main(): Promise<void> {
|
|||||||
getOllamaChatBaseUrl,
|
getOllamaChatBaseUrl,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
} else if (selected === 'atomic-chat') {
|
||||||
|
const model = argModel || (await listAtomicChatModels(argBaseUrl || undefined))[0]
|
||||||
|
if (!model) {
|
||||||
|
if (!(await hasLocalAtomicChat(argBaseUrl || undefined))) {
|
||||||
|
console.error('Atomic Chat is not running (could not connect to 127.0.0.1:1337).\n Download from https://atomic.chat/ and launch the application.')
|
||||||
|
} else {
|
||||||
|
console.error('Atomic Chat is running but no model is loaded. Open Atomic Chat and download or start a model first.')
|
||||||
|
}
|
||||||
|
process.exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
env = buildAtomicChatProfileEnv(model, {
|
||||||
|
baseUrl: argBaseUrl,
|
||||||
|
getAtomicChatChatBaseUrl,
|
||||||
|
})
|
||||||
} else if (selected === 'codex') {
|
} else if (selected === 'codex') {
|
||||||
const builtEnv = buildCodexProfileEnv({
|
const builtEnv = buildCodexProfileEnv({
|
||||||
model: argModel,
|
model: argModel,
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import type { OllamaModelDescriptor } from '../src/utils/providerRecommendation.ts'
|
import type { OllamaModelDescriptor } from '../src/utils/providerRecommendation.ts'
|
||||||
|
|
||||||
export const DEFAULT_OLLAMA_BASE_URL = 'http://localhost:11434'
|
export const DEFAULT_OLLAMA_BASE_URL = 'http://localhost:11434'
|
||||||
|
export const DEFAULT_ATOMIC_CHAT_BASE_URL = 'http://127.0.0.1:1337'
|
||||||
|
|
||||||
function withTimeoutSignal(timeoutMs: number): {
|
function withTimeoutSignal(timeoutMs: number): {
|
||||||
signal: AbortSignal
|
signal: AbortSignal
|
||||||
@@ -93,6 +94,69 @@ export async function listOllamaModels(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Atomic Chat discovery (Apple Silicon local LLMs at 127.0.0.1:1337) ──────
|
||||||
|
|
||||||
|
export function getAtomicChatApiBaseUrl(baseUrl?: string): string {
|
||||||
|
const parsed = new URL(
|
||||||
|
baseUrl || process.env.ATOMIC_CHAT_BASE_URL || DEFAULT_ATOMIC_CHAT_BASE_URL,
|
||||||
|
)
|
||||||
|
const pathname = trimTrailingSlash(parsed.pathname)
|
||||||
|
parsed.pathname = pathname.endsWith('/v1')
|
||||||
|
? pathname.slice(0, -3) || '/'
|
||||||
|
: pathname || '/'
|
||||||
|
parsed.search = ''
|
||||||
|
parsed.hash = ''
|
||||||
|
return trimTrailingSlash(parsed.toString())
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getAtomicChatChatBaseUrl(baseUrl?: string): string {
|
||||||
|
return `${getAtomicChatApiBaseUrl(baseUrl)}/v1`
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function hasLocalAtomicChat(baseUrl?: string): Promise<boolean> {
|
||||||
|
const { signal, clear } = withTimeoutSignal(1200)
|
||||||
|
try {
|
||||||
|
const response = await fetch(`${getAtomicChatChatBaseUrl(baseUrl)}/models`, {
|
||||||
|
method: 'GET',
|
||||||
|
signal,
|
||||||
|
})
|
||||||
|
return response.ok
|
||||||
|
} catch {
|
||||||
|
return false
|
||||||
|
} finally {
|
||||||
|
clear()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function listAtomicChatModels(
|
||||||
|
baseUrl?: string,
|
||||||
|
): Promise<string[]> {
|
||||||
|
const { signal, clear } = withTimeoutSignal(5000)
|
||||||
|
try {
|
||||||
|
const response = await fetch(`${getAtomicChatChatBaseUrl(baseUrl)}/models`, {
|
||||||
|
method: 'GET',
|
||||||
|
signal,
|
||||||
|
})
|
||||||
|
if (!response.ok) {
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json() as {
|
||||||
|
data?: Array<{ id?: string }>
|
||||||
|
}
|
||||||
|
|
||||||
|
return (data.data ?? [])
|
||||||
|
.filter(model => Boolean(model.id))
|
||||||
|
.map(model => model.id!)
|
||||||
|
} catch {
|
||||||
|
return []
|
||||||
|
} finally {
|
||||||
|
clear()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Ollama benchmarking ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
export async function benchmarkOllamaModel(
|
export async function benchmarkOllamaModel(
|
||||||
modelName: string,
|
modelName: string,
|
||||||
baseUrl?: string,
|
baseUrl?: string,
|
||||||
|
|||||||
@@ -16,8 +16,11 @@ import {
|
|||||||
type ProviderProfile,
|
type ProviderProfile,
|
||||||
} from '../src/utils/providerProfile.ts'
|
} from '../src/utils/providerProfile.ts'
|
||||||
import {
|
import {
|
||||||
|
getAtomicChatChatBaseUrl,
|
||||||
getOllamaChatBaseUrl,
|
getOllamaChatBaseUrl,
|
||||||
|
hasLocalAtomicChat,
|
||||||
hasLocalOllama,
|
hasLocalOllama,
|
||||||
|
listAtomicChatModels,
|
||||||
listOllamaModels,
|
listOllamaModels,
|
||||||
} from './provider-discovery.ts'
|
} from './provider-discovery.ts'
|
||||||
|
|
||||||
@@ -48,7 +51,7 @@ function parseLaunchOptions(argv: string[]): LaunchOptions {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if ((lower === 'auto' || lower === 'openai' || lower === 'ollama' || lower === 'codex' || lower === 'gemini') && requestedProfile === 'auto') {
|
if ((lower === 'auto' || lower === 'openai' || lower === 'ollama' || lower === 'codex' || lower === 'gemini' || lower === 'atomic-chat') && requestedProfile === 'auto') {
|
||||||
requestedProfile = lower as ProviderProfile | 'auto'
|
requestedProfile = lower as ProviderProfile | 'auto'
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -79,7 +82,7 @@ function loadPersistedProfile(): ProfileFile | null {
|
|||||||
if (!existsSync(path)) return null
|
if (!existsSync(path)) return null
|
||||||
try {
|
try {
|
||||||
const parsed = JSON.parse(readFileSync(path, 'utf8')) as ProfileFile
|
const parsed = JSON.parse(readFileSync(path, 'utf8')) as ProfileFile
|
||||||
if (parsed.profile === 'openai' || parsed.profile === 'ollama' || parsed.profile === 'codex' || parsed.profile === 'gemini') {
|
if (parsed.profile === 'openai' || parsed.profile === 'ollama' || parsed.profile === 'codex' || parsed.profile === 'gemini' || parsed.profile === 'atomic-chat') {
|
||||||
return parsed
|
return parsed
|
||||||
}
|
}
|
||||||
return null
|
return null
|
||||||
@@ -96,6 +99,11 @@ async function resolveOllamaDefaultModel(
|
|||||||
return recommended?.name ?? null
|
return recommended?.name ?? null
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function resolveAtomicChatDefaultModel(): Promise<string | null> {
|
||||||
|
const models = await listAtomicChatModels()
|
||||||
|
return models[0] ?? null
|
||||||
|
}
|
||||||
|
|
||||||
function runCommand(command: string, env: NodeJS.ProcessEnv): Promise<number> {
|
function runCommand(command: string, env: NodeJS.ProcessEnv): Promise<number> {
|
||||||
return runProcess(command, [], env)
|
return runProcess(command, [], env)
|
||||||
}
|
}
|
||||||
@@ -132,6 +140,10 @@ function printSummary(profile: ProviderProfile, env: NodeJS.ProcessEnv): void {
|
|||||||
console.log(`OPENAI_BASE_URL=${env.OPENAI_BASE_URL}`)
|
console.log(`OPENAI_BASE_URL=${env.OPENAI_BASE_URL}`)
|
||||||
console.log(`OPENAI_MODEL=${env.OPENAI_MODEL}`)
|
console.log(`OPENAI_MODEL=${env.OPENAI_MODEL}`)
|
||||||
console.log(`CODEX_API_KEY_SET=${Boolean(resolveCodexApiCredentials(env).apiKey)}`)
|
console.log(`CODEX_API_KEY_SET=${Boolean(resolveCodexApiCredentials(env).apiKey)}`)
|
||||||
|
} else if (profile === 'atomic-chat') {
|
||||||
|
console.log(`OPENAI_BASE_URL=${env.OPENAI_BASE_URL}`)
|
||||||
|
console.log(`OPENAI_MODEL=${env.OPENAI_MODEL}`)
|
||||||
|
console.log('OPENAI_API_KEY_SET=false (local provider, no key required)')
|
||||||
} else {
|
} else {
|
||||||
console.log(`OPENAI_BASE_URL=${env.OPENAI_BASE_URL}`)
|
console.log(`OPENAI_BASE_URL=${env.OPENAI_BASE_URL}`)
|
||||||
console.log(`OPENAI_MODEL=${env.OPENAI_MODEL}`)
|
console.log(`OPENAI_MODEL=${env.OPENAI_MODEL}`)
|
||||||
@@ -143,7 +155,7 @@ async function main(): Promise<void> {
|
|||||||
const options = parseLaunchOptions(process.argv.slice(2))
|
const options = parseLaunchOptions(process.argv.slice(2))
|
||||||
const requestedProfile = options.requestedProfile
|
const requestedProfile = options.requestedProfile
|
||||||
if (!requestedProfile) {
|
if (!requestedProfile) {
|
||||||
console.error('Usage: bun run scripts/provider-launch.ts [openai|ollama|codex|gemini|auto] [--fast] [--goal <latency|balanced|coding>] [-- <cli args>]')
|
console.error('Usage: bun run scripts/provider-launch.ts [openai|ollama|codex|gemini|atomic-chat|auto] [--fast] [--goal <latency|balanced|coding>] [-- <cli args>]')
|
||||||
process.exit(1)
|
process.exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -175,12 +187,30 @@ async function main(): Promise<void> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let resolvedAtomicChatModel: string | null = null
|
||||||
|
if (
|
||||||
|
profile === 'atomic-chat' &&
|
||||||
|
(persisted?.profile !== 'atomic-chat' || !persisted?.env?.OPENAI_MODEL)
|
||||||
|
) {
|
||||||
|
if (!(await hasLocalAtomicChat())) {
|
||||||
|
console.error('Atomic Chat is not running (could not connect to 127.0.0.1:1337).\n Download from https://atomic.chat/ and launch the application.')
|
||||||
|
process.exit(1)
|
||||||
|
}
|
||||||
|
resolvedAtomicChatModel = await resolveAtomicChatDefaultModel()
|
||||||
|
if (!resolvedAtomicChatModel) {
|
||||||
|
console.error('Atomic Chat is running but no model is loaded. Open Atomic Chat and download or start a model first.')
|
||||||
|
process.exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const env = await buildLaunchEnv({
|
const env = await buildLaunchEnv({
|
||||||
profile,
|
profile,
|
||||||
persisted,
|
persisted,
|
||||||
goal: options.goal,
|
goal: options.goal,
|
||||||
getOllamaChatBaseUrl,
|
getOllamaChatBaseUrl,
|
||||||
resolveOllamaDefaultModel: async () => resolvedOllamaModel || 'llama3.1:8b',
|
resolveOllamaDefaultModel: async () => resolvedOllamaModel || 'llama3.1:8b',
|
||||||
|
getAtomicChatChatBaseUrl,
|
||||||
|
resolveAtomicChatDefaultModel: async () => resolvedAtomicChatModel,
|
||||||
})
|
})
|
||||||
if (options.fast) {
|
if (options.fast) {
|
||||||
applyFastFlags(env)
|
applyFastFlags(env)
|
||||||
|
|||||||
@@ -186,7 +186,7 @@ function checkOpenAIEnv(): CheckResult[] {
|
|||||||
} else if (!key && !isLocalBaseUrl(request.baseUrl)) {
|
} else if (!key && !isLocalBaseUrl(request.baseUrl)) {
|
||||||
results.push(fail('OPENAI_API_KEY', 'Missing key for non-local provider URL.'))
|
results.push(fail('OPENAI_API_KEY', 'Missing key for non-local provider URL.'))
|
||||||
} else if (!key) {
|
} else if (!key) {
|
||||||
results.push(pass('OPENAI_API_KEY', 'Not set (allowed for local providers like Ollama/LM Studio).'))
|
results.push(pass('OPENAI_API_KEY', 'Not set (allowed for local providers like Atomic Chat/Ollama/LM Studio).'))
|
||||||
} else {
|
} else {
|
||||||
results.push(pass('OPENAI_API_KEY', 'Configured.'))
|
results.push(pass('OPENAI_API_KEY', 'Configured.'))
|
||||||
}
|
}
|
||||||
@@ -271,6 +271,15 @@ async function checkBaseUrlReachability(): Promise<CheckResult> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function isAtomicChatUrl(baseUrl: string): boolean {
|
||||||
|
try {
|
||||||
|
const parsed = new URL(baseUrl)
|
||||||
|
return parsed.port === '1337' && isLocalBaseUrl(baseUrl)
|
||||||
|
} catch {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function checkOllamaProcessorMode(): CheckResult {
|
function checkOllamaProcessorMode(): CheckResult {
|
||||||
if (!isTruthy(process.env.CLAUDE_CODE_USE_OPENAI) || isTruthy(process.env.CLAUDE_CODE_USE_GEMINI)) {
|
if (!isTruthy(process.env.CLAUDE_CODE_USE_OPENAI) || isTruthy(process.env.CLAUDE_CODE_USE_GEMINI)) {
|
||||||
return pass('Ollama processor mode', 'Skipped (OpenAI-compatible mode disabled).')
|
return pass('Ollama processor mode', 'Skipped (OpenAI-compatible mode disabled).')
|
||||||
@@ -281,6 +290,10 @@ function checkOllamaProcessorMode(): CheckResult {
|
|||||||
return pass('Ollama processor mode', 'Skipped (provider URL is not local).')
|
return pass('Ollama processor mode', 'Skipped (provider URL is not local).')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (isAtomicChatUrl(baseUrl)) {
|
||||||
|
return pass('Ollama processor mode', 'Skipped (Atomic Chat local provider detected, not Ollama).')
|
||||||
|
}
|
||||||
|
|
||||||
const result = spawnSync('ollama', ['ps'], {
|
const result = spawnSync('ollama', ['ps'], {
|
||||||
cwd: process.cwd(),
|
cwd: process.cwd(),
|
||||||
encoding: 'utf8',
|
encoding: 'utf8',
|
||||||
|
|||||||
@@ -57,8 +57,8 @@ class Provider:
|
|||||||
@property
|
@property
|
||||||
def is_configured(self) -> bool:
|
def is_configured(self) -> bool:
|
||||||
"""True if the provider has an API key set."""
|
"""True if the provider has an API key set."""
|
||||||
if self.name == "ollama":
|
if self.name in ("ollama", "atomic-chat"):
|
||||||
return True # Ollama needs no API key
|
return True # Local providers need no API key
|
||||||
return bool(self.api_key)
|
return bool(self.api_key)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -93,6 +93,7 @@ def build_default_providers() -> list[Provider]:
|
|||||||
big = os.getenv("BIG_MODEL", "gpt-4.1")
|
big = os.getenv("BIG_MODEL", "gpt-4.1")
|
||||||
small = os.getenv("SMALL_MODEL", "gpt-4.1-mini")
|
small = os.getenv("SMALL_MODEL", "gpt-4.1-mini")
|
||||||
ollama_url = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
|
ollama_url = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
|
||||||
|
atomic_chat_url = os.getenv("ATOMIC_CHAT_BASE_URL", "http://127.0.0.1:1337")
|
||||||
|
|
||||||
return [
|
return [
|
||||||
Provider(
|
Provider(
|
||||||
@@ -119,6 +120,14 @@ def build_default_providers() -> list[Provider]:
|
|||||||
big_model=big if "gemini" not in big and "gpt" not in big else "llama3:8b",
|
big_model=big if "gemini" not in big and "gpt" not in big else "llama3:8b",
|
||||||
small_model=small if "gemini" not in small and "gpt" not in small else "llama3:8b",
|
small_model=small if "gemini" not in small and "gpt" not in small else "llama3:8b",
|
||||||
),
|
),
|
||||||
|
Provider(
|
||||||
|
name="atomic-chat",
|
||||||
|
ping_url=f"{atomic_chat_url}/v1/models",
|
||||||
|
api_key_env="",
|
||||||
|
cost_per_1k_tokens=0.0, # free — local (Apple Silicon)
|
||||||
|
big_model=big if "gemini" not in big and "gpt" not in big else "llama3:8b",
|
||||||
|
small_model=small if "gemini" not in small and "gpt" not in small else "llama3:8b",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import { join } from 'node:path'
|
|||||||
import test from 'node:test'
|
import test from 'node:test'
|
||||||
|
|
||||||
import {
|
import {
|
||||||
|
buildAtomicChatProfileEnv,
|
||||||
buildCodexProfileEnv,
|
buildCodexProfileEnv,
|
||||||
buildGeminiProfileEnv,
|
buildGeminiProfileEnv,
|
||||||
buildLaunchEnv,
|
buildLaunchEnv,
|
||||||
@@ -381,3 +382,72 @@ test('auto profile falls back to openai when no viable ollama model exists', ()
|
|||||||
assert.equal(selectAutoProfile(null), 'openai')
|
assert.equal(selectAutoProfile(null), 'openai')
|
||||||
assert.equal(selectAutoProfile('qwen2.5-coder:7b'), 'ollama')
|
assert.equal(selectAutoProfile('qwen2.5-coder:7b'), 'ollama')
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// ── Atomic Chat profile tests ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
test('atomic-chat profiles never persist openai api keys', () => {
|
||||||
|
const env = buildAtomicChatProfileEnv('some-local-model', {
|
||||||
|
getAtomicChatChatBaseUrl: () => 'http://127.0.0.1:1337/v1',
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.deepEqual(env, {
|
||||||
|
OPENAI_BASE_URL: 'http://127.0.0.1:1337/v1',
|
||||||
|
OPENAI_MODEL: 'some-local-model',
|
||||||
|
})
|
||||||
|
assert.equal('OPENAI_API_KEY' in env, false)
|
||||||
|
})
|
||||||
|
|
||||||
|
test('atomic-chat profiles respect custom base url', () => {
|
||||||
|
const env = buildAtomicChatProfileEnv('my-model', {
|
||||||
|
baseUrl: 'http://192.168.1.100:1337',
|
||||||
|
getAtomicChatChatBaseUrl: (baseUrl?: string) =>
|
||||||
|
baseUrl ? `${baseUrl}/v1` : 'http://127.0.0.1:1337/v1',
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.equal(env.OPENAI_BASE_URL, 'http://192.168.1.100:1337/v1')
|
||||||
|
assert.equal(env.OPENAI_MODEL, 'my-model')
|
||||||
|
})
|
||||||
|
|
||||||
|
test('matching persisted atomic-chat env is reused for atomic-chat launch', async () => {
|
||||||
|
const env = await buildLaunchEnv({
|
||||||
|
profile: 'atomic-chat',
|
||||||
|
persisted: profile('atomic-chat', {
|
||||||
|
OPENAI_BASE_URL: 'http://127.0.0.1:1337/v1',
|
||||||
|
OPENAI_MODEL: 'llama-3.1-8b',
|
||||||
|
}),
|
||||||
|
goal: 'balanced',
|
||||||
|
processEnv: {},
|
||||||
|
getAtomicChatChatBaseUrl: () => 'http://127.0.0.1:1337/v1',
|
||||||
|
resolveAtomicChatDefaultModel: async () => 'other-model',
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.equal(env.OPENAI_BASE_URL, 'http://127.0.0.1:1337/v1')
|
||||||
|
assert.equal(env.OPENAI_MODEL, 'llama-3.1-8b')
|
||||||
|
assert.equal(env.OPENAI_API_KEY, undefined)
|
||||||
|
assert.equal(env.CODEX_API_KEY, undefined)
|
||||||
|
})
|
||||||
|
|
||||||
|
test('atomic-chat launch ignores mismatched persisted openai env', async () => {
|
||||||
|
const env = await buildLaunchEnv({
|
||||||
|
profile: 'atomic-chat',
|
||||||
|
persisted: profile('openai', {
|
||||||
|
OPENAI_BASE_URL: 'https://api.openai.com/v1',
|
||||||
|
OPENAI_MODEL: 'gpt-4o',
|
||||||
|
OPENAI_API_KEY: 'sk-persisted',
|
||||||
|
}),
|
||||||
|
goal: 'balanced',
|
||||||
|
processEnv: {
|
||||||
|
OPENAI_API_KEY: 'sk-live',
|
||||||
|
CODEX_API_KEY: 'codex-live',
|
||||||
|
CHATGPT_ACCOUNT_ID: 'acct_live',
|
||||||
|
},
|
||||||
|
getAtomicChatChatBaseUrl: () => 'http://127.0.0.1:1337/v1',
|
||||||
|
resolveAtomicChatDefaultModel: async () => 'local-model',
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.equal(env.OPENAI_BASE_URL, 'http://127.0.0.1:1337/v1')
|
||||||
|
assert.equal(env.OPENAI_MODEL, 'local-model')
|
||||||
|
assert.equal(env.OPENAI_API_KEY, undefined)
|
||||||
|
assert.equal(env.CODEX_API_KEY, undefined)
|
||||||
|
assert.equal(env.CHATGPT_ACCOUNT_ID, undefined)
|
||||||
|
})
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import {
|
|||||||
const DEFAULT_GEMINI_BASE_URL = 'https://generativelanguage.googleapis.com/v1beta/openai'
|
const DEFAULT_GEMINI_BASE_URL = 'https://generativelanguage.googleapis.com/v1beta/openai'
|
||||||
const DEFAULT_GEMINI_MODEL = 'gemini-2.0-flash'
|
const DEFAULT_GEMINI_MODEL = 'gemini-2.0-flash'
|
||||||
|
|
||||||
export type ProviderProfile = 'openai' | 'ollama' | 'codex' | 'gemini'
|
export type ProviderProfile = 'openai' | 'ollama' | 'codex' | 'gemini' | 'atomic-chat'
|
||||||
|
|
||||||
export type ProfileEnv = {
|
export type ProfileEnv = {
|
||||||
OPENAI_BASE_URL?: string
|
OPENAI_BASE_URL?: string
|
||||||
@@ -53,6 +53,19 @@ export function buildOllamaProfileEnv(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function buildAtomicChatProfileEnv(
|
||||||
|
model: string,
|
||||||
|
options: {
|
||||||
|
baseUrl?: string | null
|
||||||
|
getAtomicChatChatBaseUrl: (baseUrl?: string) => string
|
||||||
|
},
|
||||||
|
): ProfileEnv {
|
||||||
|
return {
|
||||||
|
OPENAI_BASE_URL: options.getAtomicChatChatBaseUrl(options.baseUrl ?? undefined),
|
||||||
|
OPENAI_MODEL: model,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export function buildGeminiProfileEnv(options: {
|
export function buildGeminiProfileEnv(options: {
|
||||||
model?: string | null
|
model?: string | null
|
||||||
baseUrl?: string | null
|
baseUrl?: string | null
|
||||||
@@ -171,6 +184,8 @@ export async function buildLaunchEnv(options: {
|
|||||||
processEnv?: NodeJS.ProcessEnv
|
processEnv?: NodeJS.ProcessEnv
|
||||||
getOllamaChatBaseUrl?: (baseUrl?: string) => string
|
getOllamaChatBaseUrl?: (baseUrl?: string) => string
|
||||||
resolveOllamaDefaultModel?: (goal: RecommendationGoal) => Promise<string>
|
resolveOllamaDefaultModel?: (goal: RecommendationGoal) => Promise<string>
|
||||||
|
getAtomicChatChatBaseUrl?: (baseUrl?: string) => string
|
||||||
|
resolveAtomicChatDefaultModel?: () => Promise<string | null>
|
||||||
}): Promise<NodeJS.ProcessEnv> {
|
}): Promise<NodeJS.ProcessEnv> {
|
||||||
const processEnv = options.processEnv ?? process.env
|
const processEnv = options.processEnv ?? process.env
|
||||||
const persistedEnv =
|
const persistedEnv =
|
||||||
@@ -248,6 +263,26 @@ export async function buildLaunchEnv(options: {
|
|||||||
return env
|
return env
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (options.profile === 'atomic-chat') {
|
||||||
|
const getAtomicChatBaseUrl =
|
||||||
|
options.getAtomicChatChatBaseUrl ?? (() => 'http://127.0.0.1:1337/v1')
|
||||||
|
const resolveModel =
|
||||||
|
options.resolveAtomicChatDefaultModel ?? (async () => null as string | null)
|
||||||
|
|
||||||
|
env.OPENAI_BASE_URL = persistedEnv.OPENAI_BASE_URL || getAtomicChatBaseUrl()
|
||||||
|
env.OPENAI_MODEL =
|
||||||
|
persistedEnv.OPENAI_MODEL ||
|
||||||
|
(await resolveModel()) ||
|
||||||
|
''
|
||||||
|
|
||||||
|
delete env.OPENAI_API_KEY
|
||||||
|
delete env.CODEX_API_KEY
|
||||||
|
delete env.CHATGPT_ACCOUNT_ID
|
||||||
|
delete env.CODEX_ACCOUNT_ID
|
||||||
|
|
||||||
|
return env
|
||||||
|
}
|
||||||
|
|
||||||
if (options.profile === 'codex') {
|
if (options.profile === 'codex') {
|
||||||
env.OPENAI_BASE_URL =
|
env.OPENAI_BASE_URL =
|
||||||
persistedEnv.OPENAI_BASE_URL && isCodexBaseUrl(persistedEnv.OPENAI_BASE_URL)
|
persistedEnv.OPENAI_BASE_URL && isCodexBaseUrl(persistedEnv.OPENAI_BASE_URL)
|
||||||
|
|||||||
130
test_atomic_chat_provider.py
Normal file
130
test_atomic_chat_provider.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
"""
|
||||||
|
test_atomic_chat_provider.py
|
||||||
|
Run: pytest test_atomic_chat_provider.py -v
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from atomic_chat_provider import (
|
||||||
|
atomic_chat,
|
||||||
|
list_atomic_chat_models,
|
||||||
|
check_atomic_chat_running,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_atomic_chat_running_true():
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
with patch("atomic_chat_provider.httpx.AsyncClient") as MockClient:
|
||||||
|
MockClient.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response)
|
||||||
|
result = await check_atomic_chat_running()
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_atomic_chat_running_false_on_exception():
|
||||||
|
with patch("atomic_chat_provider.httpx.AsyncClient") as MockClient:
|
||||||
|
MockClient.return_value.__aenter__.return_value.get = AsyncMock(side_effect=Exception("refused"))
|
||||||
|
result = await check_atomic_chat_running()
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_models_returns_ids():
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"data": [{"id": "llama-3.1-8b"}, {"id": "mistral-7b"}],
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
with patch("atomic_chat_provider.httpx.AsyncClient") as MockClient:
|
||||||
|
MockClient.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response)
|
||||||
|
models = await list_atomic_chat_models()
|
||||||
|
assert "llama-3.1-8b" in models
|
||||||
|
assert "mistral-7b" in models
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_models_empty_on_failure():
|
||||||
|
with patch("atomic_chat_provider.httpx.AsyncClient") as MockClient:
|
||||||
|
MockClient.return_value.__aenter__.return_value.get = AsyncMock(side_effect=Exception("down"))
|
||||||
|
models = await list_atomic_chat_models()
|
||||||
|
assert models == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_atomic_chat_returns_anthropic_format():
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"id": "chatcmpl-abc123",
|
||||||
|
"choices": [{"message": {"content": "42 is the answer."}}],
|
||||||
|
"usage": {"prompt_tokens": 10, "completion_tokens": 8},
|
||||||
|
}
|
||||||
|
with patch("atomic_chat_provider.httpx.AsyncClient") as MockClient:
|
||||||
|
MockClient.return_value.__aenter__.return_value.post = AsyncMock(return_value=mock_response)
|
||||||
|
result = await atomic_chat(
|
||||||
|
model="llama-3.1-8b",
|
||||||
|
messages=[{"role": "user", "content": "What is 6*7?"}],
|
||||||
|
)
|
||||||
|
assert result["type"] == "message"
|
||||||
|
assert result["role"] == "assistant"
|
||||||
|
assert "42" in result["content"][0]["text"]
|
||||||
|
assert result["usage"]["input_tokens"] == 10
|
||||||
|
assert result["usage"]["output_tokens"] == 8
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_atomic_chat_prepends_system():
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
async def mock_post(url, json=None, **kwargs):
|
||||||
|
captured.update(json or {})
|
||||||
|
m = MagicMock()
|
||||||
|
m.raise_for_status = MagicMock()
|
||||||
|
m.json.return_value = {
|
||||||
|
"id": "chatcmpl-xyz",
|
||||||
|
"choices": [{"message": {"content": "ok"}}],
|
||||||
|
"usage": {"prompt_tokens": 1, "completion_tokens": 1},
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
|
||||||
|
with patch("atomic_chat_provider.httpx.AsyncClient") as MockClient:
|
||||||
|
MockClient.return_value.__aenter__.return_value.post = mock_post
|
||||||
|
await atomic_chat(
|
||||||
|
model="llama-3.1-8b",
|
||||||
|
messages=[{"role": "user", "content": "Hi"}],
|
||||||
|
system="Be helpful.",
|
||||||
|
)
|
||||||
|
assert captured["messages"][0]["role"] == "system"
|
||||||
|
assert "helpful" in captured["messages"][0]["content"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_atomic_chat_sends_correct_payload():
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
async def mock_post(url, json=None, **kwargs):
|
||||||
|
captured.update(json or {})
|
||||||
|
m = MagicMock()
|
||||||
|
m.raise_for_status = MagicMock()
|
||||||
|
m.json.return_value = {
|
||||||
|
"id": "chatcmpl-xyz",
|
||||||
|
"choices": [{"message": {"content": "ok"}}],
|
||||||
|
"usage": {"prompt_tokens": 1, "completion_tokens": 1},
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
|
||||||
|
with patch("atomic_chat_provider.httpx.AsyncClient") as MockClient:
|
||||||
|
MockClient.return_value.__aenter__.return_value.post = mock_post
|
||||||
|
await atomic_chat(
|
||||||
|
model="test-model",
|
||||||
|
messages=[{"role": "user", "content": "Test"}],
|
||||||
|
max_tokens=2048,
|
||||||
|
temperature=0.5,
|
||||||
|
)
|
||||||
|
assert captured["model"] == "test-model"
|
||||||
|
assert captured["max_tokens"] == 2048
|
||||||
|
assert captured["temperature"] == 0.5
|
||||||
|
assert captured["stream"] is False
|
||||||
Reference in New Issue
Block a user