Coverage for src / qdrant_loader_core / llm / providers / gemini.py: 0%

189 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-06-11 09:34 +0000

1from __future__ import annotations 

2 

3from datetime import UTC, datetime 

4from typing import Any 

5 

6try: 

7 from google import genai # type: ignore 

8 from google.genai import types as genai_types # type: ignore 

9 

10 try: 

11 from google.genai import errors as genai_errors # type: ignore 

12 except Exception: # pragma: no cover - optional dependency surface 

13 genai_errors = None # type: ignore 

14except Exception: # pragma: no cover - optional dependency at this phase 

15 genai = None # type: ignore 

16 genai_types = None # type: ignore 

17 genai_errors = None # type: ignore 

18 

19from ...logging import LoggingConfig 

20from ..errors import ( 

21 AuthError, 

22 InvalidRequestError, 

23 LLMError, 

24 RateLimitedError, 

25 ServerError, 

26) 

27from ..errors import TimeoutError as LLMTimeoutError 

28from ..settings import LLMSettings 

29from ..types import ChatClient, EmbeddingsClient, LLMProvider, TokenCounter 

30 

31logger = LoggingConfig.get_logger(__name__) 

32 

33 

34def _map_gemini_exception(exc: Exception) -> LLMError: 

35 status_code: int | None = None 

36 if genai_errors is not None: 

37 try: 

38 if isinstance(exc, genai_errors.APIError): # type: ignore[attr-defined] 

39 status_code = getattr(exc, "code", None) or getattr( 

40 exc, "status_code", None 

41 ) 

42 except Exception: 

43 pass 

44 

45 if status_code is None: 

46 status_code = getattr(exc, "status_code", None) or getattr(exc, "code", None) 

47 

48 if isinstance(status_code, int): 

49 if status_code == 408 or status_code == 504: 

50 return LLMTimeoutError(str(exc)) 

51 if status_code == 429: 

52 return RateLimitedError(str(exc)) 

53 if status_code in (401, 403): 

54 return AuthError(str(exc)) 

55 if 400 <= status_code < 500: 

56 return InvalidRequestError(str(exc)) 

57 if status_code >= 500: 

58 return ServerError(str(exc)) 

59 

60 if isinstance(exc, TimeoutError): 

61 return LLMTimeoutError(str(exc)) 

62 

63 return ServerError(str(exc)) 

64 

65 

66class _GeminiTokenCounter(TokenCounter): 

67 def __init__(self, client: Any, model: str): 

68 self._client = client 

69 self._model = model 

70 

71 def count(self, text: str) -> int: 

72 if self._client is None: 

73 return len(text) 

74 try: 

75 response = self._client.models.count_tokens( 

76 model=self._model, contents=text 

77 ) 

78 total = getattr(response, "total_tokens", None) 

79 if isinstance(total, int): 

80 return total 

81 except Exception: 

82 pass 

83 return len(text) 

84 

85 

86def _messages_to_contents( 

87 messages: list[dict[str, Any]], 

88) -> tuple[str | None, list[Any]]: 

89 """Convert OpenAI-style messages to (system_instruction, contents) for Gemini.""" 

90 system_parts: list[str] = [] 

91 contents: list[Any] = [] 

92 for msg in messages: 

93 role = (msg.get("role") or "").lower() 

94 content = msg.get("content") 

95 if content is None: 

96 continue 

97 if not isinstance(content, str): 

98 content = str(content) 

99 if role == "system": 

100 system_parts.append(content) 

101 continue 

102 gemini_role = "model" if role == "assistant" else "user" 

103 if genai_types is not None: 

104 contents.append( 

105 genai_types.Content( 

106 role=gemini_role, 

107 parts=[genai_types.Part.from_text(text=content)], 

108 ) 

109 ) 

110 else: 

111 contents.append({"role": gemini_role, "parts": [{"text": content}]}) 

112 

113 system_instruction = "\n".join(system_parts) if system_parts else None 

114 return system_instruction, contents 

115 

116 

117class GeminiEmbeddings(EmbeddingsClient): 

118 def __init__( 

119 self, 

120 client: Any, 

121 model: str, 

122 *, 

123 provider_label: str = "gemini", 

124 output_dimensionality: int | None = None, 

125 ): 

126 self._client = client 

127 self._model = model 

128 self._provider_label = provider_label 

129 self._output_dimensionality = output_dimensionality 

130 

131 def _build_config(self) -> Any: 

132 if self._output_dimensionality is None or genai_types is None: 

133 return None 

134 return genai_types.EmbedContentConfig( 

135 output_dimensionality=self._output_dimensionality 

136 ) 

137 

138 def _wrap_inputs(self, inputs: list[str]) -> list[Any]: 

139 # gemini-embedding-2 aggregates a list of raw strings into ONE embedding. 

140 # Wrapping each input in its own Content forces one embedding per input. 

141 if genai_types is None: 

142 return list(inputs) 

143 return [ 

144 genai_types.Content(parts=[genai_types.Part.from_text(text=text)]) 

145 for text in inputs 

146 ] 

147 

148 async def embed(self, inputs: list[str]) -> list[list[float]]: 

149 if not self._client: 

150 raise NotImplementedError("Gemini client not available") 

151 if not self._model: 

152 raise InvalidRequestError( 

153 "Gemini embeddings model is not configured. " 

154 "Set global.llm.models.embeddings in your config " 

155 "(e.g. 'gemini-embedding-2' or 'gemini-embedding-001')." 

156 ) 

157 if not inputs: 

158 return [] 

159 import asyncio 

160 

161 config = self._build_config() 

162 call_kwargs: dict[str, Any] = { 

163 "model": self._model, 

164 "contents": self._wrap_inputs(inputs), 

165 } 

166 if config is not None: 

167 call_kwargs["config"] = config 

168 

169 started = datetime.now(UTC) 

170 try: 

171 response = await asyncio.to_thread( 

172 self._client.models.embed_content, **call_kwargs 

173 ) 

174 duration_ms = int((datetime.now(UTC) - started).total_seconds() * 1000) 

175 try: 

176 logger.info( 

177 "LLM request", 

178 provider=self._provider_label, 

179 operation="embeddings", 

180 model=self._model, 

181 inputs=len(inputs), 

182 latency_ms=duration_ms, 

183 ) 

184 except Exception: 

185 pass 

186 

187 embeddings = getattr(response, "embeddings", None) or [] 

188 vectors = [list(getattr(item, "values", []) or []) for item in embeddings] 

189 if len(vectors) != len(inputs): 

190 raise ServerError( 

191 f"Gemini embed_content returned {len(vectors)} embeddings " 

192 f"for {len(inputs)} inputs (expected 1:1). " 

193 f"Verify the model id {self._model!r} supports per-input embeddings." 

194 ) 

195 return vectors 

196 except Exception as exc: 

197 mapped = _map_gemini_exception(exc) 

198 try: 

199 logger.warning( 

200 "LLM error", 

201 provider=self._provider_label, 

202 operation="embeddings", 

203 model=self._model, 

204 error=type(exc).__name__, 

205 ) 

206 except Exception: 

207 pass 

208 raise mapped from exc 

209 

210 

211class GeminiChat(ChatClient): 

212 def __init__( 

213 self, 

214 client: Any, 

215 model: str, 

216 *, 

217 provider_label: str = "gemini", 

218 ): 

219 self._client = client 

220 self._model = model 

221 self._provider_label = provider_label 

222 

223 async def chat( 

224 self, messages: list[dict[str, Any]], **kwargs: Any 

225 ) -> dict[str, Any]: 

226 if not self._client: 

227 raise NotImplementedError("Gemini client not available") 

228 

229 model_name = kwargs.pop("model", self._model) 

230 if not model_name: 

231 raise InvalidRequestError( 

232 "Gemini chat model is not configured. " 

233 "Set global.llm.models.chat in your config " 

234 "(e.g. 'gemini-2.0-flash')." 

235 ) 

236 system_instruction, contents = _messages_to_contents(messages) 

237 

238 config_kwargs: dict[str, Any] = {} 

239 if system_instruction: 

240 config_kwargs["system_instruction"] = system_instruction 

241 for src_key, dst_key in ( 

242 ("temperature", "temperature"), 

243 ("top_p", "top_p"), 

244 ("max_tokens", "max_output_tokens"), 

245 ("stop", "stop_sequences"), 

246 ("seed", "seed"), 

247 ): 

248 if src_key in kwargs and kwargs[src_key] is not None: 

249 config_kwargs[dst_key] = kwargs[src_key] 

250 

251 config = None 

252 if config_kwargs and genai_types is not None: 

253 config = genai_types.GenerateContentConfig(**config_kwargs) 

254 import asyncio 

255 

256 started = datetime.now(UTC) 

257 try: 

258 call_kwargs: dict[str, Any] = { 

259 "model": model_name, 

260 "contents": contents, 

261 } 

262 if config is not None: 

263 call_kwargs["config"] = config 

264 

265 response = await asyncio.to_thread( 

266 self._client.models.generate_content, **call_kwargs 

267 ) 

268 duration_ms = int((datetime.now(UTC) - started).total_seconds() * 1000) 

269 try: 

270 logger.info( 

271 "LLM request", 

272 provider=self._provider_label, 

273 operation="chat", 

274 model=model_name, 

275 messages=len(messages), 

276 latency_ms=duration_ms, 

277 ) 

278 except Exception: 

279 pass 

280 

281 text = getattr(response, "text", "") or "" 

282 

283 usage_meta = getattr(response, "usage_metadata", None) 

284 normalized_usage = None 

285 if usage_meta is not None: 

286 normalized_usage = { 

287 "prompt_tokens": getattr(usage_meta, "prompt_token_count", None), 

288 "completion_tokens": getattr( 

289 usage_meta, "candidates_token_count", None 

290 ), 

291 "total_tokens": getattr(usage_meta, "total_token_count", None), 

292 } 

293 

294 return { 

295 "text": text, 

296 "raw": response, 

297 "usage": normalized_usage, 

298 "model": getattr(response, "model_version", model_name) or model_name, 

299 } 

300 except Exception as exc: 

301 mapped = _map_gemini_exception(exc) 

302 try: 

303 logger.warning( 

304 "LLM error", 

305 provider=self._provider_label, 

306 operation="chat", 

307 model=model_name, 

308 error=type(exc).__name__, 

309 ) 

310 except Exception: 

311 pass 

312 raise mapped 

313 

314 

315class GeminiProvider(LLMProvider): 

316 def __init__(self, settings: LLMSettings): 

317 self._settings = settings 

318 if genai is None: 

319 self._client = None 

320 else: 

321 kwargs: dict[str, Any] = {} 

322 if settings.api_key: 

323 kwargs["api_key"] = settings.api_key 

324 provider_opts = settings.provider_options or {} 

325 if provider_opts.get("vertexai"): 

326 kwargs["vertexai"] = True 

327 if provider_opts.get("project"): 

328 kwargs["project"] = provider_opts["project"] 

329 if provider_opts.get("location"): 

330 kwargs["location"] = provider_opts["location"] 

331 self._client = genai.Client(**kwargs) 

332 

333 def embeddings(self) -> EmbeddingsClient: 

334 model = self._settings.models.get("embeddings", "") 

335 return GeminiEmbeddings( 

336 self._client, 

337 model, 

338 provider_label="gemini", 

339 output_dimensionality=self._settings.embeddings.vector_size, 

340 ) 

341 

342 def chat(self) -> ChatClient: 

343 model = self._settings.models.get("chat", "") 

344 return GeminiChat(self._client, model, provider_label="gemini") 

345 

346 def tokenizer(self) -> TokenCounter: 

347 model = self._settings.models.get("chat") or self._settings.models.get( 

348 "embeddings", "" 

349 ) 

350 return _GeminiTokenCounter(self._client, model)