diff --git a/ollama_provider.py b/ollama_provider.py new file mode 100644 index 00000000..de19265d --- /dev/null +++ b/ollama_provider.py @@ -0,0 +1,153 @@ +""" +ollama_provider.py +------------------ +Adds native Ollama support to openclaude. +Lets Claude Code route requests to any locally-running Ollama model +(llama3, mistral, codellama, phi3, qwen2, deepseek-coder, etc.) +without needing an API key. + +Usage (.env): + PREFERRED_PROVIDER=ollama + OLLAMA_BASE_URL=http://localhost:11434 + BIG_MODEL=codellama:34b + SMALL_MODEL=llama3:8b +""" + +import httpx +import logging +import os +from typing import AsyncIterator + +logger = logging.getLogger(__name__) +OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434") + + +async def check_ollama_running() -> bool: + try: + async with httpx.AsyncClient(timeout=3.0) as client: + resp = await client.get(f"{OLLAMA_BASE_URL}/api/tags") + return resp.status_code == 200 + except Exception: + return False + + +async def list_ollama_models() -> list[str]: + try: + async with httpx.AsyncClient(timeout=5.0) as client: + resp = await client.get(f"{OLLAMA_BASE_URL}/api/tags") + resp.raise_for_status() + data = resp.json() + return [m["name"] for m in data.get("models", [])] + except Exception as e: + logger.warning(f"Could not list Ollama models: {e}") + return [] + + +def normalize_ollama_model(model_name: str) -> str: + if model_name.startswith("ollama/"): + return model_name[len("ollama/"):] + return model_name + + +def anthropic_to_ollama_messages(messages: list[dict]) -> list[dict]: + ollama_messages = [] + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + if isinstance(content, str): + ollama_messages.append({"role": role, "content": content}) + elif isinstance(content, list): + text_parts = [] + for block in content: + if isinstance(block, dict): + if block.get("type") == "text": + text_parts.append(block.get("text", "")) + elif block.get("type") == "image": + text_parts.append("[image]") + elif isinstance(block, str): + text_parts.append(block) + ollama_messages.append({"role": role, "content": "\n".join(text_parts)}) + return ollama_messages + + +async def ollama_chat( + model: str, + messages: list[dict], + system: str | None = None, + max_tokens: int = 4096, + temperature: float = 1.0, +) -> dict: + model = normalize_ollama_model(model) + ollama_messages = anthropic_to_ollama_messages(messages) + if system: + ollama_messages.insert(0, {"role": "system", "content": system}) + payload = { + "model": model, + "messages": ollama_messages, + "stream": False, + "options": {"num_predict": max_tokens, "temperature": temperature}, + } + async with httpx.AsyncClient(timeout=120.0) as client: + resp = await client.post(f"{OLLAMA_BASE_URL}/api/chat", json=payload) + resp.raise_for_status() + data = resp.json() + assistant_text = data.get("message", {}).get("content", "") + return { + "id": f"msg_ollama_{data.get('created_at', 'unknown')}", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": assistant_text}], + "model": model, + "stop_reason": "end_turn", + "stop_sequence": None, + "usage": { + "input_tokens": data.get("prompt_eval_count", 0), + "output_tokens": data.get("eval_count", 0), + }, + } + + +async def ollama_chat_stream( + model: str, + messages: list[dict], + system: str | None = None, + max_tokens: int = 4096, + temperature: float = 1.0, +) -> AsyncIterator[str]: + import json + model = normalize_ollama_model(model) + ollama_messages = anthropic_to_ollama_messages(messages) + if system: + ollama_messages.insert(0, {"role": "system", "content": system}) + payload = { + "model": model, + "messages": ollama_messages, + "stream": True, + "options": {"num_predict": max_tokens, "temperature": temperature}, + } + yield "event: message_start\n" + yield f'data: {json.dumps({"type": "message_start", "message": {"id": "msg_ollama_stream", "type": "message", "role": "assistant", "content": [], "model": model, "stop_reason": None, "usage": {"input_tokens": 0, "output_tokens": 0}}})}\n\n' + yield "event: content_block_start\n" + yield f'data: {json.dumps({"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}})}\n\n' + async with httpx.AsyncClient(timeout=120.0) as client: + async with client.stream("POST", f"{OLLAMA_BASE_URL}/api/chat", json=payload) as resp: + resp.raise_for_status() + async for line in resp.aiter_lines(): + if not line: + continue + try: + chunk = json.loads(line) + delta_text = chunk.get("message", {}).get("content", "") + if delta_text: + yield "event: content_block_delta\n" + yield f'data: {json.dumps({"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": delta_text}})}\n\n' + if chunk.get("done"): + yield "event: content_block_stop\n" + yield f'data: {json.dumps({"type": "content_block_stop", "index": 0})}\n\n' + yield "event: message_delta\n" + yield f'data: {json.dumps({"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence": None}, "usage": {"output_tokens": chunk.get("eval_count", 0)}})}\n\n' + yield "event: message_stop\n" + yield f'data: {json.dumps({"type": "message_stop"})}\n\n' + break + except json.JSONDecodeError: + continue diff --git a/test_ollama_provider.py b/test_ollama_provider.py new file mode 100644 index 00000000..8028e761 --- /dev/null +++ b/test_ollama_provider.py @@ -0,0 +1,120 @@ +""" +test_ollama_provider.py +Run: pytest test_ollama_provider.py -v +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from ollama_provider import ( + normalize_ollama_model, + anthropic_to_ollama_messages, + ollama_chat, + list_ollama_models, + check_ollama_running, +) + +def test_normalize_strips_prefix(): + assert normalize_ollama_model("ollama/llama3:8b") == "llama3:8b" + +def test_normalize_no_prefix(): + assert normalize_ollama_model("codellama:34b") == "codellama:34b" + +def test_normalize_empty(): + assert normalize_ollama_model("") == "" + +def test_converts_string_content(): + messages = [{"role": "user", "content": "Hello!"}] + result = anthropic_to_ollama_messages(messages) + assert result == [{"role": "user", "content": "Hello!"}] + +def test_converts_text_block_list(): + messages = [{"role": "user", "content": [{"type": "text", "text": "What is Python?"}]}] + result = anthropic_to_ollama_messages(messages) + assert result[0]["content"] == "What is Python?" + +def test_converts_image_block_to_placeholder(): + messages = [{"role": "user", "content": [{"type": "image", "source": {}}, {"type": "text", "text": "Describe this"}]}] + result = anthropic_to_ollama_messages(messages) + assert "[image]" in result[0]["content"] + assert "Describe this" in result[0]["content"] + +def test_converts_multi_turn(): + messages = [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello!"}, + {"role": "user", "content": "How are you?"}, + ] + result = anthropic_to_ollama_messages(messages) + assert len(result) == 3 + assert result[1]["role"] == "assistant" + +@pytest.mark.asyncio +async def test_ollama_running_true(): + mock_response = MagicMock() + mock_response.status_code = 200 + with patch("ollama_provider.httpx.AsyncClient") as MockClient: + MockClient.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response) + result = await check_ollama_running() + assert result is True + +@pytest.mark.asyncio +async def test_ollama_running_false_on_exception(): + with patch("ollama_provider.httpx.AsyncClient") as MockClient: + MockClient.return_value.__aenter__.return_value.get = AsyncMock(side_effect=Exception("refused")) + result = await check_ollama_running() + assert result is False + +@pytest.mark.asyncio +async def test_list_models_returns_names(): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"models": [{"name": "llama3:8b"}, {"name": "codellama:34b"}]} + mock_response.raise_for_status = MagicMock() + with patch("ollama_provider.httpx.AsyncClient") as MockClient: + MockClient.return_value.__aenter__.return_value.get = AsyncMock(return_value=mock_response) + models = await list_ollama_models() + assert "llama3:8b" in models + +@pytest.mark.asyncio +async def test_ollama_chat_returns_anthropic_format(): + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + mock_response.json.return_value = { + "message": {"content": "42 is the answer."}, + "created_at": "2026-01-01T00:00:00Z", + "prompt_eval_count": 10, + "eval_count": 8, + } + with patch("ollama_provider.httpx.AsyncClient") as MockClient: + MockClient.return_value.__aenter__.return_value.post = AsyncMock(return_value=mock_response) + result = await ollama_chat( + model="llama3:8b", + messages=[{"role": "user", "content": "What is 6*7?"}] + ) + assert result["type"] == "message" + assert result["role"] == "assistant" + assert "42" in result["content"][0]["text"] + +@pytest.mark.asyncio +async def test_ollama_chat_prepends_system(): + captured = {} + async def mock_post(url, json=None, **kwargs): + captured.update(json or {}) + m = MagicMock() + m.raise_for_status = MagicMock() + m.json.return_value = { + "message": {"content": "ok"}, + "created_at": "", + "prompt_eval_count": 1, + "eval_count": 1 + } + return m + with patch("ollama_provider.httpx.AsyncClient") as MockClient: + MockClient.return_value.__aenter__.return_value.post = mock_post + await ollama_chat( + model="llama3:8b", + messages=[{"role": "user", "content": "Hi"}], + system="Be helpful." + ) + assert captured["messages"][0]["role"] == "system" + assert "helpful" in captured["messages"][0]["content"]