Coverage for src/qdrant_loader_core/llm/providers/ollama.py: 86%

170 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-08 06:01 +0000

1from __future__ import annotations 

2 

3from typing import Any 

4 

5try: 

6 import httpx # type: ignore 

7except Exception: # pragma: no cover - optional dependency 

8 httpx = None # type: ignore 

9 

10from ...logging import LoggingConfig 

11from ..errors import ( 

12 AuthError, 

13 InvalidRequestError, 

14 RateLimitedError, 

15 ServerError, 

16) 

17from ..errors import TimeoutError as LLMTimeoutError 

18from ..settings import LLMSettings 

19from ..types import ChatClient, EmbeddingsClient, LLMProvider, TokenCounter 

20 

21logger = LoggingConfig.get_logger(__name__) 

22 

23 

24def _join_url(base: str | None, path: str) -> str: 

25 base = (base or "").rstrip("/") 

26 path = path.lstrip("/") 

27 return f"{base}/{path}" if base else f"/{path}" 

28 

29 

30class OllamaEmbeddings(EmbeddingsClient): 

31 def __init__( 

32 self, 

33 base_url: str | None, 

34 model: str, 

35 headers: dict[str, str] | None, 

36 *, 

37 timeout_s: float | None = None, 

38 provider_options: dict[str, Any] | None = None, 

39 ): 

40 self._base_url = (base_url or "http://localhost:11434").rstrip("/") 

41 self._model = model 

42 self._headers = headers or {} 

43 self._timeout_s = float(timeout_s) if timeout_s is not None else 30.0 

44 self._provider_options = provider_options or {} 

45 

46 async def embed(self, inputs: list[str]) -> list[list[float]]: 

47 if httpx is None: 

48 raise NotImplementedError("httpx not available for Ollama embeddings") 

49 

50 # Prefer OpenAI-compatible if base_url seems to expose /v1 

51 use_v1 = "/v1" in (self._base_url or "") 

52 async with httpx.AsyncClient(timeout=self._timeout_s) as client: 

53 try: 

54 if use_v1: 

55 # OpenAI-compatible embeddings endpoint 

56 url = _join_url(self._base_url, "/embeddings") 

57 payload = {"model": self._model, "input": inputs} 

58 resp = await client.post(url, json=payload, headers=self._headers) 

59 resp.raise_for_status() 

60 data = resp.json() 

61 logger.info( 

62 "LLM request", 

63 provider="ollama", 

64 operation="embeddings", 

65 model=self._model, 

66 base_host=self._base_url, 

67 inputs=len(inputs), 

68 # latency for v1 path hard to compute here; omitted for now 

69 ) 

70 return [item["embedding"] for item in data.get("data", [])] 

71 else: 

72 # Determine native endpoint preference: embed | embeddings | auto (default) 

73 native_pref = str( 

74 self._provider_options.get("native_endpoint", "auto") 

75 ).lower() 

76 prefer_embed = native_pref != "embeddings" 

77 

78 # Try batch embed first when preferred 

79 if prefer_embed: 

80 url = _join_url(self._base_url, "/api/embed") 

81 payload = {"model": self._model, "input": inputs} 

82 try: 

83 resp = await client.post( 

84 url, json=payload, headers=self._headers 

85 ) 

86 resp.raise_for_status() 

87 data = resp.json() 

88 vectors = data.get("embeddings") 

89 if not isinstance(vectors, list) or ( 

90 len(vectors) != len(inputs) 

91 ): 

92 raise ValueError( 

93 "Invalid embeddings response from /api/embed" 

94 ) 

95 # Normalize to list[list[float]] 

96 norm = [list(vec) for vec in vectors] 

97 logger.info( 

98 "LLM request", 

99 provider="ollama", 

100 operation="embeddings", 

101 model=self._model, 

102 base_host=self._base_url, 

103 inputs=len(inputs), 

104 # latency for native batch path not measured in this stub 

105 ) 

106 return norm 

107 except httpx.HTTPStatusError as exc: 

108 status = exc.response.status_code if exc.response else None 

109 # Fallback for servers that don't support /api/embed 

110 if status not in (404, 405, 501): 

111 raise 

112 

113 # Per-item embeddings endpoint fallback or preference 

114 url = _join_url(self._base_url, "/api/embeddings") 

115 vectors2: list[list[float]] = [] 

116 for text in inputs: 

117 payload = {"model": self._model, "input": text} 

118 resp = await client.post( 

119 url, json=payload, headers=self._headers 

120 ) 

121 resp.raise_for_status() 

122 data = resp.json() 

123 emb = data.get("embedding") 

124 if emb is None and isinstance(data.get("data"), dict): 

125 emb = data["data"].get("embedding") 

126 if emb is None: 

127 raise ValueError( 

128 "Invalid embedding response from /api/embeddings" 

129 ) 

130 vectors2.append(list(emb)) 

131 logger.info( 

132 "LLM request", 

133 provider="ollama", 

134 operation="embeddings", 

135 model=self._model, 

136 base_host=self._base_url, 

137 inputs=len(inputs), 

138 # latency for per-item path not measured in this stub 

139 ) 

140 return vectors2 

141 except httpx.TimeoutException as exc: 

142 raise LLMTimeoutError(str(exc)) 

143 except httpx.HTTPStatusError as exc: 

144 status = exc.response.status_code if exc.response else None 

145 if status == 401: 

146 raise AuthError(str(exc)) 

147 if status == 429: 

148 raise RateLimitedError(str(exc)) 

149 if status and 400 <= status < 500: 

150 raise InvalidRequestError(str(exc)) 

151 raise ServerError(str(exc)) 

152 except httpx.HTTPError as exc: 

153 raise ServerError(str(exc)) 

154 

155 

156class OllamaChat(ChatClient): 

157 def __init__( 

158 self, base_url: str | None, model: str, headers: dict[str, str] | None 

159 ): 

160 self._base_url = base_url or "http://localhost:11434" 

161 self._model = model 

162 self._headers = headers or {} 

163 

164 async def chat( 

165 self, messages: list[dict[str, Any]], **kwargs: Any 

166 ) -> dict[str, Any]: 

167 if httpx is None: 

168 raise NotImplementedError("httpx not available for Ollama chat") 

169 

170 # Prefer OpenAI-compatible if base_url exposes /v1 

171 use_v1 = "/v1" in (self._base_url or "") 

172 # Flatten messages to a single prompt for native API; preserve roles when possible 

173 if use_v1: 

174 url = _join_url(self._base_url, "/chat/completions") 

175 payload = {"model": self._model, "messages": messages} 

176 # Map common kwargs 

177 for k in ("temperature", "max_tokens", "top_p", "stop"): 

178 if k in kwargs and kwargs[k] is not None: 

179 payload[k] = kwargs[k] 

180 async with httpx.AsyncClient(timeout=60.0) as client: 

181 try: 

182 from datetime import UTC, datetime 

183 

184 started = datetime.now(UTC) 

185 resp = await client.post(url, json=payload, headers=self._headers) 

186 resp.raise_for_status() 

187 data = resp.json() 

188 text = "" 

189 choices = data.get("choices") or [] 

190 if choices: 

191 msg = (choices[0] or {}).get("message") or {} 

192 text = msg.get("content", "") or "" 

193 duration_ms = int( 

194 (datetime.now(UTC) - started).total_seconds() * 1000 

195 ) 

196 logger.info( 

197 "LLM request", 

198 provider="ollama", 

199 operation="chat", 

200 model=self._model, 

201 base_host=self._base_url, 

202 messages=len(messages), 

203 latency_ms=duration_ms, 

204 ) 

205 return { 

206 "text": text, 

207 "raw": data, 

208 "usage": data.get("usage"), 

209 "model": data.get("model", self._model), 

210 } 

211 except httpx.TimeoutException as exc: 

212 raise LLMTimeoutError(str(exc)) 

213 except httpx.HTTPStatusError as exc: 

214 status = exc.response.status_code if exc.response else None 

215 if status == 401: 

216 raise AuthError(str(exc)) 

217 if status == 429: 

218 raise RateLimitedError(str(exc)) 

219 if status and 400 <= status < 500: 

220 raise InvalidRequestError(str(exc)) 

221 raise ServerError(str(exc)) 

222 except httpx.HTTPError as exc: 

223 raise ServerError(str(exc)) 

224 else: 

225 # Native API 

226 url = _join_url(self._base_url, "/api/chat") 

227 payload = { 

228 "model": self._model, 

229 "messages": messages, 

230 "stream": False, 

231 } 

232 if "temperature" in kwargs and kwargs["temperature"] is not None: 

233 payload["options"] = {"temperature": kwargs["temperature"]} 

234 async with httpx.AsyncClient(timeout=60.0) as client: 

235 try: 

236 from datetime import UTC, datetime 

237 

238 started = datetime.now(UTC) 

239 resp = await client.post(url, json=payload, headers=self._headers) 

240 resp.raise_for_status() 

241 data = resp.json() 

242 # Ollama native returns {"message": {"content": "..."}, ...} 

243 text = "" 

244 if isinstance(data.get("message"), dict): 

245 text = data["message"].get("content", "") or "" 

246 duration_ms = int( 

247 (datetime.now(UTC) - started).total_seconds() * 1000 

248 ) 

249 logger.info( 

250 "LLM request", 

251 provider="ollama", 

252 operation="chat", 

253 model=self._model, 

254 base_host=self._base_url, 

255 messages=len(messages), 

256 latency_ms=duration_ms, 

257 ) 

258 return { 

259 "text": text, 

260 "raw": data, 

261 "usage": None, 

262 "model": self._model, 

263 } 

264 except httpx.TimeoutException as exc: 

265 raise LLMTimeoutError(str(exc)) 

266 except httpx.HTTPStatusError as exc: 

267 status = exc.response.status_code if exc.response else None 

268 if status == 401: 

269 raise AuthError(str(exc)) 

270 if status == 429: 

271 raise RateLimitedError(str(exc)) 

272 if status and 400 <= status < 500: 

273 raise InvalidRequestError(str(exc)) 

274 raise ServerError(str(exc)) 

275 except httpx.HTTPError as exc: 

276 raise ServerError(str(exc)) 

277 

278 

279class OllamaTokenizer(TokenCounter): 

280 def count(self, text: str) -> int: 

281 return len(text) 

282 

283 

284class OllamaProvider(LLMProvider): 

285 def __init__(self, settings: LLMSettings): 

286 self._settings = settings 

287 

288 def embeddings(self) -> EmbeddingsClient: 

289 model = self._settings.models.get("embeddings", "") 

290 timeout = ( 

291 self._settings.request.timeout_s 

292 if self._settings and self._settings.request 

293 else 30.0 

294 ) 

295 return OllamaEmbeddings( 

296 self._settings.base_url, 

297 model, 

298 self._settings.headers, 

299 timeout_s=timeout, 

300 provider_options=self._settings.provider_options, 

301 ) 

302 

303 def chat(self) -> ChatClient: 

304 model = self._settings.models.get("chat", "") 

305 return OllamaChat(self._settings.base_url, model, self._settings.headers) 

306 

307 def tokenizer(self) -> TokenCounter: 

308 return OllamaTokenizer()