Coverage for src/qdrant_loader_core/llm/providers/openai.py: 84%

137 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-08 06:01 +0000

1from __future__ import annotations 

2 

3from datetime import UTC, datetime 

4from typing import Any 

5from urllib.parse import urlparse 

6 

7try: 

8 from openai import OpenAI # type: ignore 

9 

10 # New-style exception classes (OpenAI Python SDK >=1.x) 

11 try: # nested to avoid failing entirely on older clients 

12 from openai import ( # type: ignore 

13 APIConnectionError, 

14 APIStatusError, 

15 APITimeoutError, 

16 AuthenticationError, 

17 BadRequestError, 

18 RateLimitError, 

19 ) 

20 except Exception: # pragma: no cover - optional dependency surface 

21 APIConnectionError = APIStatusError = APITimeoutError = AuthenticationError = BadRequestError = RateLimitError = () # type: ignore 

22except Exception: # pragma: no cover - optional dependency at this phase 

23 OpenAI = None # type: ignore 

24 APIConnectionError = APIStatusError = APITimeoutError = AuthenticationError = BadRequestError = RateLimitError = () # type: ignore 

25 

26from ...logging import LoggingConfig 

27from ..errors import ( 

28 AuthError, 

29 InvalidRequestError, 

30 LLMError, 

31 RateLimitedError, 

32 ServerError, 

33) 

34from ..errors import TimeoutError as LLMTimeoutError 

35from ..settings import LLMSettings 

36from ..types import ChatClient, EmbeddingsClient, LLMProvider, TokenCounter 

37 

38logger = LoggingConfig.get_logger(__name__) 

39 

40 

41def _safe_host(url: str | None) -> str | None: 

42 if not url: 

43 return None 

44 try: 

45 return urlparse(url).hostname or None 

46 except Exception: 

47 return None 

48 

49 

50def _map_openai_exception(exc: Exception) -> LLMError: 

51 try: 

52 # Rate limit 

53 if RateLimitError and isinstance(exc, RateLimitError): # type: ignore[arg-type] 

54 return RateLimitedError(str(exc)) 

55 # Timeout 

56 if APITimeoutError and isinstance(exc, APITimeoutError): # type: ignore[arg-type] 

57 return LLMTimeoutError(str(exc)) 

58 # Auth 

59 if AuthenticationError and isinstance(exc, AuthenticationError): # type: ignore[arg-type] 

60 return AuthError(str(exc)) 

61 # Bad request / invalid params 

62 if BadRequestError and isinstance(exc, BadRequestError): # type: ignore[arg-type] 

63 return InvalidRequestError(str(exc)) 

64 # API status error (typically non-2xx) 

65 if APIStatusError and isinstance(exc, APIStatusError): # type: ignore[arg-type] 

66 # Best-effort: check for status code 

67 status_code = getattr(exc, "status_code", None) or getattr( 

68 getattr(exc, "response", None), "status_code", None 

69 ) 

70 if isinstance(status_code, int) and 400 <= status_code < 500: 

71 if status_code == 429: 

72 return RateLimitedError(str(exc)) 

73 if status_code in (401, 403): 

74 return AuthError(str(exc)) 

75 return InvalidRequestError(str(exc)) 

76 return ServerError(str(exc)) 

77 # Connection-level errors 

78 if APIConnectionError and isinstance(exc, APIConnectionError): # type: ignore[arg-type] 

79 return ServerError(str(exc)) 

80 except Exception: 

81 pass 

82 # Fallback 

83 return ServerError(str(exc)) 

84 

85 

86class _OpenAITokenCounter(TokenCounter): 

87 def __init__(self, tokenizer: str): 

88 self._tokenizer = tokenizer 

89 

90 def count(self, text: str) -> int: 

91 # Phase 0: fallback to naive length; real tiktoken impl to come later 

92 return len(text) 

93 

94 

95class OpenAIEmbeddings(EmbeddingsClient): 

96 def __init__( 

97 self, 

98 client: Any, 

99 model: str, 

100 base_host: str | None, 

101 *, 

102 provider_label: str = "openai", 

103 ): 

104 self._client = client 

105 self._model = model 

106 self._base_host = base_host 

107 self._provider_label = provider_label 

108 

109 async def embed(self, inputs: list[str]) -> list[list[float]]: 

110 if not self._client: 

111 raise NotImplementedError("OpenAI client not available") 

112 # Use thread offloading to keep async interface consistent with sync client 

113 import asyncio 

114 

115 started = datetime.now(UTC) 

116 try: 

117 response = await asyncio.to_thread( 

118 self._client.embeddings.create, model=self._model, input=inputs 

119 ) 

120 duration_ms = int((datetime.now(UTC) - started).total_seconds() * 1000) 

121 try: 

122 logger.info( 

123 "LLM request", 

124 provider=self._provider_label, 

125 operation="embeddings", 

126 model=self._model, 

127 base_host=self._base_host, 

128 inputs=len(inputs), 

129 latency_ms=duration_ms, 

130 ) 

131 except Exception: 

132 pass 

133 return [item.embedding for item in response.data] 

134 except Exception as exc: # Normalize errors 

135 mapped = _map_openai_exception(exc) 

136 try: 

137 logger.warning( 

138 "LLM error", 

139 provider=self._provider_label, 

140 operation="embeddings", 

141 model=self._model, 

142 base_host=self._base_host, 

143 error=type(exc).__name__, 

144 ) 

145 except Exception: 

146 pass 

147 raise mapped 

148 

149 

150class OpenAIChat(ChatClient): 

151 def __init__( 

152 self, 

153 client: Any, 

154 model: str, 

155 base_host: str | None, 

156 *, 

157 provider_label: str = "openai", 

158 ): 

159 self._client = client 

160 self._model = model 

161 self._base_host = base_host 

162 self._provider_label = provider_label 

163 

164 async def chat( 

165 self, messages: list[dict[str, Any]], **kwargs: Any 

166 ) -> dict[str, Any]: 

167 if not self._client: 

168 raise NotImplementedError("OpenAI client not available") 

169 

170 # Normalize kwargs to OpenAI python client parameters 

171 create_kwargs: dict[str, Any] = {} 

172 for key in ( 

173 "temperature", 

174 "max_tokens", 

175 "top_p", 

176 "frequency_penalty", 

177 "presence_penalty", 

178 "stop", 

179 "seed", 

180 "response_format", 

181 ): 

182 if key in kwargs and kwargs[key] is not None: 

183 create_kwargs[key] = kwargs[key] 

184 

185 # Allow model override per-call 

186 model_name = kwargs.pop("model", self._model) 

187 

188 import asyncio 

189 

190 # The OpenAI python client call is sync for chat.completions 

191 started = datetime.now(UTC) 

192 try: 

193 response = await asyncio.to_thread( 

194 self._client.chat.completions.create, 

195 model=model_name, 

196 messages=messages, 

197 **create_kwargs, 

198 ) 

199 duration_ms = int((datetime.now(UTC) - started).total_seconds() * 1000) 

200 try: 

201 logger.info( 

202 "LLM request", 

203 provider=self._provider_label, 

204 operation="chat", 

205 model=model_name, 

206 base_host=self._base_host, 

207 messages=len(messages), 

208 latency_ms=duration_ms, 

209 ) 

210 except Exception: 

211 pass 

212 

213 # Normalize to provider-agnostic dict 

214 choice0 = ( 

215 response.choices[0] if getattr(response, "choices", None) else None 

216 ) 

217 text = "" 

218 if choice0 is not None: 

219 message = getattr(choice0, "message", None) 

220 if message is not None: 

221 text = getattr(message, "content", "") or "" 

222 

223 usage = getattr(response, "usage", None) 

224 normalized_usage = None 

225 if usage is not None: 

226 normalized_usage = { 

227 "prompt_tokens": getattr(usage, "prompt_tokens", None), 

228 "completion_tokens": getattr(usage, "completion_tokens", None), 

229 "total_tokens": getattr(usage, "total_tokens", None), 

230 } 

231 

232 return { 

233 "text": text, 

234 "raw": response, 

235 "usage": normalized_usage, 

236 "model": getattr(response, "model", model_name), 

237 } 

238 except Exception as exc: 

239 mapped = _map_openai_exception(exc) 

240 try: 

241 logger.warning( 

242 "LLM error", 

243 provider=self._provider_label, 

244 operation="chat", 

245 model=model_name, 

246 base_host=self._base_host, 

247 error=type(exc).__name__, 

248 ) 

249 except Exception: 

250 pass 

251 raise mapped 

252 

253 

254class OpenAIProvider(LLMProvider): 

255 def __init__(self, settings: LLMSettings): 

256 self._settings = settings 

257 self._base_host = _safe_host(settings.base_url) 

258 if OpenAI is None: 

259 self._client = None 

260 else: 

261 kwargs: dict[str, Any] = {} 

262 if settings.base_url: 

263 kwargs["base_url"] = settings.base_url 

264 if settings.api_key: 

265 kwargs["api_key"] = settings.api_key 

266 self._client = OpenAI(**kwargs) 

267 

268 def embeddings(self) -> EmbeddingsClient: 

269 model = self._settings.models.get("embeddings", "") 

270 return OpenAIEmbeddings( 

271 self._client, model, self._base_host, provider_label="openai" 

272 ) 

273 

274 def chat(self) -> ChatClient: 

275 model = self._settings.models.get("chat", "") 

276 return OpenAIChat(self._client, model, self._base_host, provider_label="openai") 

277 

278 def tokenizer(self) -> TokenCounter: 

279 return _OpenAITokenCounter(self._settings.tokenizer)