Coverage for src/qdrant_loader_core/llm/providers/openai.py: 84%
137 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-08 06:01 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-08 06:01 +0000
1from __future__ import annotations
3from datetime import UTC, datetime
4from typing import Any
5from urllib.parse import urlparse
7try:
8 from openai import OpenAI # type: ignore
10 # New-style exception classes (OpenAI Python SDK >=1.x)
11 try: # nested to avoid failing entirely on older clients
12 from openai import ( # type: ignore
13 APIConnectionError,
14 APIStatusError,
15 APITimeoutError,
16 AuthenticationError,
17 BadRequestError,
18 RateLimitError,
19 )
20 except Exception: # pragma: no cover - optional dependency surface
21 APIConnectionError = APIStatusError = APITimeoutError = AuthenticationError = BadRequestError = RateLimitError = () # type: ignore
22except Exception: # pragma: no cover - optional dependency at this phase
23 OpenAI = None # type: ignore
24 APIConnectionError = APIStatusError = APITimeoutError = AuthenticationError = BadRequestError = RateLimitError = () # type: ignore
26from ...logging import LoggingConfig
27from ..errors import (
28 AuthError,
29 InvalidRequestError,
30 LLMError,
31 RateLimitedError,
32 ServerError,
33)
34from ..errors import TimeoutError as LLMTimeoutError
35from ..settings import LLMSettings
36from ..types import ChatClient, EmbeddingsClient, LLMProvider, TokenCounter
38logger = LoggingConfig.get_logger(__name__)
41def _safe_host(url: str | None) -> str | None:
42 if not url:
43 return None
44 try:
45 return urlparse(url).hostname or None
46 except Exception:
47 return None
50def _map_openai_exception(exc: Exception) -> LLMError:
51 try:
52 # Rate limit
53 if RateLimitError and isinstance(exc, RateLimitError): # type: ignore[arg-type]
54 return RateLimitedError(str(exc))
55 # Timeout
56 if APITimeoutError and isinstance(exc, APITimeoutError): # type: ignore[arg-type]
57 return LLMTimeoutError(str(exc))
58 # Auth
59 if AuthenticationError and isinstance(exc, AuthenticationError): # type: ignore[arg-type]
60 return AuthError(str(exc))
61 # Bad request / invalid params
62 if BadRequestError and isinstance(exc, BadRequestError): # type: ignore[arg-type]
63 return InvalidRequestError(str(exc))
64 # API status error (typically non-2xx)
65 if APIStatusError and isinstance(exc, APIStatusError): # type: ignore[arg-type]
66 # Best-effort: check for status code
67 status_code = getattr(exc, "status_code", None) or getattr(
68 getattr(exc, "response", None), "status_code", None
69 )
70 if isinstance(status_code, int) and 400 <= status_code < 500:
71 if status_code == 429:
72 return RateLimitedError(str(exc))
73 if status_code in (401, 403):
74 return AuthError(str(exc))
75 return InvalidRequestError(str(exc))
76 return ServerError(str(exc))
77 # Connection-level errors
78 if APIConnectionError and isinstance(exc, APIConnectionError): # type: ignore[arg-type]
79 return ServerError(str(exc))
80 except Exception:
81 pass
82 # Fallback
83 return ServerError(str(exc))
86class _OpenAITokenCounter(TokenCounter):
87 def __init__(self, tokenizer: str):
88 self._tokenizer = tokenizer
90 def count(self, text: str) -> int:
91 # Phase 0: fallback to naive length; real tiktoken impl to come later
92 return len(text)
95class OpenAIEmbeddings(EmbeddingsClient):
96 def __init__(
97 self,
98 client: Any,
99 model: str,
100 base_host: str | None,
101 *,
102 provider_label: str = "openai",
103 ):
104 self._client = client
105 self._model = model
106 self._base_host = base_host
107 self._provider_label = provider_label
109 async def embed(self, inputs: list[str]) -> list[list[float]]:
110 if not self._client:
111 raise NotImplementedError("OpenAI client not available")
112 # Use thread offloading to keep async interface consistent with sync client
113 import asyncio
115 started = datetime.now(UTC)
116 try:
117 response = await asyncio.to_thread(
118 self._client.embeddings.create, model=self._model, input=inputs
119 )
120 duration_ms = int((datetime.now(UTC) - started).total_seconds() * 1000)
121 try:
122 logger.info(
123 "LLM request",
124 provider=self._provider_label,
125 operation="embeddings",
126 model=self._model,
127 base_host=self._base_host,
128 inputs=len(inputs),
129 latency_ms=duration_ms,
130 )
131 except Exception:
132 pass
133 return [item.embedding for item in response.data]
134 except Exception as exc: # Normalize errors
135 mapped = _map_openai_exception(exc)
136 try:
137 logger.warning(
138 "LLM error",
139 provider=self._provider_label,
140 operation="embeddings",
141 model=self._model,
142 base_host=self._base_host,
143 error=type(exc).__name__,
144 )
145 except Exception:
146 pass
147 raise mapped
150class OpenAIChat(ChatClient):
151 def __init__(
152 self,
153 client: Any,
154 model: str,
155 base_host: str | None,
156 *,
157 provider_label: str = "openai",
158 ):
159 self._client = client
160 self._model = model
161 self._base_host = base_host
162 self._provider_label = provider_label
164 async def chat(
165 self, messages: list[dict[str, Any]], **kwargs: Any
166 ) -> dict[str, Any]:
167 if not self._client:
168 raise NotImplementedError("OpenAI client not available")
170 # Normalize kwargs to OpenAI python client parameters
171 create_kwargs: dict[str, Any] = {}
172 for key in (
173 "temperature",
174 "max_tokens",
175 "top_p",
176 "frequency_penalty",
177 "presence_penalty",
178 "stop",
179 "seed",
180 "response_format",
181 ):
182 if key in kwargs and kwargs[key] is not None:
183 create_kwargs[key] = kwargs[key]
185 # Allow model override per-call
186 model_name = kwargs.pop("model", self._model)
188 import asyncio
190 # The OpenAI python client call is sync for chat.completions
191 started = datetime.now(UTC)
192 try:
193 response = await asyncio.to_thread(
194 self._client.chat.completions.create,
195 model=model_name,
196 messages=messages,
197 **create_kwargs,
198 )
199 duration_ms = int((datetime.now(UTC) - started).total_seconds() * 1000)
200 try:
201 logger.info(
202 "LLM request",
203 provider=self._provider_label,
204 operation="chat",
205 model=model_name,
206 base_host=self._base_host,
207 messages=len(messages),
208 latency_ms=duration_ms,
209 )
210 except Exception:
211 pass
213 # Normalize to provider-agnostic dict
214 choice0 = (
215 response.choices[0] if getattr(response, "choices", None) else None
216 )
217 text = ""
218 if choice0 is not None:
219 message = getattr(choice0, "message", None)
220 if message is not None:
221 text = getattr(message, "content", "") or ""
223 usage = getattr(response, "usage", None)
224 normalized_usage = None
225 if usage is not None:
226 normalized_usage = {
227 "prompt_tokens": getattr(usage, "prompt_tokens", None),
228 "completion_tokens": getattr(usage, "completion_tokens", None),
229 "total_tokens": getattr(usage, "total_tokens", None),
230 }
232 return {
233 "text": text,
234 "raw": response,
235 "usage": normalized_usage,
236 "model": getattr(response, "model", model_name),
237 }
238 except Exception as exc:
239 mapped = _map_openai_exception(exc)
240 try:
241 logger.warning(
242 "LLM error",
243 provider=self._provider_label,
244 operation="chat",
245 model=model_name,
246 base_host=self._base_host,
247 error=type(exc).__name__,
248 )
249 except Exception:
250 pass
251 raise mapped
254class OpenAIProvider(LLMProvider):
255 def __init__(self, settings: LLMSettings):
256 self._settings = settings
257 self._base_host = _safe_host(settings.base_url)
258 if OpenAI is None:
259 self._client = None
260 else:
261 kwargs: dict[str, Any] = {}
262 if settings.base_url:
263 kwargs["base_url"] = settings.base_url
264 if settings.api_key:
265 kwargs["api_key"] = settings.api_key
266 self._client = OpenAI(**kwargs)
268 def embeddings(self) -> EmbeddingsClient:
269 model = self._settings.models.get("embeddings", "")
270 return OpenAIEmbeddings(
271 self._client, model, self._base_host, provider_label="openai"
272 )
274 def chat(self) -> ChatClient:
275 model = self._settings.models.get("chat", "")
276 return OpenAIChat(self._client, model, self._base_host, provider_label="openai")
278 def tokenizer(self) -> TokenCounter:
279 return _OpenAITokenCounter(self._settings.tokenizer)