Coverage for src/qdrant_loader_core/llm/providers/ollama.py: 86%
170 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-08 06:01 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-08 06:01 +0000
1from __future__ import annotations
3from typing import Any
5try:
6 import httpx # type: ignore
7except Exception: # pragma: no cover - optional dependency
8 httpx = None # type: ignore
10from ...logging import LoggingConfig
11from ..errors import (
12 AuthError,
13 InvalidRequestError,
14 RateLimitedError,
15 ServerError,
16)
17from ..errors import TimeoutError as LLMTimeoutError
18from ..settings import LLMSettings
19from ..types import ChatClient, EmbeddingsClient, LLMProvider, TokenCounter
21logger = LoggingConfig.get_logger(__name__)
24def _join_url(base: str | None, path: str) -> str:
25 base = (base or "").rstrip("/")
26 path = path.lstrip("/")
27 return f"{base}/{path}" if base else f"/{path}"
30class OllamaEmbeddings(EmbeddingsClient):
31 def __init__(
32 self,
33 base_url: str | None,
34 model: str,
35 headers: dict[str, str] | None,
36 *,
37 timeout_s: float | None = None,
38 provider_options: dict[str, Any] | None = None,
39 ):
40 self._base_url = (base_url or "http://localhost:11434").rstrip("/")
41 self._model = model
42 self._headers = headers or {}
43 self._timeout_s = float(timeout_s) if timeout_s is not None else 30.0
44 self._provider_options = provider_options or {}
46 async def embed(self, inputs: list[str]) -> list[list[float]]:
47 if httpx is None:
48 raise NotImplementedError("httpx not available for Ollama embeddings")
50 # Prefer OpenAI-compatible if base_url seems to expose /v1
51 use_v1 = "/v1" in (self._base_url or "")
52 async with httpx.AsyncClient(timeout=self._timeout_s) as client:
53 try:
54 if use_v1:
55 # OpenAI-compatible embeddings endpoint
56 url = _join_url(self._base_url, "/embeddings")
57 payload = {"model": self._model, "input": inputs}
58 resp = await client.post(url, json=payload, headers=self._headers)
59 resp.raise_for_status()
60 data = resp.json()
61 logger.info(
62 "LLM request",
63 provider="ollama",
64 operation="embeddings",
65 model=self._model,
66 base_host=self._base_url,
67 inputs=len(inputs),
68 # latency for v1 path hard to compute here; omitted for now
69 )
70 return [item["embedding"] for item in data.get("data", [])]
71 else:
72 # Determine native endpoint preference: embed | embeddings | auto (default)
73 native_pref = str(
74 self._provider_options.get("native_endpoint", "auto")
75 ).lower()
76 prefer_embed = native_pref != "embeddings"
78 # Try batch embed first when preferred
79 if prefer_embed:
80 url = _join_url(self._base_url, "/api/embed")
81 payload = {"model": self._model, "input": inputs}
82 try:
83 resp = await client.post(
84 url, json=payload, headers=self._headers
85 )
86 resp.raise_for_status()
87 data = resp.json()
88 vectors = data.get("embeddings")
89 if not isinstance(vectors, list) or (
90 len(vectors) != len(inputs)
91 ):
92 raise ValueError(
93 "Invalid embeddings response from /api/embed"
94 )
95 # Normalize to list[list[float]]
96 norm = [list(vec) for vec in vectors]
97 logger.info(
98 "LLM request",
99 provider="ollama",
100 operation="embeddings",
101 model=self._model,
102 base_host=self._base_url,
103 inputs=len(inputs),
104 # latency for native batch path not measured in this stub
105 )
106 return norm
107 except httpx.HTTPStatusError as exc:
108 status = exc.response.status_code if exc.response else None
109 # Fallback for servers that don't support /api/embed
110 if status not in (404, 405, 501):
111 raise
113 # Per-item embeddings endpoint fallback or preference
114 url = _join_url(self._base_url, "/api/embeddings")
115 vectors2: list[list[float]] = []
116 for text in inputs:
117 payload = {"model": self._model, "input": text}
118 resp = await client.post(
119 url, json=payload, headers=self._headers
120 )
121 resp.raise_for_status()
122 data = resp.json()
123 emb = data.get("embedding")
124 if emb is None and isinstance(data.get("data"), dict):
125 emb = data["data"].get("embedding")
126 if emb is None:
127 raise ValueError(
128 "Invalid embedding response from /api/embeddings"
129 )
130 vectors2.append(list(emb))
131 logger.info(
132 "LLM request",
133 provider="ollama",
134 operation="embeddings",
135 model=self._model,
136 base_host=self._base_url,
137 inputs=len(inputs),
138 # latency for per-item path not measured in this stub
139 )
140 return vectors2
141 except httpx.TimeoutException as exc:
142 raise LLMTimeoutError(str(exc))
143 except httpx.HTTPStatusError as exc:
144 status = exc.response.status_code if exc.response else None
145 if status == 401:
146 raise AuthError(str(exc))
147 if status == 429:
148 raise RateLimitedError(str(exc))
149 if status and 400 <= status < 500:
150 raise InvalidRequestError(str(exc))
151 raise ServerError(str(exc))
152 except httpx.HTTPError as exc:
153 raise ServerError(str(exc))
156class OllamaChat(ChatClient):
157 def __init__(
158 self, base_url: str | None, model: str, headers: dict[str, str] | None
159 ):
160 self._base_url = base_url or "http://localhost:11434"
161 self._model = model
162 self._headers = headers or {}
164 async def chat(
165 self, messages: list[dict[str, Any]], **kwargs: Any
166 ) -> dict[str, Any]:
167 if httpx is None:
168 raise NotImplementedError("httpx not available for Ollama chat")
170 # Prefer OpenAI-compatible if base_url exposes /v1
171 use_v1 = "/v1" in (self._base_url or "")
172 # Flatten messages to a single prompt for native API; preserve roles when possible
173 if use_v1:
174 url = _join_url(self._base_url, "/chat/completions")
175 payload = {"model": self._model, "messages": messages}
176 # Map common kwargs
177 for k in ("temperature", "max_tokens", "top_p", "stop"):
178 if k in kwargs and kwargs[k] is not None:
179 payload[k] = kwargs[k]
180 async with httpx.AsyncClient(timeout=60.0) as client:
181 try:
182 from datetime import UTC, datetime
184 started = datetime.now(UTC)
185 resp = await client.post(url, json=payload, headers=self._headers)
186 resp.raise_for_status()
187 data = resp.json()
188 text = ""
189 choices = data.get("choices") or []
190 if choices:
191 msg = (choices[0] or {}).get("message") or {}
192 text = msg.get("content", "") or ""
193 duration_ms = int(
194 (datetime.now(UTC) - started).total_seconds() * 1000
195 )
196 logger.info(
197 "LLM request",
198 provider="ollama",
199 operation="chat",
200 model=self._model,
201 base_host=self._base_url,
202 messages=len(messages),
203 latency_ms=duration_ms,
204 )
205 return {
206 "text": text,
207 "raw": data,
208 "usage": data.get("usage"),
209 "model": data.get("model", self._model),
210 }
211 except httpx.TimeoutException as exc:
212 raise LLMTimeoutError(str(exc))
213 except httpx.HTTPStatusError as exc:
214 status = exc.response.status_code if exc.response else None
215 if status == 401:
216 raise AuthError(str(exc))
217 if status == 429:
218 raise RateLimitedError(str(exc))
219 if status and 400 <= status < 500:
220 raise InvalidRequestError(str(exc))
221 raise ServerError(str(exc))
222 except httpx.HTTPError as exc:
223 raise ServerError(str(exc))
224 else:
225 # Native API
226 url = _join_url(self._base_url, "/api/chat")
227 payload = {
228 "model": self._model,
229 "messages": messages,
230 "stream": False,
231 }
232 if "temperature" in kwargs and kwargs["temperature"] is not None:
233 payload["options"] = {"temperature": kwargs["temperature"]}
234 async with httpx.AsyncClient(timeout=60.0) as client:
235 try:
236 from datetime import UTC, datetime
238 started = datetime.now(UTC)
239 resp = await client.post(url, json=payload, headers=self._headers)
240 resp.raise_for_status()
241 data = resp.json()
242 # Ollama native returns {"message": {"content": "..."}, ...}
243 text = ""
244 if isinstance(data.get("message"), dict):
245 text = data["message"].get("content", "") or ""
246 duration_ms = int(
247 (datetime.now(UTC) - started).total_seconds() * 1000
248 )
249 logger.info(
250 "LLM request",
251 provider="ollama",
252 operation="chat",
253 model=self._model,
254 base_host=self._base_url,
255 messages=len(messages),
256 latency_ms=duration_ms,
257 )
258 return {
259 "text": text,
260 "raw": data,
261 "usage": None,
262 "model": self._model,
263 }
264 except httpx.TimeoutException as exc:
265 raise LLMTimeoutError(str(exc))
266 except httpx.HTTPStatusError as exc:
267 status = exc.response.status_code if exc.response else None
268 if status == 401:
269 raise AuthError(str(exc))
270 if status == 429:
271 raise RateLimitedError(str(exc))
272 if status and 400 <= status < 500:
273 raise InvalidRequestError(str(exc))
274 raise ServerError(str(exc))
275 except httpx.HTTPError as exc:
276 raise ServerError(str(exc))
279class OllamaTokenizer(TokenCounter):
280 def count(self, text: str) -> int:
281 return len(text)
284class OllamaProvider(LLMProvider):
285 def __init__(self, settings: LLMSettings):
286 self._settings = settings
288 def embeddings(self) -> EmbeddingsClient:
289 model = self._settings.models.get("embeddings", "")
290 timeout = (
291 self._settings.request.timeout_s
292 if self._settings and self._settings.request
293 else 30.0
294 )
295 return OllamaEmbeddings(
296 self._settings.base_url,
297 model,
298 self._settings.headers,
299 timeout_s=timeout,
300 provider_options=self._settings.provider_options,
301 )
303 def chat(self) -> ChatClient:
304 model = self._settings.models.get("chat", "")
305 return OllamaChat(self._settings.base_url, model, self._settings.headers)
307 def tokenizer(self) -> TokenCounter:
308 return OllamaTokenizer()