Merge pull request #1 from gnanam1990/feat/ollama-provider
feat: add native Ollama provider for local LLM support
This commit is contained in:
153
ollama_provider.py
Normal file
153
ollama_provider.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""
|
||||
ollama_provider.py
|
||||
------------------
|
||||
Adds native Ollama support to openclaude.
|
||||
Lets Claude Code route requests to any locally-running Ollama model
|
||||
(llama3, mistral, codellama, phi3, qwen2, deepseek-coder, etc.)
|
||||
without needing an API key.
|
||||
|
||||
Usage (.env):
|
||||
PREFERRED_PROVIDER=ollama
|
||||
OLLAMA_BASE_URL=http://localhost:11434
|
||||
BIG_MODEL=codellama:34b
|
||||
SMALL_MODEL=llama3:8b
|
||||
"""
|
||||
|
||||
import httpx
|
||||
import logging
|
||||
import os
|
||||
from typing import AsyncIterator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
|
||||
|
||||
|
||||
async def check_ollama_running() -> bool:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=3.0) as client:
|
||||
resp = await client.get(f"{OLLAMA_BASE_URL}/api/tags")
|
||||
return resp.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
async def list_ollama_models() -> list[str]:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
resp = await client.get(f"{OLLAMA_BASE_URL}/api/tags")
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return [m["name"] for m in data.get("models", [])]
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not list Ollama models: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def normalize_ollama_model(model_name: str) -> str:
|
||||
if model_name.startswith("ollama/"):
|
||||
return model_name[len("ollama/"):]
|
||||
return model_name
|
||||
|
||||
|
||||
def anthropic_to_ollama_messages(messages: list[dict]) -> list[dict]:
|
||||
ollama_messages = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, str):
|
||||
ollama_messages.append({"role": role, "content": content})
|
||||
elif isinstance(content, list):
|
||||
text_parts = []
|
||||
for block in content:
|
||||
if isinstance(block, dict):
|
||||
if block.get("type") == "text":
|
||||
text_parts.append(block.get("text", ""))
|
||||
elif block.get("type") == "image":
|
||||
text_parts.append("[image]")
|
||||
elif isinstance(block, str):
|
||||
text_parts.append(block)
|
||||
ollama_messages.append({"role": role, "content": "\n".join(text_parts)})
|
||||
return ollama_messages
|
||||
|
||||
|
||||
async def ollama_chat(
|
||||
model: str,
|
||||
messages: list[dict],
|
||||
system: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 1.0,
|
||||
) -> dict:
|
||||
model = normalize_ollama_model(model)
|
||||
ollama_messages = anthropic_to_ollama_messages(messages)
|
||||
if system:
|
||||
ollama_messages.insert(0, {"role": "system", "content": system})
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": ollama_messages,
|
||||
"stream": False,
|
||||
"options": {"num_predict": max_tokens, "temperature": temperature},
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
resp = await client.post(f"{OLLAMA_BASE_URL}/api/chat", json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
assistant_text = data.get("message", {}).get("content", "")
|
||||
return {
|
||||
"id": f"msg_ollama_{data.get('created_at', 'unknown')}",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": assistant_text}],
|
||||
"model": model,
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
"usage": {
|
||||
"input_tokens": data.get("prompt_eval_count", 0),
|
||||
"output_tokens": data.get("eval_count", 0),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def ollama_chat_stream(
|
||||
model: str,
|
||||
messages: list[dict],
|
||||
system: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 1.0,
|
||||
) -> AsyncIterator[str]:
|
||||
import json
|
||||
model = normalize_ollama_model(model)
|
||||
ollama_messages = anthropic_to_ollama_messages(messages)
|
||||
if system:
|
||||
ollama_messages.insert(0, {"role": "system", "content": system})
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": ollama_messages,
|
||||
"stream": True,
|
||||
"options": {"num_predict": max_tokens, "temperature": temperature},
|
||||
}
|
||||
yield "event: message_start\n"
|
||||
yield f'data: {json.dumps({"type": "message_start", "message": {"id": "msg_ollama_stream", "type": "message", "role": "assistant", "content": [], "model": model, "stop_reason": None, "usage": {"input_tokens": 0, "output_tokens": 0}}})}\n\n'
|
||||
yield "event: content_block_start\n"
|
||||
yield f'data: {json.dumps({"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}})}\n\n'
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
async with client.stream("POST", f"{OLLAMA_BASE_URL}/api/chat", json=payload) as resp:
|
||||
resp.raise_for_status()
|
||||
async for line in resp.aiter_lines():
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
chunk = json.loads(line)
|
||||
delta_text = chunk.get("message", {}).get("content", "")
|
||||
if delta_text:
|
||||
yield "event: content_block_delta\n"
|
||||
yield f'data: {json.dumps({"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": delta_text}})}\n\n'
|
||||
if chunk.get("done"):
|
||||
yield "event: content_block_stop\n"
|
||||
yield f'data: {json.dumps({"type": "content_block_stop", "index": 0})}\n\n'
|
||||
yield "event: message_delta\n"
|
||||
yield f'data: {json.dumps({"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence": None}, "usage": {"output_tokens": chunk.get("eval_count", 0)}})}\n\n'
|
||||
yield "event: message_stop\n"
|
||||
yield f'data: {json.dumps({"type": "message_stop"})}\n\n'
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
120
test_ollama_provider.py
Normal file
120
test_ollama_provider.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""
|
||||
test_ollama_provider.py
|
||||
Run: pytest test_ollama_provider.py -v
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from ollama_provider import (
|
||||
normalize_ollama_model,
|
||||
anthropic_to_ollama_messages,
|
||||
ollama_chat,
|
||||
list_ollama_models,
|
||||
check_ollama_running,
|
||||
)
|
||||
|
||||
def test_normalize_strips_prefix():
|
||||
assert normalize_ollama_model("ollama/llama3:8b") == "llama3:8b"
|
||||
|
||||
def test_normalize_no_prefix():
|
||||
assert normalize_ollama_model("codellama:34b") == "codellama:34b"
|
||||
|
||||
def test_normalize_empty():
|
||||
assert normalize_ollama_model("") == ""
|
||||
|
||||
def test_converts_string_content():
|
||||
messages = [{"role": "user", "content": "Hello!"}]
|
||||
result = anthropic_to_ollama_messages(messages)
|
||||
assert result == [{"role": "user", "content": "Hello!"}]
|
||||
|
||||
def test_converts_text_block_list():
|
||||
messages = [{"role": "user", "content": [{"type": "text", "text": "What is Python?"}]}]
|
||||
result = anthropic_to_ollama_messages(messages)
|
||||
assert result[0]["content"] == "What is Python?"
|
||||
|
||||
def test_converts_image_block_to_placeholder():
|
||||
messages = [{"role": "user", "content": [{"type": "image", "source": {}}, {"type": "text", "text": "Describe this"}]}]
|
||||
result = anthropic_to_ollama_messages(messages)
|
||||
assert "[image]" in result[0]["content"]
|
||||
assert "Describe this" in result[0]["content"]
|
||||
|
||||
def test_converts_multi_turn():
|
||||
messages = [
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": "Hello!"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
]
|
||||
result = anthropic_to_ollama_messages(messages)
|
||||
assert len(result) == 3
|
||||
assert result[1]["role"] == "assistant"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ollama_running_true():
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
with patch("ollama_provider.httpx.AsyncClient") as MockClient:
|
||||
MockClient.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response)
|
||||
result = await check_ollama_running()
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ollama_running_false_on_exception():
|
||||
with patch("ollama_provider.httpx.AsyncClient") as MockClient:
|
||||
MockClient.return_value.__aenter__.return_value.get = AsyncMock(side_effect=Exception("refused"))
|
||||
result = await check_ollama_running()
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_models_returns_names():
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"models": [{"name": "llama3:8b"}, {"name": "codellama:34b"}]}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
with patch("ollama_provider.httpx.AsyncClient") as MockClient:
|
||||
MockClient.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response)
|
||||
models = await list_ollama_models()
|
||||
assert "llama3:8b" in models
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ollama_chat_returns_anthropic_format():
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"message": {"content": "42 is the answer."},
|
||||
"created_at": "2026-01-01T00:00:00Z",
|
||||
"prompt_eval_count": 10,
|
||||
"eval_count": 8,
|
||||
}
|
||||
with patch("ollama_provider.httpx.AsyncClient") as MockClient:
|
||||
MockClient.return_value.__aenter__.return_value.post = AsyncMock(return_value=mock_response)
|
||||
result = await ollama_chat(
|
||||
model="llama3:8b",
|
||||
messages=[{"role": "user", "content": "What is 6*7?"}]
|
||||
)
|
||||
assert result["type"] == "message"
|
||||
assert result["role"] == "assistant"
|
||||
assert "42" in result["content"][0]["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ollama_chat_prepends_system():
|
||||
captured = {}
|
||||
async def mock_post(url, json=None, **kwargs):
|
||||
captured.update(json or {})
|
||||
m = MagicMock()
|
||||
m.raise_for_status = MagicMock()
|
||||
m.json.return_value = {
|
||||
"message": {"content": "ok"},
|
||||
"created_at": "",
|
||||
"prompt_eval_count": 1,
|
||||
"eval_count": 1
|
||||
}
|
||||
return m
|
||||
with patch("ollama_provider.httpx.AsyncClient") as MockClient:
|
||||
MockClient.return_value.__aenter__.return_value.post = mock_post
|
||||
await ollama_chat(
|
||||
model="llama3:8b",
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
system="Be helpful."
|
||||
)
|
||||
assert captured["messages"][0]["role"] == "system"
|
||||
assert "helpful" in captured["messages"][0]["content"]
|
||||
Reference in New Issue
Block a user